/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "meanlingrush.h"

#include "combatsim.h"
#include "movefilters.h"
#include "player.h"
#include "state.h"
#include "task.h"
#include "utils.h"

#include <bwem/map.h>

#include <deque>
#include <glog/logging.h>
#include <memory>
#include <vector>

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<MeanLingRushModule>("MeanLingRushModule")(
      metadata("type", rttr::type::get<MeanLingRushModule>()))
      .constructor();
}

namespace {
class MeanLingRushTask : public Task {
 public:
  MeanLingRushTask(int upcId) : Task(upcId) {}

  virtual void update(State* state) override {
    removeDeadOrReassignedUnits(state);
  }
};
}

void MeanLingRushModule::process(State* state, int srcUpcId) {

  Position targetPos = lastTarget;

  const Tile* targetTile =
      &state->tilesInfo().getTile(targetPos.x, targetPos.y);
  if (targetTile->visible &&
      (!targetTile->building || !targetTile->building->isEnemy)) {
    targetPos = Position();
    for (auto tilePos : state->map()->StartingLocations()) {
      Position pos(
          tilePos.x * tc::BW::XYWalktilesPerBuildtile,
          tilePos.y * tc::BW::XYWalktilesPerBuildtile);
      auto& tile = state->tilesInfo().getTile(pos.x, pos.y);
      if (tile.building && tile.building->isEnemy) {
        targetPos = pos;
        break;
      }
    }
    if (targetPos == Position()) {
      for (auto tilePos : state->map()->StartingLocations()) {
        Position pos(
            tilePos.x * tc::BW::XYWalktilesPerBuildtile,
            tilePos.y * tc::BW::XYWalktilesPerBuildtile);
        auto& tile = state->tilesInfo().getTile(pos.x, pos.y);
        if (tile.lastSeen == 0) {
          targetPos = pos;
          break;
        }
      }
    }
  }

  for (Unit* u : state->unitsInfo().enemyUnits()) {
    if (!u->gone && u->type->isBuilding && !u->flying()) {
      targetPos = Position(u);
      break;
    }
  }

  if (true) {
    float bestDistance = std::numeric_limits<float>::infinity();
    for (auto tilePos : state->map()->StartingLocations()) {
      Position pos(
          tilePos.x * tc::BW::XYWalktilesPerBuildtile,
          tilePos.y * tc::BW::XYWalktilesPerBuildtile);
      auto& tile = state->tilesInfo().getTile(pos.x, pos.y);
      if (tile.lastSeen == 0) {
        for (Unit* u : state->unitsInfo().enemyUnits()) {
          if (!u->gone && u->type->isBuilding && !u->flying()) {
            float distance = utils::distance(Position(u), pos);
            if (distance < bestDistance) {
              bestDistance = distance;
              targetPos = pos;
            }
          }
        }
      }
    }
  }

  for (Unit* u : state->unitsInfo().enemyUnits()) {
    if (!u->gone && !u->flying() && u->type->isResourceDepot) {
      targetPos = Position(u);
    }
  }

  lastTarget = targetPos;

  std::vector<Unit*> units;

  std::shared_ptr<MeanLingRushTask> task;
  for (auto& t : state->board()->tasksOfModule(this)) {
    task = std::static_pointer_cast<MeanLingRushTask>(t);
    break;
  }
  if (!task) {
    task = std::make_shared<MeanLingRushTask>(srcUpcId);
    state->board()->postTask(task, this, true);
  }

  for (Unit* u : state->unitsInfo().myUnits()) {
    if (!u->type->isWorker && u->type != buildtypes::Zerg_Overlord &&
        !u->type->isBuilding && !u->type->isNonUsable) {
      units.push_back(u);
    }
  }
  if (true) {
    std::unordered_set<Unit*> unitSet;
    for (Unit* u : units) {
      unitSet.insert(u);
    }
    task->setUnits(std::move(unitSet));
    state->board()->updateTasksByUnit(task.get());
  }

  auto move = [&](Unit* u, Position target) {
    if (noOrderUntil[u] > state->currentFrame()) {
      return;
    }
    if (u->burrowed()) {
      if (lastTargetInRange[u] - state->currentFrame() > 30) {
        state->board()->postCommand(tc::Client::Command(
            tc::BW::Command::CommandUnit,
            u->id,
            tc::BW::UnitCommandType::Unburrow));
        noOrderUntil[u] = state->currentFrame() + 15;
      }
      return;
    }
    if (!u->unit.orders.empty()) {
      auto& o = u->unit.orders.front();
      if (o.type == tc::BW::Order::Move && o.targetX == target.x &&
          o.targetY == target.y) {
        return;
      }
    }
    if (VLOG_IS_ON(2)) {
      utils::drawLine(state, Position(u), target, tc::BW::Color::Green);
    }
    state->board()->postUPC(
        utils::makeSharpUPC(u, target, Command::Move), srcUpcId, this);
  };
  auto attack = [&](Unit* u, Unit* target) {
    if (noOrderUntil[u] > state->currentFrame()) {
      return;
    }
    auto enemyCanMoveTo = [&](Vec2 newPos) {
      int n = (int)(utils::distance(Position(target), Position(newPos)) / 4.0f);
      Vec2 step = (newPos - Vec2(target)).normalize() * 4;
      Vec2 pos = Vec2(target);
      for (int i = 0; i != n; ++i) {
        if (utils::distance(pos, Position(u)) < 8) {
          return false;
        }
        const Tile* tile =
            state->tilesInfo().tryGetTile((int)pos.x, (int)pos.y);
        if (!tile || !tile->entirelyWalkable) {
          return false;
        }
        pos += step;
      }
      return true;
    };

    if (u->type == buildtypes::Zerg_Lurker) {
      if (target->inRangeOf(u)) {
        if (!u->burrowed()) {
          state->board()->postCommand(tc::Client::Command(
              tc::BW::Command::CommandUnit,
              u->id,
              tc::BW::UnitCommandType::Burrow));
          noOrderUntil[u] = state->currentFrame() + 15;
          return;
        }
        return;
      } else {
        if (u->burrowed()) {
          if (lastTargetInRange[u] - state->currentFrame() > 30) {
            state->board()->postCommand(tc::Client::Command(
                tc::BW::Command::CommandUnit,
                u->id,
                tc::BW::UnitCommandType::Unburrow));
            noOrderUntil[u] = state->currentFrame() + 15;
          }
          return;
        }
      }
    }

    if (target->visible && (target->type->isWorker || target->type == buildtypes::Terran_Vulture)) {
      if (target->topSpeed >= u->topSpeed * 0.66f && target->moving() &&
          !target->inRangeOf(u, 3)) {
        const int latency = 3;
        float weaponRange =
            target->flying() ? u->unit.groundRange : u->unit.airRange;
        auto targetVelocity =
            Vec2(target->unit.velocityX, target->unit.velocityY);
        auto targetNextPos = Vec2(target) + targetVelocity * latency;
        auto myNextPos =
            Vec2(u) + Vec2(u->unit.velocityX, u->unit.velocityY) * latency;
        float distance = std::min(
            utils::distanceBB(u, myNextPos, target, targetNextPos),
            utils::distanceBB(u, Position(u), target, targetNextPos));
        if (distance > weaponRange) {
          float distance = utils::distance(u->x, u->y, target->x, target->y);
          if (utils::distance(u->x, u->y, targetNextPos.x, targetNextPos.y) >
              distance) {
            auto np = Vec2(u) + targetVelocity.normalize() * 16;
            if (enemyCanMoveTo(np)) {
              move(u, Position(np));
              return;
            }
          } else {
            auto np = Vec2(target) +
                targetVelocity.normalize() *
                    std::min(std::max(distance - 4.0f, 4.0f), 20.0f);
            if (enemyCanMoveTo(np)) {
              move(u, Position(np));
              return;
            } else {
              auto np = Vec2(target) +
                  targetVelocity.normalize() *
                      std::min(std::max(distance - 4.0f, 4.0f), 12.0f);
              if (enemyCanMoveTo(np)) {
                move(u, Position(np));
                return;
              }
            }
          }
        }
      }
    }
    if (!u->unit.orders.empty()) {
      auto& o = u->unit.orders.front();
      if (o.type == tc::BW::Order::AttackUnit && o.targetId == target->id) {
        return;
      }
    }
    if (VLOG_IS_ON(2)) {
      utils::drawLine(state, Position(u), Position(target), tc::BW::Color::Red);
    }
    auto cmd = tc::Client::Command(
        tc::BW::Command::CommandUnit,
        u->id,
        tc::BW::UnitCommandType::Attack_Unit,
        target->id);
    state->board()->postCommand(cmd);
    noOrderUntil[u] = state->currentFrame() + 6;
  };

  std::vector<Unit*> targetUnits;
  std::vector<Unit*> targetUnitsNoSingleTargets;
  std::vector<Unit*> targetBuildings;

  for (Unit* u : state->unitsInfo().enemyUnits()) {
    if (!u->gone && !u->flying() && !u->type->isNonUsable) {
      if (!u->type->isBuilding || u->type->hasGroundWeapon ||
          u->type == buildtypes::Terran_Bunker) {
        targetUnits.push_back(u);
        for (Unit* u2 : state->unitsInfo().enemyUnits()) {
          if (u != u2 && !u2->gone && !u2->type->isBuilding && !u2->flying() &&
              utils::distance(u, u2) <= 4 * 12) {
            targetUnitsNoSingleTargets.push_back(u);
            break;
          }
        }
      } else {
        targetBuildings.push_back(u);
      }
    }
  }

  struct MapNode {
    Unit* target = nullptr;
    // Unit* target2 = nullptr;
    float distance = 0.0f;
  };

  auto findTargets = [&](
      std::vector<MapNode>& map,
      const std::vector<Unit*>& targets,
      float maxDistance,
      bool ignoreBuildings) {
    map.clear();
    map.resize(TilesInfo::tilesWidth * TilesInfo::tilesHeight);

    auto& tilesInfo = state->tilesInfo();
    auto* tilesData = tilesInfo.tiles.data();

    const int mapWidth = state->mapWidth();
    const int mapHeight = state->mapHeight();

    struct OpenNode {
      const Tile* tile;
      Unit* source;
      float distance;
    };

    std::deque<OpenNode> open;
    auto addTile = [&](Position pos, Unit* source) {
      auto* tile = tilesInfo.tryGetTile(pos.x, pos.y);
      if (tile) {
        open.push_back({tile, source, 0.0f});
        // map.at(tile - tilesData) = {source, nullptr, 0.0f};
        map.at(tile - tilesData) = {source, 0.0f};
      }
    };
    for (Unit* u : targets) {
      if (u->type->isBuilding && !u->lifted()) {
        for (int x = u->buildX - tc::BW::XYWalktilesPerBuildtile; x !=
             u->buildX +
                 (u->type->tileWidth + 1) * tc::BW::XYWalktilesPerBuildtile;
             x += tc::BW::XYWalktilesPerBuildtile) {
          addTile(Position(x, u->buildY - tc::BW::XYWalktilesPerBuildtile), u);
          addTile(
              Position(
                  x,
                  u->buildY +
                      (u->type->tileHeight + 1) *
                          tc::BW::XYWalktilesPerBuildtile),
              u);
        }
        for (int y = u->buildY; y !=
             u->buildY + u->type->tileHeight * tc::BW::XYWalktilesPerBuildtile;
             y += tc::BW::XYWalktilesPerBuildtile) {
          addTile(Position(u->buildX - tc::BW::XYWalktilesPerBuildtile, y), u);
          addTile(
              Position(
                  u->buildX +
                      (u->type->tileWidth + 1) *
                          tc::BW::XYWalktilesPerBuildtile,
                  y),
              u);
        }
      } else {
        addTile(Position(u), u);
      }
    }
    while (!open.empty()) {
      OpenNode curNode = open.front();
      open.pop_front();

      auto add = [&](const Tile* ntile) {
        if (!ntile->entirelyWalkable || (!ignoreBuildings && ntile->building)) {
          return;
        }

        float sourceDistance = utils::distance(
            ntile->x, ntile->y, curNode.source->x, curNode.source->y);
        if (sourceDistance >= maxDistance) {
          return;
        }

        auto& v = map[ntile - tilesData];
        if (v.target) {
          return;
          //          if (v.target2) {
          //            return;
          //          } else {
          //            v.target2 = curNode.source;
          //            open.push_back({ntile, curNode.source, curNode.distance
          //            + utils::distance(curNode.tile->x, curNode.tile->y,
          //            ntile->x, ntile->y)});
          //          }
        }
        v.target = curNode.source;
        v.distance = curNode.distance;
        open.push_back(
            {ntile,
             curNode.source,
             curNode.distance +
                 utils::distance(
                     curNode.tile->x, curNode.tile->y, ntile->x, ntile->y)});
      };

      const Tile* tile = curNode.tile;

      if (tile->x > 0) {
        add(tile - 1);
        if (tile->y > 0) {
          add(tile - 1 - TilesInfo::tilesWidth);
          add(tile - TilesInfo::tilesWidth);
        }
        if (tile->y < mapHeight - tc::BW::XYWalktilesPerBuildtile) {
          add(tile - 1 + TilesInfo::tilesHeight);
          add(tile + TilesInfo::tilesHeight);
        }
      } else {
        if (tile->y > 0) {
          add(tile - TilesInfo::tilesWidth);
        }
        if (tile->y < mapHeight - tc::BW::XYWalktilesPerBuildtile) {
          add(tile + TilesInfo::tilesHeight);
        }
      }
      if (tile->x < mapWidth - tc::BW::XYWalktilesPerBuildtile) {
        add(tile + 1);
        if (tile->y > 0) {
          add(tile + 1 - TilesInfo::tilesWidth);
        }
        if (tile->y < mapHeight - tc::BW::XYWalktilesPerBuildtile) {
          add(tile + 1 + TilesInfo::tilesHeight);
        }
      }
    }

  };

  std::vector<uint8_t> visited(TilesInfo::tilesWidth * TilesInfo::tilesHeight);
  uint8_t visitedN = 0;

  auto findNearbyTile = [&](
      Position source, float maxDistance, auto&& callback) {
    const Tile* sourceTile = state->tilesInfo().tryGetTile(source.x, source.y);
    if (!sourceTile) {
      return sourceTile;
    }
    uint8_t visitedValue = ++visitedN;

    auto& tilesInfo = state->tilesInfo();
    auto* tilesData = tilesInfo.tiles.data();

    const int mapWidth = state->mapWidth();
    const int mapHeight = state->mapHeight();

    std::deque<const Tile*> open;
    open.push_back(sourceTile);
    visited[sourceTile - tilesData] = visitedValue;
    while (!open.empty()) {
      const Tile* tile = open.front();
      open.pop_front();

      if (tile->entirelyWalkable && !tile->building && callback(tile)) {
        return tile;
      }

      auto add = [&](const Tile* ntile) {
        if (!ntile->entirelyWalkable || ntile->building) {
          return;
        }

        float sourceDistance =
            utils::distance(ntile->x, ntile->y, source.x, source.y);
        if (sourceDistance >= maxDistance) {
          return;
        }

        auto& v = visited[ntile - tilesData];
        if (v == visitedValue) {
          return;
        }
        v = visitedValue;
        open.push_back(ntile);
      };

      if (tile->x > 0) {
        add(tile - 1);
        if (tile->y > 0) {
          add(tile - 1 - TilesInfo::tilesWidth);
          add(tile - TilesInfo::tilesWidth);
        }
        if (tile->y < mapHeight - tc::BW::XYWalktilesPerBuildtile) {
          add(tile - 1 + TilesInfo::tilesHeight);
          add(tile + TilesInfo::tilesHeight);
        }
      } else {
        if (tile->y > 0) {
          add(tile - TilesInfo::tilesWidth);
        }
        if (tile->y < mapHeight - tc::BW::XYWalktilesPerBuildtile) {
          add(tile + TilesInfo::tilesHeight);
        }
      }
      if (tile->x < mapWidth - tc::BW::XYWalktilesPerBuildtile) {
        add(tile + 1);
        if (tile->y > 0) {
          add(tile + 1 - TilesInfo::tilesWidth);
        }
        if (tile->y < mapHeight - tc::BW::XYWalktilesPerBuildtile) {
          add(tile + 1 + TilesInfo::tilesHeight);
        }
      }
    }
    return (const Tile*)nullptr;
  };

  std::vector<MapNode> nearestTarget;
  std::vector<MapNode> nearestTargetNoSingleTargets;
  std::vector<MapNode> nearestTargetIgnoreBuildings;
  std::vector<MapNode> nearestBuilding;

  findTargets(nearestTarget, targetUnits, 4 * 20, false);
  findTargets(
      nearestTargetNoSingleTargets, targetUnitsNoSingleTargets, 4 * 20, false);
  findTargets(nearestTargetIgnoreBuildings, targetUnits, 4 * 20, true);
  findTargets(nearestBuilding, targetBuildings, 4 * 20, false);

  auto* tilesData = state->tilesInfo().tiles.data();

  std::vector<std::pair<float, Unit*>> sortedUnits;

  for (Unit* u : units) {
    float nearestDistance = std::numeric_limits<float>::infinity();
    for (Unit* e : targetUnits) {
      if (e->flying() ? u->type->hasAirWeapon : u->type->hasGroundWeapon) {
        float d = utils::distance(u->x, u->y, e->x, e->y);
        if (d < nearestDistance) {
          nearestDistance = d;
        }
      }
    }
    sortedUnits.emplace_back(nearestDistance, u);
  }
  std::sort(sortedUnits.begin(), sortedUnits.end());

  std::unordered_map<Unit*, int> targetCount;

  std::vector<uint8_t> spotTaken(
      TilesInfo::tilesWidth * TilesInfo::tilesHeight);

  std::vector<std::pair<Unit*, Unit*>> attacks;

  for (auto& v : sortedUnits) {
    Unit* u = v.second;

    Unit* inRangeTarget = utils::getBestScoreCopy(
        targetUnits,
        [&](Unit* e) {
          if (!e->inRangeOf(u, 0)) {
            return std::numeric_limits<double>::infinity();
          }
          return (double)(e->unit.health + e->unit.shield);
        },
        std::numeric_limits<double>::infinity());
    if (inRangeTarget) {
      lastTargetInRange[u] = state->currentFrame();
      attack(u, inRangeTarget);
      continue;
    }

    const Tile* tile = state->tilesInfo().tryGetTile(u->x, u->y);
    if (tile) {
      Unit* target = nearestTarget[tile - tilesData].target;
      if (u->type == buildtypes::Zerg_Lurker && target) {
        if (!u->inRangeOf(target) && !target->inRangeOf(u)) {
          target = nullptr;
        }
      }
      if (target) {
        int& c = targetCount[target];
        if (c >= 6) {
          target = nearestTargetNoSingleTargets[tile - tilesData].target;
        } else {
          ++c;
        }
      }
      if (target) {
        attacks.emplace_back(u, target);
        continue;
      } else {
        target = nearestTargetIgnoreBuildings[tile - tilesData].target;
        if (target) {
          auto& b = nearestBuilding[tile - tilesData];
          if (b.target && u->type != buildtypes::Zerg_Lurker) {
            int nSpots = 0;
            const Tile* bestTile = nullptr;
            int bestN = std::numeric_limits<int>::max();
            findNearbyTile(Position(u), 4 * 20, [&](const Tile* ntile) {
              if (nearestBuilding[ntile - tilesData].distance >= 4 * 7 &&
                  nearestBuilding[ntile - tilesData].distance < 4 * 9) {
                ++nSpots;
                int n = spotTaken[ntile - tilesData];
                if (n < bestN) {
                  bestN = n;
                  bestTile = ntile;
                }
              }
              return nSpots >= 16;
            });
            if (bestTile) {
              move(u, Position(bestTile));
              ++spotTaken[bestTile - tilesData];
              continue;
            }
          }
        } else {
          target = nearestBuilding[tile - tilesData].target;
          if (target) {
            attacks.emplace_back(u, target);
            continue;
          }
        }
      }
    }

    if (targetPos != Position()) {
      move(u, targetPos);
    }
  }

  for (auto& v : attacks) {
    Unit* u = v.first;
    Unit* target = v.second;
    if (target->visible) {
      attack(u, target);
    } else {
      move(u, Position(target));
    }
  }
}

void MeanLingRushModule::step(State* state) {
  auto board = state->board();

  if (!board->hasKey("MeanLingRushEnabled") ||
      !board->get<bool>("MeanLingRushEnabled")) {
    for (auto& t : board->tasksOfModule(this)) {
      t->setUnits({});
      state->board()->updateTasksByUnit(t.get());
    }
    return;
  }

  auto srcUpcId = findSourceUpc(state);
  if (srcUpcId < 0) {
    VLOG(4) << "No suitable source UPC";
    return;
  }

  board->consumeUPC(srcUpcId, this);

  process(state, srcUpcId);
}

UpcId MeanLingRushModule::findSourceUpc(State* state) {
  // Find 'Delete' UPC with unspecified (empty) units
  for (auto& upcs : state->board()->upcsWithCommand(Command::Delete, 0.5)) {
    if (upcs.second->unit.empty()) {
      return upcs.first;
    }
  }
  return -1;
}

} // namespace fairrsh
