/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "modules/defendbase.h"
#include "state.h"
#include "utils.h"

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<DefendBaseModule>("DefendBaseModule")(
      metadata("type", rttr::type::get<DefendBaseModule>()))
      .constructor();
}

namespace {

float const vInfty = std::numeric_limits<float>::infinity();
Position const invalidPosition = Position(-1, -1);

class DefendBaseTask : public Task {
public:
  DefendBaseTask(UpcId upcId, std::unordered_set<Unit*> us,
		 Unit* target, Position myLocation)
    : Task(upcId, us), target_(target), location_(myLocation) {
    setStatus(TaskStatus::Ongoing);
  }

  void update(State* state) override {
    removeDeadOrReassignedUnits(state);
    if (units().empty()) {
      VLOG(1) << "DefendBaseModule: defending units dead: failure";
      setStatus(TaskStatus::Failure);
      return;
    }
    // add tracker for the target
    if (target()->dead) {
      VLOG(1) << "DefendBaseModule: target unit dead: success";
      setStatus(TaskStatus::Success);
      return;
    }
    if (target()->gone) {
      VLOG(1) << "DefendBaseModule: target unit gone: success";
      setStatus(TaskStatus::Success);
      return;
    }
  }

  Unit* target() const {
    return target_;
  }

  Position const& location() const {
    return location_;
  }

  bool attacking(State* state, Unit const* unit) const {
    auto tcu = unit->unit;
    auto acceptableOrders = tc::BW::commandToOrders(
			            tc::BW::UnitCommandType::Attack_Unit);
    for (auto order : tcu.orders) {
      auto ot = tc::BW::Order::_from_integral_nothrow(order.type);
      if (!ot) {
        continue;
      }
      if (std::find(acceptableOrders.begin(), acceptableOrders.end(), *ot)
	  != acceptableOrders.end()) {
	auto tgt = state->unitsInfo().getUnit(order.targetId);
	if (tgt == target()) {
	  VLOG(3) << "DefendBaseModule: " << utils::unitString(unit)
		  << " vs " << utils::unitString(tgt)
		  << "cooldown " << tcu.groundCD << " hit points " << tcu.health
	    // << " position " << Position(unit->x, unit->y)
	    //   << " tgt pos " << Position(tgt->x, tgt->y)
		  << " tgt hp " << tgt->unit.health
		  << " tgt shield " << tgt->unit.shield;
	  return true;
	}
      }
    }
    return false;
  }

  void attack(State* state, Unit const* unit) const {
    auto cmd = tc::Client::Command(
		tc::BW::Command::CommandUnit,
		unit->id,
		tc::BW::UnitCommandType::Attack_Unit,
		target()->id);
    state->board()->postCommand(cmd);
  }

protected:
  Unit* target_;
  Position location_;
};

} // namespace

void DefendBaseModule::step(State* state) {
  auto board = state->board();

  // the module creates his own UPC for now, see the associated dostring
  createUPCs(state);

  // there is no harassment command in UPCs, so we know the harassment UPCs
  // by going through the UPCs sent by the module
  // note: this assumes the module itself doesn't send UPCs according to
  // the intended UPC communication protocol
  for (auto& pairIdUPC : board->upcsFrom(this)) {
    // as intended, createUPC creates tasks, or updates the current task
    consumeUPC(state, pairIdUPC.first, pairIdUPC.second);
  }

  /*
   * how to properly deal with unit assignment
   * what is the role of tasks here ?
   */
  // post upcs regarding all tasks, including new ones
  for(auto task : board->tasksOfModule(this)) {
    if (task->finished()) {
      VLOG(1) << "DefendBaseModule: task " << task->upcId() << " finished"
	      << " with status: " << (int) task->status();
      continue;
    }
    auto htask = std::dynamic_pointer_cast<DefendBaseTask>(task);
    if (tooFar(state, htask->target())) {
      VLOG(1) << "DefendBaseModule: target " << utils::unitString(htask->target())
	      << " too far -> cancelling task " << htask->upcId();
      htask->cancel(state);
      continue;
    }
    else {
      VLOG(3) << "DefendBaseModule: distance of target: "
	      << minDist(state, htask->target());
    }
    for (auto& unit : htask->proxiedUnits()) {
      if (!htask->attacking(state, unit)) {
	VLOG(2) << "DefendBaseModule: sending attack command to "
		<< utils::unitString(unit) << " in task " << htask->upcId()
		<< " against " << utils::unitString(htask->target());
	htask->attack(state, unit);
      }
    }
  }
}

void DefendBaseModule::setDefenseParams(float minDistToBuildings,
					float maxNbWorkers,
					bool onAttackBuildings,
					bool onAttackWorkers) {
  minDistToBuildings_ = minDistToBuildings; // threshold on distance to respond
  maxNbWorkers_ = maxNbWorkers; // max nb workers for the defense task
  onAttackBuildings_ = onAttackBuildings; // only respond if attacker attacks buildings
  onAttackWorkers_ = onAttackWorkers; // start defending only if a worker is attacked
}

bool DefendBaseModule::tooFar(State* state, Unit* unit) {
  auto board = state->board();
  if (unit->type->isRefinery) {
    return false;
  }
  auto range = minDistToBuildings_;
  if (unit->type->isBuilding) {
    range = 50; // what is the correct policy here?
  }
  Position myBase = board->get<Position>(Blackboard::kMyLocationKey);
  // TODO: memoize that, update only if new buildings
  // TODO: take geysers into account to prevent gas steal
  // TODO: how do we know a building is attached to a particular base
  for (auto bldg : state->unitsInfo().myBuildings()) {
    if (myBase.dist(bldg) > 100) { // is 100 a viable constant ?
      continue;
    }
    // what is the best policy here ?
    if (utils::distance(unit, bldg) < range) {
      return false;
    }
  }
  return true;
}

float DefendBaseModule::minDist(State* state, Unit* unit) {
  Position myBase = state->board()->get<Position>(Blackboard::kMyLocationKey);
  auto dist = vInfty;
  for (auto bldg : state->unitsInfo().myBuildings()) {
    if (myBase.dist(bldg) > 100) { // is 100 a viable constant ?
      continue;
    }
    auto curDist = utils::distance(unit, bldg);
    if (curDist < dist) {
      dist = curDist;
    }
  }
  return dist;
}

bool DefendBaseModule::buildingAttacked(State* state) {
  Position myBase = state->board()->get<Position>(Blackboard::kMyLocationKey);
  for (auto bldg : state->unitsInfo().myBuildings()) {
    if (myBase.dist(bldg) > 100) { // is 100 a viable constant ?
      continue;
    }
    if (!bldg->beingAttackedByEnemies.empty()) {
      return true;
    }
  }
  return false;
}

bool DefendBaseModule::workerAttacked(State* state) {
  Position myBase = state->board()->get<Position>(Blackboard::kMyLocationKey);
  for (auto bldg : state->unitsInfo().myWorkers()) {
    if (myBase.dist(bldg) > 100) { // is 100 a viable constant ?
      continue;
    }
    if (!bldg->beingAttackedByEnemies.empty()) {
      return true;
    }
  }
  return false;
}


// takes closest workers for now
Unit* DefendBaseModule::findWorker(
     State* state, Unit* target, std::unordered_set<Unit*>& already_assigned) {
  VLOG(3) << "DefendBaseModule: finding worker to fight "
	  << utils::unitString(target);
  auto board = state->board();
  auto dist = vInfty;
  Unit* best_worker = nullptr;
  for (auto worker : state->unitsInfo().myWorkers()) {
    auto task = board->taskWithUnitOfModule(worker, this);
    if (task) {
      if (task->finished()) {
      VLOG(4) << "DefendBaseModule [findWorker]: " << utils::unitString(worker)
	      << " assigned to finished defense task"
	      << task->upcId();
      }
      if (!task->finished()) {
	continue;
      }
    }
    if (already_assigned.find(worker) != already_assigned.end()) {
      continue;
    }
    auto d = utils::distance(worker, target);
    if (d < dist) {
      best_worker = worker;
      dist = d;
    }
  }
  if (!best_worker) {
    VLOG(4) << "DefendBaseModule [findWorker]: worker not found";
  }
  else {
    VLOG(4) << "DefendBaseModule [findWorker]: found " << utils::unitString(best_worker);
  }
  return best_worker;
}

// target/unit reassignment between tasks should probably be done here
void DefendBaseModule::consumeUPC(
     State* state, UpcId upcId, std::shared_ptr<UPCTuple> upc) {
  auto board = state->board();
  board->consumeUPCs({upcId}, this);
  std::shared_ptr<DefendBaseTask> task_found = nullptr;
  auto loc = upc->positionS;
  auto targets = upc->positionU;
  for(auto task : board->tasksOfModule(this)) {
    auto htask = std::dynamic_pointer_cast<DefendBaseTask> (task);
    if (htask->location() == loc
	&& targets.find(htask->target()) != targets.end()) {
      // assumes target is properly dealt with with current task
      targets.erase(htask->target());
    }
  }
  VLOG(3) << "DefendBaseModule: number of remaining targets for UPC: "
	  << targets.size();
  // remaining targets
  auto defenders = std::unordered_set<Unit*> ();
  for (auto tgt : targets) {
    if (tooFar(state, tgt.first)) {
      VLOG(3) << "DefendBaseModule: target " << tgt.first->id
	      << " found but too far to defend against";
      continue;
    }
    if (tgt.second <= 0) {
      LOG(WARNING) << "DefendBaseModule: non > 0 probability on upc.positionU";
    }
    for (int i=1; i<= maxNbWorkers_; i++) {
      auto worker = findWorker(state, tgt.first, defenders);
      if (!worker) {
	LOG(WARNING) << "DefendBaseModule: not enough workers to defend "
		     << " required " << maxNbWorkers_ << " obtained " << defenders.size();
	break;
      }
      defenders.emplace(worker);
    }
    if (!defenders.empty()) {
      auto task = std::make_shared<DefendBaseTask>(upcId, defenders, tgt.first, loc);
      VLOG(3) << "DefendBaseModule: creating new task " << task->upcId()
	      << "to defend against "
	      << " enemy " << utils::unitString(tgt.first)
	      << " with " << defenders.size() << " workers "
	      << " nb total workers " << state->unitsInfo().myWorkers().size();
      board->postTask(task, this, true);
    }
  }
}

/* ***************
 * functions that should belong to higher-level modules
 *****************/

/*
 * location is the our base location
 */
void DefendBaseModule::createUPC(State* state, const Position& loc) {
  if (onAttackBuildings_ && !buildingAttacked(state)) {
    return;
  }
  if (onAttackWorkers_ && !workerAttacked(state)) {
    return;
  }
  auto baseUpc = std::make_shared<UPCTuple>();
  baseUpc->positionS = loc;
  // are there enmy workers trying to harass us or steal our gas ?
  for (auto unit : state->unitsInfo().visibleEnemyUnits()) {
    if ((unit->type->isWorker || unit->type->isBuilding)
	&& loc.dist(unit) < 200) {
      baseUpc->positionU[unit] = 1;
    }
  }
  if (baseUpc->positionU.empty()) {
    return;
  }

  // shortcut
  if (baseUpc->positionU.size() == 1) {
    for (auto task : state->board()->tasksOfModule(this)) {
      auto htask = std::dynamic_pointer_cast<DefendBaseTask>(task);
      if (htask->target() == baseUpc->positionU.begin()->first) {
	return;
      }
    }
  }

  // should be replaced with friendly units in area info
  for (auto unit : state->unitsInfo().myWorkers()) {
    if (unit->dead) {
      continue;
    }
    if (loc.dist(unit) > 100) {
      VLOG(3) << "DefendBaseModule: discarding "
	      << utils::unitString(unit)
	      << " because too far; distance to loc is "
	      << loc.dist(unit);
      continue;
    }
    baseUpc->unit[unit] = 0.5;
  }
  if (baseUpc->unit.empty()) {
    VLOG(4) << "DefendBaseModule: enemy found but no claimable worker";
    return;
  }
  auto upcId = state->board()->postUPC(std::move(baseUpc), kRootUpcId, this);
  if (upcId < 0) {
    LOG(ERROR) << "DefendBaseModule: base upc could not be posted";
  }
}

void DefendBaseModule::createUPCs(State* state) {
  auto board = state->board();
  // helper to go faster when cheking is not necessary
  if (!board->hasKey(Blackboard::kMyLocationKey)) {
    return;
  }

  // should go through all enemy locations, for now only the one in the BB
  // helper to go faster when cheking is not necessary
  auto nmyLoc = board->get<Position>(Blackboard::kMyLocationKey);
  if (nmyLoc.x <= 0 || nmyLoc.y <= 0) {
    LOG(ERROR) << "HarassModule: invalid enemy location in the BlackBoard";
  }
  createUPC(state, nmyLoc);
}

} // namespace fairrsh
