/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "modules/scouting.h"

#include "areainfo.h"
#include "commandtrackers.h"
#include "state.h"
#include "task.h"
#include "unitsinfo.h"
#include "utils.h"
#include "movefilters.h"

#include <bwem/map.h>

#include <deque>

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<ScoutingModule>("ScoutingModule")(
      metadata("type", rttr::type::get<ScoutingModule>()))
      .constructor();
}

namespace {

class ScoutingTask : public Task {
 public:
  ScoutingTask(int upcId, Unit* unit, Position location, ScoutingGoal goal)
      : Task(upcId, {unit}), location_(location), goal_(goal) {
    setStatus(TaskStatus::Ongoing);
  }

  void update(State* state) override {
    auto loc = location();
    auto& tgtArea = state->areaInfo().getArea(loc);
    targetVisited_ =
        (tgtArea.isEnemyBase || !tgtArea.isPossibleEnemyBase ||
         foundBlockedChoke(state));
    if (!proxiedUnits().empty()) { // debug: keep track of reallocations
      auto unit = *proxiedUnits().begin();
      auto taskData = state->board()->taskDataWithUnit(unit);
      if (!taskData.task && !unit->dead) {
        VLOG(3) << "scout " << utils::unitString(unit)
                << " reassigned to no task";
      }
      if (taskData.task && taskData.task.get() != this) {
        VLOG(2) << "scout " << utils::unitString(unit)
                << " reassigned to module " << taskData.owner->name();
      }
      if (unit->dead) {
        VLOG(3) << "scout " << utils::unitString(unit) << " dead ";
      }
    }

    // Now check the failure case. If all my units died
    // then this task failed
    removeDeadOrReassignedUnits(state);
    if (units().empty()) {
      setStatus(TaskStatus::Failure);
      return;
    }
    auto unit = *units().begin();
    for (auto bldg : state->unitsInfo().visibleEnemyUnits()) {
      if (bldg->type->isBuilding &&
          utils::distance(bldg->x, bldg->y, unit->x, unit->y) <=
              unit->sightRange) {
        targetScouted_ = true;
        break;
      }
    }
  }

  Unit* getUnit() {
    if (units().empty()) {
      LOG(DFATAL) << "getting the unit of a task without unit";
    }
    return *units().begin();
  }

  Position location() const {
    return location_;
  }

  ScoutingGoal goal() const {
    return goal_;
  }

  bool satisfiesGoal() {
    switch (goal_) {
      case ScoutingGoal::ExploreEnemyBase:
        return targetScouted_;
      case ScoutingGoal::FindEnemyBase:
      case ScoutingGoal::FindEnemyExpand:
        return targetVisited_;
      case ScoutingGoal::Automatic:
        LOG(ERROR) << "invalid goal specification "
                   << "in check that the goal is satisfied for scouting task "
                   << upcId();
        setStatus(TaskStatus::Failure);
        return true;
    }
    // cannot be reached -- avoid warning
    return true;
  }

  void resetLocation(Position const& pos) {
    if (pos == location_) {
      LOG(WARNING) << "reseting a scouting task with the same location";
    }
    location_ = pos;
    targetVisited_ = false;
    targetScouted_ = false;
  }

  bool foundBlockedChoke(State* state) {
    // the chokepoint is considered "blocked" if there is
    // a chokepoint of the target area for which some units attack us
    // what to do if attacked at some other location ? we need a bit of micro
    // for that
    if (units().empty()) {
      LOG(ERROR) << "updating a finished scouting task";
    }
    auto unit = getUnit();
    if (unit->beingAttackedByEnemies.empty()) {
      // if we're not attacked, it's not blocked
      return false;
    }

    // Heuristic value
    const int distanceFromChokePoint = 42;

    auto areaInfo = state->areaInfo();
    auto tgt = location();
    auto targetArea =
        state->map()->GetNearestArea(BWAPI::WalkPosition(tgt.x, tgt.y));
    for (auto other : unit->beingAttackedByEnemies) {
      for (auto& choke : targetArea->ChokePoints()) {
        auto cwtp = choke->Center();
        if (utils::distance(cwtp.x, cwtp.y, other->x, other->y) <
            distanceFromChokePoint) {
          return true;
        }
      }
    }
    return false;
  }

 protected:
  Position location_;
  ScoutingGoal goal_;
  bool targetVisited_ = false;
  bool targetScouted_ = false;
};

} // namespace

void ScoutingModule::setScoutingGoal(ScoutingGoal goal) {
  scoutingGoal_ = goal;
}

bool ScoutingModule::automaticScoutingPolicy(State* state) {
  if (state->board()->hasKey(Blackboard::kBuilderScoutingPolicyKey)) {
    return state->board()->get<bool>(Blackboard::kBuilderScoutingPolicyKey);
  }
  return false;
}

ScoutingGoal ScoutingModule::goal(State* state) const {
  if (scoutingGoal_ != ScoutingGoal::Automatic) {
    return scoutingGoal_;
  } else if (!state->areaInfo().foundEnemyBase()) {
    return ScoutingGoal::FindEnemyBase;
  } else if (
      state->areaInfo().foundEnemyBase() && !state->areaInfo().enemyBase()) {
    return ScoutingGoal::ExploreEnemyBase;
  } else {
    return ScoutingGoal::FindEnemyExpand;
  }
}

void ScoutingModule::step(State* state) {
  // set policy based on blackboard
  if (automaticScoutingPolicy(state)) {
    maxNbWorkers_ = 0;
    maxNbExplorers_ = 0;
  } else {
    maxNbWorkers_ = vScoutingMaxNbWorkers;
    maxNbExplorers_ = vScoutingMaxNbExplorers;
  }
  // do the higher-level job
  CreateMainUPCs(state);

  updateLocations(
      state, startingLocations_, state->areaInfo().candidateEnemyStartLoc());

  auto board = state->board();
  // clean up tasks
  for (auto task : board->tasksOfModule(this)) {
    if (!task->finished()) {
      task->update(state); // check re-assignment at this step
    }
    if (task->finished()) {
      continue;
    }
    auto stask = std::static_pointer_cast<ScoutingTask>(task);
    auto unit = stask->getUnit();
    if (stask->satisfiesGoal()) {
      if (stask->goal() == ScoutingGoal::FindEnemyBase &&
          goal(state) == ScoutingGoal::FindEnemyBase) {
        auto pos = stask->location();
	auto tgt = nextScoutingLocation(state, unit, startingLocations_);
        if (tgt == pos) {
          LOG(ERROR) << "reseting scouting task with same location with "
                     << startingLocations_.size() << " candidate locations."
                     << " Do we know the enemyBase (check areaInfo)? "
                     << state->areaInfo().foundEnemyBase()
                     << ". Do we know the enemy location (check state)? "
                     << board->hasKey(Blackboard::kEnemyLocationKey)
                     << ". according to areaInfo: "
                     << state->areaInfo().foundEnemyBase()
                     << " current scouting goal " << (int)goal(state);
        }
        stask->resetLocation(tgt);
        if (postMoveUPC(state, stask->upcId(), unit, tgt)) {
          VLOG(3) << "starting location " << pos.x << ", " << pos.y
                  << " visited"
                  << " sending scout " << utils::unitString(unit)
                  << " to next location: " << tgt.x << ", " << tgt.y;
          startingLocations_[tgt] = state->currentFrame();
        } else {
          // what to do here
          LOG(WARNING) << "move to location " << tgt.x << ", " << tgt.y
                       << " for scout " << utils::unitString(unit)
                       << " filtered by the blackboard, canceling task "
                       << stask->upcId();
          stask->cancel(state);
        }
      } else { // no need to update on explore
        stask->setStatus(TaskStatus::Success);
        VLOG(3) << "scouting task " << stask->upcId()
                << " marked as succeedded";
      }
    } else {
      postMoveUPC(state, stask->upcId(), unit, stask->location());
    }
  }

  // consume UPCs
  // all UPCs at a given time will be set using the current module's goal
  // since the UPC does not directly allow for goal specification
  for (auto upcPair : board->upcsWithSharpCommand(Command::Scout)) {
    if (upcPair.second->unit.empty()) {
      LOG(ERROR) << "main scouting UPC without unit specification -- consuming "
                    "but ignoring";
      board->consumeUPC(upcPair.first, this);
      continue;
    }
    Unit* unit = nullptr;
    switch (goal(state)) {
      case ScoutingGoal::FindEnemyBase:
        unit = findUnit(state, upcPair.second->unit, Position(-1, -1));
        break;
      case ScoutingGoal::FindEnemyExpand:
      case ScoutingGoal::ExploreEnemyBase:
        if (!(startingLocations_.size() == 1)) {
          LOG(ERROR)
              << "invalid scouting goal (ExploreEnemyBase/FindEnemyExpand) "
              << " because no enemy location";
          break;
        }
        unit = findUnit(
            state, upcPair.second->unit, startingLocations_.begin()->first);
        break;
      case ScoutingGoal::Automatic:
        LOG(ERROR) << "invalid goal";
    }
    if (!unit) {
      VLOG(3) << "could not find scout for upc " << upcPair.first
              << " -- skipping for now"
              << "number of units of required type: "
              << state->unitsInfo()
                     .myCompletedUnitsOfType(buildtypes::Zerg_Drone)
                     .size();
      continue;
    }
    board->consumeUPC(upcPair.first, this);
    auto tgt = nextScoutingLocation(state, unit, startingLocations_);
    if (postTask(state, upcPair.first, unit, tgt, goal(state))) {
      startingLocations_[tgt] = state->currentFrame();
    }
  }

  // clean up finished tasks: send the scouts back to base
  auto myLoc = state->board()->get<Position>(Blackboard::kMyLocationKey);
  for (auto task : state->board()->tasksOfModule(this)) {
    if (!task->finished()) {
      continue;
    }
    if (!task->proxiedUnits().empty()) {
      auto stask = std::static_pointer_cast<ScoutingTask>(task);
      auto unit = stask->getUnit();
      auto chkTask = board->taskWithUnit(unit);
      // scout not reallocated, send it back
      if (chkTask == task) {
        VLOG(3) << "sending scout " << utils::unitString(unit)
                << " back to base";
        postMoveUPC(state, stask->upcId(), unit, myLoc);
      }
    }
    // manual removal because the status might have changed during the step
    board->markTaskForRemoval(task->upcId());
  }
}

Unit* ScoutingModule::findUnit(
    State* state,
    std::unordered_map<Unit*, float> const& candidates,
    Position const& pos) {
  auto board = state->board();

  // Find some units to scout with, preferring faster units and flying units,
  // and ignoring workers if possible to let them keep on working
  auto map = state->map();
  auto mapSize = state->mapWidth() * state->mapHeight();
  auto unitScore = [&](Unit* u) -> double {
    auto it = candidates.find(u);
    if (it == candidates.end() || it->second <= 0) {
      return std::numeric_limits<double>::infinity();
    }
    auto tdata = board->taskDataWithUnit(u);
    if (tdata.owner == this) {
      // scout is free and previously assigned to us
      if (tdata.task->finished()) {
        int pLength = 0;
        if (pos.x > 0 && pos.y > 0) {
          map->GetPath(
              BWAPI::Position(BWAPI::WalkPosition(u->x, u->y)),
              BWAPI::Position(BWAPI::WalkPosition(pos.x, pos.y)),
              &pLength);
        }
        return -2 * mapSize + pLength;
      }
      // We're using this unit already
      return std::numeric_limits<double>::infinity();
    }
    if (!u->active()) {
      return -200;
    }
    if (tdata.task && tdata.task->status() == TaskStatus::Success) {
      // The unit just finished a task, it should be free now
      return -100;
    }

    // wait for an available worker if all are currently busy bringing resources
    if (!u->idle() && !u->unit.orders.empty()) {
      if (u->unit.orders.front().type == tc::BW::Order::MoveToMinerals) {
        return 15.0;
      } else if (u->unit.orders.front().type == tc::BW::Order::MoveToGas) {
        return 50.0;
      }
    }
    return 100;
  };

  auto uinfo = state->unitsInfo();
  return utils::getBestScoreCopy(
      uinfo.myUnits(), unitScore, std::numeric_limits<double>::infinity());
}

bool ScoutingModule::postTask(
    State* state,
    UpcId baseUpcId,
    Unit* unit,
    Position loc,
    ScoutingGoal goal) {
  if (!postMoveUPC(state, baseUpcId, unit, loc)) {
    VLOG(1) << "task for unit " << utils::unitString(unit) << " not created";
    return false;
  }
  auto newTask = std::make_shared<ScoutingTask>(baseUpcId, unit, loc, goal);
  state->board()->postTask(newTask, this, false); // no auto-removal
  VLOG(1) << "new scouting task " << baseUpcId << " with unit "
          << utils::unitString(unit) << "for location " << loc.x << ", "
          << loc.y;
  return true;
}

bool ScoutingModule::postMoveUPC(
    State* state,
    UpcId baseUpcId,
    Unit* unit,
    const Position& loc) {

  auto tgt = movefilters::safeMoveTo(state, unit, loc);
  if (tgt.x <= 0 || tgt.y <= 0) {
    LOG(WARNING) << "scout stuck";
  }
  if (tgt.dist(unit->getMovingTarget()) <= 4) {
    return true;
  }
  auto upc = std::make_shared<UPCTuple>();
  upc->unit[unit] = 1;
  upc->command[Command::Move] = 1;
  upc->positionS = tgt;
  auto upcId = state->board()->postUPC(std::move(upc), baseUpcId, this);
  if (upcId < 0) {
    VLOG(1) << "MoveUPC for unit " << utils::unitString(unit)
            << " filtered by blackboard";
    return false;
  }
  return true;
}

Position ScoutingModule::nextScoutingLocation(
    State* state,
    Unit* unit,
    std::unordered_map<Position, int> const& locations) {
  // next location is latest visited then closest
  auto curPos = Position(unit);
  auto minDist = std::numeric_limits<double>::infinity();
  auto lastFrame = std::numeric_limits<int>::max();
  auto bestPos = Position(-1, -1);

  auto map = state->map();
  for (auto tgtPosPair : locations) {
    auto pos = tgtPosPair.first;
    auto frame = tgtPosPair.second;
    int pLength;
    map->GetPath(
        BWAPI::Position(BWAPI::WalkPosition(unit->x, unit->y)),
        BWAPI::Position(BWAPI::WalkPosition(pos.x, pos.y)),
        &pLength);
    auto d = curPos.dist(tgtPosPair.first);
    if (frame < lastFrame || (frame == lastFrame && d < minDist)) {
      minDist = d;
      bestPos = pos;
      lastFrame = frame;
    }
  }
  return bestPos;
}

void ScoutingModule::updateLocations(
    State* state,
    std::unordered_map<Position, int>& locations,
    std::vector<Position> const& candidates) {
  if (locations.empty()) { // intialization
    for (auto pos : candidates) {
      locations.emplace(pos, -1);
    }
  }
  if (locations.size() < 2) {
    return;
  }
  for (auto task : state->board()->tasksOfModule(this)) {
    auto stask = std::static_pointer_cast<ScoutingTask>(task);
    auto it = locations.find(stask->location());
    if (it == locations.end()) {
      LOG(ERROR) << "Scouting task for a non-starting location";
    } else {
      it->second = state->currentFrame();
    }
  }
  // clean up startingLocations
  for (auto it = locations.begin(); it != locations.end();) {
    if (std::find(candidates.begin(), candidates.end(), it->first) ==
        candidates.end()) {
      it = locations.erase(it);
    } else {
      ++it;
    }
  }
}

/*
 * main scouting logic should be here
 * todo: improve to have external "what to scout with whom here
 * merge with updateLocations
 */
UpcId ScoutingModule::createUPC(
    State* state,
    BuildType const* type,
    std::vector<Position> const& locations) {
  auto upc = std::make_shared<UPCTuple>();
  for (auto unit : state->unitsInfo().myUnits()) {
    if (unit->type == type) {
      upc->unit[unit] = .5;
    }
  }
  if (upc->unit.empty()) {
    return kInvalidUpcId;
  }
  // dummy UPC, not to be confused with other UPCs sent by the module
  upc->command[Command::Scout] = 1;
  auto upcId = state->board()->postUPC(upc, -1, this);
  if (upcId < 0) {
    VLOG(2) << "main scouting UPC not sent to the blackboard";
    return kInvalidUpcId;
  }
  return upcId;
}

// main logic for sending scouts, and what type of scouts
void ScoutingModule::CreateMainUPCs(State* state) {
  int kMinScoutFrame = state->board()->hasKey("kMinScoutFrame")
      ? state->board()->get<int>("kMinScoutFrame")
      : 1560;
  if (kMinScoutFrame <= 0) {
    kMinScoutFrame = std::numeric_limits<int>::max();
  }
  // posts at most one UPC per unit type at each frame
  // send worker to see enemy base
  // this is the case where we know the enemy base, we want to send a scout to
  // it even if it was found by elimination
  auto locations = state->areaInfo().candidateEnemyStartLoc();
  while (
      state->areaInfo().foundEnemyBase() &&
      state->currentFrame() >= kMinScoutFrame &&
      nbExplorers_ < maxNbExplorers_ &&
      state->unitsInfo().myCompletedUnitsOfType(buildtypes::Zerg_Drone).size() >
          nbExplorers_) {
    auto upcId = createUPC(state, buildtypes::Zerg_Drone, locations);
    if (upcId > 0) {
      nbExplorers_++;
      VLOG(3) << "creating the " << nbExplorers_
              << "th scouting UPC for explorer workers u" << upcId;
    }
  }
  if (state->currentFrame() >= 0 && !state->areaInfo().foundEnemyBase()) {
    while (nbOverlords_ < maxNbOverLords_ &&
           state->unitsInfo()
                   .myCompletedUnitsOfType(buildtypes::Zerg_Overlord)
                   .size() > nbOverlords_) {
      auto upcId = createUPC(state, buildtypes::Zerg_Overlord, locations);
      if (upcId > 0) {
        nbOverlords_++;
        VLOG(3) << "creating the " << nbOverlords_
                << "th scouting UPC for overlords u" << upcId;
      }
    }
  }
  if (state->currentFrame() >= kMinScoutFrame &&
      !state->areaInfo().foundEnemyBase()) {
    while (nbWorkers_ < maxNbWorkers_ &&
           state->unitsInfo()
                   .myCompletedUnitsOfType(buildtypes::Zerg_Drone)
                   .size() > nbWorkers_) {
      auto upcId = createUPC(state, buildtypes::Zerg_Drone, locations);
      if (upcId > 0) {
        nbWorkers_++;
        VLOG(3) << "creating the " << nbWorkers_
                << "th scouting UPC for workers u" << upcId;
      }
    }
  }
}

} // namespace fairrsh
