/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#pragma once

#include "fairrsh.h"
#include "model.h"
#include "utils.h"

#include <cereal/types/polymorphic.hpp>
#include <cereal/types/memory.hpp>
#include <cereal/archives/binary.hpp>

#include <random>

DECLARE_string(umm_path);
DECLARE_double(lr);
//DECLARE_double(temperature);
DECLARE_double(noise_scale);
DECLARE_double(l2_reg);
DECLARE_double(pow_anneal_noise);
DECLARE_double(lin_anneal_noise);
DECLARE_double(adagrad_epsilon);
DECLARE_double(sqrt_lr);
DECLARE_bool(umm_train);

namespace fairrsh {

namespace model {

struct Action {
  std::string name;
};

extern Action actionDrone;
extern Action actionZergling;
extern Action actionHydralisk;
extern Action actionMutalisk;
extern Action actionMacroTech;
extern Action actionSunkenDef;
extern Action actionSafeMutaLing;

using UnitsMixerPtr = std::shared_ptr<UnitsMixer>;

/**
 * A Units Mix Model that decides what to produce with each larva
 * see https://www.overleaf.com/10674832mcjzvbpkwtdj#/40006644/
 */
class UnitsMixer {
  // This represents a SINGLE GAME
  RTTR_ENABLE()
  friend class cereal::access;

  public:
    // TODO rule of three
    UnitsMixer();
    virtual ~UnitsMixer();

    // These will throw if create/load/save fails.
    static UnitsMixerPtr make(std::string name);
    // load will load by default from FLAGS_umm_path/read/vanilla_weights
    static UnitsMixerPtr load(std::string path = std::string());
    // save will write by default to FLAGS_umm_path/write/<timestamp>_<random-suffix>
    static void save(UnitsMixerPtr, std::string path = std::string());

    virtual Action* forward(State* s, autobuild::BuildState* bst) = 0;
    virtual const std::vector<float>& computeUpdate() = 0;
    virtual void updateWeights(std::vector<float> const&);

    std::vector<int> const& nbActions() {
      return nbActions_;
    }
    
    virtual void onGameStart(State* s) {}
    virtual void onGameStep(State* s) {}
    virtual void onGameEnd(State* s) {}

    std::vector<float> update_;
    std::vector<float> sqGrad_;
    std::vector<float> tmpGrad_;
    std::unique_ptr<Model> model_;
    float reward_;
    float batch_n_;
    // logging actions
    std::vector<int> nbActions_;
    
    template<class Archive>
    void serialize(Archive& ar, const std::uint32_t version) {
      // TODO change for Cereal's preferred syntax
      // TODO refactor, e.g. http://uscilab.github.io/cereal/pimpl.html
      LOG(INFO) << "serializing UnitsMixer";
      ar & reward_;
      ar & batch_n_;
      ar & model_;
      ar & sqGrad_;
      ar & nbActions_;
    }

  protected:
    void checkModelAndActions();

    std::vector<Action*> actions_;
    std::mt19937 rdm_;
};

class ZOUnitsMixer : public UnitsMixer {  // ZOUM ZOUM!
  RTTR_ENABLE(UnitsMixer);
  friend class cereal::access;
 protected:
  std::vector<float> weights_noise_;
  std::normal_distribution<float> normal_0_1_;
  
  public:
    ZOUnitsMixer();

    virtual void onGameStart(State* s) override;
    virtual void onGameEnd(State* s) override;

    Action* forward(State* s, autobuild::BuildState* bst) override;
    virtual const std::vector<float>& computeUpdate() override;

    template<class Archive>
    void serialize(Archive& ar, const std::uint32_t version) {
      LOG(INFO) << "serializing ZOUnitsMixer";
      ar & ::cereal::base_class<UnitsMixer>(this);
      ar & weights_noise_;
    }
};

class PGUnitsMixer : public UnitsMixer {
  RTTR_ENABLE(UnitsMixer);
  friend class cereal::access;

  std::vector<size_t> chosen_actions_;
  std::uniform_real_distribution<float> uni_0_1_;

  public:
    PGUnitsMixer();

    Action* forward(State* s, autobuild::BuildState* bst) override;
    virtual const std::vector<float>& computeUpdate() override;

    virtual void onGameEnd(State* s) override;

    template<class Archive>
    void serialize(Archive& ar) {
      LOG(INFO) << "serializing PGUnitsMixer";
      ar & ::cereal::base_class<UnitsMixer>(this);
      ar & chosen_actions_;
    }
};

class ZOSampleMixer : public ZOUnitsMixer {
  RTTR_ENABLE(ZOUnitsMixer);
  friend class cereal::access;

  std::uniform_real_distribution<float> uni_0_1_;
 public:
  ZOSampleMixer();

  Action* forward(State* s, autobuild::BuildState* bst) override;
  
  virtual void onGameStart(State* s) override;
  virtual void onGameEnd(State* s) override;

  template<class Archive>
  void serialize(Archive& ar) {
    LOG(INFO) << "serializing ZOSampleMixer";
    ar & ::cereal::base_class<ZOUnitsMixer>(this);
  }
};

} // namespace model

} // namespace fairrsh

// XXX Use rttr for all this stuff, including serialization
CEREAL_REGISTER_TYPE(fairrsh::model::ZOUnitsMixer)
CEREAL_REGISTER_TYPE(fairrsh::model::PGUnitsMixer)
CEREAL_REGISTER_TYPE(fairrsh::model::ZOSampleMixer)
