/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "player.h"

#include "modules/upctocommand.h"
#include "utils.h"

#include <algorithm>
#include <chrono>
#include <future>
#include <thread>
#include <tuple>

#include <glog/logging.h>

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<Player>("Player")(
      metadata("type", rttr::type::get<Player>()))
      .constructor<std::shared_ptr<tc::Client>>()
      .property_readonly("topModule", &Player::getTopModule)
      .method("addModule", &Player::addModule)
      .method("addModules", &Player::addModules)
      .method("setFrameskip", &Player::setFrameskip)
      .method("setCombineFrames", &Player::setCombineFrames)
      .method("setWarnIfSlow", &Player::setWarnIfSlow)
      .method("setNonBlocking", &Player::setNonBlocking)
      .method("setRealtimeFactor", &Player::setRealtimeFactor)
      .method("setCollectStats", &Player::setCollectStats)
      .method("setCollectTimers", &Player::setCollectTimers)
      .method("init", &Player::init)
      .method("step", &Player::step)
      .method("run", &Player::run);
}

namespace {
auto constexpr kMaxStepDuration = std::chrono::milliseconds(40);
auto constexpr kMaxInitialStepDuration = std::chrono::seconds(9);
auto constexpr kStepDurationAtFastest = std::chrono::milliseconds(42);
} // namespace

Player::Player(std::shared_ptr<tc::Client> client)
    : client_(std::move(client)), top_(nullptr) {

  // Make sure we start from an initialized client so that we have a valid
  // TorchCraft state from the start
  if (client_->state()->map_size[0] <= 0) {
    LOG(ERROR) << "TorchCraft state has not been initialized yet";
    throw std::runtime_error("Uninitialized TorchCraft state");
  }

  state_ = new State(*client_->state());
  state_->setCollectTimers(collectTimers_);
  state_->board()->setCollectTimers(collectTimers_);

  // Start stepping thread
  stepThread_ = std::make_unique<std::thread>(&Player::doStepThread, this);
}

Player::~Player() {
  // Terminate stepping thread
  std::unique_lock<std::mutex> stepLock(stepMutex_);
  stepDone_ = false;
  exitStepThread_ = true;
  while (!stepDone_) {
    stepStartCondition_.notify_one();
    stepDoneCondition_.wait_for(stepLock, std::chrono::milliseconds(10));
  }
  stepThread_->join();

  delete state_;
}

std::shared_ptr<Module> Player::getTopModule() const {
  return top_;
}

void Player::addModule(std::shared_ptr<Module> module) {
  if (module == nullptr) {
    throw std::runtime_error("Attempting to add null module to player");
  }

  if (std::find_if(
          modules_.begin(),
          modules_.end(),
          [module](std::shared_ptr<Module> target) {
            return module->name() == target->name();
          }) != modules_.end()) {
    LOG(ERROR) << "Module named " << module->name()
               << " already added. Hence skipping";
    return;
  }

  modules_.push_back(module);
  if (!top_) {
    top_ = module;
    VLOG(1) << "Added module '" << module->name() << "' as top module";
  } else {
    VLOG(1) << "Added module '" << module->name() << "'";
  }
  module->setPlayer(this);
}

void Player::addModules(std::vector<std::shared_ptr<Module>> const& modules) {
  for (auto& m : modules) {
    addModule(m);
  }
}

void Player::setFrameskip(int n) {
  if (initialized_) {
    throw std::runtime_error("Set frameskip before calling init()");
  }
  frameskip_ = n;
}

void Player::setCombineFrames(int n) {
  if (initialized_) {
    throw std::runtime_error("Set combineFrames before calling init()");
  }
  combineFrames_ = n;
}

void Player::setWarnIfSlow(bool warn) {
  warnIfSlow_ = warn;
}

void Player::setNonBlocking(bool nonBlocking) {
  nonBlocking_ = nonBlocking;
}

void Player::setRealtimeFactor(float factor) {
  realtimeFactor_ = factor;
}

void Player::setCollectStats(bool collect) {
  collectStats_ = collect;
}

void Player::setCollectTimers(bool collect) {
  collectTimers_ = collect;
}

void Player::stepModules() {
  // Call step on all modules
  for (auto& module : modules_) {
    stepModule(module);
  }
  steps_++;
}

void Player::stepModule(std::shared_ptr<Module> module) {

  std::chrono::time_point<hires_clock> start;
  if (collectTimers_) {
    start = hires_clock::now();
  }
  module->step(state_);
  if (collectTimers_) {
    auto duration = hires_clock::now() - start;
    moduleTimeSpent_[module] = duration;
    moduleTimeSpentAgg_[module] += duration;
  }
}

// Updates the structure that stores statistics on what frame
// a command to a unit begins and when it ends.
void Player::updateCommandStats(ClientCommands commands) {
  const FrameNum curFrame = state_->currentFrame();
  const FrameNum invalidFrame = -1;

  auto& allUnits = state_->unitsInfo();

  // 1. Check which units have commands finished
  // 2. Add new commands to units if needed
  // The order of operatios is important, since it is possible
  // that if the new commands are added first, they will pass
  // the test of commandsJ that have finished, and they will
  // have the wrong end frame.

  // Iterate through units that have invalid frame as the end frame.
  // This is tracked in commandStatsHelper_. If the command
  // is not executing, then the command must have finished. So it
  // records the current frame as the end frame.
  for (auto record : commandStats_) {
    const auto unitId = record.first;

    if (allUnits.getUnit(unitId) == nullptr) {
      LOG(WARNING) << "Got null ptr when querying unit";
      continue;
    }

    const auto tcUnit = allUnits.getUnit(unitId)->unit;

    auto iters = commandStatsHelper_[unitId];
    for (size_t i = 0; i < iters.size(); ++i) {
      auto it = iters[i];

      if ((it < commandStats_[unitId].begin()) ||
          (it >= commandStats_[unitId].end())) {
        LOG(WARNING) << "Iterator out of range";
        continue;
      }

      const auto cmd = (*it).first;
      if (!(utils::isExecutingCommand(tcUnit, cmd))) {
        (*it).second.second = curFrame;
      }
    }
  }

  // Done with updates
  commandStatsHelper_.clear();

  // Iterate through all the commands to be posted. If
  // the unit is executing the command then it is skipped.
  // Otherwise it is added with the current frame as the
  // start frame and -1 as the end frame
  for (auto command : commands) {

    if (command.code == tc::BW::Command::CommandUnit &&
        command.args.size() >= 2) {
      const int unitId = command.args[0];
      const auto cmd =
          tc::BW::UnitCommandType::_from_integral_nothrow(command.args[1]);
      if (!cmd) {
        LOG(INFO) << "Unknown command, not collecting stats: "
                  << command.args[1];
        continue;
      }

      if (allUnits.getUnit(unitId) == nullptr) {
        LOG(WARNING) << "Got null ptr when querying unit";
        continue;
      }
      const auto tcUnit = allUnits.getUnit(unitId)->unit;

      // A new element needs to be added to our record
      if (!utils::isExecutingCommand(tcUnit, *cmd)) {
        const commandStartEndFrame elem =
            std::make_pair(*cmd, std::make_pair(curFrame, invalidFrame));

        commandStats_[unitId].push_back(elem);

        // Track where the new element was added
        commandStatsHelper_[unitId].push_back(commandStats_[unitId].end() - 1);
      }
    }
  }
}

void Player::step() {
  if (state_->gameEnded()) {
    // Return here if the game is over. Otherwise, client_->receive() will
    // just wait and time out eventually.
    LOG(INFO) << "Game did end already";
    return;
  }

  // It's possible that the game ended after a receive() in done in the
  // while loop below. If this is the case, perform some cleanup only.
  if (client_->state()->game_ended) {
    logFailedCommands();
    preStep();
    if (state_->gameEnded()) {
      VLOG(1) << "Game has ended, not stepping through modules again";
    } else {
      LOG(WARNING) << "Game has ended, but State does not think so?";
    }
    for (auto& module : modules_) {
      module->onGameEnd(state_);
    }
    return;
  }

  std::vector<std::string> updates;
  if (!client_->receive(updates)) {
    throw std::runtime_error(
        std::string("Receive failure: ") + client_->error());
  }
  unitDeaths_.insert(
      unitDeaths_.end(),
      client_->state()->deaths.begin(),
      client_->state()->deaths.end());
  setLoggingFrame(client_->state()->frame_from_bwapi);

  // Similarly, the game could end now. If we don't check this here we'd
  // have to check this below while doStep() is possibly still running.
  if (client_->state()->game_ended) {
    logFailedCommands();
    preStep();
    if (state_->gameEnded()) {
      VLOG(1) << "Game has ended, not stepping through modules again";
    } else {
      LOG(WARNING) << "Game has ended, but State does not think so?";
    }
    for (auto& module : modules_) {
      module->onGameEnd(state_);
    }
    return;
  }

  auto maxDuration = kMaxStepDuration;

  // Start worker thread
  auto start = hires_clock::now();
  std::unique_lock<std::mutex> stepLock(stepMutex_);
  startStep_ = true;
  stepDone_ = false;
  stepStartCondition_.notify_one();

  while (true) {
    if (nonBlocking_) {
      if (client_->state()->frame_from_bwapi == 0) {
        // Provide a time window for bot stepping. We'll do some more work
        // during the first frame (e.g. BWEM map analysis) so let's use a higher
        // timeout in this case.
        maxDuration = kMaxInitialStepDuration;
      }

      stepDoneCondition_.wait_for(
          stepLock, maxDuration, [&] { return stepDone_; });
    } else {
      stepDoneCondition_.wait(stepLock, [&] { return stepDone_; });
    }

    std::lock_guard<std::mutex> lock(stepClientMutex_);
    if (stepDone_) {
      // We got some commands in time -- send them and return.
      if (!client_->send(commands_)) {
        throw std::runtime_error(
            std::string("Send failure: ") + client_->error());
      }
      break;
    }

    // No commands ready for sending yet, so don't send anything for this
    // frame and keep the game going.
    auto frame = client_->state()->frame_from_bwapi;
    LOG(WARNING) << "Timeout for frame " << frame
                 << ", trying again next frame";
    if (!client_->send(ClientCommands())) {
      throw std::runtime_error(
          std::string("Send failure: ") + client_->error());
    }

    // Receive new updates so that the game won't be blocked. We'll be
    // missing this frame and only act on the one aftwards, though.
    std::vector<std::string> updates;
    if (!client_->receive(updates)) {
      throw std::runtime_error(
          std::string("Receive failure: ") + client_->error());
    }
    unitDeaths_.insert(
        unitDeaths_.end(),
        client_->state()->deaths.begin(),
        client_->state()->deaths.end());

    if (client_->state()->game_ended) {
      break;
    }
  }

  auto duration = hires_clock::now() - start;
  if (warnIfSlow_ && duration > maxDuration) {
    auto taskTimeStats = state_->board()->getTaskTimeStats();
    auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(duration);
    LOG(WARNING) << "Maximum duration exceeded; step took " << ms.count()
                 << "ms";
    LOG(WARNING) << "Timings for this step:";
    ms = std::chrono::duration_cast<std::chrono::milliseconds>(
        stateUpdateTimeSpent_);
    LOG(WARNING) << "  State::update(): " << ms.count() << "ms";
    auto stateUpdateTimes = state_->getStateUpdateTimes();
    for (auto stTime : stateUpdateTimes) {
      LOG(WARNING) << "    " << stTime.first << ": " << stTime.second.count()
                   << "ms";
    }
    for (auto& module : modules_) {
      ms = std::chrono::duration_cast<std::chrono::milliseconds>(
          moduleTimeSpent_[module]);
      LOG(WARNING) << "  " << module->name() << ": " << ms.count() << "ms";
    }
    // Show Areainfo cache hit rate after done with all times
    auto cacheStats = state_->areaInfo().getCacheStats();
    if (std::get<1>(cacheStats) > 0) {
      LOG(WARNING) << "  AreaInfo cache hit rate: "
                   << (float)std::get<0>(cacheStats) /
              (float)std::get<1>(cacheStats);
      LOG(WARNING) << "  AreaInfo cache size: " << std::get<2>(cacheStats);
    }
    for (auto& stat : taskTimeStats) {
      LOG(WARNING) << "      Task: " << std::get<0>(stat) << " from "
                   << std::get<1>(stat) << ": " << std::get<2>(stat).count()
                   << "ms";
    }
  }

  if (realtimeFactor_ > 0) {
    auto target = combineFrames_ * kStepDurationAtFastest;
    auto timeSinceLastStep = hires_clock::now() - lastStep_;
    auto left = (target - timeSinceLastStep) / realtimeFactor_;
    if (left.count() > 0) {
      std::this_thread::sleep_for(left);
    }
  }

  size_t const logFreq = 100;
  if ((steps_ % logFreq == 0) && collectTimers_) {
    VLOG(1) << "Aggregate timings for previous " << logFreq << " steps:";
    auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(
        stateUpdateTimeSpentAgg_);
    VLOG(1) << "  State::update(): " << ms.count() << "ms";
    for (auto& module : modules_) {
      ms = std::chrono::duration_cast<std::chrono::milliseconds>(
          moduleTimeSpentAgg_[module]);
      VLOG(1) << "  " << module->name() << ": " << ms.count() << "ms";
    }

    moduleTimeSpentAgg_.clear();
    stateUpdateTimeSpentAgg_ = Duration();
  }

  lastStep_ = hires_clock::now();
  unsetLoggingFrame();
}

void Player::init() {
  steps_ = 0;

  // Initial setup
  ClientCommands comms;
  comms.emplace_back(tc::BW::Command::SetSpeed, 0);
  comms.emplace_back(tc::BW::Command::SetGui, 1);
  comms.emplace_back(tc::BW::Command::SetCombineFrames, combineFrames_);
  comms.emplace_back(tc::BW::Command::SetFrameskip, frameskip_);
  if (!client_->send(comms)) {
    throw std::runtime_error(std::string("Send failure: ") + client_->error());
  }
  lastStep_ = hires_clock::now();
  initialized_ = true;
}

void Player::run() {
  init();
  do {
    step();
  } while (!state_->gameEnded());
}

void Player::preStep() {
  std::unique_lock<std::mutex> lock(stepClientMutex_);
  tc::State stateCopy = *client_->state();
  // Provide buffered list of unit deaths
  stateCopy.deaths = unitDeaths_;
  unitDeaths_.clear();
  lock.unlock();

  std::chrono::time_point<hires_clock> start;
  if (collectTimers_) {
    start = hires_clock::now();
  }
  state_->update(std::move(stateCopy));
  if (collectTimers_) {
    auto duration = hires_clock::now() - start;
    stateUpdateTimeSpent_ = duration;
    stateUpdateTimeSpentAgg_ += duration;
  }
}

void Player::postStep() {
  if (collectStats_) {
    updateCommandStats(state_->board()->commands());

    // Check and log errors if any
    state_->board()->checkPostStep();
  }

  // Visualize our base so that we immediately know where we are on the map
  if (VLOG_IS_ON(1)) {
    if (Unit* myBase = state_->areaInfo().myBase()) {
      utils::drawCircle(state_, myBase, 50, tc::BW::Color::Blue);
      utils::drawCircle(state_, myBase, 52, tc::BW::Color::Blue);
    }
  }

  VLOG(2) << state_->board()->upcs().size() << " UPC tuples in blackboard";
}

/// Run doStep() when instructed to do so via stepStartCondition_.
/// Meant to be run in a separate thread
void Player::doStepThread() {
  std::unique_lock<std::mutex> lock(stepMutex_);
  while (true) {
    stepStartCondition_.wait(lock);
    if (exitStepThread_) {
      break;
    }
    if (!startStep_) {
      continue;
    }
    startStep_ = false;
    lock.unlock();

    // Do work
    ClientCommands commands;
    try {
      commands = doStep();
    } catch (std::exception& e) {
      LOG(ERROR) << "Caught exception: " << e.what();
    }

    lock.lock();
    commands_ = commands;
    stepDone_ = true;
    stepDoneCondition_.notify_one();
  }

  stepDone_ = true;
  stepDoneCondition_.notify_one();
}

/// Do the actual per-step work.
/// This can be run asynchronously during step() as well.
Player::ClientCommands Player::doStep() {
  std::unique_lock<std::mutex> lock(stepClientMutex_);
  setLoggingFrame(client_->state()->frame_from_bwapi);
  logFailedCommands();
  lock.unlock();

  preStep();
  if (state_->gameEnded()) {
    VLOG(1) << "Game has ended, not stepping through modules again";
    for (auto& module : modules_) {
      module->onGameEnd(state_);
    }
    return ClientCommands();
  }
  stepModules();
  postStep();

  return state_->board()->commands();
}

void Player::logFailedCommands() {
  auto lastCommands = client_->lastCommands();
  auto status = client_->lastCommandsStatus();
  for (size_t i = 0; i < status.size(); i++) {
    if (status[i] != 0) {
      auto& comm = lastCommands[i];
      int st = int(status[i]);
      if (st & 0x40) {
        // BWAPI error
        LOG(INFO) << "Command failed: "
                  << UPCToCommandModule::commandString(state_, comm) << " ("
                  << "code " << st << ", BWAPI code " << (st & ~0x40) << ")";
        // For unit commands with BWAPI "busy" error code, show some more
        // information to make debugging easier
        if (VLOG_IS_ON(1) && comm.code == +tc::BW::Command::CommandUnit &&
            (st & ~0x40) == 3) {
          std::ostringstream oss;
          auto unit = state_->unitsInfo().getUnit(comm.args[0]);
          for (auto order : unit->unit.orders) {
            oss << "(frame=" << order.first_frame << ", type=" << order.type
                << ", targetId=" << order.targetId
                << ", targetX=" << order.targetX
                << ", targetY=" << order.targetY << ") ";
          }
          VLOG(1) << "Current orders for " << utils::unitString(unit) << ": "
                  << oss.str();
          VLOG(1) << "Current flags for " << utils::unitString(unit) << ": "
                  << unit->unit.flags;
        }
      } else {
        LOG(INFO) << "Command failed: "
                  << UPCToCommandModule::commandString(state_, comm) << " ("
                  << st << ")";
      }
    }
  }
}

} // namespace fairrsh
