/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "model.h"

namespace fairrsh {
namespace model {

using namespace buildtypes;
using namespace autobuild;

// static const float max_float = std::log(std::numeric_limits<float>::max());

Featurizer::Featurizer() {
  buildState_.frame = -1;
  indices_.myUnitTypes.push_back(Zerg_Zergling->unit);
  indices_.myUnitTypes.push_back(Zerg_Hydralisk->unit);
  indices_.myUnitTypes.push_back(Zerg_Ultralisk->unit);
  indices_.myUnitTypes.push_back(Zerg_Drone->unit);
  indices_.myUnitTypes.push_back(Zerg_Overlord->unit);
  indices_.myUnitTypes.push_back(Zerg_Mutalisk->unit);
  indices_.myUnitTypes.push_back(Zerg_Guardian->unit);
  indices_.myUnitTypes.push_back(Zerg_Queen->unit);
  indices_.myUnitTypes.push_back(Zerg_Defiler->unit);
  indices_.myUnitTypes.push_back(Zerg_Scourge->unit);
  indices_.myUnitTypes.push_back(Zerg_Infested_Terran->unit);
  indices_.myUnitTypes.push_back(Zerg_Cocoon->unit);
  indices_.myUnitTypes.push_back(Zerg_Devourer->unit);
  indices_.myUnitTypes.push_back(Zerg_Lurker_Egg->unit);
  indices_.myUnitTypes.push_back(Zerg_Lurker->unit);
  indices_.myUnitTypes.push_back(Zerg_Infested_Command_Center->unit);
  indices_.myUnitTypes.push_back(Zerg_Hatchery->unit);
  indices_.myUnitTypes.push_back(Zerg_Lair->unit);
  indices_.myUnitTypes.push_back(Zerg_Hive->unit);
  indices_.myUnitTypes.push_back(Zerg_Nydus_Canal->unit);
  indices_.myUnitTypes.push_back(Zerg_Hydralisk_Den->unit);
  indices_.myUnitTypes.push_back(Zerg_Defiler_Mound->unit);
  indices_.myUnitTypes.push_back(Zerg_Greater_Spire->unit);
  indices_.myUnitTypes.push_back(Zerg_Queens_Nest->unit);
  indices_.myUnitTypes.push_back(Zerg_Evolution_Chamber->unit);
  indices_.myUnitTypes.push_back(Zerg_Ultralisk_Cavern->unit);
  indices_.myUnitTypes.push_back(Zerg_Spire->unit);
  indices_.myUnitTypes.push_back(Zerg_Spawning_Pool->unit);
  indices_.myUnitTypes.push_back(Zerg_Creep_Colony->unit);
  indices_.myUnitTypes.push_back(Zerg_Spore_Colony->unit);
  indices_.myUnitTypes.push_back(Zerg_Sunken_Colony->unit);
  indices_.myUnitTypes.push_back(Zerg_Extractor->unit);
  /* left as an exercise for the reader
  Zerg_Carapace_1;
  Zerg_Carapace_2;
  Zerg_Carapace_3;
  Zerg_Flyer_Carapace_1;
  Zerg_Flyer_Carapace_2;
  Zerg_Flyer_Carapace_3;
  Zerg_Melee_Attacks_1;
  Zerg_Melee_Attacks_2;
  Zerg_Melee_Attacks_3;
  Zerg_Missile_Attacks_1;
  Zerg_Missile_Attacks_2;
  Zerg_Missile_Attacks_3;
  Zerg_Flyer_Attacks_1;
  Zerg_Flyer_Attacks_2;
  Zerg_Flyer_Attacks_3;
  Ventral_Sacs;
  Antennae;
  Pneumatized_Carapace;
  Metabolic_Boost;
  Adrenal_Glands;
  Muscular_Augments;
  Grooved_Spines;
  Gamete_Meiosis;
  Metasynaptic_Node;
  Chitinous_Plating;
  Anabolic_Synthesis;
  Burrowing;
  Infestation;
  Spawn_Broodlings;
  Dark_Swarm;
  Plague;
  Consume;
  Ensnare;
  Parasite;
  Lurker_Aspect;
  */
  indices_.enemyUnitTypes.push_back(Zerg_Zergling->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Hydralisk->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Ultralisk->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Drone->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Overlord->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Mutalisk->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Guardian->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Queen->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Defiler->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Scourge->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Infested_Terran->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Cocoon->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Devourer->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Lurker_Egg->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Lurker->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Infested_Command_Center->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Hatchery->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Lair->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Hive->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Nydus_Canal->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Hydralisk_Den->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Defiler_Mound->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Greater_Spire->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Queens_Nest->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Evolution_Chamber->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Ultralisk_Cavern->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Spire->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Spawning_Pool->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Creep_Colony->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Spore_Colony->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Sunken_Colony->unit);
  indices_.enemyUnitTypes.push_back(Zerg_Extractor->unit);
  indices_.enemyUnitTypes.push_back(Terran_Marine->unit);
  indices_.enemyUnitTypes.push_back(Terran_Ghost->unit);
  indices_.enemyUnitTypes.push_back(Terran_Vulture->unit);
  indices_.enemyUnitTypes.push_back(Terran_Goliath->unit);
  indices_.enemyUnitTypes.push_back(Terran_Siege_Tank_Tank_Mode->unit);
  indices_.enemyUnitTypes.push_back(Terran_SCV->unit);
  indices_.enemyUnitTypes.push_back(Terran_Wraith->unit);
  indices_.enemyUnitTypes.push_back(Terran_Science_Vessel->unit);
  indices_.enemyUnitTypes.push_back(Terran_Dropship->unit);
  indices_.enemyUnitTypes.push_back(Terran_Battlecruiser->unit);
  indices_.enemyUnitTypes.push_back(Terran_Siege_Tank_Siege_Mode->unit);
  indices_.enemyUnitTypes.push_back(Terran_Firebat->unit);
  indices_.enemyUnitTypes.push_back(Terran_Medic->unit);
  indices_.enemyUnitTypes.push_back(Terran_Valkyrie->unit);
  indices_.enemyUnitTypes.push_back(Terran_Command_Center->unit);
  indices_.enemyUnitTypes.push_back(Terran_Comsat_Station->unit);
  indices_.enemyUnitTypes.push_back(Terran_Nuclear_Silo->unit);
  indices_.enemyUnitTypes.push_back(Terran_Supply_Depot->unit);
  indices_.enemyUnitTypes.push_back(Terran_Refinery->unit);
  indices_.enemyUnitTypes.push_back(Terran_Barracks->unit);
  indices_.enemyUnitTypes.push_back(Terran_Academy->unit);
  indices_.enemyUnitTypes.push_back(Terran_Factory->unit);
  indices_.enemyUnitTypes.push_back(Terran_Starport->unit);
  indices_.enemyUnitTypes.push_back(Terran_Control_Tower->unit);
  indices_.enemyUnitTypes.push_back(Terran_Science_Facility->unit);
  indices_.enemyUnitTypes.push_back(Terran_Covert_Ops->unit);
  indices_.enemyUnitTypes.push_back(Terran_Physics_Lab->unit);
  indices_.enemyUnitTypes.push_back(Terran_Machine_Shop->unit);
  indices_.enemyUnitTypes.push_back(Terran_Engineering_Bay->unit);
  indices_.enemyUnitTypes.push_back(Terran_Armory->unit);
  indices_.enemyUnitTypes.push_back(Terran_Missile_Turret->unit);
  indices_.enemyUnitTypes.push_back(Terran_Bunker->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Corsair->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Dark_Templar->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Dark_Archon->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Probe->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Zealot->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Dragoon->unit);
  indices_.enemyUnitTypes.push_back(Protoss_High_Templar->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Archon->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Shuttle->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Scout->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Arbiter->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Carrier->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Reaver->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Observer->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Nexus->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Robotics_Facility->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Pylon->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Assimilator->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Observatory->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Gateway->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Photon_Cannon->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Citadel_of_Adun->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Cybernetics_Core->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Templar_Archives->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Forge->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Stargate->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Fleet_Beacon->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Arbiter_Tribunal->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Robotics_Support_Bay->unit);
  indices_.enemyUnitTypes.push_back(Protoss_Shield_Battery->unit);
  for (size_t i = 0; i < indices_.myUnitTypes.size(); ++i)
    indices_.myUnitTypesToIndex[indices_.myUnitTypes[i]] = i;
  size_t offset = indices_.myUnitTypes.size();
  for (size_t i = 0; i < indices_.myUnitTypes.size(); ++i)
    indices_.myInProdUnitTypesToIndex[indices_.myUnitTypes[i]] = i + offset;
  offset += indices_.myUnitTypes.size();
  for (size_t i = 0; i < indices_.enemyUnitTypes.size(); ++i)
    indices_.enemyUnitTypesToIndex[indices_.enemyUnitTypes[i]] = i + offset;
  offset += indices_.enemyUnitTypes.size();
  indices_.oreIndex = offset;
  indices_.gasIndex = offset + 1;
  indices_.usedSupplyIndex = offset + 2;
  indices_.maxSupplyIndex = offset + 3;
  indices_.availSupplyIndex = offset + 4;
  // Below indices for if/when we use this within AutoBuild
  // TODO, ask Vegard
  indices_.oreBstIndex = offset + 5;
  indices_.gasBstIndex = offset + 6;
  indices_.usedSupplyBstIndex = offset + 7;
  indices_.maxSupplyBstIndex = offset + 8;
  indices_.inprodSupplyBstIndex = offset + 9;
  indices_.availSupplyBstIndex = offset + 10;
  indices_.BIAS = offset + 11;
  output_size_ = offset + 12;
}

const std::vector<std::pair<size_t, float>>& Featurizer::forward(
    State* s,
    BuildState* bst) {
  // Do not forget to add a constant feature for bias
  std::unordered_map<int, int> myUnitTypesNumbers;
  std::unordered_map<int, int> myInProdUnitTypesNumbers;
  for (auto& u : s->unitsInfo().myUnits()) {
    if (u->completed()) {
      myUnitTypesNumbers[u->type->unit] += 1;
    } else {
      const BuildType* type = u->type;
      int n = 1;
      if (type == buildtypes::Zerg_Egg && u->constructingType) {
        type = u->constructingType;
        if (type->isTwoUnitsInOneEgg) {
          n = 2;
        }
      }
      myInProdUnitTypesNumbers[type->unit] += n;
    }
  }

  features_.clear();
  for (auto& ut : myUnitTypesNumbers) {
    features_.push_back(
        std::make_pair(indices_.myUnitTypesToIndex[ut.first], ut.second));
  }
  for (auto& ut : myInProdUnitTypesNumbers) {
    features_.push_back(
        std::make_pair(indices_.myInProdUnitTypesToIndex[ut.first], ut.second));
  }
  auto enemyUnitTypes = s->unitsInfo().inferredEnemyUnitTypes();
  for (auto& ut : enemyUnitTypes) {
    features_.push_back(
        std::make_pair(
            indices_.enemyUnitTypesToIndex[ut.first->unit], ut.second));
  }
  features_.push_back(std::make_pair(indices_.oreIndex, s->resources().ore));
  features_.push_back(std::make_pair(indices_.gasIndex, s->resources().gas));
  features_.push_back(
      std::make_pair(indices_.usedSupplyIndex, s->resources().used_psi / 2));
  features_.push_back(
      std::make_pair(indices_.maxSupplyIndex, s->resources().total_psi / 2));
  features_.push_back(
      std::make_pair(
          indices_.availSupplyIndex,
          (s->resources().total_psi - s->resources().used_psi) / 2));
  BuildState* b;
  if (bst == nullptr && buildState_.frame != s->currentFrame())
    buildState_ = getMyState(s);
  if (bst == nullptr)
    b = &buildState_;
  else
    b = bst;
  // ^ could otherwise set all BuildState based numbers to 0
  features_.push_back(std::make_pair(indices_.oreBstIndex, s->resources().ore));
  features_.push_back(std::make_pair(indices_.gasBstIndex, s->resources().gas));
  features_.push_back(
      std::make_pair(
          indices_.usedSupplyBstIndex, b->usedSupply[tc::BW::Race::Zerg]));
  features_.push_back(
      std::make_pair(
          indices_.maxSupplyBstIndex, b->maxSupply[tc::BW::Race::Zerg]));
  features_.push_back(
      std::make_pair(
          indices_.inprodSupplyBstIndex, b->inprodSupply[tc::BW::Race::Zerg]));
  features_.push_back(
      std::make_pair(
          indices_.availSupplyBstIndex,
          b->maxSupply[tc::BW::Race::Zerg] -
              b->usedSupply[tc::BW::Race::Zerg]));
  features_.push_back(std::make_pair(indices_.BIAS, 1.0));
  return features_;
}

const std::vector<std::pair<size_t, float>>& FeaturizerSqrt::forward(
    State* s,
    BuildState* bst) {
  Featurizer::forward(s, bst);
  for (auto& f : features_)
    f.second = sqrt(f.second);
  return features_;
}

Model::Model(Featurizer&& f, size_t output_size)
    : output_size_(output_size), featurizer_(std::move(f)) {
  input_size_ = featurizer_.output_size_;
  weights_.resize(output_size_);
  for (auto& w : weights_) {
    w.resize(input_size_);
    // std::fill(w.begin(), w.end(), 0); // TODO different init
  }
}

std::vector<float>& Model::forward(State* s, BuildState* bst) {
  // forward_(computation_.forward(featurizer_.forward(s)));
  const std::vector<std::pair<size_t, float>>& x = featurizer_.forward(s, bst);
  auto p = std::vector<float>(output_size_);
  for (size_t output_class = 0; output_class < output_size_; ++output_class) {
    const auto& w = weights_[output_class];
    for (auto& e : x) {
      p[output_class] += e.second * w[e.first];
    }
  }
  inputs_.push_back(x);
  outputs_.push_back(std::move(p));
  return outputs_[outputs_.size() - 1];
}

void Model::addToWeights(const std::vector<float>& a, float lr) {
  // pray for -O3 our lord that can avoid us a call to AVX/BLAS.
  for (size_t i = 0; i < weights_.size(); ++i) {
    auto& w = weights_[i];
    for (size_t j = 0; j < w.size(); ++j) {
      w[j] += lr * a[i * input_size_ + j];
    }
  }
}

void Model::zeroWeights() {
  for (size_t i = 0; i < weights_.size(); ++i) {
    std::fill(weights_[i].begin(), weights_[i].end(), 0);
  }
}

} // namespace model
} // namespace fairrsh
