/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "modules/workerdefence.h"

#include "areainfo.h"
#include "buildtype.h"
#include "commandtrackers.h"
#include "movefilters.h"
#include "state.h"
#include "task.h"
#include "unitsinfo.h"
#include "utils.h"

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<WorkerDefenceModule>("WorkerDefenceModule")(
      metadata("type", rttr::type::get<WorkerDefenceModule>()))
      .constructor();
}

// float const vInfty = std::numeric_limits<float>::infinity();
Position const invalidPosition = Position(-1, -1);

// keep track of workers
class WorkerDefenceTask : public Task {
 public:
  WorkerDefenceTask(UpcId upcId, Unit* unit)
      : Task(upcId, {unit}), location_(Position(unit)) {
    setStatus(TaskStatus::Unknown);
  }

  void update(State* state) override {
    // defend base tasks do not really succeed
    removeDeadOrReassignedUnits(state);
    if (units().empty()) {
      VLOG(1) << "defending units dead or reassigned";
      setStatus(TaskStatus::Failure);
      return;
    }
  }

  Unit* getUnit() const {
    if (units().empty()) {
      LOG(DFATAL) << "trying to take the unit of an empty task";
    }
    return *units().begin();
  }

  void attack(State* state, Unit* unit) {
    initiateAction(state);
    auto cmd = tc::Client::Command(
        tc::BW::Command::CommandUnit,
        getUnit()->id,
        tc::BW::UnitCommandType::Attack_Unit,
        unit->id);
    state->board()->postCommand(cmd);
  }

  void move(State* state, Position tgtPos) {
    initiateAction(state);
    auto me = getUnit();
    auto cmd = tc::Client::Command(
        tc::BW::Command::CommandUnit,
        me->id,
        tc::BW::UnitCommandType::Move,
        -1,
        tgtPos.x,
        tgtPos.y);
    shouldMove_ = state->currentFrame();
    state->board()->postCommand(cmd);
  }

  bool shouldMove(State* state) {
    // to avoid sending too many commands
    return shouldMove_ + movefilters::kTimeUpdateMove < state->currentFrame();
  }

 protected:
  void initiateAction(State* state) {
    shouldMove_ = -1;
    fleeing_ = false; // unused
  }
  bool fleeing_ = false;
  int shouldMove_ = -1;
  Position location_; // unused
};

void WorkerDefenceModule::step(State* state) {
  auto board = state->board();

  checkState(state);

  createUPCs(state);

  Unit* attacker = nullptr;
  if (attackers_.size() > 0) {
    attacker = *attackers_.begin();
    VLOG(3) << "considering main attacker " << utils::unitString(attacker);
  }
  for (auto& pairIdUPC : board->upcsFrom(this)) {
    // as intended, consumeUPC creates tasks,
    consumeUPC(state, pairIdUPC.first, pairIdUPC.second);
  }

  // post commands regarding all tasks, including new ones
  for (auto task : board->tasksOfModule(this)) {
    if (task->finished()) {
      continue;
    }
    auto htask = std::dynamic_pointer_cast<WorkerDefenceTask>(task);
    auto unit = htask->getUnit();
    VLOG(3) << utils::unitString(unit);
    // check reassigned
    if (task != board->taskWithUnit(unit)) {
      // log culprit
      auto data = board->taskDataWithUnit(unit);
      if (data.task && data.owner) {
        VLOG(3) << "worker reassigned from defence to "
                << "module " << data.owner->name();
      }
      task->setStatus(TaskStatus::Failure);
      continue;
    }
    // if overlord, goto next blindspot
    if (unit->type == buildtypes::Zerg_Overlord) {
      VLOG(3) << "ordering to an overlord";
      if (!unit->threateningEnemies.empty()) {
        // don't do anything -> leave to tactics
        htask->setStatus(TaskStatus::Success);
        VLOG(3) << "leaving the control of the overlord";
        continue;
      }
      auto tgt = nextBlindSpot(state);
      VLOG(3) << "next blindSpot " << tgt;
      if (tgt.x > 0 &&
          (tgt != unit->getMovingTarget() || htask->shouldMove(state))) {
        VLOG(3) << utils::unitString(unit) << " moving to " << tgt;
        htask->move(state, tgt);
      }
      VLOG(3) << "overlord going on living his life";
      continue;
    }
    // if defender, attack if not already attacking
    if (defenders_.find(unit) != defenders_.end()) {
      if (!(utils::isExecutingCommand(
                unit, tc::BW::UnitCommandType::Attack_Unit) &&
            unit->attackingTarget == attacker)) {
        if (!attacker) {
          LOG(ERROR) << "trying to attack nullptr";
          continue;
        }
        htask->attack(state, attacker);
        VLOG(3) << "attacking " << utils::unitString(attacker)
                << " with worker " << utils::unitString(unit);
      }
      continue;
    }
    if (!unit->beingAttackedByEnemies.empty()) {
      // attack if reasonable
      auto shouldFlee = (unit->unit.groundCD != 0);
      for (auto nmy : unit->beingAttackedByEnemies) {
        if (!nmy->type->isWorker && unit->inRangeOf(nmy, 12)) {
          shouldFlee = true;
          break;
        }
      }
      if (shouldFlee) {
        flee(state, htask);
        continue;
      }
    }
    // check if someone should be attacked
    auto minDist = movefilters::vInfty;
    Unit* tgt = nullptr;
    for (auto nmy : unit->enemyUnitsInSightRange) {
      if (unit->unit.groundCD <= 1 + timeout_[nmy] &&
          nmy->inRangeOf(unit, 1 + timeout_[nmy])) {
        auto d = utils::distance(unit, nmy);
        if (d < minDist) {
          minDist = d;
          tgt = nmy;
        }
      }
    }
    // unit is not threatened -> release
    if (!tgt) {
      task->setStatus(TaskStatus::Success);
      VLOG(3) << "worker no more threatened";
      continue;
    }
    // only reallocate enemies at high CD values
    if (utils::isExecutingCommand(unit, tc::BW::UnitCommandType::Attack_Unit)) {
      if (unit->attackingTarget && !unit->attackingTarget->dead &&
          unit->unit.groundCD > unit->unit.maxCD - 5) {
        htask->attack(state, tgt);
        VLOG(3) << "worker " << utils::unitString(unit) << " set to attack "
                << utils::unitString(tgt);
        continue;
      }
    } else {
      htask->attack(state, tgt);
      VLOG(3) << "worker " << utils::unitString(unit) << " set to attack "
              << utils::unitString(tgt);
      continue;
    }
  }

  // delete finished tasks
  for (auto task : board->tasksOfModule(this)) {
    if (task->finished()) {
      board->markTaskForRemoval(task);
    }
  }
}

void WorkerDefenceModule::flee(
    State* state,
    std::shared_ptr<WorkerDefenceTask> task) {
  // find friends
  auto unit = task->getUnit();
  auto myFriends = std::vector<Position>();
  auto strongerFriends = false;
  for (auto frnd : unit->allyUnitsInSightRange) {
    if (frnd->beingAttackedByEnemies.empty()) {
      if (frnd->type->isWorker && strongerFriends) {
        continue;
      } else {
        strongerFriends = true;
        myFriends.clear();
      }
      myFriends.push_back(Position(frnd));
    }
  }
  auto tgt = Position(-1, -1);
  if (!myFriends.empty()) {
    tgt = movefilters::safeDirectionTo<std::vector<Position>>(
        state, unit, myFriends);
  } else {
    auto& ainfo = state->areaInfo();
    tgt = ainfo.myExpandPosition(ainfo.myClosestExpand(Position(unit)));
  }
  if (tgt.x > 0 && tgt.y > 0) {
    if (task->shouldMove(state)) {
      task->move(state, tgt);
      VLOG(3) << "worker " << utils::unitString(unit) << " fleeing to " << tgt
              << " enemy attackers are "
              << utils::unitsString(unit->beingAttackedByEnemies)
              << " enemy threats are "
              << utils::unitsString(unit->beingAttackedByEnemies);
    }
    VLOG(3) << "stop at should fee - no update";
  } else {
    if (unit->beingAttackedByEnemies.empty()) {
      LOG(WARNING) << "fleeing nothing ? set status to Sucess";
      task->setStatus(TaskStatus::Success);
    } else if (!(unit->attackingTarget &&
                 state->currentFrame() - unit->lastAttacked <
                     unit->unit.groundCD)) {
      // the condition means that we commit a bit to a unit that is attacked,
      // should we attack we should have a proper util for basic target
      // prorization
      Unit* nmy = nullptr;
      auto pos = Position(unit);
      auto minDist = movefilters::vInfty;
      for (auto attacker : unit->beingAttackedByEnemies) {
        auto d = pos.dist(attacker);
        if (d < minDist) {
          nmy = attacker;
          minDist = d;
        }
      }
      VLOG(3) << "nowhere to run to for worker " << utils::unitString(unit)
              << ", attacking " << utils::unitString(nmy);
      task->attack(state, nmy);
    }
  }
}

// target/unit reassignment between tasks should probably be done here
void WorkerDefenceModule::consumeUPC(
    State* state,
    UpcId upcId,
    std::shared_ptr<UPCTuple> upc) {
  auto board = state->board();
  board->consumeUPCs({upcId}, this);
  if (upc->unit.size() != 1) {
    LOG(ERROR) << "invalid defence UPC";
    if (upc->unit.empty()) {
      LOG(WARNING) << "empty upc u" << upcId << " consuming but skipping";
      return;
    }
  }
  auto unit = upc->unit.begin()->first;
  auto task = board->taskWithUnitOfModule(unit, this);
  // create a new task in case the previous is over
  if (task && !task->finished()) {
    VLOG(3) << "existing task for " << utils::unitString(unit);
    return;
  }
  task = std::make_shared<WorkerDefenceTask>(upcId, unit);
  VLOG(1) << "creating task u" << upcId << " with worker "
          << utils::unitString(unit);
  board->postTask(task, this, false);
}

void WorkerDefenceModule::checkState(State* state) {
  checkEnemyBuildings(state);
  findBlindSpots(state);
  VLOG(3) << blindSpots_.size() << " blindSpots";
  myWorkers_.clear();

  auto board = state->board();

  for (auto unit : state->unitsInfo().myUnits()) {
    auto isDefender = defenders_.find(unit) != defenders_.end();
    if (unit->type->isWorker && !isDefender) {
      if (VLOG_IS_ON(2)) {
        utils::drawCircle(
            state, Position(unit), unit->sightRange * 8, tc::BW::Color::Red);
      }
    } else if (isDefender) {
      if (VLOG_IS_ON(2)) {
        utils::drawCircle(
            state, Position(unit), unit->sightRange * 8, tc::BW::Color::Blue);
      }
      // create a UPC for the defender even if already under the control of the
      // module; let the module decide if a new task needs be created.
      createUPC(state, unit);
      continue;
    }
    if (!unit->dead && unit->type == buildtypes::Zerg_Overlord) {
      auto taskData = board->taskDataWithUnit(unit);
      VLOG(3) << utils::unitString(unit) << "doing stg ?"
              << (taskData.task ? taskData.owner->name() : "nobody");
      if (!blindSpots_.empty() && !board->taskWithUnit(unit)) {
        myWorkers_.emplace(unit);
        VLOG(3) << "adding overlord " << utils::unitString(unit);
      }
    }
    if (!(unit->type->isBuilding || unit->type->isWorker ||
          unit->type == buildtypes::Zerg_Egg)) {
      continue;
    }
    if (!unit->dead && !unit->threateningEnemies.empty() &&
        !board->taskWithUnitOfModule(unit, this)) {
      // check for scout or harasser that have their own micro
      auto taskData = board->taskDataWithUnit(unit);
      if (taskData.task &&
          (taskData.owner->name().find("Scouting") != std::string::npos ||
           taskData.owner->name().find("Harass") != std::string::npos ||
           taskData.owner->name().find("Builder") != std::string::npos)) {
        continue;
      }
      if (unit->type->isWorker) {
        myWorkers_.emplace(unit);
      }
      // threatening or attacking ?
      // I guess what we want here is a check that the enemy is not fleeing
      for (auto nmy : unit->beingAttackedByEnemies) {
        auto curCD =
            (nmy->type->isFlyer ? nmy->unit.airCD : nmy->unit.groundCD);
        if (prevCD_.find(nmy) != prevCD_.end()) {
          if (prevCD_[nmy] < curCD) {
            // enemy attacked us for real, increase defence timeout
            // constant 3...
            timeout_[nmy] += 3;
          }
        } else {
          timeout_[nmy] = 0;
        }
        prevCD_[nmy] = curCD;
      }
    }
  }
}

void WorkerDefenceModule::checkEnemyBuildings(State* state) {
  // the workerdefence module does not implement full micro
  // here, simple anti-harassment / worker rush strategies are implemented
  // for now attackers_ will contain a single unit
  // we need a full allocation algorithm to be more general
  // deactivate after first expand
  if (state->areaInfo().numExpands() > 1) {
    blindSpots_.clear();
    return;
  }
  attackers_.clear();
  defenders_.clear();
  // we use the same comparison for attack and defence for simplicity
  // the current one tends to keep the same workers for defense
  // since it favors workers with low health, but this limits reassignments
  auto compareUnits = [](auto p1, auto p2) -> bool {
    if (!p1.second->completed()) {
      if (!p2.second->completed()) {
        return p1.second->unit.health < p2.second->unit.health;
      }
      return true;
    } else if (!p2.second->completed()) {
      return false;
    }
    if (p1.first < p2.first)
      return true;
    if (p1.first > p2.first)
      return false;
    if (p1.second->unit.health > p2.second->unit.health)
      return true;
    if (p1.second->unit.health < p2.second->unit.health)
      return true;
    if (p1.second->id < p1.second->id)
      return true;
    return false;
  };
  auto myPos = state->areaInfo().myBasePosition();

  auto attackers = std::vector<std::pair<double, Unit*>>();
  for (auto unit : state->unitsInfo().enemyUnits()) {
    if (VLOG_IS_ON(2)) {
      utils::drawCircle(
          state,
          Position(unit),
          unit->unit.groundRange * 8,
          tc::BW::Color::Green);
    }
    if (!unit->type->isBuilding || unit->type->isRefinery) {
      continue;
    }
    // hardcoded constant of 100 to consider an enemy building as
    // a rush against our base
    if (myPos.dist(unit) < 100) {
      VLOG(3) << "attacker found " << utils::unitString(unit);
      if (VLOG_IS_ON(2)) {
        utils::drawCircle(
            state,
            Position(unit),
            unit->unit.groundRange * 8,
            tc::BW::Color::Red);
      }
      LOG_FIRST_N(WARNING, 1)
          << "probable cannon rush "
          << " we have " << state->areaInfo().numExpands() << " expansions";
      attackers.emplace_back(myPos.dist(unit), unit);
    }
  }
  std::sort(attackers.begin(), attackers.end(), compareUnits);
  if (attackers.size() > 0) {
    attackers_.emplace((*attackers.begin()).second);
  }
  // first try
  // logic: among enemy buildings, assign take n=4 workers to attack
  auto defenders = std::vector<std::pair<double, Unit*>>();
  if (!attackers_.empty()) {
    auto attacker = *attackers_.begin();
    auto pAttacker = Position(attacker);
    for (auto unit : state->unitsInfo().myWorkers()) {
      if (pAttacker.dist(unit) <= 2 * unit->sightRange) {
        // keep the worders that are already attacking
        defenders.emplace_back(
            pAttacker.dist(unit) -
                (unit->attackingTarget == attacker ? 2 * unit->sightRange : 0),
            unit);
      }
    }
    std::sort(defenders.begin(), defenders.end(), compareUnits);
  }

  for (size_t i = 0; i < std::min(size_t(4), defenders.size()); i++) {
    defenders_.emplace(defenders[i].second);
  }
}

void WorkerDefenceModule::findBlindSpots(State* state) {
  auto& ainfo = state->areaInfo();
  auto myBase = ainfo.myBase();
  if (!myBase || myBase->dead) {
    VLOG(1) << "we don't know our base?";
    blindSpots_.clear();
    return;
  }
  if (!blindSpots_.empty()) {
    return;
  }
  blindSpots_.clear();
  auto& tinfo = state->tilesInfo();
  auto myArea = ainfo.myBaseArea();
  if (!myArea) {
    LOG(ERROR) << "don't know my area";
    return;
  }
  auto myPos = Position(myBase);
  VLOG(3) << "my base position " << myPos;
  for (int i = -5; i <= 5; i++) {
    for (int j = -5; j <= 5; j++) {
      auto nextx = myPos.x + 12 * i;
      auto nexty = myPos.y + 12 * j;
      auto nextPos = utils::clampPositionToMap(state, nextx, nexty, false);
      auto tile = tinfo.tryGetTile(nextPos.x, nextPos.y);
      if (tile && !tile->entirelyWalkable) {
        VLOG(3) << "tile at position " << nextPos << " non-walkable "
                << " mypos = " << myPos;
      }
      if (nextPos.x > 0 && tile && tile->entirelyWalkable && myArea &&
          ainfo.tryGetArea(nextPos) == myArea) {
        blindSpots_.emplace(nextPos, -1);
      }
    }
  }
}

Position WorkerDefenceModule::nextBlindSpot(State* state) {
  auto bestPos = Position(-1, -1);
  auto lastFrame = std::numeric_limits<int>::max();
  auto& tinfo = state->tilesInfo();
  for (auto posFramePair : blindSpots_) {
    auto pos = posFramePair.first;
    auto lastSeen = posFramePair.second;
    if (lastSeen < lastFrame) {
      bestPos = pos;
      lastFrame = lastSeen;
    }
  }
  if (bestPos.x > 0) {
    blindSpots_[bestPos] = lastFrame;
  }
  return bestPos;
}

/* ***************
 * functions that should belong to higher-level modules
 *****************/

/*
 * location is the our base location
 */
void WorkerDefenceModule::createUPC(State* state, Unit* unit) {
  auto baseUpc = std::make_shared<UPCTuple>();
  baseUpc->unit[unit] = 1.0f;
  auto upcId = state->board()->postUPC(std::move(baseUpc), kRootUpcId, this);
  if (upcId < 0) {
    LOG(WARNING) << "base upc could not be posted";
  }
}

void WorkerDefenceModule::createUPCs(State* state) {
  auto board = state->board();

  // the default policy is to defend the base
  // note that myWorkers_ are not currently in a task by this module
  for (auto unit : myWorkers_) {
    if (unit->type == buildtypes::Zerg_Overlord) {
      createUPC(state, unit);
      continue;
    }
    if (!unit->dead && !unit->threateningEnemies.empty() &&
        !board->taskWithUnitOfModule(unit, this)) {
      // go through all possible enemies and check their timeout
      for (auto nmy : unit->enemyUnitsInSightRange) {
        // if worker -> defend while mining as much as possible
        // else: flee
        if ((nmy->type->isWorker && nmy->inRangeOf(unit, 1 + timeout_[nmy]) &&
             unit->unit.groundCD < 1 + timeout_[nmy]) ||
            unit->inRangeOf(nmy, 12)) {
          createUPC(state, unit);
          break;
        }
      }
    }
  }
}

} // namespace fairrsh
