/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#pragma once

#include "fairrsh.h"
#include "buildtype.h"
#include "modules/autobuild.h"
#include <cereal/types/vector.hpp>
#include <cereal/types/unordered_map.hpp>
#include <cereal/types/utility.hpp>
#include <cereal/archives/binary.hpp>


namespace fairrsh {

namespace model {

class Featurizer {
  // Feel free to subclass to add cross-features, or include kernel
  RTTR_ENABLE()
  friend class cereal::access;

  public:
    // TODO rule of three
    virtual ~Featurizer() = default;
    Featurizer();
    const std::vector<std::pair<size_t, float>>& forward(
        State* s,
        autobuild::BuildState* bst);
    std::vector<std::pair<size_t, float>> features_;
    size_t output_size_;
    struct Indices {
      std::vector<int> myUnitTypes;
      std::vector<int> enemyUnitTypes;
      std::unordered_map<int, size_t> myUnitTypesToIndex;
      std::unordered_map<int, size_t> myInProdUnitTypesToIndex;
      std::unordered_map<int, size_t> enemyUnitTypesToIndex;
      size_t oreIndex;
      size_t gasIndex;
      size_t usedSupplyIndex;
      size_t maxSupplyIndex;
      size_t availSupplyIndex;
      // Below indices for if/when we use this within AutoBuild
      // TODO, ask Vegard
      size_t oreBstIndex;
      size_t gasBstIndex;
      size_t usedSupplyBstIndex;
      size_t maxSupplyBstIndex;
      size_t inprodSupplyBstIndex;
      size_t availSupplyBstIndex;
      size_t BIAS;
      template<class Archive>
      void serialize(Archive& ar, const std::uint32_t version) {
        ar & myUnitTypes;
        ar & enemyUnitTypes;
        ar & myUnitTypesToIndex;
        ar & myInProdUnitTypesToIndex;
        ar & enemyUnitTypesToIndex;
        ar & oreIndex;
        ar & gasIndex;
        ar & usedSupplyIndex;
        ar & maxSupplyIndex;
        ar & availSupplyIndex;
        ar & oreBstIndex;
        ar & gasBstIndex;
        ar & usedSupplyBstIndex;
        ar & maxSupplyBstIndex;
        ar & inprodSupplyBstIndex;
        ar & availSupplyBstIndex;
        ar & BIAS;
      }
    } indices_ ;
    template<class Archive>
    void serialize(Archive& ar, const std::uint32_t version) {
      ar & features_;
      ar & output_size_;
      ar & indices_;
    }

  private:
    autobuild::BuildState buildState_;
};

class FeaturizerSqrt : public Featurizer {
  RTTR_ENABLE()
  friend class cereal::access;

  public:
    const std::vector<std::pair<size_t, float>>& forward(
        State* s,
        autobuild::BuildState* bst);
};

class Model {
  RTTR_ENABLE()
  friend class cereal::access;

  public:
    // TODO rule of three
    Model() = default;
    Model(Featurizer&& f, size_t output_size);
    virtual ~Model() = default;
    std::vector<float>& forward(State* s, autobuild::BuildState* bst);
    void addToWeights(const std::vector<float>& a, float lr = 1.0f);
    void zeroWeights();
    const std::vector<std::vector<float>>& weights() const {
      return weights_;
    }
    size_t dim() const {
      return input_size_ * output_size_;
    }
    size_t input_size_;
    size_t output_size_;
    template<class Archive>
    void serialize(Archive& ar, const std::uint32_t version) {
      ar & inputs_;
      ar & outputs_;
      ar & input_size_;
      ar & output_size_;
      ar & weights_;
      ar & featurizer_;
    }

  //private:
    std::vector<std::vector<std::pair<size_t, float>>> inputs_;
    std::vector<std::vector<float>> outputs_;
    std::vector<std::vector<float>> weights_;
    Featurizer featurizer_;
};

} // namespace model

} // namespace fairrsh
