/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "units_mixer.h"
#include <fstream>
#include <glog/logging.h>

DEFINE_string(umm_path, "", "path to look at for objects to train UMM on");
DEFINE_double(lr, 0.1, "learning rate");
// DEFINE_double(temperature, 1.0, "temperature of the Gibbs policy");
DEFINE_double(noise_scale, 0.01, "noise scaling factor in ZO");
DEFINE_double(l2_reg, 0.00001, "L2 regularization coefficient");
DEFINE_double(pow_anneal_noise, 1.0, "multiply noise by this each time");
DEFINE_double(lin_anneal_noise, 0.0, "divide noise by batch# times this");
DEFINE_double(adagrad_epsilon, -1, "set to > 0 value to use adagrad");
DEFINE_double(sqrt_lr, -1, "lr in the forrm of lr*sqrt(sqrt_lr/(sqrt_lr+batch_n_))");
DEFINE_bool(umm_train, false, "set true to train");

namespace fairrsh {
namespace model {

RTTR_REGISTRATION {
  rttr::registration::class_<UnitsMixer>("UnitsMixer")(
      metadata("type", rttr::type::get<UnitsMixer>()));

  rttr::registration::class_<ZOUnitsMixer>("ZOUnitsMixer")(
      metadata("type", rttr::type::get<ZOUnitsMixer>()))
      .constructor();

  rttr::registration::class_<PGUnitsMixer>("PGUnitsMixer")(
      metadata("type", rttr::type::get<PGUnitsMixer>()))
      .constructor();
  rttr::registration::class_<ZOSampleMixer>("ZOSampleMixer")(
      metadata("type", rttr::type::get<ZOSampleMixer>()))
      .constructor();

}

Action actionDrone{"Drone"};
Action actionZergling{"Zergling"};
Action actionHydralisk{"Hydralisk"};
Action actionMutalisk{"Mutalisk"};
Action actionMacroTech{"MacroTech"};
Action actionSunkenDef{"SunkenDef"};
Action actionSafeMutaLing{"SafeMutaLing"};

namespace {

void softmax_inplace(std::vector<float>& x) {
  // feel free to replace / call lib
  float m = *std::max_element(x.begin(), x.end());
  float s = 0;
  for (auto& e : x) {
    e = std::exp(e - m);
    s += e;
  }
  for (auto& e : x)
    e /= s;
}

size_t sampleIndex(
    const std::vector<float>& p,
    std::uniform_real_distribution<float> uni_0_1,
    std::mt19937 rdm) {
  float s = uni_0_1(rdm);
  float t = 0;
  for (size_t i = 0; i < p.size(); ++i) {
    if (s <= t + p[i])
      return i;
    t += p[i];
  }
  return p.size() - 1;
}

} // namespace

UnitsMixerPtr UnitsMixer::load(std::string path) {
  if (path.empty()) {
    path = FLAGS_umm_path + "/read/vanilla_weights";
  }

  std::ifstream ifs(path, std::ios::binary);
  if (!ifs.good()) {
    throw std::runtime_error("Cannot read from " + path);
  }
  cereal::BinaryInputArchive ia(ifs);
  UnitsMixerPtr mixer;
  ia(mixer);

  mixer->checkModelAndActions();
  return mixer;
}

void UnitsMixer::save(UnitsMixerPtr mixer, std::string path) {
  if (path.empty()) {
    auto now =
        std::chrono::system_clock::to_time_t(std::chrono::system_clock::now());
    path = FLAGS_umm_path + "/write/" + std::to_string(now) + "_" +
        std::to_string(std::random_device()());
  }

  std::ofstream ofs(path, std::ios::binary);
  if (!ofs.good()) {
    LOG(WARNING) << "Cannot write to " + path;
    return;
  }
  cereal::BinaryOutputArchive oa(ofs);
  oa(mixer);
}

UnitsMixerPtr UnitsMixer::make(std::string name) {
  auto lowerName = utils::stringToLower(name);

  for (auto typ : rttr::type::get<UnitsMixer>().get_derived_classes()) {
    auto lowerType = utils::stringToLower(typ.get_name());
    if (lowerType == lowerName || lowerType == lowerName + "unitsmixer") {
      auto var = typ.get_constructor().invoke();
      if (!var.is_valid()) {
        LOG(DFATAL) << "Failed to instantiate UnitsMixer: " << name;
      }
      return var.get_value<std::shared_ptr<UnitsMixer>>();
    }
  }

  LOG(DFATAL) << "No such UnitsMixer type: " << name;
  return nullptr;
}

UnitsMixer::UnitsMixer() : batch_n_(0.) {
  std::random_device rd;
  rdm_ = std::mt19937(rd()); // or init only once?

  // TODO change actions here!!!
  actions_.push_back(&actionDrone);
  actions_.push_back(&actionZergling);
  /*actions_.push_back(&actionHydralisk);  // TODO put back actions?
  actions_.push_back(&actionMutalisk);
  actions_.push_back(&actionMacroTech);
  actions_.push_back(&actionSunkenDef);*/
  model_ = std::make_unique<Model>(FeaturizerSqrt(), actions_.size());

  if (FLAGS_adagrad_epsilon > 0) {
    auto sz = model_->dim();
    sqGrad_.resize(sz);
    std::fill(sqGrad_.begin(), sqGrad_.end(), FLAGS_adagrad_epsilon);
    tmpGrad_.resize(sz);
  }
  nbActions_.resize(model_->output_size_);
}

UnitsMixer::~UnitsMixer() {}

void UnitsMixer::checkModelAndActions() {
  if (model_ == nullptr) {
    LOG(DFATAL) << "UnitsMixer's model was not recovered or created properly";
  }
  if (actions_.size() == 0) {
    LOG(DFATAL) << "UnitsMixer's actions were not recovered or created properly";
  }
}

void UnitsMixer::updateWeights(std::vector<float> const& update) {
  float lr = FLAGS_lr;
  if (FLAGS_sqrt_lr > 0) {
    lr = lr * std::sqrt(FLAGS_sqrt_lr / (FLAGS_sqrt_lr + batch_n_));
  }
  if (FLAGS_adagrad_epsilon <= 0) {
    model_->addToWeights(update, lr);
    return;
  }
  utils::inplace_flat_vector_addcmul(sqGrad_, update, update);
  for (size_t i = 0; i < update.size(); i++) {
    tmpGrad_[i] = update[i] / std::sqrt(sqGrad_[i]);
  }
  model_->addToWeights(tmpGrad_, lr);
}

ZOUnitsMixer::ZOUnitsMixer() : UnitsMixer() {
  normal_0_1_ = std::normal_distribution<float>(0, 1);
  weights_noise_.resize(model_->input_size_ * model_->output_size_);
}

Action* ZOUnitsMixer::forward(State* s, autobuild::BuildState* bst) {
  auto& p = model_->forward(s, bst);
  if (p.size() != actions_.size())
    LOG(ERROR) << "predictions size != actions size";
  auto i = utils::argmax<float>(p);
  nbActions_[i]++;
  return actions_[i];
}

const std::vector<float>& ZOUnitsMixer::computeUpdate() {
  if (!FLAGS_umm_train) {
    throw std::runtime_error("calling update without training");
  }
  // update += FLAGS_lr * reward_ * weights_noise_;
  update_ = std::vector<float>(weights_noise_);
  if (FLAGS_l2_reg > 0) {
    for (size_t i = 0; i < model_->output_size_; ++i) {
      for (size_t j = 0; j < model_->input_size_; ++j) {
        size_t ij = i * model_->input_size_ + j;
        update_[ij] = reward_ * weights_noise_[ij] -
            FLAGS_l2_reg * pow(model_->weights()[i][j] - weights_noise_[ij], 2);
      }
    }
  } else {
    for (size_t ij = 0; ij < update_.size(); ++ij)
      update_[ij] = reward_ * weights_noise_[ij];
  }
  return update_;
}

void ZOUnitsMixer::onGameStart(State* s) {
  if (FLAGS_umm_train) {
    // Sample noise and add it to model
    float anneal = 1.0;
    if (FLAGS_lin_anneal_noise > 0)
      anneal /= (1.0 + FLAGS_lin_anneal_noise * batch_n_);
    for (size_t ij = 0; ij < model_->input_size_ * model_->output_size_; ++ij) {
      weights_noise_[ij] = normal_0_1_(rdm_) * FLAGS_noise_scale * anneal;
      // TODO divive normal by d?
    }
    model_->addToWeights(weights_noise_);
  }
  // re-initialise logs
  std::fill(nbActions_.begin(), nbActions_.end(), 0);
}

void ZOUnitsMixer::onGameEnd(State* s) {
  reward_ = s->won() ? 1.0 : -1.0;
}

PGUnitsMixer::PGUnitsMixer() : UnitsMixer() {
  uni_0_1_ = std::uniform_real_distribution<float>(0, 1);
}

Action* PGUnitsMixer::forward(State* s, autobuild::BuildState* bst) {
  auto& p = model_->forward(s, bst);
  if (p.size() != actions_.size())
    LOG(ERROR) << "predictions size != actions size";
  softmax_inplace(p); // tricky!
  auto i = sampleIndex(p, uni_0_1_, rdm_);
  chosen_actions_.push_back(i);
  return actions_[i];
}

const std::vector<float>& PGUnitsMixer::computeUpdate() {
  // TODO replace reward by advantage/baseline or proper critic
  update_ = std::vector<float>(model_->input_size_ * model_->output_size_);
  // TODO REINFORCE
  // weights_ += avg_{k actions} [FLAGS_lr * reward_ * <input_, output_> /
  // output_[k]]
  return update_;
}

void PGUnitsMixer::onGameEnd(State* s) {
  reward_ = s->won() ? 1.0 : -1.0;
}


void ZOSampleMixer::onGameStart(State* s) {
  if (FLAGS_umm_train) {
    // Sample noise and add it to model
    float anneal = 1.0;
    if (FLAGS_lin_anneal_noise > 0)
      anneal /= (1.0 + FLAGS_lin_anneal_noise * batch_n_);
    float norm = std::sqrt(weights_noise_.size());
    for (size_t ij = 0; ij < model_->input_size_ * model_->output_size_; ++ij) {
      weights_noise_[ij] = normal_0_1_(rdm_) * FLAGS_noise_scale * anneal / norm;
    }
    model_->addToWeights(weights_noise_);
  }
}

ZOSampleMixer::ZOSampleMixer() : ZOUnitsMixer() {
  uni_0_1_ = std::uniform_real_distribution<float>(0, 1);
}

Action* ZOSampleMixer::forward(State* s, autobuild::BuildState* bst) {
  auto& p = model_->forward(s, bst);
  if (p.size() != actions_.size())
    LOG(ERROR) << "predictions size != actions size";
  size_t i = 0;
  if (FLAGS_umm_train) {
    softmax_inplace(p); // tricky!
    i = sampleIndex(p, uni_0_1_, rdm_);
  }
  else {
    i = utils::argmax<float>(p);
  }
  nbActions_[i]++;
  return actions_[i];
}

void ZOSampleMixer::onGameEnd(State* s) {
  reward_ = s->won() ? 1.0 : 0.0;
}

} // namespace model
} // namespace fairrsh
