/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "base.h"
#include "models/units_mixer.h"

namespace fairrsh {

/// Build for unit mix model
class ABBOumm : public ABBOBase {
  // RTTR_ENABLE(ABBOBase)
 public:
  using ABBOBase::ABBOBase;

  model::Action* currentAction = nullptr;

  bool saveLarva = false;

  virtual void preBuild2(State* state, Module* module) override {
    if (FLAGS_umm_path == "") {
      LOG(ERROR) << "Using UMM without -umm_path set";
      return;
    }
    using namespace buildtypes;
    using namespace autobuild;

    currentAction = state->board()->unitsMixer->forward(state, nullptr);

    // hack!
    auto input = state->board()->unitsMixer->model_->inputs_.back();
    state->board()->unitsMixer->model_->inputs_.pop_back();
    auto output = state->board()->unitsMixer->model_->outputs_.back();
    state->board()->unitsMixer->model_->outputs_.pop_back();

    state->board()->unitsMixerCurrentInput = std::move(input);
    state->board()->unitsMixerCurrentOutput = std::move(output);

    saveLarva = false;
    if (state->unitsInfo().myCompletedUnitsOfType(Zerg_Spire).empty()) {
      for (Unit* u : state->unitsInfo().myUnitsOfType(Zerg_Spire)) {
        if (u->remainingBuildTrainTime < 900) {
          saveLarva = true;
        }
      }
    }
  }

  virtual void buildStep2(autobuild::BuildState& st) override {
    using namespace buildtypes;
    using namespace autobuild;

    if (currentAction) {
      if (currentAction == &model::actionDrone) {
        build(Zerg_Drone);
      } else if (currentAction == &model::actionZergling) {
        build(Zerg_Zergling);
      } else if (currentAction == &model::actionHydralisk) {
        build(Zerg_Hydralisk);
      } else if (currentAction == &model::actionMutalisk) {
        build(Zerg_Mutalisk);
      } else if (currentAction == &model::actionMacroTech) {
        if (countProduction(st, Zerg_Drone) < 2) {
          buildN(Zerg_Drone, 66);
        }
        if (st.workers >= 30) {
          buildN(Zerg_Hydralisk_Den, 1);
          buildN(Zerg_Spire, 1);
        }
        if (st.workers >= 24 && bases < 3 && !st.isExpanding && canExpand) {
          st.isExpanding = true;
          build(Zerg_Hatchery, nextBase);
        }
        buildN(Zerg_Zergling, 2);
      } else if (currentAction == &model::actionSunkenDef) {
        if (countProduction(st, Zerg_Drone) < 2) {
          buildN(Zerg_Drone, 66);
        }
        if (myCompletedHatchCount >= 2 && (st.workers >= 10 || enemyArmySupplyInOurBase)) {
          if (hasOrInProduction(st, Zerg_Creep_Colony)) {
            build(Zerg_Sunken_Colony);
          } else {
            int droneCount = countPlusProduction(st, Zerg_Drone);
            int desiredSunkenCount = 1;
            if (droneCount >= 16) {
              desiredSunkenCount = 2;
            }
            if (droneCount >= 18) {
              desiredSunkenCount = 3;
            }
            if (droneCount >= 24) {
              desiredSunkenCount = 6;
            }
            if (countPlusProduction(st, Zerg_Sunken_Colony) < desiredSunkenCount &&
                !isInProduction(st, Zerg_Creep_Colony)) {
              build(Zerg_Creep_Colony, nextStaticDefencePos);
            }
          }
        }
      } else if (currentAction == &model::actionSafeMutaLing) {
        if (st.workers < 30 && armySupply < 20.0) {
          st.autoBuildRefineries = false;
          autoBuildHatcheries = false;
          autoExpand = false;
        }

        if (saveLarva) {
          buildN(Zerg_Overlord, 5);
          if (armySupply < enemySupplyInOurBase + 1.0) {
            build(Zerg_Zergling);
          }
          return;
        }

        build(Zerg_Zergling);
        buildN(Zerg_Drone, 26);
        build(Zerg_Mutalisk);
        buildN(Zerg_Extractor, 2);
        if (enemyArmySupply > 4.0) {
          buildN(Zerg_Zergling, 6);
        }
        upgrade(Metabolic_Boost);
        buildN(Zerg_Spire, 1);
        buildN(Zerg_Drone, 16);
        buildN(Zerg_Lair, 1);
        buildN(Zerg_Drone, 12);
        buildN(Zerg_Zergling, 2);

        if (myCompletedHatchCount >= 2 &&
            (st.workers >= 12 || enemyArmySupplyInOurBase) &&
            !has(st, Zerg_Spire)) {
          if (hasOrInProduction(st, Zerg_Creep_Colony)) {
            build(Zerg_Sunken_Colony);
          } else {
            int droneCount = countPlusProduction(st, Zerg_Drone);
            int desiredSunkenCount = 0;
            if (droneCount >= 16)
              desiredSunkenCount = 1;
            if (enemyArmySupply - enemyVultureCount >= 4.0)
              desiredSunkenCount = 2;
            if (enemyArmySupply - enemyVultureCount >= 10.0)
              desiredSunkenCount = 4;
            if (countPlusProduction(st, Zerg_Sunken_Colony) < desiredSunkenCount &&
                !isInProduction(st, Zerg_Creep_Colony)) {
              build(Zerg_Creep_Colony, nextStaticDefencePos);
            }
          }
        }
        buildN(Zerg_Extractor, 1);
        buildN(Zerg_Spawning_Pool, 1);
        if (countPlusProduction(st, Zerg_Hatchery) == 1) {
          build(Zerg_Hatchery, nextBase);
          buildN(Zerg_Drone, 12);
        }

        if (st.workers >= 13 && armySupply < enemySupplyInOurBase + 1.0) {
          build(Zerg_Zergling);
        }
      }
    }
  }
};

// RTTR_REGISTRATION {
//   rttr::registration::class_<ABBOumm>("ABBOumm")(
//       metadata("type", rttr::type::get<ABBOumm>()))
//       .constructor<UpcId>();
// }
}
