/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "state.h"

#include "upcfilter.h"
#include "utils.h"

#include <BWAPI.h>
#include <bwem/bwem.h>
#include <glog/logging.h>
#include <tcbwapi/tcgame.h>
#include <tcbwapi/tcunit.h>

#include <utility>

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<State>("State")(
      metadata("type", rttr::type::get<State>()))
      .property_readonly("mapWidth", &State::mapWidth)
      .property_readonly("mapHeight", &State::mapHeight)
      .property_readonly("currentFrame", &State::currentFrame)
      .property_readonly("gameEnded", &State::gameEnded)
      .property_readonly("won", &State::won)
      .property_readonly("lost", &State::lost)
      .property_readonly("unitsInfo", &State::unitsInfo)
      .property_readonly("tilesInfo", &State::tilesInfo)
      .property_readonly("areaInfo", &State::areaInfo);
}

State::State(tc::State tcstate)
    : tcstate_(std::move(tcstate)),
      board_(new Blackboard(this)),
      mapWidth_(tcstate_.map_size[0]),
      mapHeight_(tcstate_.map_size[1]) {
  currentFrame_ = tcstate_.frame_from_bwapi;
  playerId_ = tcstate_.player_id;
  neutralId_ = tcstate_.neutral_id;

  board_->addUPCFilter(std::make_shared<AssignedUnitsFilter>());

  initTechnologyStatus();
  initUpgradeStatus();
  findEnemyName();
  findEnemyRace();
  board_->init();
}

State::~State() {
  delete board_;
}

tc::Resources State::resources() const {
  return resources_;
}

void State::addPlannedResources(tc::Resources res) {
  plannedResourceUsage_.used_psi += res.used_psi;
  plannedResourceUsage_.ore += res.ore;
  plannedResourceUsage_.gas += res.gas;
  resources_.ore -= res.ore;
  resources_.gas -= res.gas;
  resources_.used_psi += res.used_psi;
}

void State::removePlannedResources(tc::Resources res) {
  plannedResourceUsage_.used_psi -= res.used_psi;
  plannedResourceUsage_.ore -= res.ore;
  plannedResourceUsage_.gas -= res.gas;
  resources_.ore += res.ore;
  resources_.gas += res.gas;
  resources_.used_psi -= res.used_psi;
}

bool State::hasResearched(const BuildType* tech) const {
  if (!tech) {
    LOG(ERROR) << "Null pointer to tech type";
    return false;
  }
  auto tech_it = tech2StatusMap_.find(tech->tech);
  if (tech_it != tech2StatusMap_.end()) {
    return tech_it->second;
  }
  LOG(ERROR) << "Tech status requested for an unknown tech " << tech->tech;
  return false;
}

UpgradeLevel State::getUpgradeLevel(const BuildType* upgrade) const {
  if (!upgrade) {
    LOG(ERROR) << "Null pointer to upgrade type";
    return false;
  }
  auto upg_it = upgrade2LevelMap_.find(upgrade->upgrade);
  if (upg_it != upgrade2LevelMap_.end()) {
    return upg_it->second;
  }
  LOG(ERROR) << "Upgrade level requested for an unknown upgrade "
             << upgrade->upgrade;
  return 0;
}

void State::setCollectTimers(bool collect) {
  collectTimers_ = collect;
}

void State::update(tc::State tcstate) {
  tcstate_ = std::move(tcstate);

  // Nuke before starting to collect timings
  stateUpdateTimeSpent_.clear();

  if (map_ == nullptr) {
    map_ = BWEM::Map::Make();
    VLOG(1) << "Running BWEM analysis...";
    tcbGame_ = std::make_unique<tcbwapi::TCGame>();
    tcbGame_->setState(&tcstate_);
    map_->Initialize(tcbGame_.get());
    map_->EnableAutomaticPathAnalysis();
    if (!map_->FindBasesForStartingLocations()) {
      LOG(INFO) << "Failed to find base locations with BWEM";
    }
    VLOG(1) << "Analysis done, found " << map_->Areas().size() << " areas and "
            << map_->ChokePointCount() << " choke points";
  }

  currentFrame_ = tcstate_.frame_from_bwapi;

  // Update id -> unit mapping
  units_.clear();
  for (auto& e : tcstate_.frame->units) {
    for (auto& unit : e.second) {
      // Ignore "unknown" unit types that will just lead to confusion regarding
      // state checks and module functionality.
      auto ut = tc::BW::UnitType::_from_integral_nothrow(unit.type);
      if (ut) {
        units_[unit.id] = &unit;
      }
    }
  }

  tilesInfo_.preUnitsUpdate();
  std::chrono::time_point<hires_clock> start;

  if (collectTimers_) {
    start = hires_clock::now();
  }
  unitsInfo_.update();
  if (collectTimers_) {
    auto timeTaken = hires_clock::now() - start;
    stateUpdateTimeSpent_.push_back(
        std::make_pair("UnitsInfo::update()", timeTaken));
  }
  tilesInfo_.postUnitsUpdate();

  // Determine our location
  if (!checkedForLocations_) {
    findMyLocation();
    findEnemyLocation();
    checkedForLocations_ = true;
  }
  if (!sawFirstEnemyUnit_) {
    for (auto eunit : unitsInfo().enemyUnits()) {
      board_->post(Blackboard::kEnemyRaceKey, eunit->type->race);
      sawFirstEnemyUnit_ = true;
      break;
    }
  }

  updateBWEM();
  start = hires_clock::now();
  areaInfo_.update();
  if (collectTimers_) {
    auto timeTaken = hires_clock::now() - start;
    stateUpdateTimeSpent_.push_back(
        std::make_pair("AreaInfo::update()", timeTaken));
  }

  updateTechnologyStatus();
  updateUpgradeStatus();
  updateTrackers();

  if (collectTimers_) {
    start = hires_clock::now();
  }
  board_->update();
  if (collectTimers_) {
    auto timeTaken = hires_clock::now() - start;
    stateUpdateTimeSpent_.push_back(
        std::make_pair("Board::update()", timeTaken));
  }

  auto pid = tcstate_.player_id;
  resources_ = tcstate_.frame->resources[pid];
  VLOG(4) << "Resources from TorchCraft:   "
          << utils::resourcesString(resources_);
  resources_.ore -= plannedResourceUsage_.ore;
  resources_.gas -= plannedResourceUsage_.gas;
  resources_.used_psi += plannedResourceUsage_.used_psi;
  VLOG(4) << "Resources including planned: "
          << utils::resourcesString(resources_);
}

bool State::gameEnded() const {
  return tcstate_.game_ended;
}

bool State::won() const {
  if (tcstate_.game_ended && tcstate_.game_won) {
    return const_cast<State*>(this)->unitsInfo().myBuildings().size() > 1;
  } else {
    return false;
  }
}

bool State::lost() const {
  return tcstate_.game_ended && !tcstate_.game_won;
}

void State::initTechnologyStatus() {
  const auto& techs = buildtypes::allTechTypes;
  for (auto* tech : techs) {
    if (!tech) {
      LOG(ERROR) << "Null pointer encountered when querying all techs";
    } else if (tech2StatusMap_.find(tech->tech) != tech2StatusMap_.end()) {
      LOG(ERROR) << "Multiple techs with the same ID encountered "
                 << "when querying all techs (" << tech->tech << ")";
    } else {
      tech2StatusMap_[tech->tech] = false;
    }
  }
}

void State::initUpgradeStatus() {
  const auto& upgrades = buildtypes::allUpgradeTypes;
  for (const auto* upg : upgrades) {
    if (!upg) {
      LOG(ERROR) << "Null pointer encountered when querying all upgrades";
    } else {
      // different levels of the same upgrade are represented by different
      // build types with the same ID, so check for ID uniqueness is
      // invalid in this case
      upgrade2LevelMap_[upg->upgrade] = 0;
    }
  }
}

void State::updateBWEM() {
  if (map_ == nullptr || tcbGame_ == nullptr) {
    return;
  }

  // Update BWEM instance with destroyed minerals and special buildings
  for (auto unit : unitsInfo_.getDestroyUnits()) {
    if (unit->type->isMinerals) {
      BWAPI::Unit bwu = tcbGame_->getUnit(unit->id);
      if (bwu == nullptr) {
        LOG(WARNING) << "Destroyed unit " << utils::unitString(unit)
                     << " is unknown to TC game wrapper";
      } else {
        try {
          map_->OnMineralDestroyed(bwu);
        } catch (std::exception& e) {
          LOG(WARNING) << "Exception removing mineral from BWEM map: "
                       << e.what();
        }
      }
    } else if (unit->type->isSpecialBuilding) {
      BWAPI::Unit bwu = tcbGame_->getUnit(unit->id);
      if (bwu == nullptr) {
        LOG(WARNING) << "Destroyed unit " << utils::unitString(unit)
                     << " is unknown to TC game wrapper";
      } else {
        try {
          map_->OnStaticBuildingDestroyed(bwu);
        } catch (std::exception& e) {
          LOG(WARNING) << "Exception removing special building from BWEM map: "
                       << e.what();
        }
      }
    }
  }
}

void State::updateTechnologyStatus() {
  for (auto& tech2status : tech2StatusMap_) {
    auto tt = tc::BW::TechType::_from_integral_nothrow(tech2status.first);
    if (!tt) {
      continue;
    }
    if (!tech2status.second && tcstate_.hasResearched(*tt)) {
      tech2status.second = true;
    }
  }
}

void State::updateUpgradeStatus() {
  for (auto& upgrade2level : upgrade2LevelMap_) {
    auto ut = tc::BW::UpgradeType::_from_integral_nothrow(upgrade2level.first);
    if (!ut) {
      continue;
    }
    upgrade2level.second = tcstate_.getUpgradeLevel(*ut);
  }
}

void State::updateTrackers() {
  auto it = trackers_.begin();
  while (it != trackers_.end()) {
    auto tracker = *it;
    tracker->update(this);

    switch (tracker->status()) {
      case TrackerStatus::Timeout:
        VLOG(1) << "Timeout for tracker";
        break;
      case TrackerStatus::Success:
        VLOG(1) << "Tracker reported success";
        break;
      case TrackerStatus::Failure:
        VLOG(1) << "Tracker reported failure";
        break;
      case TrackerStatus::Cancelled:
        VLOG(1) << "Tracker was cancelled";
        break;
      default:
        // Keep tracker, advance in loop
        ++it;
        continue;
    }

    trackers_.erase(it++);
  }
}

void State::findMyLocation() {
  // Find our initial resource depot and determine the closest start location.
  Unit* depot = nullptr;
  for (auto unit : unitsInfo_.myUnits()) {
    if (unit->type->isResourceDepot) {
      depot = unit;
      break;
    }
  }

  Position pos(0, 0);
  if (depot) {
    float minDistance = std::numeric_limits<float>::max();
    for (auto& loc : tcstate_.start_locations) {
      auto dist = utils::distance(loc.x, loc.y, depot->x, depot->y);
      if (dist < minDistance) {
        minDistance = dist;
        pos = Position(loc.x, loc.y);
      }
    }

    if (pos.x <= 0 || pos.y <= 0) {
      LOG(INFO) << "Start location not available";
    }
  } else {
    LOG(INFO) << "No resource depot found, can't determine start location";
  }

  board_->post(Blackboard::kMyLocationKey, pos);
}

void State::findEnemyLocation() {
  // On two-player maps we know where the enemy is located
  if (tcstate_.start_locations.size() != 2) {
    return;
  }

  auto myLoc = board_->get<Position>(Blackboard::kMyLocationKey);
  for (auto& loc : tcstate_.start_locations) {
    if (loc.x != myLoc.x && loc.y != myLoc.y) {
      board_->post(Blackboard::kEnemyLocationKey, Position(loc.x, loc.y));
    }
  }
}

void State::findEnemyRace() {
  for (size_t pid = 0; pid < tcstate_.player_races.size(); ++pid) {
    if (pid == playerId_ || pid == tcstate_.neutral_id)
      continue;
    auto pr = tcstate_.player_races[pid];
    if (pr == tc::BW::Race::Zerg || pr == tc::BW::Race::Terran ||
        pr == tc::BW::Race::Protoss || pr == tc::BW::Race::Random) {
      // if the key is already there it means either that:
      // - checkedForLocations_ was reset (bug)
      // - there is strictly more than 1 enemy
      //   (and then this logic should change)
      assert(!board_->hasKey(Blackboard::kEnemyRaceKey));
      board_->post(Blackboard::kEnemyRaceKey, pr);
    } else {
      board_->post(Blackboard::kEnemyRaceKey,
          (+tc::BW::Race::Unknown)._to_integral());
    }
  }
}

void State::findEnemyName() {
  std::string ename = "NONAME";
  for (size_t pid = 0; pid < tcstate_.player_names.size(); ++pid) {
    if (pid == playerId_ || pid == tcstate_.neutral_id)
      continue;
    auto pn = tcstate_.player_names[pid];
    if (pn != "NONAME")
      ename = pn;
  }
  board_->post(Blackboard::kEnemyNameKey, ename);
}

tc::BW::Race State::getRaceFromClient(PlayerId playerId) {
  auto race = tc::BW::Race::_from_integral_nothrow(
      client_->state()->player_races[playerId]);
  if (race) {
    return *race;
  }
  return tc::BW::Race::Unknown;
}

} // namespace fairrsh
