/*
 * Copyright (c) 2017-present, Facebook, Inc.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "modules/scouting.h"

#include "areainfo.h"
#include "builderhelper.h"
#include "commandtrackers.h"
#include "fogofwar.h"
#include "movefilters.h"
#include "state.h"
#include "task.h"
#include "unitsinfo.h"
#include "utils.h"

#include <bwem/map.h>

#include <deque>
#include <queue>

namespace cherrypi {

REGISTER_SUBCLASS_0(Module, ScoutingModule);

namespace {

class SneakyOverlordImpl {
 public:
  SneakyOverlordImpl(State* state) {
    // Precompute the range at which we will try to keep overlords.
    int range = 4 * 8;
    std::unordered_set<Position> set;
    for (int y = -range; y <= range; y += 4) {
      for (int x = -range; x <= range; x += 4) {
        if (utils::distance(Position(0, 0), Position(x, y)) <= range) {
          relativePositions.emplace_back(x, y);
          set.emplace(x, y);
        }
      }
    }
    for (int y = -range - 4; y <= range + 4; y += 4) {
      for (int x = -range - 4; x <= range + 4; x += 4) {
        Position pos(x, y);
        if (set.find(pos) == set.end()) {
          if (set.find(pos + Position(4, 0)) != set.end() ||
              set.find(pos + Position(-4, 0)) != set.end() ||
              set.find(pos + Position(0, 4)) != set.end() ||
              set.find(pos + Position(0, -4)) != set.end()) {
            edgeRelativePositions.push_back(pos);
          }
        }
      }
    }
  }

  size_t posIndex(Position pos) {
    return (size_t)pos.y / (size_t)tc::BW::XYWalktilesPerBuildtile *
        TilesInfo::tilesWidth +
        (size_t)pos.x / (size_t)tc::BW::XYWalktilesPerBuildtile;
  }

  // Update the score map, which is mostly the distance/cost to move from each
  // build tile to some goal.
  void updateScoreMap(State* state, Unit* unit, Position targetPos) {
    std::vector<Position> path;
    Position enemyPos = targetPos;
    Position enemyExpoPos = targetPos;
    Position enemyChokePos = targetPos;
    // If we haven't found the enemy, then just go to the location we are
    // initially assigned.
    if (!state->areaInfo().foundEnemyStartLocation()) {
      if (scoutPos == Position()) {
        scoutPos = targetPos;
      }
      enemyPos = targetPos = enemyExpoPos = enemyChokePos = scoutPos;
      auto* tile = state->tilesInfo().tryGetTile(scoutPos.x, scoutPos.y);
      if (!tile || tile->visible) {
        auto& locs = state->areaInfo().candidateEnemyStartLocations();
        if (!locs.empty()) {
          scoutPos = utils::getBestScoreCopy(locs, [&](Position pos) {
            return utils::distance(pos, Position(unit));
          });
        }
      }
    } else {
      // Find the enemy natural and choke between their main and natural.
      enemyPos = state->areaInfo().enemyStartLocation();
      targetPos = enemyPos;

      Position myPos = state->areaInfo().myStartLocation();
      enemyPos = state->areaInfo().enemyStartLocation();
      path = state->areaInfo().walkPath(enemyPos, myPos);
      path.resize(path.size() / 3);
      if (!path.empty()) {
        float bestBasePathIndexScore = kdInfty;
        size_t bestBasePathIndex = path.size() - 1;
        Position bestBasePos;
        for (auto& area : state->map()->Areas()) {
          for (auto& base : area.Bases()) {
            if (!base.BlockingMinerals().empty()) {
              continue;
            }
            Position pos(
                base.Location().x * tc::BW::XYWalktilesPerBuildtile,
                base.Location().y * tc::BW::XYWalktilesPerBuildtile);
            if (utils::distance(pos, enemyPos) <= 4 * 15) {
              continue;
            }
            if (!builderhelpers::canBuildAt(
                    state, buildtypes::Zerg_Hatchery, pos, true)) {
              bool skip = true;
              if (utils::distance(pos, enemyPos) > 4 * 10) {
                Tile* tile = state->tilesInfo().tryGetTile(pos.x, pos.y);
                if (tile && tile->building && tile->building->isEnemy) {
                  skip = false;
                }
              }
              if (skip) {
                continue;
              }
            }
            auto* area = state->areaInfo().tryGetArea(pos);
            if (!area) {
              continue;
            }
            size_t bestPathPosIndex = 0;
            float bestPathPosScore = kdInfty;
            for (auto* cp : area->area->ChokePoints()) {
              BWAPI::WalkPosition cpPos = cp->Center();
              for (size_t i = 0; i != path.size(); ++i) {
                auto pathPos = path[i];
                float score =
                    utils::distance(cpPos.x, cpPos.y, pathPos.x, pathPos.y);
                if (score < bestPathPosScore) {
                  bestPathPosScore = score;
                  bestPathPosIndex = i;
                }
              }
            }
            float s = bestPathPosIndex * 4 * 12.0f;
            s = s * s + bestPathPosScore * bestPathPosScore;
            if (s < bestBasePathIndexScore) {
              bestBasePathIndexScore = s;
              bestBasePathIndex = bestPathPosIndex;
              bestBasePos = pos;
            }
          }
          enemyExpoPos = bestBasePos;
          enemyChokePos = path[bestBasePathIndex];
        }
      }
    }

    uint8_t inRangeValue = nextInRangeValue++;
    if (inRangeValue == 0) {
      inRange.fill(0);
      inRangeValue = nextInRangeValue++;
    }

    for (unsigned y = 0; y != state->tilesInfo().mapTileHeight(); ++y) {
      auto from = scoreMap.begin() + y * TilesInfo::tilesWidth;
      auto to = from + state->tilesInfo().mapTileWidth();
      std::fill(from, to, 0.0f);
    }

    struct OpenNode {
      Position pos;
      float score = 0.0f;
    };
    struct OpenNodeCmp {
      bool operator()(const OpenNode& a, const OpenNode& b) const {
        return a.score > b.score;
      }
    };

    std::priority_queue<OpenNode, std::vector<OpenNode>, OpenNodeCmp> open;

    // Find the area around each unit that is "in range", or too close for our
    // overlord to move there. Only consider units that can shoot up.
    for (Unit* e : state->unitsInfo().enemyUnits()) {
      if (e->gone) {
        continue;
      }
      if ((!e->type->isBuilding || e->flying()) && !e->type->hasAirWeapon) {
        continue;
      }

      for (Position relPos : relativePositions) {
        Position pos = utils::clampPositionToMap(state, Position(e) + relPos);
        size_t index = posIndex(pos);
        inRange[index] = inRangeValue;
      }
    }

    // How desireable some location is based on when we saw it last.
    auto frameScore = [&](FrameNum frame) {
      FrameNum age = state->currentFrame() - frame;
      FrameNum maxAge = 24 * 60 * 2;
      if (age > maxAge) {
        age = maxAge;
      }
      return (float)(maxAge - age);
    };

    // Add some position as a desirable scout target.
    auto addOpen = [&](Position sourcePos, float score) {
      if (sourcePos == Position()) {
        return;
      }
      if (score == -1) {
        score = frameScore(
            state->tilesInfo().getTile(sourcePos.x, sourcePos.y).lastSeen);
      }
      for (Position relPos : edgeRelativePositions) {
        Position pos = utils::clampPositionToMap(state, sourcePos + relPos);
        size_t index = posIndex(pos);
        if (!inRange[index] && scoreMap[index] == 0.0f) {
          scoreMap[index] = score;
          open.push({pos, score});
        }
      }
    };

    bool isNearestEnemyExpo = true;
    float enemyExpoDistance = utils::distance(unit, enemyExpoPos);
    for (Unit* u : state->unitsInfo().myUnitsOfType(unit->type)) {
      if (utils::distance(u, enemyExpoPos) < enemyExpoDistance) {
        isNearestEnemyExpo = false;
      }
    }

    // One overlord checks our the natural/choke, and one/the rest checks out
    // the main.
    if (isNearestEnemyExpo) {
      addOpen(enemyExpoPos, -1);
      addOpen(enemyChokePos, -1);
    } else {
      addOpen(enemyPos, -1);
    }

    // Try to keep a tab on all enemy buildings.
    for (Unit* e : state->unitsInfo().enemyUnits()) {
      if (e->gone)
        continue;
      if (!e->type->isBuilding || e->flying())
        continue;

      addOpen(Position(e), frameScore(e->lastSeen));
    }

    while (!open.empty()) {
      OpenNode cur = open.top();
      open.pop();

      auto add = [&](Position pos, float dist) {
        pos = utils::clampPositionToMap(state, pos);
        size_t index = posIndex(pos);
        if (scoreMap[index] != 0.0f || inRange[index] == inRangeValue) {
          return;
        }
        float newScore = cur.score + dist;
        scoreMap[index] = newScore;
        open.push({pos, newScore});
      };

      add(cur.pos + Position(4, 0), 4.0f);
      add(cur.pos + Position(-4, 0), 4.0f);
      add(cur.pos + Position(0, 4), 4.0f);
      add(cur.pos + Position(0, -4), 4.0f);
      add(cur.pos + Position(4, 4), 5.656854249f);
      add(cur.pos + Position(-4, 4), 5.656854249f);
      add(cur.pos + Position(-4, -4), 5.656854249f);
      add(cur.pos + Position(4, -4), 5.656854249f);
    }
  }

  bool update(State* state, Unit* unit, Position& location) {
    Position targetPos = location;

    if (state->currentFrame() - lastUpdateScoreMap >= 6) {
      lastUpdateScoreMap = state->currentFrame();
      updateScoreMap(state, unit, targetPos);
    }

    // Move to some nearby position with a low score
    float bestScore = kfInfty;
    int range = 4 * 6;
    bool escape = inRange[posIndex(unit->pos())];
    if (escape) {
      range = 4 * 12;
    }
    Position beginPos = Position(unit) - Position(range, range);
    Position endPos = beginPos + Position(range * 2, range * 2);
    beginPos = utils::clampPositionToMap(state, beginPos);
    endPos = utils::clampPositionToMap(state, endPos);
    int beginTileX = beginPos.x / tc::BW::XYWalktilesPerBuildtile;
    int beginTileY = beginPos.y / tc::BW::XYWalktilesPerBuildtile;
    int endTileX = endPos.x / tc::BW::XYWalktilesPerBuildtile;
    int endTileY = endPos.y / tc::BW::XYWalktilesPerBuildtile;
    for (int y = beginTileY; y != endTileY; ++y) {
      for (int x = beginTileX; x != endTileX; ++x) {
        auto& tile =
            state->tilesInfo()
                .tiles[(unsigned)y * TilesInfo::tilesWidth + (unsigned)x];
        Position pos(tile);
        size_t index = posIndex(pos);
        if (index != posIndex(unit->pos())) {
          if (escape) {
            if (!inRange[index]) {
              float d = utils::distance(pos, unit);
              float s = scoreMap[index];
              s = s * s + d * d;
              if (s < bestScore) {
                bestScore = s;
                targetPos = pos;
              }
            }
          } else {
            float s = scoreMap[index];
            if (s != 0.0f && s < bestScore) {
              bestScore = s;
              targetPos = pos;
            }
          }
        }
      }
    }

    // If there's something that can attack us, just flee from it.
    Vec2 fleeSum;
    int fleeN = 0;
    for (Unit* e : unit->unitsInSightRange) {
      if (e->isEnemy && e->type->hasAirWeapon &&
          utils::distance(unit, e) <= 4 * 9) {
        fleeSum += Vec2(e);
        ++fleeN;
      }
    }
    if (fleeN) {
      fleeSum /= fleeN;
      location = utils::clampPositionToMap(
          state,
          Position(unit) +
              Position((Vec2(unit) - fleeSum).normalize() * 4 * 8));
      return true;
    }

    if (utils::distance(unit, targetPos) < 12) {
      targetPos = Position(unit) +
          Position((Vec2(targetPos) - Vec2(unit)).normalize() * 12);
    }

    location = utils::clampPositionToMap(state, targetPos);
    return true;
  }

 private:
  Position scoutPos;
  std::vector<Position> relativePositions;
  std::vector<Position> edgeRelativePositions;
  std::array<uint8_t, TilesInfo::tilesWidth * TilesInfo::tilesHeight> inRange{};
  uint8_t nextInRangeValue = 1;
  std::array<float, TilesInfo::tilesWidth * TilesInfo::tilesHeight> scoreMap{};
  FrameNum lastUpdateScoreMap = 0;
};

class WorkerScoutImpl {
 public:
  WorkerScoutImpl(State* state) {

  }

  size_t posIndex(Position pos) {
    return (size_t)pos.y / (size_t)tc::BW::XYWalktilesPerBuildtile *
        TilesInfo::tilesWidth +
        (size_t)pos.x / (size_t)tc::BW::XYWalktilesPerBuildtile;
  }

  uint8_t visitNumber_ = 0;
  std::vector<uint8_t> tileVisitTracker_ =
      std::vector<uint8_t>(TilesInfo::tilesWidth * TilesInfo::tilesHeight);
  
  Position findBaseScoutPos(State* state, Unit* u) {
    auto mapWidth = state->mapWidth();
    auto mapHeight = state->mapHeight();
    bool flying = u->flying();

    uint8_t visitedValue = ++visitNumber_;

    auto* tilesData = state->tilesInfo().tiles.data();

    Position targetPos = state->areaInfo().enemyStartLocation();
    Position startPos = u->pos();

    std::deque<const Tile*> open;
    auto* startTile = &state->tilesInfo().getTile(u->x, u->y);
    open.push_back(startTile);
    while (!open.empty()) {
      const Tile* tile = open.front();
      open.pop_front();

      if (state->currentFrame() - tile->lastSeen >=
          15 * utils::distance(targetPos, Position(tile))) {
        return Position(tile);
      }
      
      auto add = [&](const Tile* ntile) {
        if (!flying && !tile->entirelyWalkable && tile != startTile) {
          return;
        }
        auto& v = tileVisitTracker_[ntile - tilesData];
        if (v == visitedValue) {
          return;
        }
        v = visitedValue;
        if (utils::distance(targetPos, Position(ntile)) <= 4.0f * 12) {
          open.push_back(ntile);
        }
      };
      if (tile->x > 0) {
        add(tile - 1);
      }
      if (tile->y > 0) {
        add(tile - TilesInfo::tilesWidth);
      }
      if (tile->x < mapWidth - tc::BW::XYWalktilesPerBuildtile) {
        add(tile + 1);
      }
      if (tile->y < mapHeight - tc::BW::XYWalktilesPerBuildtile) {
        add(tile + TilesInfo::tilesHeight);
      }
    }

    return kInvalidPosition;
  }

  bool update(State* state, Unit* unit, Position& location, Unit*& target) {
    Position targetPos = location;
    target = nullptr;

    std::vector<Position> path;
    Position enemyPos = targetPos;
    Position enemyExpoPos = targetPos;
    Position enemyChokePos = targetPos;
    // If we haven't found the enemy, then just go to the location we are
    // initially assigned.
    if (!state->areaInfo().foundEnemyStartLocation()) {
      if (scoutPos == Position()) {
        scoutPos = targetPos;
      }
      enemyPos = targetPos = enemyExpoPos = enemyChokePos = scoutPos;
      auto* tile = state->tilesInfo().tryGetTile(scoutPos.x, scoutPos.y);
      if (!tile || tile->visible) {
        auto& locs = state->areaInfo().candidateEnemyStartLocations();
        if (!locs.empty()) {
          scoutPos = utils::getBestScoreCopy(locs, [&](Position pos) {
            return utils::distance(pos, Position(unit));
          });
        }
      }
    } else {
      // Find the enemy natural and choke between their main and natural.
      enemyPos = state->areaInfo().enemyStartLocation();
      targetPos = enemyPos;

      Position myPos = state->areaInfo().myStartLocation();
      enemyPos = state->areaInfo().enemyStartLocation();
      path = state->areaInfo().walkPath(enemyPos, myPos);
      path.resize(path.size() / 3);
      if (!path.empty()) {
        float bestBasePathIndexScore = kdInfty;
        size_t bestBasePathIndex = path.size() - 1;
        Position bestBasePos;
        for (auto& area : state->map()->Areas()) {
          for (auto& base : area.Bases()) {
            if (!base.BlockingMinerals().empty()) {
              continue;
            }
            Position pos(
                base.Location().x * tc::BW::XYWalktilesPerBuildtile,
                base.Location().y * tc::BW::XYWalktilesPerBuildtile);
            if (utils::distance(pos, enemyPos) <= 4 * 15) {
              continue;
            }
            if (!builderhelpers::canBuildAt(
                    state, buildtypes::Zerg_Hatchery, pos, true)) {
              bool skip = true;
              if (utils::distance(pos, enemyPos) > 4 * 10) {
                Tile* tile = state->tilesInfo().tryGetTile(pos.x, pos.y);
                if (tile && tile->building && tile->building->isEnemy) {
                  skip = false;
                }
              }
              if (skip) {
                continue;
              }
            }
            auto* area = state->areaInfo().tryGetArea(pos);
            if (!area) {
              continue;
            }
            size_t bestPathPosIndex = 0;
            float bestPathPosScore = kdInfty;
            for (auto* cp : area->area->ChokePoints()) {
              BWAPI::WalkPosition cpPos = cp->Center();
              for (size_t i = 0; i != path.size(); ++i) {
                auto pathPos = path[i];
                float score =
                    utils::distance(cpPos.x, cpPos.y, pathPos.x, pathPos.y);
                if (score < bestPathPosScore) {
                  bestPathPosScore = score;
                  bestPathPosIndex = i;
                }
              }
            }
            float s = bestPathPosIndex * 4 * 12.0f;
            s = s * s + bestPathPosScore * bestPathPosScore;
            if (s < bestBasePathIndexScore) {
              bestBasePathIndexScore = s;
              bestBasePathIndex = bestPathPosIndex;
              bestBasePos = pos;
            }
          }
          enemyExpoPos = bestBasePos;
          enemyChokePos = path[bestBasePathIndex];
        }
      }
    }

    targetPos = enemyPos;

    int depots = 0;
    int barracks = 0;
    int pools = 0;
    int gateways = 0;
    int cybercores = 0;
    int rangedUnits = 0;
    Unit* nearestBuilding = nullptr;
    float nearestBuildingDistance = kfInfty;
    Unit* leastRecentlySeenBuilding = nullptr;
    float leastRecentlySeenBuildingAge = 0;
    Unit* nearestGatewayOrBarracks = nullptr;
    float nearestGatewayOrBarracksDistance = kfInfty;
    for (Unit* e : state->unitsInfo().enemyUnits()) {
      if (e->type->isBuilding) {
        if (e->type->isResourceDepot) ++depots;
        if (e->type == buildtypes::Terran_Barracks) ++barracks;
        else if (e->type == buildtypes::Zerg_Spawning_Pool) ++pools;
        else if (e->type == buildtypes::Protoss_Gateway) ++gateways;
        else if (e->type == buildtypes::Protoss_Cybernetics_Core) ++cybercores;

        if (e->unit.groundRange >= 4 * 3) {
          ++rangedUnits;
        }

        float d = utils::distance(unit->pos(), e->pos());
        if (!nearestBuilding || d < nearestBuildingDistance) {
          nearestBuildingDistance = d;
          nearestBuilding = e;
        }

        if (e->type == buildtypes::Terran_Barracks || e->type == buildtypes::Protoss_Gateway) {
          if (d < nearestGatewayOrBarracksDistance) {
            nearestGatewayOrBarracksDistance = d;
            nearestGatewayOrBarracks = e;
          }
        }

        int age = state->currentFrame() - e->lastSeen;
        if (age > leastRecentlySeenBuildingAge) {
          leastRecentlySeenBuildingAge = age;
          leastRecentlySeenBuilding = e;
        }
      }
    }

    Unit* targetBuilding = nullptr;

    if (nearestBuildingDistance < 4.0f * 12) {
      targetBuilding = nearestBuilding;
    }

    if (state->currentFrame() < 24 * 60 * 2 + 24 * 45) {
      if (nearestGatewayOrBarracks) {
        targetBuilding = nearestGatewayOrBarracks;
      }
    }

    if (targetBuilding) {
      Vec2 offset = (targetBuilding->posf() - unit->posf()).normalize() * 12.0f;
      if (utils::distance(targetBuilding->pos(), unit->pos()) < 4.0f * 2) {
        offset.rotateDegrees(direction > 0 ? 110 : -110);
      } else {
        offset.rotateDegrees(direction > 0 ? 60 : -60);
      }

      if (!targetBuilding->visible) targetPos = targetBuilding->pos();
      else targetPos = unit->pos() + Position(offset);
    }

    if (utils::distance(unit->pos(), state->areaInfo().enemyStartLocation()) <=
        4.0f * 12) {
      Position scoutPos = findBaseScoutPos(state, unit);
      if (scoutPos != kInvalidPosition) {
        targetPos = scoutPos;
      }
    }

    Unit* mineralTarget = utils::getBestScoreCopy(state->unitsInfo().resourceUnits(), [&](Unit* u) {
      if (!u->visible) return kfInfty;
      if (!u->type->isMinerals) return kfInfty;
      float d = utils::distance(state->areaInfo().enemyStartLocation(), u->pos());
      if (d > 4.0f * 12) return kfInfty;
      return d;
    }, kfInfty);
    if (!mineralTarget) {
      mineralTarget = utils::getBestScoreCopy(state->unitsInfo().resourceUnits(), [&](Unit* u) {
        if (!u->type->isMinerals) return kfInfty;
        float d = utils::distance(state->areaInfo().enemyStartLocation(), u->pos());
        if (d > 4.0f * 12) return kfInfty;
        return d;
      }, kfInfty);
      if (mineralTarget) {
        targetPos = mineralTarget->pos();
      }
    }

    auto tileAge = [&](Position pos) {
      auto* t = state->tilesInfo().tryGetTile(pos.x, pos.y);
      if (t) return state->currentFrame() - t->lastSeen;
      return 0;
    };

    if (barracks + pools + gateways == 0 || state->currentFrame() >= 24 * 60 * 3) {
      if (depots) {
        if (tileAge(enemyExpoPos) > 24 * 15) {
          targetPos = enemyExpoPos;
          location = targetPos;
          return true;
        }
        else if (tileAge(enemyChokePos) > 24 * 15) targetPos = enemyChokePos;
        else if (tileAge(enemyPos) > 24 * 15) targetPos = enemyPos;
      } else {
        if (tileAge(enemyPos) > 24 * 15) targetPos = enemyPos;
      }
    }

    Unit* nearestEnemyWorker = utils::getBestScoreCopy(state->unitsInfo().enemyUnits(), [&](Unit* e) {
      if (!e->type->isWorker) return kfInfty;
      return float(utils::distanceBB(unit, e) - e->rangeAgainst(unit));
    }, kfInfty);

    if (true) {
      targetPos = utils::clampPositionToMap(state, targetPos);
      auto tarea = state->map()->GetArea(BWAPI::WalkPosition(targetPos.x, targetPos.y));
      auto uarea = state->map()->GetArea(BWAPI::WalkPosition(unit->x, unit->y));
      if (tarea && uarea && tarea != uarea) {
        int pLength;
        auto path = state->map()->GetPath(
            BWAPI::Position(BWAPI::WalkPosition(unit->x, unit->y)),
            BWAPI::Position(BWAPI::WalkPosition(targetPos.x, targetPos.y)),
            &pLength);
        if (!path.empty()) {
          auto ch1 = path[0]->Center();
          auto chkPos = Position(ch1.x, ch1.y);
          if (nearestEnemyWorker && utils::distance(unit->pos(), nearestEnemyWorker->pos()) <= 4.0f * 8) {
            if (utils::distance(chkPos, nearestEnemyWorker->pos()) <= 8.0f) {
              target = nearestEnemyWorker;
            }
          }
          if (chkPos.distanceTo(unit) > 20) {
            targetPos = chkPos;
          } else {
            if (path.size() >= 2) {
              auto ch2 = path[1]->Center();
              targetPos = Position(ch2.x, ch2.y);
            }
          }
        }
      }
    }

    auto canMoveInDirection = [&](Vec2 odir,
                                  float distance = DFOASG(4.0f * 2, 4.0f)) {
      Vec2 dir = odir.normalize();
      for (float d = 4.0f; ; d += 4.0f) {
        Position pos = d < distance ? Position(unit->posf() + dir * d) : Position(unit->posf() + odir);
        auto* tile = state->tilesInfo().tryGetTile(pos.x, pos.y);
        if (!tile || !tile->entirelyWalkable || tile->building) {
          return false;
        }
        if (d >= distance) break;
      }
      return true;
    };
    auto adjust = [&](Vec2 targetPos) {
      if (unit->flying()) return targetPos;
      Vec2 dir = targetPos - unit->posf();
      float d = dir.length();
      if (canMoveInDirection(dir, d)) return targetPos;
      for (float dg = 22.5f; dg <= 180.0f; dg += 22.5f) {
        if (dg > 90.0f) dg += 22.5f;
        Vec2 l = dir.rotateDegrees(dg);
        if (canMoveInDirection(l, d)) return unit->posf() + l;
        Vec2 r = dir.rotateDegrees(-dg);
        if (canMoveInDirection(r, d)) return unit->posf() + r;
      }
      return targetPos;
    };

    int nWorkers = 0;
    Unit* nearestThreat = utils::getBestScoreCopy(state->unitsInfo().enemyUnits(), [&](Unit* e) {
      if (!e->completed()) return kfInfty;
      if (!e->canAttack(unit) && e->type != buildtypes::Terran_Bunker) return kfInfty;
      if (e->type->isWorker && unit->unit.health >= e->unit.health) {
        ++nWorkers;
        if (nWorkers < 2) return kfInfty;
      }// else if (e->unit.health == e->unit.max_health) return kfInfty;
      return float(utils::distanceBB(unit, e) - e->rangeAgainst(unit));
    }, kfInfty);

    if (nearestThreat) {
      target = nullptr;
      float d = utils::distanceBB(unit, nearestThreat) - nearestThreat->rangeAgainst(unit);
      if (d <= 8.0f) {
        Vec2 offset = (nearestThreat->posf() - unit->posf()).normalize() * 12.0f;
        targetPos = unit->pos() - Position(offset);
      } else if (d <= 9.0f) {
        Vec2 offset = (nearestThreat->posf() - unit->posf()).normalize() * 12.0f;
        if (d < 7.0f) {
          offset.rotateDegrees(direction > 0 ? 100 : -100);
        } else {
          offset.rotateDegrees(direction > 0 ? 80 : -80);
        }

        targetPos = unit->pos() + Position(offset);
      }
    }

    if (utils::distance(unit, targetPos) <= 15.0f) {
      targetPos = Position(adjust(Vec2(targetPos)));
    }

    auto* tile = state->tilesInfo().tryGetTile(targetPos.x, targetPos.y);
    //if (!tile || !tile->entirelyWalkable || tile->building) {
    if (!tile || !tile->entirelyWalkable || tile->building) {
      direction = -direction;
    }

    //VLOG(0) << " scout heading to " << targetPos << " seen " << depots << " depots";

    location = utils::clampPositionToMap(state, targetPos);
    return true;
  }

 private:
  Position scoutPos;
  int direction = 1;
};

class ScoutingTask : public Task {
 public:
  ScoutingTask(int upcId, Unit* unit, Position location, ScoutingGoal goal)
      : Task(upcId, {unit}), location_(location), goal_(goal) {
    setStatus(TaskStatus::Ongoing);
  }

  void update(State* state) override {
    if (finished()) {
      return;
    }
    auto loc = location();
    auto& tgtArea = state->areaInfo().getArea(loc);
    targetVisited_ =
        (tgtArea.isEnemyBase || !tgtArea.isPossibleEnemyStartLocation ||
         foundBlockedChoke(state));
    if (!proxiedUnits().empty()) { // debug: keep track of reallocations
      auto unit = *proxiedUnits().begin();
      auto taskData = state->board()->taskDataWithUnit(unit);
      if (!taskData.task && !unit->dead) {
        VLOG(3) << "scout " << utils::unitString(unit)
                << " reassigned to no task";
      }
      if (taskData.task && taskData.task.get() != this) {
        VLOG(2) << "scout " << utils::unitString(unit)
                << " reassigned to module " << taskData.owner->name();
      }
      if (unit->dead) {
        VLOG(3) << "scout " << utils::unitString(unit) << " dead ";
      }
    }

    // Now check the failure case. If all my units died
    // then this task failed
    removeDeadOrReassignedUnits(state);
    if (units().empty()) {
      setStatus(TaskStatus::Failure);
      return;
    }
    auto unit = *units().begin();
    for (auto bldg : state->unitsInfo().visibleEnemyUnits()) {
      if (bldg->type->isBuilding &&
          utils::distance(bldg->x, bldg->y, unit->x, unit->y) <=
              unit->sightRange) {
        targetScouted_ = true;
        break;
      }
    }
    if (goal_ == ScoutingGoal::FindEnemyBase) {
      if (state->areaInfo().foundEnemyStartLocation()) {
        targetScouted_ = true;
      }
    }

    if (goal_ == ScoutingGoal::SneakyOverlord) {
      if (!sneakyOverlordImpl_) {
        sneakyOverlordImpl_ = std::make_unique<SneakyOverlordImpl>(state);
      }
      if (!sneakyOverlordImpl_->update(state, unit, location_)) {
        goal_ = ScoutingGoal::FindEnemyBase;
      }
    }

    if (goal_ == ScoutingGoal::WorkerScout) {
      if (!workerScoutImpl_) {
        workerScoutImpl_ = std::make_unique<WorkerScoutImpl>(state);
      }
      if (!workerScoutImpl_->update(state, unit, location_, target_)) {
        goal_ = ScoutingGoal::FindEnemyBase;
      }
    }
  }

  Unit* getUnit() {
    if (units().empty()) {
      LOG(DFATAL) << "getting the unit of a task without unit";
    }
    return *units().begin();
  }

  Position location() const {
    return location_;
  }

  Unit* target() const {
    return target_;
  }

  ScoutingGoal goal() const {
    return goal_;
  }

  bool satisfiesGoal() {
    switch (goal_) {
      case ScoutingGoal::ExploreEnemyBase:
        return targetScouted_;
      case ScoutingGoal::FindEnemyBase:
      case ScoutingGoal::FindEnemyExpand:
        return targetVisited_;
      case ScoutingGoal::SneakyOverlord:
      case ScoutingGoal::WorkerScout:
        return false;
      case ScoutingGoal::Automatic:
        LOG(ERROR) << "invalid goal specification "
                   << "in check that the goal is satisfied for scouting task "
                   << upcId();
        setStatus(TaskStatus::Failure);
        return true;
    }
    // cannot be reached -- avoid warning
    return true;
  }

  void resetLocation(Position const& pos) {
    if (pos == location_) {
      VLOG(0) << "reseting a scouting task with the same location";
    }
    location_ = pos;
    targetVisited_ = false;
    targetScouted_ = false;
  }

  bool foundBlockedChoke(State* state) {
    // the chokepoint is considered "blocked" if there is
    // a chokepoint of the target area for which some units attack us
    // what to do if attacked at some other location ? we need a bit of micro
    // for that
    if (units().empty()) {
      LOG(ERROR) << "updating a finished scouting task";
    }
    auto unit = getUnit();
    if (unit->beingAttackedByEnemies.empty()) {
      // if we're not attacked, it's not blocked
      return false;
    }

    // Heuristic value
    const int distanceFromChokePoint = 42;

    auto tgt = location();
    auto targetArea =
        state->map()->GetNearestArea(BWAPI::WalkPosition(tgt.x, tgt.y));
    for (auto other : unit->beingAttackedByEnemies) {
      for (auto& choke : targetArea->ChokePoints()) {
        auto cwtp = choke->Center();
        if (utils::distance(cwtp.x, cwtp.y, other->x, other->y) <
            distanceFromChokePoint) {
          return true;
        }
      }
    }
    return false;
  }

 protected:
  Position location_;
  Unit* target_ = nullptr;
  ScoutingGoal goal_;
  bool targetVisited_ = false;
  bool targetScouted_ = false;
  std::unique_ptr<SneakyOverlordImpl> sneakyOverlordImpl_;
  std::unique_ptr<WorkerScoutImpl> workerScoutImpl_;
};

} // namespace

void ScoutingModule::setScoutingGoal(ScoutingGoal goal) {
  scoutingGoal_ = goal;
}

ScoutingGoal ScoutingModule::goal(State* state) const {
  if (scoutingGoal_ != ScoutingGoal::Automatic) {
    return scoutingGoal_;
  } else if (!state->areaInfo().foundEnemyStartLocation()) {
    return ScoutingGoal::FindEnemyBase;
  } else if (
      state->areaInfo().foundEnemyStartLocation() &&
      !state->areaInfo().numEnemyBases()) {
    return ScoutingGoal::ExploreEnemyBase;
  } else {
    return ScoutingGoal::FindEnemyExpand;
  }
}

void ScoutingModule::step(State* state) {
  updateLocations(
      state,
      startingLocations_,
      state->areaInfo().candidateEnemyStartLocations());

  auto board = state->board();
  // clean up tasks
  for (auto task : board->tasksOfModule(this)) {
    if (!task->finished()) {
      task->update(state); // check re-assignment at this step
    }
    if (task->finished()) {
      continue;
    }
    auto stask = std::static_pointer_cast<ScoutingTask>(task);
    auto unit = stask->getUnit();
    if (stask->satisfiesGoal()) {
      VLOG(0) << " task satisfies goal";
      if (stask->goal() == ScoutingGoal::FindEnemyBase &&
          goal(state) == ScoutingGoal::FindEnemyBase) {
        auto pos = stask->location();
        auto tgt = nextScoutingLocation(state, unit, startingLocations_);
        if (tgt == pos) {
          VLOG(0) << "reseting scouting task with same location with "
                  << startingLocations_.size() << " candidate locations."
                  << " Do we know the enemy start location (check areaInfo)? "
                  << state->areaInfo().foundEnemyStartLocation()
                  << " current scouting goal " << (int)goal(state);
        }
        stask->resetLocation(tgt);
        if (postMoveUPC(state, stask->upcId(), unit, tgt)) {
          VLOG(3) << "starting location " << pos.x << ", " << pos.y
                  << " visited"
                  << " sending scout " << utils::unitString(unit)
                  << " to next location: " << tgt.x << ", " << tgt.y;
          startingLocations_[tgt] = state->currentFrame();
        } else {
          // what to do here
          VLOG(0) << "move to location " << tgt.x << ", " << tgt.y
                  << " for scout " << utils::unitString(unit)
                  << " filtered by the blackboard, canceling task "
                  << stask->upcId();
          stask->cancel(state);
        }
      } else { // no need to update on explore
        stask->setStatus(TaskStatus::Success);
        VLOG(3) << "scouting task " << stask->upcId()
                << " marked as succeedded";
      }
    } else {
      if (stask->target()) {
        postDeleteUPC(
            state,
            stask->upcId(),
            unit,
            stask->target());
      } else {
        postMoveUPC(
            state,
            stask->upcId(),
            unit,
            stask->location(),
            stask->goal() != ScoutingGoal::SneakyOverlord && stask->goal() != ScoutingGoal::WorkerScout);
      }
    }
  }

  // consume UPCs
  // all UPCs at a given time will be set using the current module's goal
  // since the UPC does not directly allow for goal specification
  for (auto upcPair : board->upcsWithSharpCommand(Command::Scout)) {
    if (upcPair.second->unit.empty()) {
      LOG(ERROR) << "main scouting UPC without unit specification -- consuming "
                    "but ignoring";
      board->consumeUPC(upcPair.first, this);
      continue;
    }
    Unit* unit = nullptr;
    switch (goal(state)) {
      case ScoutingGoal::FindEnemyBase:
        unit = findUnit(state, upcPair.second->unit, Position(-1, -1));
        break;
      case ScoutingGoal::FindEnemyExpand:
      case ScoutingGoal::ExploreEnemyBase:
        if (!(startingLocations_.size() == 1)) {
          LOG(ERROR)
              << "invalid scouting goal (ExploreEnemyBase/FindEnemyExpand) "
              << " because no enemy location";
          break;
        }
        unit = findUnit(
            state, upcPair.second->unit, startingLocations_.begin()->first);
        break;
      case ScoutingGoal::SneakyOverlord:
      case ScoutingGoal::WorkerScout:
      case ScoutingGoal::Automatic:
        LOG(ERROR) << "invalid goal";
    }
    if (!unit) {
      VLOG(3) << "could not find scout for upc " << upcPair.first
              << " -- skipping for now"
              << "number of units of required type: "
              << state->unitsInfo()
                     .myCompletedUnitsOfType(buildtypes::Zerg_Drone)
                     .size();
      continue;
    }
    board->consumeUPC(upcPair.first, this);
    auto taskGoal = goal(state);
    if (unit && unit->type == buildtypes::Zerg_Overlord) {
      taskGoal = ScoutingGoal::SneakyOverlord;
    }
    if (unit && unit->type->isWorker) {
      taskGoal = ScoutingGoal::WorkerScout;
    }
    auto tgt = nextScoutingLocation(state, unit, startingLocations_);
    if (postTask(state, upcPair.first, unit, tgt, taskGoal)) {
      startingLocations_[tgt] = state->currentFrame();
    }
  }

  // clean up finished tasks: send the scouts back to base
  auto myLoc = state->areaInfo().myStartLocation();
  for (auto task : state->board()->tasksOfModule(this)) {
    if (!task->finished()) {
      continue;
    }
    if (!task->proxiedUnits().empty()) {
      auto stask = std::static_pointer_cast<ScoutingTask>(task);
      auto unit = stask->getUnit();
      auto chkTask = board->taskWithUnit(unit);
      // scout not reallocated, send it back
      if (chkTask == task) {
        VLOG(3) << "sending scout " << utils::unitString(unit)
                << " back to base";
        postMoveUPC(state, stask->upcId(), unit, myLoc);
      }
    }
    // manual removal because the status might have changed during the step
    board->markTaskForRemoval(task->upcId());
  }
}

Unit* ScoutingModule::findUnit(
    State* state,
    std::unordered_map<Unit*, float> const& candidates,
    Position const& pos) {
  auto board = state->board();

  // Find some units to scout with, preferring faster units and flying units,
  // and ignoring workers if possible to let them keep on working
  auto map = state->map();
  auto mapSize = state->mapWidth() * state->mapHeight();
  auto unitScore = [&](Unit* u) -> double {
    auto it = candidates.find(u);
    if (it == candidates.end() || it->second <= 0) {
      return kdInfty;
    }
    auto tdata = board->taskDataWithUnit(u);
    if (tdata.owner == this) {
      // scout is free and previously assigned to us
      if (tdata.task->finished()) {
        int pLength = 0;
        if (pos.x > 0 && pos.y > 0) {
          map->GetPath(
              BWAPI::Position(BWAPI::WalkPosition(u->x, u->y)),
              BWAPI::Position(BWAPI::WalkPosition(pos.x, pos.y)),
              &pLength);
        }
        return -2 * mapSize + pLength;
      }
      // We're using this unit already
      return kdInfty;
    }
    if (!u->active()) {
      return -200;
    }
    if (tdata.task && tdata.task->status() == TaskStatus::Success) {
      // The unit just finished a task, it should be free now
      return -100;
    }

    // wait for an available worker if all are currently busy bringing resources
    if (!u->idle() && !u->unit.orders.empty()) {
      if (u->unit.orders.front().type == tc::BW::Order::MoveToMinerals) {
        return 15.0;
      } else if (u->unit.orders.front().type == tc::BW::Order::MoveToGas) {
        return 50.0;
      }
    }
    return 100;
  };

  auto& uinfo = state->unitsInfo();
  return utils::getBestScoreCopy(uinfo.myUnits(), unitScore, kdInfty);
}

bool ScoutingModule::postTask(
    State* state,
    UpcId baseUpcId,
    Unit* unit,
    Position loc,
    ScoutingGoal goal) {
  if (!postMoveUPC(state, baseUpcId, unit, loc)) {
    VLOG(1) << "task for unit " << utils::unitString(unit) << " not created";
    return false;
  }
  auto newTask = std::make_shared<ScoutingTask>(baseUpcId, unit, loc, goal);
  state->board()->postTask(newTask, this, false); // no auto-removal
  VLOG(1) << "new scouting task " << baseUpcId << " with unit "
          << utils::unitString(unit) << "for location " << loc.x << ", "
          << loc.y;
  return true;
}


bool ScoutingModule::postDeleteUPC(
    State* state,
    UpcId baseUpcId,
    Unit* unit,
    Unit* target) {
  auto upc = utils::makeSharpUPC(unit, target, Command::Delete);
  auto upcId = state->board()->postUPC(std::move(upc), baseUpcId, this);
  return true;
}

bool ScoutingModule::postMoveUPC(
    State* state,
    UpcId baseUpcId,
    Unit* unit,
    const Position& loc,
    bool useSafeMove) {
  auto tgt = useSafeMove ? movefilters::safeMoveTo(state, unit, loc) : loc;
  if (tgt.x <= 0 || tgt.y <= 0) {
    LOG(WARNING) << "scout stuck";
  }
  if (tgt.distanceTo(unit->getMovingTarget()) <= 4) {
    return true;
  }
  auto upc = std::make_shared<UPCTuple>();
  upc->unit[unit] = 1;
  upc->command[Command::Move] = 1;
  upc->position = tgt;
  auto upcId = state->board()->postUPC(std::move(upc), baseUpcId, this);
  if (upcId < 0) {
    VLOG(1) << "MoveUPC for unit " << utils::unitString(unit)
            << " filtered by blackboard";
    return false;
  }
  return true;
}

Position ScoutingModule::nextScoutingLocation(
    State* state,
    Unit* unit,
    std::unordered_map<Position, int> const& locations) {
  // next location is latest visited then closest
  Position curPos = unit->pos();
  float minDist = kfInfty;
  auto lastFrame = std::numeric_limits<int>::max();
  auto bestPos = Position(-1, -1);

  for (auto tgtPosPair : locations) {
    auto pos = tgtPosPair.first;
    auto frame = tgtPosPair.second;
    float d = 0.0f;
    // Send overlords to the nearest bases, and drone to base far away
    if (unit->flying()) {
      d = utils::distance(curPos, pos);
    } else {
      if (unit->type == buildtypes::Zerg_Drone) {
        state->areaInfo().walkPath(curPos, pos, &d);
        d = -d;
      } else {
        d = state->areaInfo().walkPathLength(curPos, pos);
      }
    }
    if (frame < lastFrame || (frame == lastFrame && d < minDist)) {
      minDist = d;
      bestPos = pos;
      lastFrame = frame;
    }
  }
  return bestPos;
}

void ScoutingModule::updateLocations(
    State* state,
    std::unordered_map<Position, int>& locations,
    std::vector<Position> const& candidates) {
  if (locations.empty()) { // intialization
    for (auto pos : candidates) {
      locations.emplace(pos, -1);
    }
  }
  if (locations.size() < 2) {
    return;
  }
  for (auto task : state->board()->tasksOfModule(this)) {
    auto stask = std::static_pointer_cast<ScoutingTask>(task);
    auto it = locations.find(stask->location());
    if (it != locations.end()) {
      it->second = state->currentFrame();
    }
  }
  // clean up startingLocations
  for (auto it = locations.begin(); it != locations.end();) {
    if (std::find(candidates.begin(), candidates.end(), it->first) ==
        candidates.end()) {
      it = locations.erase(it);
    } else {
      ++it;
    }
  }
}

} // namespace cherrypi
