#include "base.h"
#include "fairrsh.h"
#include "utils.h"

#include "modules/tactics.h"

#include <bwem/map.h>

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<ABBOBase>("ABBOBase")(
      metadata("type", rttr::type::get<ABBOBase>()));
}

void ABBOBase::findNaturalDefencePos(State* state) {
  if (lastFindNaturalDefencePosEnemyPos == enemyBasePos ||
      naturalPos == Position()) {
    return;
  }

  auto path = state->map()->GetPath(
      BWAPI::Position(BWAPI::WalkPosition(naturalPos.x, naturalPos.y)),
      BWAPI::Position(BWAPI::WalkPosition(enemyBasePos.x, enemyBasePos.y)));

  if (path.size() <= 1) {
    naturalDefencePos = naturalPos;
  } else {
    auto pos = BWAPI::WalkPosition(path[1]->Center());
    naturalDefencePos = Position(pos.x, pos.y);
  }
}

Position ABBOBase::getStaticDefencePos(State* state, const BuildType* type) {
  // ugly hack. temporarily unset all reserved tiles
  auto& tilesInfo = state->tilesInfo();
  auto copy = tilesInfo.tiles;
  size_t stride = TilesInfo::tilesWidth - tilesInfo.mapTileWidth();
  Tile* ptr = tilesInfo.tiles.data();
  for (unsigned tileY = 0; tileY != tilesInfo.mapTileHeight();
       ++tileY, ptr += stride) {
    for (unsigned tileX = 0; tileX != tilesInfo.mapTileWidth();
         ++tileX, ++ptr) {
      ptr->reservedAsUnbuildable = false;
    }
  }
  Position r = builderhelpers::findBuildLocation(
      state,
      {naturalPos},
      type,
      {},
      [&](State* state, const BuildType* type, const Tile* tile) {
        if (utils::distance(tile->x, tile->y, naturalPos.x, naturalPos.y) >
            4 * 7)
          return std::numeric_limits<float>::infinity();
        return utils::distance(
            tile->x, tile->y, naturalDefencePos.x, naturalDefencePos.y);
      });
  state->tilesInfo().tiles = copy;
  if (utils::distance(r.x, r.y, naturalPos.x, naturalPos.y) > 4 * 7)
    return Position();
  return r;
}

void ABBOBase::buildSunkens(autobuild::BuildState& st, int n) {
  using namespace buildtypes;
  if (hasOrInProduction(st, Zerg_Creep_Colony)) {
    build(Zerg_Sunken_Colony);
  } else {
    if (myCompletedHatchCount >= 2 && nextStaticDefencePos != Position()) {
      if (countPlusProduction(st, Zerg_Sunken_Colony) < n &&
          !isInProduction(st, Zerg_Creep_Colony)) {
        build(Zerg_Creep_Colony, nextStaticDefencePos);
      }
    }
  }
}

void ABBOBase::calculateArmySupply(const autobuild::BuildState& st) {
  armySupply = 0.0;
  airArmySupply = 0.0;
  groundArmySupply = 0.0;
  for (auto& v : st.units) {
    const BuildType* t = v.first;
    if (!t->isWorker) {
      armySupply += t->supplyRequired * v.second.size();
      if (t->isFlyer) {
        airArmySupply += t->supplyRequired * v.second.size();
      } else {
        groundArmySupply += t->supplyRequired * v.second.size();
      }
    }
  }
  for (auto& v : st.production) {
    const BuildType* t = v.second;
    if (!t->isWorker) {
      armySupply += t->supplyRequired;
      if (t->isFlyer) {
        airArmySupply += t->supplyRequired;
      } else {
        groundArmySupply += t->supplyRequired;
      }
    }
  }
}

Position ABBOBase::findSunkenPos(State* state, const BuildType* type, bool mainBaseOnly, bool coverMineralsOnly) {
  using namespace buildtypes;

  std::unordered_map<Unit*, int> coverage;

  // ugly hack. temporarily unset all reserved tiles
  auto& tilesInfo = state->tilesInfo();
  auto copy = tilesInfo.tiles;
  size_t stride = TilesInfo::tilesWidth - tilesInfo.mapTileWidth();
  Tile* ptr = tilesInfo.tiles.data();
  for (unsigned tileY = 0; tileY != tilesInfo.mapTileHeight();
       ++tileY, ptr += stride) {
    for (unsigned tileX = 0; tileX != tilesInfo.mapTileWidth();
         ++tileX, ++ptr) {
      ptr->reservedAsUnbuildable = false;
    }
  }

  float coverageRange = 4 * 5.5;

  std::vector<Position> basePositions;

  std::vector<Unit*> existingStaticDefence;
  for (Unit* u : state->unitsInfo().myUnitsOfType(type)) {
    existingStaticDefence.push_back(u);
  }
  for (Unit* u : state->unitsInfo().myUnitsOfType(Zerg_Creep_Colony)) {
    existingStaticDefence.push_back(u);
  }

  if (!coverMineralsOnly) {
    for (Unit* building : state->unitsInfo().myBuildings()) {
      for (Unit* u : existingStaticDefence) {
        if (utils::distance(u, building) <= coverageRange) {
          ++coverage[building];
        }
      }
    }
  }

  for (int i = 0; i != state->areaInfo().numExpands() + 1; ++i) {
    auto resources = state->areaInfo().myExpandResources(i);

    Unit* depot = state->areaInfo().myExpand(i);
    if (!depot) {
      continue;
    }

    for (Unit* u : existingStaticDefence) {
      for (Unit* r : resources) {
        if (utils::distance(u, r) <= coverageRange) {
          ++coverage[r];
        }
      }
    }

    basePositions.push_back(Position(depot));

    if (mainBaseOnly) {
      break;
    }

  }

  Position r = builderhelpers::findBuildLocation(
      state,
      basePositions,
      Zerg_Creep_Colony,
      {},
      [&](State* state, const BuildType* type, const Tile* tile) {
        Position pos = Position(tile) + Position(4, 4);
        float r = 0.0f;
        for (auto& v : coverage) {
          if (utils::distance(pos, Position(v.first)) <= coverageRange) {
            r -= 1.25f - (v.second ? v.second : -12.0f);
          }
        }
        for (Unit* u : existingStaticDefence) {
          if (utils::distance(pos, Position(u)) < 12) {
            r += 24.0f;
          }
        }
        return r;
      });

  state->tilesInfo().tiles = copy;

  return r;
}

void ABBOBase::preBuild(State* state, Module* module) {
  using namespace buildtypes;
  using namespace autobuild;

  calculateArmySupply(autobuild::getMyState(state));

  currentFrame = state->currentFrame();

  if (!state->unitsInfo().myResourceDepots().empty()) {
    homePosition.x = state->unitsInfo().myResourceDepots().front()->x;
    homePosition.y = state->unitsInfo().myResourceDepots().front()->y;
  } else if (!state->unitsInfo().myBuildings().empty()) {
    homePosition.x = state->unitsInfo().myBuildings().front()->x;
    homePosition.y = state->unitsInfo().myBuildings().front()->y;
  } else if (!state->unitsInfo().myUnits().empty()) {
    homePosition.x = state->unitsInfo().myUnits().front()->x;
    homePosition.y = state->unitsInfo().myUnits().front()->y;
  }

  int mineralFields = 0;
  bases = 0;
  for (auto& area : state->map()->Areas()) {
    for (auto& base : area.Bases()) {
      if (!base.BlockingMinerals().empty()) {
        continue;
      }
      Position pos(
          base.Location().x * tc::BW::XYWalktilesPerBuildtile,
          base.Location().y * tc::BW::XYWalktilesPerBuildtile);
      const Tile& tile = state->tilesInfo().getTile(pos.x, pos.y);
      Unit* building = tile.building;
      if (building) {
        if (building->isMine) {
          ++bases;
          mineralFields += base.Minerals().size();
        }
      }
    }
  }

  std::vector<std::tuple<Position, double>> allBases;
  for (auto& area : state->map()->Areas()) {
    for (auto& base : area.Bases()) {
      if (!base.BlockingMinerals().empty()) {
        continue;
      }
      Position pos(
          base.Location().x * tc::BW::XYWalktilesPerBuildtile,
          base.Location().y * tc::BW::XYWalktilesPerBuildtile);
      if (!builderhelpers::canBuildAt(state, Zerg_Hatchery, pos, true)) {
        continue;
      }

      double score = -1.0 * std::min((int)base.Minerals().size(), 8) +
          -8.0 * std::min((int)base.Geysers().size(), 1);

      if (bases < 2 && base.Geysers().empty()) {
        score += 1000.0;
      }

      int length = 0;
      state->map()->GetPath(
          BWAPI::Position(BWAPI::WalkPosition(homePosition.x, homePosition.y)),
          BWAPI::Position(BWAPI::WalkPosition(pos.x, pos.y)),
          &length);
      double distance =
          (length > 0 ? length : std::numeric_limits<double>::infinity());
      score += distance / (4 * 4);

      if (enemyBasePos != Position() && bases >= 2) {
        int length = 0;
        state->map()->GetPath(
            BWAPI::Position(
                BWAPI::WalkPosition(enemyBasePos.x, enemyBasePos.y)),
            BWAPI::Position(BWAPI::WalkPosition(pos.x, pos.y)),
            &length);
        double distance = (length > 0 ? length : 0);
        score -= distance / (4 * 4);

        score -= utils::distance(pos, Position(state->mapWidth() / 2, state->mapHeight() / 2)) / (4 * 4);
      }

      allBases.emplace_back(pos, score);
    }
  }

  auto* bestBase = utils::getBestScorePointer(
      allBases, [&](auto& v) { return std::get<1>(v); });
  if (bestBase) {
    canExpand = true;
    nextBase = std::get<0>(*bestBase);
  } else {
    canExpand = false;
    nextBase = Position();
  }

  if (VLOG_IS_ON(2)) {
    utils::drawCircle(state, enemyBasePos, 70);
    utils::drawCircle(state, nextBase, 55);
  }

  nextStaticDefencePos = getStaticDefencePos(state, Zerg_Creep_Colony);

  if (VLOG_IS_ON(2)) {
    utils::drawLine(state, naturalPos, naturalDefencePos);
    utils::drawLine(state, naturalDefencePos, nextStaticDefencePos);
    utils::drawCircle(state, nextStaticDefencePos, 32);
  }

  if (!hasFoundEnemyBase) {
    for (auto tilePos : state->map()->StartingLocations()) {
      Position pos(
          tilePos.x * tc::BW::XYWalktilesPerBuildtile,
          tilePos.y * tc::BW::XYWalktilesPerBuildtile);
      auto& tile = state->tilesInfo().getTile(pos.x, pos.y);
      if (tile.building && tile.building->isEnemy) {
        enemyBasePos = pos;
        hasFoundEnemyBase = true;
        break;
      } else if (tile.lastSeen == 0) {
        enemyBasePos = pos;
      }
    }
    if (!hasFoundEnemyBase) {
      for (int i = 0; i != 3; ++i) {
        bool found = false;
        for (Unit* u : state->unitsInfo().enemyUnits()) {
          if (i == 0 ? u->type->isBuilding : i == 1
                      ? (u->type->hasGroundWeapon || u->type->hasAirWeapon) &&
                          !u->type->isWorker
                      : true) {
            Position nearestPos;
            float nearestDistance = std::numeric_limits<float>::infinity();
            for (auto tilePos : state->map()->StartingLocations()) {
              Position pos(
                  tilePos.x * tc::BW::XYWalktilesPerBuildtile,
                  tilePos.y * tc::BW::XYWalktilesPerBuildtile);
              auto& tile = state->tilesInfo().getTile(pos.x, pos.y);
              if (!tile.building) {
                float d = utils::distance(u->x, u->y, pos.x, pos.y);
                if (d < nearestDistance) {
                  nearestDistance = d;
                  nearestPos = pos;
                }
              }
            }
            if (nearestPos != Position()) {
              enemyBasePos = nearestPos;
              found = true;
              break;
            }
          }
        }
        if (found) {
          break;
        }
      }
    }
  }

  if (naturalPos == Position() && nextBase != Position()) {
    naturalPos = nextBase;
    findNaturalDefencePos(state);
  }

  if (state->currentFrame() - lastUpdateInBaseArea >= 90) {
    lastUpdateInBaseArea = state->currentFrame();
    updateInBaseArea(state, inBaseArea);
  }

  auto forAllTiles = [&](TilesInfo& tt, auto&& f) {
    size_t stride = TilesInfo::tilesWidth - tt.mapTileWidth();
    Tile* ptr = tt.tiles.data();
    for (unsigned tileY = 0; tileY != tt.mapTileHeight();
         ++tileY, ptr += stride) {
      for (unsigned tileX = 0; tileX != tt.mapTileWidth(); ++tileX, ++ptr) {
        f(*ptr);
      }
    }
  };

  forAllTiles(state->tilesInfo(), [&](Tile& t) {
    size_t index = &t - state->tilesInfo().tiles.data();
    if (inBaseArea[index]) {
      bool draw = false;
      if (t.x == 0 || !inBaseArea[index - 1]) {
        draw = true;
      }
      if (t.y == 0 || !inBaseArea[index - TilesInfo::tilesWidth]) {
        draw = true;
      }
      if (t.x >= state->mapWidth() - 4 || !inBaseArea[index + 1]) {
        draw = true;
      }
      if (t.y >= state->mapHeight() - 4 ||
          !inBaseArea[index + TilesInfo::tilesWidth]) {
        draw = true;
      }
      if (draw) {
        // drawCircles.emplace_back(Position(t.x + 2, t.y + 2), 12);
      }
    }
  });

  shouldExpand = canExpand &&
      bases <
          std::max(
              ((int)state->unitsInfo().myResourceDepots().size() + 1) / 2 + 1,
              2);
  forceExpand =
      canExpand && state->unitsInfo().myWorkers().size() >= mineralFields * 1.8;
  if (forceExpand) {
    shouldExpand = true;
  }

  auto& tilesInfo = state->tilesInfo();
  auto* tilesData = tilesInfo.tiles.data();

  enemyZealotCount = 0;
  enemyVultureCount = 0;
  enemyGoliathCount = 0;
  enemyTankCount = 0;
  enemyMissileTurretCount = 0;
  enemyCorsairCount = 0;
  enemyWraithCount = 0;
  enemyStaticDefenceCount = 0;
  enemyReaverCount = 0;
  enemyBarracksCount = 0;
  enemyMutaliskCount = 0;
  enemyMarineCount = 0;
  enemyFactoryCount = 0;

  enemySupplyInOurBase = 0.0;
  enemyArmySupplyInOurBase = 0.0;
  enemyArmySupply = 0.0;
  enemyGroundArmySupply = 0.0;
  enemyAirArmySupply = 0.0;
  enemyAntiAirArmySupply = 0.0;
  enemyAttackingArmySupply = 0.0;
  enemyAttackingWorkerCount = 0;
  enemyLargeArmySupply = 0.0;
  enemySmallArmySupply = 0.0;
  for (Unit* u : state->unitsInfo().enemyUnits()) {
    if (u->type == Protoss_Zealot) {
      ++enemyZealotCount;
    } else if (u->type == Terran_Vulture) {
      ++enemyVultureCount;
    } else if (u->type == Terran_Goliath) {
      ++enemyGoliathCount;
    } else if (
        u->type == Terran_Siege_Tank_Tank_Mode ||
        u->type == Terran_Siege_Tank_Siege_Mode) {
      ++enemyTankCount;
    } else if (u->type == Terran_Missile_Turret) {
      ++enemyMissileTurretCount;
    } else if (u->type == Protoss_Corsair) {
      ++enemyCorsairCount;
    } else if (u->type == Terran_Wraith) {
      ++enemyWraithCount;
    } else if (u->type == Protoss_Reaver) {
      ++enemyReaverCount;
    } else if (u->type == Terran_Barracks) {
      ++enemyBarracksCount;
    } else if (u->type == Zerg_Mutalisk) {
      ++enemyMutaliskCount;
    } else if (u->type == Terran_Marine) {
      ++enemyMarineCount;
    } else if (u->type == Terran_Factory) {
      ++enemyFactoryCount;
    }
    if (u->type->isBuilding &&
        (u->type == Terran_Bunker || u->type->hasGroundWeapon ||
         u->type->hasAirWeapon)) {
      ++enemyStaticDefenceCount;
    }
    if (u->type->isResourceDepot && !enemyHasExpanded) {
      if (utils::distance(u->x, u->y, enemyBasePos.x, enemyBasePos.y) > 12) {
        enemyHasExpanded = true;
      }
    }
    if (!u->type->isWorker) {
      enemyArmySupply += u->type->supplyRequired;
      if (u->flying()) {
        enemyAirArmySupply += u->type->supplyRequired;
      } else {
        enemyGroundArmySupply += u->type->supplyRequired;
      }
      if (u->type->hasAirWeapon || u->type == Protoss_Carrier) {
        enemyAntiAirArmySupply += u->type->supplyRequired;
      }
      if (utils::distance(u->x, u->y, homePosition.x, homePosition.y) <
          utils::distance(u->x, u->y, enemyBasePos.x, enemyBasePos.y) * 1.25) {
        enemyAttackingArmySupply += u->type->supplyRequired;
      }
      if (u->type->size == 1) {
        enemySmallArmySupply += u->type->supplyRequired;
      }
      if (u->type->size == 3) {
        enemyLargeArmySupply += u->type->supplyRequired;
      }
    } else {
      if (utils::distance(u->x, u->y, homePosition.x, homePosition.y) <
          utils::distance(u->x, u->y, enemyBasePos.x, enemyBasePos.y) * 1.25) {
        enemyAttackingWorkerCount += u->type->supplyRequired;
      }
    }
    const Tile* tile = tilesInfo.tryGetTile(u->x, u->y);
    if (tile) {
      size_t index = tile - tilesData;
      if (inBaseArea[index]) {
        enemySupplyInOurBase += u->type->supplyRequired;
        if (!u->type->isWorker) {
          enemyArmySupplyInOurBase += u->type->supplyRequired;
        }
      }
    }
  }

  if (enemyFactoryCount == 0) {
    if (enemyVultureCount + enemyGoliathCount + enemyTankCount) {
      enemyFactoryCount = 1;
    }
  }

  if (state->currentFrame() < 3 * 60 * 24 && enemyBarracksCount >= 2) {
    enemyIsRushing = true;
  }
  if (state->currentFrame() < 4 * 60 * 24 && enemyArmySupply > 4) {
    enemyIsRushing = true;
  }
  if (state->currentFrame() > 6 * 60 * 24) {
    enemyIsRushing = false;
  }

  myZerglingCount = state->unitsInfo().myUnitsOfType(Zerg_Zergling).size();
  myCompletedHatchCount =
      state->unitsInfo().myCompletedUnitsOfType(Zerg_Hatchery).size() +
      state->unitsInfo().myUnitsOfType(Zerg_Lair).size() +
      state->unitsInfo().myUnitsOfType(Zerg_Hive).size();

  isLosingAnOverlord = false;
  for (Unit* u : state->unitsInfo().myCompletedUnitsOfType(Zerg_Overlord)) {
    if (u->unit.health <= u->type->maxHp / 2) {
      isLosingAnOverlord = true;
    }
  }

  auto enemyRaceBB = tc::BW::Race::_from_integral_nothrow(
      state->board()->get<int>(Blackboard::kEnemyRaceKey));
  if (enemyRaceBB) {
    enemyRace = *enemyRaceBB;
  }

  if (currentFrame < 15 * 60 * 5) {
    if (enemyAttackingWorkerCount >= 3) {
      state->board()->post("kMinScoutFrame", 0);
    }
  }

  preBuild2(state, module);

  weArePlanningExpansion = false;
}

void ABBOBase::postBuild(State* state) {
  postBuild2(state);
}

void ABBOBase::buildStep(autobuild::BuildState& st) {
  using namespace buildtypes;
  using namespace autobuild;

  calculateArmySupply(st);

  if (st.frame - currentFrame <= 15 * 30 && st.isExpanding) {
    weArePlanningExpansion = true;
  }

  if (autoBuildHatcheries) {
    int larva = 0;
    for (auto& v : st.units) {
      const BuildType* t = v.first;
      if (t == Zerg_Hatchery || t == Zerg_Lair || t == Zerg_Hive) {
        for (auto& u : v.second) {
          larva += larvaCount(st, u);
        }
      }
    }
    if (larva == 0) {
      if (autoExpandWithMacroHatcheries && !st.isExpanding) {
        build(Zerg_Hatchery, nextBase);
      } else {
        build(Zerg_Hatchery);
      }
    }
  }

  buildStep2(st);

  if (st.frame < 15 * 60 * 5) {
    if (enemyAttackingWorkerCount >= 3) {
      if (!hasOrInProduction(st, Zerg_Spawning_Pool)) {
        buildN(Zerg_Spawning_Pool, 1);
        buildN(Zerg_Drone, 8);
      } else {
        buildN(Zerg_Zergling, std::max(enemyAttackingWorkerCount, 4));
      }
    }
  }

  if (st.frame < 15 * 60 * 4 && enemyAttackingWorkerCount >= 2 && st.workers < 13) {
    buildN(Zerg_Zergling, enemyAttackingWorkerCount);
  }

  if (st.workers >= 50 && ((armySupply > enemyArmySupply && armySupply >= 40.0) || armySupply >= 70.0)) {
    if (countPlusProduction(st, Zerg_Mutalisk) >= 10) {
      upgrade(Zerg_Flyer_Attacks_3);
      upgrade(Zerg_Flyer_Carapace_3);
    }
    if (countPlusProduction(st, Zerg_Hydralisk) + countPlusProduction(st, Zerg_Lurker) * 1 >= 15) {
      upgrade(Zerg_Missile_Attacks_3);
    }
    if (countPlusProduction(st, Zerg_Zergling) >= 20) {
      upgrade(Zerg_Melee_Attacks_3);
    }
    if (has(st, Zerg_Hive) || countPlusProduction(st, Zerg_Zergling) >= 40) {
      upgrade(Zerg_Carapace_3);
      upgrade(Adrenal_Glands);
    }
    upgrade(Pneumatized_Carapace) && upgrade(Antennae);
  }

  if (buildExtraOverlordsIfLosingThem && isLosingAnOverlord) {
    int n = enemyCorsairCount + enemyWraithCount ? 2 : 1;
    if (countProduction(st, Zerg_Overlord) < n &&
        st.usedSupply[tc::BW::Race::Zerg] >=
            st.maxSupply[tc::BW::Race::Zerg] - 8 * n) {
      build(Zerg_Overlord);
    }
  }

  if (autoExpand) {
    if (forceExpand && !st.isExpanding) {
      build(Zerg_Hatchery, nextBase);
    }
  }
}
}
