/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "areainfo.h"

#include "state.h"
#include "utils.h"

#include <bwem/map.h>

#include <tuple>

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<Area>("Area")(
      metadata("type", rttr::type::get<Area>()))
      .property_readonly("id", &Area::id)
      .property_readonly("x", &Area::x)
      .property_readonly("y", &Area::y)
      .property_readonly("liveUnits", &Area::liveUnits)
      .property_readonly("visibleUnits", &Area::visibleUnits)
      .property_readonly("lastExplored", &Area::lastExplored)
      .property_readonly("isMyBase", &Area::isMyBase)
      .property_readonly("isEnemyBase", &Area::isEnemyBase)
      .property_readonly("isMyExpand", &Area::isMyExpand)
      .property_readonly("isEnemyExpand", &Area::isEnemyExpand)
      .property_readonly("hasMyBuildings", &Area::hasMyBuildings)
      .property_readonly("hasEnemyBuildings", &Area::hasEnemyBuildings)
      .property_readonly("myGndStrength", &Area::myGndStrength)
      .property_readonly("myAirStrength", &Area::myAirStrength)
      .property_readonly("myDetStrength", &Area::myDetStrength)
      .property_readonly("enemyGndStrength", &Area::enemyGndStrength)
      .property_readonly("enemyAirStrength", &Area::enemyAirStrength)
      .property_readonly("enemyDetStrength", &Area::enemyDetStrength);

  rttr::registration::class_<AreaInfo>("AreaInfo")(
      metadata("type", rttr::type::get<AreaInfo>()))
      .method("areas", &AreaInfo::areas)
      .method(
          "getArea",
          rttr::select_overload<Area const*(int id) const>(
              &AreaInfo::tryGetArea))
      .method(
          "getAreaAt",
          rttr::select_overload<Area const*(Position) const>(
              &AreaInfo::tryGetArea));
}

AreaInfo::AreaInfo(State* state) : state_(state) {
  map_ = state_->map();
}

void AreaInfo::update() {
  if (map_ == nullptr) {
    map_ = state_->map();
  } else if (map_ != state_->map()) {
    throw std::runtime_error("Map data has changed in-game");
  }

  if (areas_.empty()) {
    initialize();
  }

  updateUnits();
  updateEnemyStartingLocations();
  updateStrengths();
  updateNeighbors();
}

std::tuple<size_t, size_t, size_t> AreaInfo::getCacheStats() const {
  return std::make_tuple(cacheAccess_, totalAccess_, neighborAreaCache_.size());
}

void AreaInfo::resetCacheStats() {
  cacheAccess_ = totalAccess_ = 0;
}

std::vector<Area> const& AreaInfo::areas() const {
  return areas_;
}

Area& AreaInfo::getArea(int id) {
  if (id < 1 || (size_t)id > areas_.size()) {
    throw std::runtime_error("Attempt to get invalid area");
  }
  return areas_[id - 1];
}

Area const& AreaInfo::getArea(int id) const {
  return const_cast<AreaInfo*>(this)->getArea(id);
}

Area* AreaInfo::getCachedArea(Position p) {
  // Check if the current position is an area first. If it is
  // then we don't need to call GetNearestArea below.
  auto curArea = map_->GetArea(BWAPI::WalkPosition(p.x, p.y));
  if (curArea != nullptr) {
    return &getArea(curArea->Id());
  }

  // This is only used to calculate cache hit rates. So only
  // increment if we may hit the cache.
  ++totalAccess_;

  // Linerized index, logic stolen from BWEM
  const auto pos = map_->WalkSize().x * p.y + p.x;
  bool hitCache = (neighborAreaCache_.find(pos) != neighborAreaCache_.end());
  if (!hitCache) {
    auto area = map_->GetNearestArea(BWAPI::WalkPosition(p.x, p.y));

    if (area == nullptr) {
      // Return nullptr explicitly since area is not of the type farirrsh::Area
      return nullptr;
    }
    neighborAreaCache_[pos] = area->Id();
  } else {
    ++cacheAccess_;
  }

  return &getArea(neighborAreaCache_[pos]);
}

Area& AreaInfo::getArea(Position p) {
  auto area = getCachedArea(p);

  if (area == nullptr) {
    throw std::runtime_error("No area at or near this position");
  } else {
    return *area;
  }
}

Area const& AreaInfo::getArea(Position p) const {
  return const_cast<AreaInfo*>(this)->getArea(p);
}

Area* AreaInfo::tryGetArea(int id) {
  if (id < 1 || (size_t)id > areas_.size()) {
    return nullptr;
  }
  return &areas_[id - 1];
}

Area const* AreaInfo::tryGetArea(int id) const {
  return const_cast<AreaInfo*>(this)->tryGetArea(id);
}

Area* AreaInfo::tryGetArea(Position p) {
  return getCachedArea(p);
}

Area const* AreaInfo::tryGetArea(Position p) const {
  return const_cast<AreaInfo*>(this)->tryGetArea(p);
}

Unit* AreaInfo::myBase() const {
  return myExpand(0);
}

Area const* AreaInfo::myBaseArea() const {
  return myExpandArea(0);
}

Position AreaInfo::myBasePosition() const {
  return myExpandPosition(0);
}

int AreaInfo::numExpands() const {
  return myBases_.size() - 1;
}

Unit* AreaInfo::myExpand(int n) const {
  if (n < 0 || (size_t)n >= myBases_.size()) {
    return nullptr;
  }
  return myBases_[n];
}

Area const* AreaInfo::myExpandArea(int n) const {
  Unit* expand = myExpand(n);
  if (expand == nullptr) {
    return nullptr;
  }
  return tryGetArea(expand);
}

Position AreaInfo::myExpandPosition(int n) const {
  Unit* expand = myExpand(n);
  if (expand == nullptr) {
    return Position(-1, -1);
  }
  return Position(expand->x, expand->y);
}

int AreaInfo::myClosestExpand(Position const& p) const {
  int closest = -1;
  float mind = std::numeric_limits<float>::infinity();
  for (size_t i = 0; i < myBases_.size(); i++) {
    auto d = utils::distance(p, myBases_[i]);
    if (d < mind) {
      closest = int(i);
      mind = d;
    }
  }
  return closest;
}

std::vector<Unit*> AreaInfo::myExpandResources(int n) const {
  std::vector<Unit*> resources;
  Unit* expand = myExpand(n);
  if (expand == nullptr) {
    return resources;
  }
  auto baseInfoIt = myBasesInfo_.find(expand);
  if (baseInfoIt == myBasesInfo_.end()) {
    LOG(WARNING) << "Missing entry in myBasesInfo_";
    return resources;
  }

  // We could cache per-base resources but would need to update them
  // periodically as mineral patches disappear once they've been fully mined.
  auto& baseInfo = baseInfoIt->second;
  auto* bwemArea = map_->GetArea(baseInfo.areaId);
  if (bwemArea == nullptr) {
    LOG(WARNING) << "Invalid area ID: " << baseInfo.areaId;
    return resources;
  }
  if (baseInfo.baseId < 0 ||
      (size_t)baseInfo.baseId >= bwemArea->Bases().size()) {
    LOG(WARNING) << "Invalid base ID: " << baseInfo.baseId;
    return resources;
  }
  auto& bwemBase = bwemArea->Bases()[baseInfo.baseId];

  auto& unitsInfo = state_->unitsInfo();
  for (auto* mineral : bwemBase.Minerals()) {
    Unit* unit = unitsInfo.getUnit(mineral->Unit()->getID());
    if (unit == nullptr) {
      LOG(ERROR) << "Null unit from BWEM u" << mineral->Unit()->getID();
    } else if (!unit->type->isMinerals) {
      LOG(ERROR) << "BWEM mineral is not actually a mineral: " << utils::unitString(unit);
    } else {
      resources.push_back(unit);
    }
  }

  for (auto* geyser : bwemBase.Geysers()) {
    Unit* unit = unitsInfo.getUnit(geyser->Unit()->getID());
    if (unit == nullptr) {
      LOG(ERROR) << "Null unit from BWEM u" << geyser->Unit()->getID();
    } else if (!unit->type->isGas) {
      LOG(ERROR) << "BWEM geyser is not actually gas: " << utils::unitString(unit);
    } else {
      resources.push_back(unit);
    }
  }

  return resources;
}

float AreaInfo::mineralSaturationAtMyExpand(int n) const {
  std::vector<Unit*> resources;
  Unit* expand = myExpand(n);
  if (expand == nullptr) {
    LOG(INFO) << "No such expand: " << n;
    return 0.0f;
  }
  auto baseInfoIt = myBasesInfo_.find(expand);
  if (baseInfoIt == myBasesInfo_.end()) {
    LOG(WARNING) << "Missing entry in myBasesInfo_";
    return 0.0f;
  }

  return baseInfoIt->second.mineralSaturation;
}

void AreaInfo::setMineralSaturationAtMyExpand(int n, float sat) {
  // XXX This should ideally be computed in update() but we don't have access to
  // the exact data. Hence, let GathererModule do it...
  std::vector<Unit*> resources;
  Unit* expand = myExpand(n);
  if (expand == nullptr) {
    LOG(INFO) << "No such expand: " << n;
    return;
  }
  auto baseInfoIt = myBasesInfo_.find(expand);
  if (baseInfoIt == myBasesInfo_.end()) {
    LOG(WARNING) << "Missing entry in myBasesInfo_";
    return;
  }

  baseInfoIt->second.mineralSaturation = sat;
}

bool AreaInfo::foundEnemyBase() const {
  return candidateEnemyStartLoc_.size() == 1;
}

std::vector<Position> const& AreaInfo::candidateEnemyStartLoc() const {
  return candidateEnemyStartLoc_;
}

Unit* AreaInfo::enemyBase() const {
  return enemyBase_;
}

Area const& AreaInfo::enemyBaseArea() const {
  if (enemyBase_ == nullptr) {
    throw std::runtime_error("Have not found enemy base yet");
  }
  return getArea(enemyBase_);
}

Position AreaInfo::enemyBasePosition() const {
  if (enemyBase_ == nullptr) {
    throw std::runtime_error("Have not found enemy base yet");
  }
  return Position(enemyBase_->x, enemyBase_->y);
}

void AreaInfo::initialize() {
  areas_.clear();

  auto& mapAreas = map_->Areas();
  areas_.resize(mapAreas.size());
  for (size_t i = 0; i < mapAreas.size(); i++) {
    areas_[i].areaInfo = this;
    areas_[i].id = mapAreas[i].Id();
    areas_[i].area = &(mapAreas[i]);
    // We rely on area IDs being equal to the index of the area in the vector
    // (plus one). This is an implementation detail of BWEM, so better check
    // for it.
    if (areas_[i].id != (int)(i + 1)) {
      throw std::runtime_error("Unexpected Area ID");
    }

    // Compute center of bounding box
    auto topLeft = mapAreas[i].TopLeft();
    auto bottomRight = mapAreas[i].BottomRight();
    areas_[i].x = (topLeft.x + (bottomRight.x - topLeft.x) / 2) *
        unsigned(tc::BW::XYWalktilesPerBuildtile);
    areas_[i].y = (topLeft.y + (bottomRight.y - topLeft.y) / 2) *
        unsigned(tc::BW::XYWalktilesPerBuildtile);
    areas_[i].size = mapAreas[i].MiniTiles();

    for (auto& base : mapAreas[i].Bases()) {
      BWAPI::WalkPosition pos(base.Center());
      areas_[i].baseLocations.emplace_back(pos.x, pos.y);
    }
  }

  // initialize possible enemy locations and areas
  if (state_->board()->hasKey(Blackboard::kMyLocationKey)) {
    auto myPos = state_->board()->get<Position>(Blackboard::kMyLocationKey);
    auto& myArea = getArea(myPos);
    myArea.isMyBase = true;
  }
  if (state_->board()->hasKey(Blackboard::kEnemyLocationKey)) {
    auto nmyPos = state_->board()->get<Position>(Blackboard::kEnemyLocationKey);
    candidateEnemyStartLoc_.push_back(nmyPos);
    auto& nmyArea = getArea(nmyPos);
    nmyArea.isEnemyBase = true;
    VLOG(1) << "scouting info: enemy base known from the start by blackboard";
  } else {
    Position myLoc = Position(-1, -1);
    if (state_->board()->hasKey(Blackboard::kMyLocationKey)) {
      myLoc = state_->board()->get<Position>(Blackboard::kMyLocationKey);
    }
    for (auto& loc : state_->tcstate()->start_locations) {
      if (loc.x != myLoc.x || loc.y != myLoc.y) {
        candidateEnemyStartLoc_.push_back(Position(loc.x, loc.y));
        auto& nmyArea = getArea(loc);
        nmyArea.isPossibleEnemyBase = true;
      }
    }
    if (candidateEnemyStartLoc_.size() == 0) {
      LOG(WARNING) << "no possible enemy starting locations";
    }
    // this set is redundant with the one in updateEnemyStartingLocations
    // but left here to make sure the initialization is correct by itself
    if (candidateEnemyStartLoc_.size() == 1) {
      VLOG(1)
          << "scouting info: enemy base known from the start by elimination";
      auto nmyLoc = candidateEnemyStartLoc_[0];
      auto& nmyArea = getArea(nmyLoc);
      nmyArea.isEnemyBase = true;
      state_->board()->post(Blackboard::kEnemyLocationKey, nmyLoc);
    }
  }
}

void AreaInfo::updateUnits() {
  for (auto& area : areas_) {
    area.liveUnits.clear();
    area.visibleUnits.clear();
    area.isMyExpand = false;
    area.isEnemyExpand = false;
    area.hasMyBuildings = false;
    area.hasEnemyBuildings = false;
    // We don't clear the is*Base flags
  }

  // Resets the stats before we use GetArea()
  resetCacheStats();

  auto myPos = state_->board()->get<Position>(Blackboard::kMyLocationKey);
  auto myArea = getArea(myPos);
  auto frame = state_->currentFrame();

  // Collect info about live units
  // TODO: Some of the code below only needs to run for getShowUnits() really
  for (auto& unit : state_->unitsInfo().liveUnits()) {
    if (unit->type->isSpecialBuilding) {
      continue;
    }

    Area& area = getArea(unit);
    area.liveUnits.push_back(unit);
    if (unit->visible) {
      area.visibleUnits.push_back(unit);
    }
    if (unit->isMine) {
      area.lastExplored = frame;
    }

    if (unit->type->isResourceDepot) {
      if (unit->isMine) {
        if (unit->completed() &&
            (macroDepots_.find(unit) == macroDepots_.end() &&
             myBasesInfo_.find(unit) == myBasesInfo_.end())) {
          // Is unit placed at possible base location?
          int correspondingBaseLoc = -1;
          auto upos = Position(unit);
          for (size_t i = 0; i < area.baseLocations.size(); i++) {
            if (utils::distance(upos, area.baseLocations[i]) <=
                tc::BW::XYWalktilesPerBuildtile * 2) {
              // Yup
              correspondingBaseLoc = (int)i;
              break;
            }
          }

          if (correspondingBaseLoc >= 0) {
            VLOG(1) << "Registered new base (expand " << myBases_.size()
                    << "): " << utils::unitString(unit) << " at " << unit->x
                    << "," << unit->y;
            myBases_.push_back(unit);
            myBasesInfo_[unit] = {area.id, correspondingBaseLoc};
            area.wasMyExpand = true;
          } else {
            macroDepots_.insert(unit);
          }
        }
      } else {
        // check for enemybase is done afterwards, enemyBase_ may be one update
        // late
        if (enemyBase_ == nullptr && area.isEnemyBase) {
          enemyBase_ = unit;
        } else if (!area.isEnemyBase) {
          area.isEnemyExpand = true;
          area.wasEnemyExpand = true;
        }
      }
    }

    if (unit->type->isBuilding) {
      if (unit->isMine) {
        area.hasMyBuildings = true;
      } else {
        area.hasEnemyBuildings = true;
      }
    }
  }
}

void AreaInfo::updateStrengths() {
  auto unitValue = [](Unit* u) {
    // Heuristic from Gab's thesis:
    // http://emotion.inrialpes.fr/people/synnaeve/phdthesis/phdthesis.html#x1-131002r2
    return u->type->mineralCost + 4.0 / 3.0 * u->type->gasCost +
        50 * u->type->supplyRequired;
  };

  for (auto& area : areas_) {
    area.myGndStrength = 0;
    area.myAirStrength = 0;
    area.myDetStrength = 0;
    area.enemyGndStrength = 0;
    area.enemyAirStrength = 0;
    area.enemyDetStrength = 0;

    for (Unit* unit : area.liveUnits) {
      auto type = unit->type;
      if (unit->isMine) {
        if (type->hasGroundWeapon) {
          area.myGndStrength += unitValue(unit);
        }
        if (type->hasAirWeapon) {
          area.myAirStrength += unitValue(unit);
        }
        // TODO Include detector buildings based on area size?
        if (type->isDetector && !type->isBuilding) {
          area.myDetStrength += unitValue(unit);
        }
      } else {
        if (type->hasGroundWeapon) {
          area.enemyGndStrength += unitValue(unit);
        }
        if (type->hasAirWeapon) {
          area.enemyAirStrength += unitValue(unit);
        }
        // TODO Include detector buildings based on area size?
        if (type->isDetector && !type->isBuilding) {
          area.enemyDetStrength += unitValue(unit);
        }
      }
    }
  }
}

void AreaInfo::updateNeighbors() {
  auto& mapAreas = map_->Areas();
  assert(areas_.size() == mapAreas.size());
  for (size_t i = 0; i < mapAreas.size(); i++) {
    auto const& neighbors = mapAreas[i].AccessibleNeighbours();
    // XXX This check may not be super robust
    if (neighbors.size() != areas_[i].neighbors.size()) {
      areas_[i].neighbors.resize(neighbors.size());
      for (size_t j = 0; j < neighbors.size(); j++) {
        areas_[i].neighbors[j] = tryGetArea(neighbors[j]->Id());
        assert(areas_[i].neighbors[j] != nullptr);
      }
    }
  }
}

// checks if a starting location is confirmed not to be that of the opopnent
void AreaInfo::updateEnemyStartingLocations() {
  // we found the enemy base by elimination
  if (foundEnemyBase()) {
    // nothing to do at this stage
    return;
  }
  Position nmyPos;
  int nmyAreaId = -1;
  for (auto it = candidateEnemyStartLoc_.begin();
       it != candidateEnemyStartLoc_.end();) {
    auto pos = *it;
    auto& checkArea = getArea(pos);
    // we found the base
    if (checkArea.hasEnemyBuildings) {
      nmyPos = pos;
      nmyAreaId = checkArea.id;
      break;
    }
    if (state_->tilesInfo().getTile(pos.x, pos.y).visible) {
      checkArea.isPossibleEnemyBase = false;
      it = candidateEnemyStartLoc_.erase(it);
    } else {
      it++;
    }
  }
  // base found by elimination
  if (candidateEnemyStartLoc_.size() == 1) {
    auto pos = candidateEnemyStartLoc_[0];
    auto& checkArea = getArea(pos);
    nmyPos = pos;
    nmyAreaId = checkArea.id;
  }
  if (nmyAreaId >= 0) { // cleanup if we found
    for (auto& area : areas_) {
      if (area.id != nmyAreaId) {
        area.isPossibleEnemyBase = false;
      } else {
        area.isEnemyBase = true;
        area.isEnemyExpand = false;
      }
    }
    candidateEnemyStartLoc_.clear();
    candidateEnemyStartLoc_.push_back(nmyPos);
    state_->board()->post(Blackboard::kEnemyLocationKey, nmyPos);
    VLOG(1) << "Enemy location found at " << nmyPos.x << ", " << nmyPos.y;
    if (VLOG_IS_ON(3)) {
      for (auto area : areas()) {
        if (area.id != nmyAreaId && area.isEnemyBase) {
          LOG(ERROR) << "more than one enemy area";
        }
      }
    }
  }

  // debug
  if (VLOG_IS_ON(3) && foundEnemyBase()) {
    auto nmyPos = candidateEnemyStartLoc_[0];
    auto& nmyArea = getArea(nmyPos);
    for (auto area : areas()) {
      if (area.id != nmyArea.id) {
        if (area.isEnemyBase) {
          LOG(ERROR) << "area improperly marked as enemyBase";
        }
        if (area.isPossibleEnemyBase) {
          LOG(ERROR) << "area improperly marked as possible enemyBase";
        }
      } else {
        if (area.isMyBase) {
          LOG(ERROR) << "enemy area marked as my base";
        }
        if (!area.isEnemyBase) {
          LOG(ERROR) << "enemy base not marked as such";
        }
      }
    }
  }
}

} // namespace fairrsh
