/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "squadcombat.h"

#include "commandtrackers.h"
#include "movefilters.h"
#include "player.h"
#include "state.h"
#include "upctocommand.h"
#include "utils.h"

#include "bwem/bwem.h"

#include <glog/logging.h>

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<SquadCombatModule>("SquadCombatModule")(
      metadata("type", rttr::type::get<SquadCombatModule>()))
      .constructor();
}

namespace {
/**
 * Behavior:
 *  - Ratio of Move : Delete determines how safe you want the attack to be
 *  - Delete = 1 means to engage no matter what, i.e. commit, sacrifice yourself
 * for the greater cause.
 *  - Delete < 0.75 means don't engage unless you think you can win
 *  - Targetting units:
 *    - Delete < 1 will decide to attack other units as necessary
 *  - Targetting location:
 *    - Delete < 1 will decide to attack other units along the way as necessary
 *    - Delete == 1 will force units to move to location, avoiding all damage
 *      and confront only units near the location
 */
class SquadTask : public Task {
 public:
  // Targets I should hit
  std::vector<Unit*> targets;
  // Location I should defend / attack
  int targetX, targetY;
  // Whether I'm using the location or units
  bool targettingLocation;
  // Some helpful stuff
  bool hasAirUnits = false;
  bool hasGroundUnits = false;
  double delProb;
  std::shared_ptr<UPCTuple> sourceUpc;

  SquadTask(
      int upcId,
      std::shared_ptr<UPCTuple> upc,
      std::unordered_set<Unit*> units,
      std::vector<Unit*> targets,
      std::unordered_map<Unit const*, SquadCombatModule::EnemyState>*
          enemyStates,
      std::unordered_map<Unit const*, SquadCombatModule::UnitState>* unitStates)
      : Task(upcId, units),
        targets(std::move(targets)),
        targettingLocation(false),
        delProb(upc->command[Command::Delete]),
        sourceUpc(upc),
        enemyStates_(enemyStates),
        unitStates_(unitStates) {}

  SquadTask(
      int upcId,
      std::shared_ptr<UPCTuple> upc,
      std::unordered_set<Unit*> units,
      int x,
      int y,
      std::unordered_map<Unit const*, SquadCombatModule::EnemyState>*
          enemyStates,
      std::unordered_map<Unit const*, SquadCombatModule::UnitState>* unitStates)
      : Task(upcId, units),
        targetX(x),
        targetY(y),
        targettingLocation(true),
        delProb(upc->command[Command::Delete]),
        sourceUpc(upc),
        enemyStates_(enemyStates),
        unitStates_(unitStates) {}

  void initUnits() {}

  // Gets the targets the group should attack
  std::vector<Unit*> getGroupTargets(State* state) const {
    if (targettingLocation) {
      if (delProb == 1) {
        // Only attack units near the location
        return utils::filterUnitsByDistance(
            state->unitsInfo().enemyUnits(), targetX, targetY, 75);
      } else if (delProb > 0) {
        // Attack units near the location and us
        auto retset = utils::findNearbyEnemyUnits(state, grouped_);
        for (const auto unit : utils::filterUnitsByDistance(
                 state->unitsInfo().enemyUnits(), targetX, targetY, 75)) {
          retset.insert(unit);
        }
        return std::vector<Unit*>(retset.begin(), retset.end());
      } else {
        throw std::runtime_error("delProb <= 0 or > 1, that's not right");
      }
    } else {
      if (delProb == 1) {
        // Only attack targets
        return targets;
      } else if (delProb > 0) {
        // Can attack nearby enemy units too
        auto ret = targets;
        for (const auto unit : utils::findNearbyEnemyUnits(state, grouped_)) {
          if (isThreat(unit)) {
            ret.push_back(std::move(unit));
          }
        }
        return ret;
      } else {
        throw std::runtime_error("delProb <= 0 or > 1, that's not right");
      }
    }
  }

  // Gets all the threats to the group
  std::vector<Unit*> getGroupThreats(State* state) const {
    std::vector<Unit*> ret;
    for (const auto unit : utils::findNearbyEnemyUnits(state, grouped_)) {
      if (isThreat(unit)) {
        ret.push_back(std::move(unit));
      }
    }
    return ret;
  }

  // Traveling units need to go to the group, while trying to avoid or kill
  // enemy units along the way. Prefers to avoid, not engage.
  std::vector<std::shared_ptr<UPCTuple>> makeTravelingUPCs(State* state) {
    std::vector<std::shared_ptr<UPCTuple>> upcs;
    for (auto unit : traveling_) {
      std::vector<Unit*> toTarget;
      for (auto u :
           utils::findNearbyEnemyUnits(state, std::vector<Unit*>({unit}))) {
        toTarget.push_back(u);
      }
      pickTarget(state, unit, toTarget);
    }
    for (auto unit : traveling_) {
      auto& us = unitStates_->at(unit);
      if (us.target == nullptr) {
        auto upc = this->smartMove(state, unit, center_);
        if (upc != nullptr) {
          upcs.push_back(std::move(upc));
        }
      } else {
        auto upc = this->microUnit(state, unit);
        if (upc != nullptr) {
          upcs.push_back(std::move(upc));
        }
      }
    }
    return upcs;
  }

  std::vector<std::shared_ptr<UPCTuple>> makeGroupedUPCs(State* state) {
    std::vector<std::shared_ptr<UPCTuple>> upcs;
    for (auto unit : grouped_) {
      pickTarget(state, unit, targets_);
    }
    VLOG(2) << "Group of " << utils::upcString(upcId()) << " has units "
            << utils::unitsString(grouped_) << " and targets "
            << utils::unitsString(targets_);
    for (auto unit : grouped_) {
      auto upc = this->microUnit(state, unit);
      if (upc != nullptr) {
        upcs.push_back(std::move(upc));
      }
    }
    return upcs;
  }

  enum KiteBehavior {
    Hover,
    Dodge,
    Chase,
  };

  Vec2 getKitePosition(Unit const* unit, int behavior = KiteBehavior::Dodge) {
    Vec2 threatP = utils::centerOfUnits(unit->threateningEnemies);
    Vec2 myP(unit);
    Vec2 safeP(squadPosition_);
    auto dirThreat = (threatP - myP).normalize();
    auto dirSafe = (safeP - myP).normalize();
    if (behavior == KiteBehavior::Hover) {
      auto angle = rand() % 20 + 80;
      dirThreat.rotate(angle);
      if (Vec2::dot(dirThreat, dirSafe) < 0) {
        dirThreat *= -1.0f;
      }
      return myP + (dirThreat * 10.0f);
    } else if (behavior == KiteBehavior::Dodge) {
      auto dotproduct = Vec2::dot(dirThreat, dirSafe);
      if (dotproduct < 0) {
        return myP + (dirSafe * 10.0f);
      } else {
        auto dir = ((dirThreat * -2.0f) + dirSafe).normalize();
        return myP + (dir * 10.0f);
      }
    } else if (behavior == KiteBehavior::Chase) {
      return myP + (dirThreat * 10.0f);
    } else {
      LOG(ERROR) << "Undefined kiting behavior";
      return myP;
    }
  }

  // Grouped units should attack targets_
  std::shared_ptr<UPCTuple> microUnit(State* state, Unit* unit) {
    auto& us = unitStates_->at(unit);

    auto specificUPC = microSpecificUnit(state, unit);
    if (specificUPC.first) {
      return specificUPC.second;
    }

    VLOG(2) << utils::unitString(unit) << " is attacking "
            << (us.attacking ? utils::unitString(us.attacking) : "nobody")
            << ", should be attacking "
            << (us.target ? utils::unitString(us.target) : "nobody");
    if (us.target == nullptr && targettingLocation) {
      return moveTo(state, unit, {targetX, targetY});
    } else if (us.target == nullptr) {
      if (!unit->threateningEnemies.empty()) {
        auto center = utils::centerOfUnits(unit->threateningEnemies);
        auto away = (Vec2(unit) - center).normalize();
        return smartMove(state, unit, Vec2(unit) + away * 15);
      } else {
        return smartMove(state, unit, center_);
      }
      return nullptr;
    }

    auto isRanged = std::max(unit->unit.groundRange, unit->unit.airRange) > 1;
    auto shouldChase = (isRanged && delProb > 0.75) ||
        (isRanged && unit->threateningEnemies.empty());
    auto shouldKite = unit->enemyUnitsInSightRange.empty()
        ? false
        : utils::countUnits(unit->enemyUnitsInSightRange, [&](Unit* o) {
            return !unit->canKite(o);
          }) == 0; // should kite if I can kite everyone around me
    auto shouldDance = unit->canKite(us.target) &&
        (unit->type == buildtypes::Zerg_Mutalisk ||
         unit->type == buildtypes::Terran_Vulture);

    if (shouldChase || shouldKite || shouldDance) {
      auto target = us.target;
      auto cd = target->flying() ? unit->unit.airCD : unit->unit.groundCD;
      auto inRange = target->inRangeOf(unit);
      // If I can attack the nearest unit, and my target isn't in range,
      // attack the nearest unit instead.
      if (!inRange && cd < 4) {
        for (auto u : unit->enemyUnitsInSightRange) {
          if (!u->inRangeOf(unit))
            break;
          if (isThreat(u)) {
            target = u;
            break;
          }
        }
      }
      // Recalcluate cd in case target changed
      cd = target->flying() ? unit->unit.airCD : unit->unit.groundCD;

      if (shouldChase && target->topSpeed > 0) {
        // Not too close and not too far
        if (target->inRangeOf(unit, 7) && cd < 5) {
          return attack(state, unit, target);
        } else if (!target->inRangeOf(unit, -4)) {
          // Move only if we're not too close
          auto velo = Vec2(target->unit.velocityX, target->unit.velocityY) /
              tc::BW::XYPixelsPerWalktile;
          auto pos = Vec2(target) + (velo * 5);
          return smartMove(state, unit, pos);
        }
      } else if (shouldDance) {
        auto dist = utils::distanceBB(unit, target);
        auto wrange =
            target->flying() ? unit->unit.airRange : unit->unit.groundRange;
        if (cd < 3 && us.lastMove > 0 // Attack off cooldown
            && !target->gone && target->inRangeOf(unit) // Can target enemy
            && unit->atTopSpeed() // Moving at max speed, if we can accel
        ) {
          VLOG(2) << utils::unitString(unit)
                  << " is launching a kiting attack ";
          return attack(state, unit, target);
        }

        us.lastMove = state->currentFrame();
        // # of frames it takes for us to turn
        auto tr = 128 / tc::BW::data::TurnRadius[unit->type->unit];
        auto cond = !unit->atTopSpeed() ||
            std::max(0.0, (double(dist) - wrange) / unit->topSpeed) + tr < cd;
        VLOG(2) << utils::unitString(unit) << " is kiting and moving";
        if (cond || us.target->gone) {
          if (unit->beingAttackedByEnemies.size() == 0) {
            // Hover around unit if not being threatened
            return moveTo(
                state, unit, getKitePosition(unit, KiteBehavior::Hover));
          } else {
            // Flee otherwise
            return moveTo(state, unit, getKitePosition(unit));
          }
        } else {
          auto fleePos = utils::getMovePos(state, unit, us.target, 0, false);
          return moveTo(state, unit, std::move(fleePos));
        }
      } else if (shouldKite) {
        if (unit->inRangeOf(target, 7)) {
          VLOG(3) << unit << " is kiting";
          return filterMove(state, unit, {movefilters::avoidThreatening()});
        } else {
          VLOG(3) << unit << " is kiting but can attack";
          return attack(state, unit, target);
        }
      }
    }

    // Can't kite
    VLOG(2) << utils::unitString(unit) << " is attacking now";
    if (us.target->inRangeOf(unit, 4)) {
      // Send attack command if we're in range or aren't already attacking
      // the target
      if (us.lastMove > 0 || us.attacking != us.target) {
        return attack(state, unit, us.target);
      }
    } else {
      // Send move command once in a while so we don't screw up the
      // game built in pathfinding too much
      auto upc = smartMove(state, unit, us.target);
      return upc;
    }

    return nullptr;
  }

  std::shared_ptr<UPCTuple>
  attack(State* state, Unit* unit, Position const& pos) {
    if (VLOG_IS_ON(3)) {
      VLOG(3) << "Sending attack move to " << pos;
      utils::drawLine(state, unit, pos, tc::BW::Color::Red);
    }
    auto& us = (*unitStates_)[unit];
    us.lastMove = -1;
    us.attacking = us.target;
    return utils::makeSharpUPC(unit, pos, Command::Delete);
  }

  std::shared_ptr<UPCTuple> attack(State* state, Unit* unit, Unit* u) {
    if (VLOG_IS_ON(3)) {
      VLOG(3) << utils::unitString(unit) << " isending attack to "
              << utils::unitString(u);
      utils::drawLine(state, unit, u, tc::BW::Color::Red);
      utils::drawCircle(state, u, 10);
    }
    auto& us = (*unitStates_)[unit];
    us.lastMove = -1;
    us.attacking = u;
    return utils::makeSharpUPC(unit, u, Command::Delete);
  }

  std::shared_ptr<UPCTuple> moveTo(State* state, Unit* unit, Position pos) {
    auto& us = (*unitStates_)[unit];
    if (!unit->flying() && us.lastMove >= 0 &&
        state->currentFrame() - us.lastMove < 20) {
      // For ground units, protect move commands so we don't fuck up pathfinding
      return nullptr;
    }
    if (pos.x > state->mapWidth() || pos.y > state->mapHeight()) {
      pos = utils::clampPositionToMap(state, pos);
      VLOG(2) << "Position out of bounds of map, clipping to " << pos;
    }
    if (VLOG_IS_ON(3)) {
      VLOG(3) << "Sending move to " << pos;
      utils::drawLine(state, unit, pos);
    }
    us.lastMove = state->currentFrame();
    us.attacking = nullptr;
    return utils::makeSharpUPC(unit, pos, Command::Move);
  }

  std::shared_ptr<UPCTuple> filterMove(
      State* state,
      Unit* unit,
      const movefilters::PositionFilters& pfs) {
    return moveTo(state, unit, movefilters::smartMove(state, unit, pfs));
  }

  std::shared_ptr<UPCTuple>
  smartMove(State* state, Unit* unit, const Position& tgt) {
    return moveTo(state, unit, movefilters::smartMove(state, unit, tgt));
  }

  template <typename T>
  movefilters::PPositionFilter unclumpFilter(T getter, double dist) {
    // This totally doesn't work yet and isn't used anywhere
    return movefilters::makePositionFilter<Unit*>(
        getter,
        [this, dist](Unit* agent, Position const& pos, Unit* blocker) {
          if (agent == blocker || !blocker->hasCollision ||
              pos.dist(agent) < 3) {
            return true;
          }
          auto predMe = Vec2(agent) +
              (pos - Vec2(agent)).normalize() * agent->topSpeed * 6;
          auto predHe = Vec2(blocker) +
              Vec2(blocker->unit.velocityX, blocker->unit.velocityY)
                      .normalize() *
                  blocker->topSpeed * 6;
          auto bbdist = utils::distanceBB(agent, predMe, blocker, predHe);
          if (bbdist <= dist) {
            return false;
          }
          return true;
        },
        movefilters::zeroScore);
  }

  std::pair<bool, std::shared_ptr<UPCTuple>> microSpecificUnit(
      State* state,
      Unit* unit) {
    auto doAction = [](std::shared_ptr<UPCTuple> x) {
      return std::make_pair(true, x);
    };
    // This sends no command
    auto doNothing = std::make_pair(true, nullptr);
    // Pass gives control to the non-unit specific micro controller
    auto pass = std::make_pair(false, nullptr);

    auto& us = unitStates_->at(unit);
    if (unit->irradiated()) {
      std::vector<Unit*> units;
      for (auto u : unit->allyUnitsInSightRange) {
        if (utils::distance(u, unit) < 16) {
          units.push_back(u);
        }
      }
      if (!units.empty()) {
        auto centroid = Vec2(utils::centerOfUnits(units));
        auto pos = Vec2(unit) + (Vec2(unit) - centroid).normalize() * 10;
        return doAction(moveTo(state, unit, pos));
      }
    }
    for (auto stormLoc : storms_) {
      if (utils::distance(unit, stormLoc) > 16)
        continue;
      auto pos = Vec2(unit) + (Vec2(unit) - Vec2(stormLoc)).normalize() * 10;
      return doAction(moveTo(state, unit, pos));
    }
    // Lurker specific micro
    if (unit->type == buildtypes::Zerg_Lurker) {
      if ((us.burrowing == 1 && unit->burrowed()) ||
          (us.burrowing == -1 && !unit->burrowed())) {
        us.burrowing = 0;
      } else if (us.burrowing != 0) {
        // Burrowing or unburrowing, don't send a command
        return doNothing;
      }
      if (us.target == nullptr) {
        return pass;
      }
      // TODO if enemy units are coming towards us, don't unburrow.
      if (unit->burrowed()) {
        if (us.target->inRangeOf(unit)) {
          if (us.attacking != us.target) {
            VLOG(4) << "LURKER Attacking with lurker "
                    << utils::unitString(unit);
            return doAction(attack(state, unit, us.target));
          }
        } else {
          VLOG(4) << "LURKER Unburrowing " << utils::unitString(unit);
          auto board = state->board();
          board->postCommand(tc::Client::Command(
              tc::BW::Command::CommandUnit,
              unit->id,
              tc::BW::UnitCommandType::Unburrow));
          us.burrowing = -1;
          return doNothing;
        }
      } else {
        auto mod = utils::clamp(1 - delProb, 0.25, 0.9);
        if (utils::distanceBB(us.target, unit) < unit->unit.groundRange * mod) {
          // Range is 6, but let's get closer to burrow
          VLOG(4) << "LURKER Burrowing " << utils::unitString(unit);
          auto board = state->board();
          board->postCommand(tc::Client::Command(
              tc::BW::Command::CommandUnit,
              unit->id,
              tc::BW::UnitCommandType::Burrow));
          us.burrowing = 1;
          us.lastMove = -1;
          return doNothing;
        } else {
          if (us.lastMove < 0 || state->currentFrame() - us.lastMove > 36) {
            VLOG(4) << "LURKER Moving with lurker " << utils::unitString(unit);
            us.lastMove = state->currentFrame();
            return doAction(
                utils::makeSharpUPC(unit, us.target, Command::Move));
          }
        }
      }
    }
    for (auto u : unit->beingAttackedByEnemies) {
      if (u->type == buildtypes::Terran_Vulture_Spider_Mine) {
        // If I'm a zergling and targetted by mines, suicide into a random
        // ground enemy unit TODO closest?
        if (unit->type == buildtypes::Zerg_Zergling) {
          for (auto t : targets_) {
            if (t->type != buildtypes::Terran_Vulture_Spider_Mine &&
                !t->flying()) {
              return doAction(utils::makeSharpUPC(unit, t, Command::Move));
            }
          }
        }
      }
      if (u->type == buildtypes::Protoss_Scarab) {
        auto pos = Vec2(unit) + (Vec2(unit) - Vec2(u)).normalize() * 10;
        return doAction(moveTo(state, unit, pos));
      }
    }
    if (unit->unit.groundRange > 1 && hasGroundUnits) {
      for (auto u : unit->enemyUnitsInSightRange) {
        if (u->type == buildtypes::Terran_Vulture_Spider_Mine &&
            u->inRangeOf(unit, 4)) {
          return doAction(attack(state, unit, u));
        }
      }
    }
    if (unit->type == buildtypes::Zerg_Mutalisk) {
      for (auto u : unit->enemyUnitsInSightRange) {
        if (u->type != buildtypes::Zerg_Scourge) {
          continue;
        }
        if (utils::distance(u, unit) > unit->unit.airRange + 30) {
          break;
        }
        us.target = u;
        break;
      }
      if (us.target == nullptr || us.target->type != buildtypes::Zerg_Scourge) {
        return pass;
      }
      auto u = us.target;
      auto cd = unit->unit.airCD;

      auto scourgeVelo = Vec2(u->unit.velocityX, u->unit.velocityY) /
          tc::BW::XYPixelsPerWalktile;
      auto myVelo = Vec2(unit->unit.velocityX, unit->unit.velocityY) /
          tc::BW::XYPixelsPerWalktile;
      auto dirToScourge = (Vec2(u) - Vec2(unit)).normalize();
      if (VLOG_IS_ON(2)) {
        utils::drawCircle(
            state, unit, unit->unit.airRange * tc::BW::XYPixelsPerWalktile);
      }

      scourgeVelo.normalize();
      myVelo.normalize();

      auto distBB = utils::distanceBB(u, unit);
      if (us.mutaliskTurning ||
          (cd < 3 && distBB > 3 && myVelo.dot(dirToScourge) > 0)) {
        VLOG(3) << utils::unitString(unit) << " is launching a scourge attack ";
        utils::drawCircle(state, u, 25, tc::BW::Color::Red);
        us.mutaliskTurning = false;
        return doAction(attack(state, unit, u));
      } else if (cd < 6 && distBB > 8) {
        VLOG(3) << utils::unitString(unit) << " is turning to face unit";
        utils::drawCircle(state, u, 25, tc::BW::Color::Red);
        us.mutaliskTurning = true;
        return doAction(moveTo(state, unit, Vec2(unit) + dirToScourge * 20));
      } else if (myVelo.dot(scourgeVelo) < 0.1 || !u->atTopSpeed()) {
        VLOG(3) << utils::unitString(unit)
                << " is moving away from the scourge";
        return doAction(moveTo(state, unit, Vec2(unit) + dirToScourge * -20));
      } else {
        auto pos1 = Vec2(unit) + scourgeVelo.rotate(100) * 20;
        auto pos2 = Vec2(unit) + scourgeVelo.rotate(-200) * 20;
        auto pos = pos1.dist(u) < pos2.dist(u) ? pos2 : pos1;
        utils::drawCircle(state, unit, 25, tc::BW::Color::Blue);
        VLOG(3) << utils::unitString(unit)
                << " is doing the chinese triangle and moving to dir "
                << scourgeVelo;
        return doAction(moveTo(state, unit, pos));
      }
    }
    if (unit->type == buildtypes::Zerg_Scourge) {
      if (us.target == nullptr) {
        if (!unit->threateningEnemies.empty()) {
          auto centroid = utils::centerOfUnits(unit->threateningEnemies);
          auto pos = Vec2(unit) + (Vec2(unit) - centroid).normalize() * 10;
          return doAction(moveTo(state, unit, pos));
        } else {
          return doAction(moveTo(state, unit, center_));
        }
      }
      // Scourges wants to click past the target so to move at full speed, and
      // issue an attack command when they are right on top of the target.
      auto invalidUnit = [&](Unit const* u) {
        if (u->type == buildtypes::Protoss_Interceptor ||
            u->type == buildtypes::Zerg_Overlord || u->type->isBuilding) {
          return true;
        }
        if (u != us.target &&
            enemyStates_->at(u).damages >
                u->unit.health + u->unit.shield - 15) {
          return true;
        }
        return false;
      };
      if (invalidUnit(us.target)) {
        us.target = nullptr;
        for (auto u : unit->enemyUnitsInSightRange) {
          if (!invalidUnit(u)) {
            us.target = u;
            break;
          }
        }
      }
      if (us.target == nullptr) {
        return doNothing;
      }
      if (us.target->inRangeOf(unit, 3)) {
        return doAction(attack(state, unit, us.target));
      }
      auto dir = Vec2(us.target) - Vec2(unit);
      dir.normalize();
      return doAction(moveTo(state, unit, Vec2(unit) + dir * 25));
    }
    if (unit->type == buildtypes::Zerg_Zergling) {
      if (us.target == nullptr) {
        return pass;
      }

      Unit* u = unit;
      Unit* target = us.target;

      if (target->visible &&
          (target->type->isWorker ||
           target->type == buildtypes::Terran_Vulture)) {
        auto shouldMoveTo = [&](Vec2& newPos) {
          int n =
              (int)(utils::distance(Position(target), Position(newPos)) / 4.0f);
          Vec2 step = (newPos - Vec2(target)).normalize() * 4;
          Vec2 pos = Vec2(target);
          for (int i = 0; i != n; ++i) {
            if (utils::distance(pos, Position(u)) < 8) {
              return false;
            }
            const Tile* tile =
                state->tilesInfo().tryGetTile((int)pos.x, (int)pos.y);
            if (!tile || !tile->entirelyWalkable) {
              return false;
            }
            pos += step;
          }
          return true;
        };

        if (target->topSpeed >= u->topSpeed * 0.66f && target->moving() &&
            !target->inRangeOf(u, 4)) {
          const int latency = 4;
          float weaponRange =
              target->flying() ? u->unit.groundRange : u->unit.airRange;
          auto targetVelocity =
              Vec2(target->unit.velocityX, target->unit.velocityY);
          auto targetNextPos = Vec2(target) + targetVelocity * latency;
          auto myNextPos =
              Vec2(u) + Vec2(u->unit.velocityX, u->unit.velocityY) * latency;
          float distance = std::min(
              utils::distanceBB(u, myNextPos, target, targetNextPos),
              utils::distanceBB(u, Position(u), target, targetNextPos));
          if (distance > weaponRange) {
            float distance = utils::distance(u->x, u->y, target->x, target->y);
            if (utils::distance(u->x, u->y, targetNextPos.x, targetNextPos.y) >
                distance) {
              auto np = Vec2(u) + targetVelocity.normalize() * 16;
              if (shouldMoveTo(np)) {
                return std::make_pair(true, moveTo(state, unit, np));
              }
            } else {
              auto np = Vec2(target) +
                  targetVelocity.normalize() *
                      std::min(std::max(distance - 4.0f, 4.0f), 20.0f);
              if (shouldMoveTo(np)) {
                return std::make_pair(true, moveTo(state, unit, np));
              } else {
                auto np = Vec2(target) +
                    targetVelocity.normalize() *
                        std::min(std::max(distance - 4.0f, 4.0f), 12.0f);
                if (shouldMoveTo(np)) {
                  return std::make_pair(true, moveTo(state, unit, np));
                }
              }
            }
          }
        }
      }

      auto nearby = unit->enemyUnitsInSightRange;
      for (auto nmy : nearby) {
        if (nmy->playerId == -1) {
          continue;
        }
        auto distBB = utils::distanceBB(unit, nmy);
        if (distBB > 10) {
          // If the unit is 10 walk tiles away, just attack move
          break;
        }
        // Look for next target if I can't attack
        if (!unit->canAttack(nmy))
          continue;
        // If a target is super close, send an attack command
        if (distBB < 1.5 && us.attacking != nmy) {
          return doAction(attack(state, unit, nmy));
        }
        // If I'm not attacking, attack, otherwise, don't send a command
        if (us.lastMove > 0) {
          return doAction(attack(state, unit, Position(nmy)));
        } else {
          return doNothing;
        }
      }
      // If no nearby targets, move
      return doAction(smartMove(state, unit, us.target));
    }
    if (unit->type == buildtypes::Zerg_Overlord) {
      Unit* cloakedTarget = utils::getBestScoreCopy(
          utils::filterUnits(
              targets_, [](Unit* e) { return e->cloaked() || e->burrowed(); }),
          [&](Unit* e) { return utils::distance(unit, e); },
          std::numeric_limits<float>::infinity());
      if (cloakedTarget) {
        Unit* ally = utils::getBestScoreCopy(
            units(),
            [&](Unit* u) {
              if (u == unit || !u->canAttack(cloakedTarget)) {
                return std::numeric_limits<float>::infinity();
              }
              return utils::distance(u, cloakedTarget);
            },
            std::numeric_limits<float>::infinity());
        if (ally &&
            utils::distance(unit, cloakedTarget) < unit->sightRange - 4) {
          VLOG(3) << unit << " senses ally near cloaked target, moving to "
                  << ally << " near cloaked " << cloakedTarget;
          return doAction(smartMove(state, unit, Vec2(ally)));
        }
      }
      if (!unit->threateningEnemies.empty()) {
        auto threat = unit->threateningEnemies[0];
        auto dir = Vec2(unit) - Vec2(threat);
        dir.normalize();
        VLOG(3) << unit << " senses threat, moving away from " << threat;
        return doAction(smartMove(state, unit, Vec2(unit) + dir * 25));
      }
      if (cloakedTarget) {
        VLOG(3) << unit << " senses cloaked target, moving to "
                << cloakedTarget;
        return doAction(smartMove(state, unit, Vec2(cloakedTarget)));
      }
      VLOG(3) << unit << " has no purpose in life, following the group";
      return doAction(smartMove(state, unit, center_));
    }

    return pass;
  }

  double EHPScoreHeuristic(Unit const* me, Unit const* o) const {
    if (o->type == buildtypes::Terran_Vulture_Spider_Mine ||
        o->type == buildtypes::Zerg_Scourge ||
        o->type == buildtypes::Zerg_Infested_Terran) {
      return 1e7; // Should target these guys first
    }
    auto gDamage = hasGroundUnits ? o->unit.groundATK : 0;
    auto aDamage = hasAirUnits ? o->unit.airATK : 0;
    if (o->type == buildtypes::Terran_Bunker) {
      aDamage = 6 * 4; // 4 marines without upgrades
    }
    if (o->type == buildtypes::Terran_Missile_Turret && hasAirUnits) {
      // Missle turrets are stronger than they appear due to a position adv
      aDamage *= 3;
    }
    if (o->type == buildtypes::Terran_SCV &&
        enemyStates_->at(o).lastRepairing > 0) {
      gDamage = 12;
      aDamage = 12;
    }
    auto dps = std::max(gDamage, aDamage) / (o->unit.maxCD + 0.042);
    return (o->type->gScore + dps) / me->computeEHP(o);
  }

  inline bool isThreat(Unit const* u) const {
    return (u->type->hasGroundWeapon && hasGroundUnits) ||
        (u->type->hasAirWeapon && hasAirUnits) ||
        (u->type->isWorker); // Always nice to kill workers
  }

  Unit* pickTarget(State* state, Unit* unit, std::vector<Unit*> const& units) {
    auto& us = unitStates_->at(unit);
    if (us.target == nullptr
        // TODO Make this stickiness per unit type
        || state->currentFrame() - us.lastTarget > 15 ||
        enemyStates_->find(us.target) == enemyStates_->end()) {
      auto target = pickTarget_(state, unit, units);
      us.target = target;
      if (target != nullptr)
        us.lastTarget = state->currentFrame();
    }

    if (us.target != nullptr) {
      int hp, sh;
      auto& es = enemyStates_->at(us.target);
      unit->computeDamageTo(us.target, &hp, &sh);
      es.damages += hp + sh;
    }

    return us.target;
  }

  Unit*
  pickTarget_(State* state, Unit const* unit, std::vector<Unit*> const& targs) {
    // Order targets by probability and use distance from this unit as a
    // tie-breaker.
    if (targs.empty()) {
      VLOG(2) << "No targets for " << utils::unitString(unit) << " in task "
              << utils::upcString(upcId());
      return nullptr;
    }
    auto upc = sourceUpc;

    std::vector<Unit*> potTargs;
    for (auto t : targs) {
      if (t->gone || t->dead || !unit->canAttack(t)) {
        continue;
      }
      potTargs.push_back(t);
    }

    auto fns = std::vector<std::function<double(Unit * t)>>();
#define COMP(CODE) fns.push_back([&](Unit* t) { CODE; })
    COMP(return isThreat(t) ? 0 : 1);
    // Pick the closest threat if enemies are too far away
    COMP({
      auto dist = utils::distance(unit, t);
      if (dist > unit->sightRange + 3) {
        return dist;
      }
      return 0.f;
    });
    if (unit->type != buildtypes::Zerg_Scourge) {
      COMP(return t->inRangeOf(unit, 3) ? 0 : 1);
    }
    if (unit->type == buildtypes::Zerg_Zergling) {
      COMP({
        auto atk = 0;
        for (auto u : t->beingAttackedByEnemies) {
          if (!u->flying())
            atk++;
          if (atk == 3)
            break;
        }
        return atk >= 3 ? 1 : 0; // Score crowded enemies higher
      });
      COMP(return t->type == buildtypes::Terran_Vulture);
    }
    COMP(return enemyStates_->at(t).damages >
         t->unit.health + t->unit.shield + 10);
    if (!threats_.empty() || delProb < 0.85) {
      COMP(return -EHPScoreHeuristic(unit, t));
      /*
      COMP({
        auto pu = upc->positionU;
        auto it = pu.find(t);
        auto p = (it == pu.end()) ? 0.f : it->second;
        // Higher probabilities should end up first in target list
        return -p;
      });
      */
    }
    COMP(return utils::distanceBB(unit, t));
#undef COMP

    Unit* bestUnit = nullptr;
    for (auto fn : fns) {
      auto bestUnitSoFar = bestUnit;
      auto bestScore = bestUnitSoFar == nullptr ? 0 : fn(bestUnit);
      for (auto i = 0U; i < potTargs.size(); i++) {
        auto t = potTargs[i];
        auto score = fn(t);
        if (bestUnitSoFar == nullptr || score < bestScore) {
          bestScore = score;
          bestUnitSoFar = t;
        } else if (score > bestScore) {
          std::swap(potTargs[i], potTargs.back());
          potTargs.pop_back();
          --i;
        }
      }
      bestUnit = bestUnitSoFar;
    }

    if (unit->type == buildtypes::Zerg_Scourge) {
      if (bestUnit &&
          enemyStates_->at(bestUnit).damages >
              bestUnit->unit.health + bestUnit->unit.shield - 20) {
        return nullptr;
      }
    }
    return bestUnit;
  }

  void update(State* state) override {
    removeDeadOrReassignedUnits(state);

    // Update the task status if no more units
    if (units().empty()) {
      VLOG(1) << "All units died or was reassigned, marking task "
              << utils::upcString(upcId()) << " as failed";
      setStatus(TaskStatus::Failure);
      return;
    }

    // Remove all dead units
    for (size_t i = 0; i < targets.size(); i++) {
      if (targets[i]->dead) {
        std::swap(targets[i], targets.back());
        targets.pop_back();
        --i;
      }
    }
  }

  void update_(State* state) {
    // Actually this probably all belongs in Module...
    update(state);

    // Updates has air and has ground
    hasAirUnits = hasGroundUnits = false;
    for (const auto unit : this->units()) {
      if (unit->flying()) {
        this->hasAirUnits = true;
      } else {
        this->hasGroundUnits = true;
      }
    }

    // Units within 75 walktiles of the centroid are grouped
    // TODO Dynamic based on # of units in group
    center_ = grouped_.empty() ? utils::centerOfUnits(units())
                               : utils::centerOfUnits(grouped_);
    /*
    grouped_.clear();
    traveling_.clear();
    for (auto unit : units()) {
      if (center_.dist(Position(unit)) < 75) {
        grouped_.insert(unit);
      } else {
        traveling_.insert(unit);
      }
    }
    */
    grouped_ = units();

    targets_ = getGroupTargets(state);
    threats_ = getGroupThreats(state);
    squadPosition_ = utils::centerOfUnits(grouped_);

    // If no more targets and we're not targetting a location, declare victory
    if (!targettingLocation && targets_.empty()) {
      VLOG(1) << "Squad for " << utils::upcString(upcId())
              << " has no more targets, marking as succeeded";
      setStatus(TaskStatus::Success);
      return;
    }

    for (auto& bullet : state->tcstate()->frame->bullets) {
      if (bullet.type == tc::BW::BulletType::Psionic_Storm) {
        storms_.emplace_back(bullet.x, bullet.y);
      }
    }
  }

  auto makeUPCs(State* state) {
    update_(state);
    auto tUPCs = this->makeTravelingUPCs(state);
    auto upcs = this->makeGroupedUPCs(state);
    upcs.insert(upcs.end(), tUPCs.begin(), tUPCs.end());
    return upcs;
  }

 private:
  Position center_;
  std::unordered_set<Unit*> grouped_, traveling_;
  std::unordered_map<Unit const*, SquadCombatModule::EnemyState>* enemyStates_;
  std::unordered_map<Unit const*, SquadCombatModule::UnitState>* unitStates_;
  Position squadPosition_;
  std::vector<Position> storms_;

  // The actual targets my grouped units can fight, calculated based off
  // the UPC probabilities
  std::vector<Unit*> targets_;
  // The actual threats my grouped units encounter
  std::vector<Unit*> threats_;
};
} // namespace

void SquadCombatModule::step(State* state) {
  auto board = state->board();

  // Form new squads based on new UPCs
  auto myUpcs = board->upcsFrom(this);
  for (auto& upcs : board->upcsWithCommand(Command::Delete, 0.1)) {
    auto id = upcs.first;
    auto upc = upcs.second;
    if (upc->command[Command::Gather] > 0 ||
        upc->command[Command::Create] > 0) {
      continue;
    }
    if (myUpcs.find(id) != myUpcs.end()) {
      continue;
    }
    if (upc->unit.empty()) {
      continue;
    }

    // Skip UPCs targeting allied units (Builder might want to remove a blocking
    // building, for example)
    bool targetsMyUnit = false;
    for (auto& it : upc->positionU) {
      if (it.second > 0 && it.first->isMine) {
        targetsMyUnit = true;
        break;
      }
    }
    if (targetsMyUnit) {
      continue;
    }

    if (formNewSquad(state, upc, id)) {
      board->consumeUPCs({id}, this);
    }
  }

  // Update my units
  for (auto u : state->unitsInfo().myUnits()) {
    if (unitStates_.find(u) == unitStates_.end()) {
      unitStates_[u] = SquadCombatModule::UnitState();
    }
  }

  // Erase dead units from unitStates_
  for (auto it = unitStates_.begin(); it != unitStates_.end();) {
    if (it->first->dead) {
      unitStates_.erase(it++);
    } else {
      ++it;
    }
  }

  // Update enemy units
  for (auto u : state->unitsInfo().enemyUnits()) {
    if (enemyStates_.find(u) == enemyStates_.end()) {
      enemyStates_[u] = SquadCombatModule::EnemyState();
    }
    auto& es = enemyStates_[u];
    if (u->flag(tc::Unit::Flags::Repairing)) {
      es.lastRepairing = state->currentFrame();
    } else if (
        es.lastRepairing != -1 &&
        state->currentFrame() - es.lastRepairing > 36) {
      es.lastRepairing = -1;
    }
  }

  // Erase dead units from enemyStates_
  for (auto it = enemyStates_.begin(); it != enemyStates_.end();) {
    if (it->first->dead) {
      enemyStates_.erase(it++);
    } else {
      it->second.damages = 0;
      ++it;
    }
  }

  // Update existing squads
  for (auto& task : board->tasksOfModule(this)) {
    updateTask(state, task);
  }
  Module::step(state);
}

bool SquadCombatModule::formNewSquad(
    State* state,
    std::shared_ptr<UPCTuple> sourceUpc,
    int sourceUpcId) {
  // Form a squad task with all units with non-zero probability
  std::unordered_set<Unit*> units;
  std::vector<Unit*> targets;

  for (auto& uprob : sourceUpc->unit) {
    if (uprob.second > 0) {
      units.insert(uprob.first);
    }
  }
  if (units.empty()) {
    VLOG(1) << "No units to take care of in " << utils::upcString(sourceUpcId);
    return false;
  }

  std::shared_ptr<SquadTask> task;
  auto positionSDefined = (sourceUpc->positionS.x != -1);
  if (!positionSDefined && !sourceUpc->positionU.empty()) {
    for (auto it : sourceUpc->positionU) {
      if (it.second > 0) {
        targets.push_back(it.first);
      }
    }
    VLOG(2) << "Targetting " << targets.size() << " units";
    task = std::make_shared<SquadTask>(
        sourceUpcId,
        sourceUpc,
        units,
        std::move(targets),
        &enemyStates_,
        &unitStates_);
  } else {
    if (positionSDefined) {
      if (sourceUpc->positionS.x < 0) {
        LOG(INFO) << "No targets to attack in "
                  << utils::upcString(sourceUpcId);
        return false;
      }
      auto pos = sourceUpc->positionS;
      task = std::make_shared<SquadTask>(
          sourceUpcId,
          sourceUpc,
          units,
          pos.x,
          pos.y,
          &enemyStates_,
          &unitStates_);
      VLOG(2) << "Targeting single position at " << pos.x << "," << pos.y;
    } else if (sourceUpc->position.defined()) {
      auto argmax = utils::argmax(sourceUpc->position, sourceUpc->scale);
      int x, y;
      std::tie(x, y, std::ignore) = argmax;
      task = std::make_shared<SquadTask>(
          sourceUpcId, sourceUpc, units, x, y, &enemyStates_, &unitStates_);
      VLOG(2) << "Targeting position argmax at " << x << "," << y;
    } else {
      VLOG(1) << "No targets for " << utils::upcString(sourceUpcId);
      return false;
    }
  }

  state->board()->postTask(task, this);
  task->setStatus(TaskStatus::Unknown);

  size_t numUnits = units.size();
  VLOG(1) << "Formed squad for " << utils::upcString(sourceUpcId) << " with "
          << numUnits << " units: " << utils::unitsString(units)
          << utils::unitsString(task->targets);
  return true;
}

void SquadCombatModule::updateTask(State* state, std::shared_ptr<Task> task) {
  auto squad = std::static_pointer_cast<SquadTask>(task);
  auto board = state->board();

  if (squad->status() == TaskStatus::Success) {
    VLOG(2) << "Squad for " << utils::upcString(squad->upcId())
            << " has succeeded";
    board->markTaskForRemoval(squad);
    return;
  } else if (squad->status() == TaskStatus::Failure) {
    VLOG(2) << "Squad for " << utils::upcString(squad->upcId())
            << " has failed";
    board->markTaskForRemoval(squad);
    return;
  } else if (squad->status() == TaskStatus::Cancelled) {
    VLOG(2) << "Squad for UPC " << squad->upcId() << " has been cancelled";
    board->markTaskForRemoval(squad);
    return;
  }

  auto upcs = squad->makeUPCs(state);
  for (auto upc : upcs) {
    if (upc) {
      board->postUPC(std::move(upc), task->upcId(), this);
    }
  }
}
} // namespace fairrsh
