/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#ifdef WITH_ATEN

#include "fonf.h"

#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>

#include <torchcraft/constants.h>

// Before adding another 'project' to this dir, probably move this to
// another directory.

#define EPOCH 20
#define LR 0.1f
#define WEIGHT_DECAY 0.0000001f
#define N_INPUT_THREADS 40
#define NTHREADS 10
#define N_HID 10000

// can be gab_score, supply_score, simple_score
#define SCORE_TYPE gab_score

namespace fairrsh {
namespace fonf {
namespace detail {

LogReg::LogReg() {
  weights = at::CPU(at::kFloat).tensor();
  bias = at::CPU(at::kFloat).tensor();
  mean = at::CPU(at::kFloat).tensor();
  std = at::CPU(at::kFloat).tensor();
}

LogReg::LogReg(int32_t nin, int32_t nout) {
  weights = at::CPU(at::kFloat).zeros({nin, nout}) / (float) std::sqrt(nin * nout);
  bias = at::CPU(at::kFloat).zeros({nout}) / (float) std::sqrt(nout);
  mean = at::CPU(at::kFloat).zeros({1, nin});
  std = at::CPU(at::kFloat).zeros({1, nin});
}

at::Tensor LogReg::predict(at::Tensor input) const {
  // Input is of size bsz x nin
  if (input.ndimension() == 1) {
    input = input.unsqueeze(0);
  }
  input = (input - mean.expand(input.sizes())) / std.expand(input.sizes());
  auto mm = input.mm(weights); 
  auto expanded = bias.unsqueeze(0).expand(mm.sizes());

  return (mm + expanded).sign_().add_(1e-2f).sign_();
}

void LogReg::train(at::Tensor data, at::Tensor targ, uint32_t epochs, float lr, float weight_decay, int threads) {
  assert(data.size(0) == targ.size(0));
  assert(data.size(1) == weights.size(0));
  assert(targ.size(1) == weights.size(1));

  mean = data.mean(0).unsqueeze(0);
  data = data - (mean.expand(data.sizes()));
  // nans are terrible, I hate indian food, I can't use .std :(
  std = ((data * data).sum(0) + 1e-10f).sqrt_().unsqueeze(0);
  data = data / (std.expand(data.sizes()));

  for (auto e = 0U; e < epochs; e++) {
    auto sum_loss = at::CPU(at::kFloat).zeros({weights.size(1)});
    auto shuf_ind = std::vector<uint32_t>(data.size(0));
    for (auto i = 0U; i < shuf_ind.size(); i++) {
      shuf_ind[i] = i;
    }
    std::shuffle(shuf_ind.begin(), shuf_ind.end(), std::default_random_engine(0));
    for (auto i = 0; i < data.size(0); i++) {
      auto row = shuf_ind[i];
      auto x = data[row];
      auto y = targ[row].unsqueeze(0);

      auto mm = x.unsqueeze(0).mm(weights); 
      auto expanded = bias.unsqueeze(0).expand(mm.sizes());
      auto p = mm + expanded;

      auto loss = p * y; 
      auto dl = loss.clone(); 
      auto A_loss = loss.accessor<float,2>();
      auto A_dl = dl.accessor<float,2>();
      for (auto k = 0; k < loss.size(1); k++) {
        auto z = A_dl[0][k];
        A_loss[0][k] = z > 6 ? std::exp(-z)
          : (z < -6 ? -z : std::log(1 + std::exp(-z)));

        A_dl[0][k] = z > 6 ? std::exp(-z)
          : (z < -6 ? 1 : 1.f / (1 + std::exp(z)));
      }
      sum_loss += loss;
      dl *= - y;

      bias += - bias * WEIGHT_DECAY - lr * dl;
      weights += - weights * WEIGHT_DECAY - lr * x.unsqueeze(1).mm(dl);
    }
    std::cout << "loss is " << sum_loss / (float) shuf_ind.size() << '\n';
  }
}

std::ostream& operator<<(std::ostream& out, const fonf::detail::LogReg& o) {
  auto nin = int32_t(o.weights.size(0));
  auto nout = int32_t(o.weights.size(1));
  out.write(reinterpret_cast<const char *>(&nin), sizeof(nin));
  out.write(reinterpret_cast<const char *>(&nout), sizeof(nout));
  out.write(reinterpret_cast<const char *>(o.weights.data_ptr()), sizeof(float) * nin * nout);
  out.write(reinterpret_cast<const char *>(o.bias.data_ptr()), sizeof(float) * nout);
  out.write(reinterpret_cast<const char *>(o.mean.data_ptr()), sizeof(float) * nin);
  out.write(reinterpret_cast<const char *>(o.std.data_ptr()), sizeof(float) * nin);

  return out;
}

std::istream& operator>>(std::istream& in, fonf::detail::LogReg& o) {
  char buffer[4];
  int32_t nin = 0, nout = 0;

  in.read(buffer, 4);
  memcpy(&nin, buffer, sizeof(nin));
  in.read(buffer, 4);
  memcpy(&nout, buffer, sizeof(nout));

  char buffer2[nin*nout*sizeof(float)];
  o.weights = o.weights.contiguous().resize_({nin, nout});
  o.bias = o.bias.contiguous().resize_({nout});
  o.mean = o.mean.contiguous().resize_({1, nin});
  o.std = o.std.contiguous().resize_({1, nin});

  in.read(buffer2, nin * nout * sizeof(float));
  memcpy(o.weights.data_ptr(), buffer2, nin * nout * sizeof(float));
  in.read(buffer2, nout * sizeof(float));
  memcpy(o.bias.data_ptr(), buffer2, nout * sizeof(float));
  in.read(buffer2, nin * sizeof(float));
  memcpy(o.mean.data_ptr(), buffer2, nin * sizeof(float));
  in.read(buffer2, nin * sizeof(float));
  memcpy(o.std.data_ptr(), buffer2, nin * sizeof(float));

  return in;
}

LogReg* model = nullptr;
} // namespace detail

void initialize() {
  std::ifstream fin;
  fin.open("fonf.bin", std::ios::binary | std::ios::in);
  if (!fin.good()) {
    return;
  }
  detail::model = new detail::LogReg();
  fin >> (*detail::model);
  fin.close();
}

} // namespace fonf
} // namespace fairrsh

/*
float simple_score(uint32_t ut) {
  return 1;
}

float supply_score(uint32_t ut) {
  using namespace torchcraft::BW::data;
  return SupplyRequired[ut];
}

float gab_score(uint32_t ut) {
  using namespace torchcraft::BW::data;
  return MineralPrice[ut] + 4.0/3 * GasPrice[ut] + 50 * SupplyRequired[ut];
}

// Below is all training code
template<class T>
int winner(const rapidjson::Value& battle, size_t pers, T scorefn) {
  const auto& my_dead = battle["units_died"][pers];
  const auto& op_dead = battle["units_died"][1 - pers];
  auto my_score = 0., op_score = 0.;

  for (auto i = 0U; i < my_dead.Size(); i++) {
    my_score += scorefn(my_dead[i].GetInt());
  }
  for (auto i = 0U; i < op_dead.Size(); i++) {
    op_score += scorefn(op_dead[i].GetInt());
  }

  return my_score < op_score ? 1 : -1;
}

template<class T>
float heuristic_winner(const rapidjson::Value& battle, T scorefn) {
  const auto& my_dead = battle["units_used"][0];
  const auto& op_dead = battle["units_used"][1];
  auto my_score = 0., op_score = 0.;

  for (auto i = 0U; i < my_dead.Size(); i++) {
    my_score += scorefn(my_dead[i].GetInt());
  }
  for (auto i = 0U; i < op_dead.Size(); i++) {
    op_score += scorefn(op_dead[i].GetInt());
  }

  return my_score > op_score ? 1 : -1; // If I have more strength I probably won
}

template<class T>
float calculate_ratio(const rapidjson::Value& battle, T scorefn) {
  // TODO - units weighted? (probes, high tech units are weighted more)
  const auto& my_used = battle["units_used"][0];
  const auto& op_used = battle["units_used"][1];
  const auto& my_dead = battle["units_died"][0];
  const auto& op_dead = battle["units_died"][1];
  auto used_score = 0., dead_score = 0.;

  for (auto i = 0U; i < my_dead.Size(); i++) {
    dead_score += scorefn(my_dead[i].GetInt());
  }
  for (auto i = 0U; i < op_dead.Size(); i++) {
    dead_score += scorefn(op_dead[i].GetInt());
  }
  for (auto i = 0U; i < my_used.Size(); i++) {
    used_score += scorefn(my_used[i].GetInt());
  }
  for (auto i = 0U; i < op_used.Size(); i++) {
    used_score += scorefn(op_used[i].GetInt());
  }

  return used_score == 0 ? 0 : dead_score / used_score;
}

template<class T>
bool filter(const rapidjson::Value& battle, T scorefn) {
  const auto& my_used = battle["units_used"][0];
  const auto& op_used = battle["units_used"][1];
  const auto& my_dead = battle["units_died"][0];
  const auto& op_dead = battle["units_died"][1];
  auto my_score = 0., op_score = 0.;

  for (auto i = 0U; i < my_used.Size(); i++) {
    my_score += scorefn(my_used[i].GetInt());
  }
  for (auto i = 0U; i < op_used.Size(); i++) {
    op_score += scorefn(op_used[i].GetInt());
  }

  return ( //true
      my_dead.Size() + op_dead.Size() > 0 
      && std::min(my_score, op_score) / std::max(my_score, op_score) > (1/1.5)
      );
}

template<class T>
std::pair<std::vector<float>, float> featurize(
    const rapidjson::Value& battle,
    size_t perspective,
    T scorefn) {
  auto me = perspective;
  auto op = 1-me;
  auto inp = std::vector<float>(FEAT_SIZE, 0);

  const auto& units = battle["units_used"];

  for (rapidjson::SizeType i = 0; i < units[me].Size(); i++) {
    inp[units[me][i].GetInt()] += 1;
  }
  for (rapidjson::SizeType i = 0; i < units[op].Size(); i++) {
    inp[units[op][i].GetInt() + UT_MAX] += 1;
  }

  auto out = winner(battle, perspective, scorefn);

  return std::make_pair(inp, out);
}

auto read_files(std::string root, std::string set, size_t max = 0) {
  std::ifstream infile(root + '/' + set);
  std::vector<std::string> files;
  std::string line;
  while (std::getline(infile, line)) {
    files.push_back(line);
    if (max > 0 && files.size() > max) break;
  }

  std::vector<std::vector<float>> arr;
  std::vector<float> labels;
  auto ratio = 0.;
  auto nbattles = 0, n_heurwins = 0;

  for (auto i = 0U; i < files.size(); i++) {
    auto battlefn = files[i];

    std::ifstream ifs(root + '/' + battlefn);
    rapidjson::IStreamWrapper isw(ifs);
    auto document = rapidjson::Document();
    document.ParseStream(isw);

    auto scorefn = SCORE_TYPE;
    for (rapidjson::SizeType i = 0; i < document.Size(); i++) {
      nbattles++;
      const auto& battle = document[i];
      auto good = filter(battle, scorefn);
      if (!good) continue;
      auto featurized0 = featurize(battle, 0, scorefn);
      auto featurized1 = featurize(battle, 1, scorefn);
      auto battle_ratio = calculate_ratio(battle, scorefn);
      auto heuristic = heuristic_winner(battle, scorefn);
      {
        if (heuristic == featurized0.second) {
          n_heurwins++;
        }

        ratio += battle_ratio;
        arr.push_back(std::move(featurized0.first));
        labels.push_back(std::move(featurized0.second));
        arr.push_back(std::move(featurized1.first));
        labels.push_back(std::move(featurized1.second));
      }
    }
  }

  std::cout << "Average ratio of strengths lost in engagements: " << ratio / labels.size() << '\n';
  std::cout << "data size: " << labels.size() << " out of " << nbattles * 2 << " unfiltered total battles\n";
  std::cout << "heuristic correct: " << n_heurwins << " out of " << labels.size() / 2 << " total battles\n";
  auto data = at::CPU(at::kFloat).tensor({(long long) arr.size(), FEAT_SIZE});
  auto lab = at::CPU(at::kFloat).tensor({(long long) labels.size(), 1});
  auto A_data = data.accessor<float,2>();
  auto A_lab = lab.accessor<float,2>();

  for (auto i = 0U; i < arr.size(); i++) {
    for (auto j = 0; j < FEAT_SIZE; j++) {
      A_data[i][j] = arr[i][j];
    }
    A_lab[i][0] = labels[i];
  }

  return std::make_pair(data, lab);
}

int main(int argc, char** argv) {
  auto data_path = std::string(argv[1]);
  std::cout << "Reading files...\n";
  auto train = read_files(data_path, "train.list");
  std::cout << "Reading validation files...\n";
  auto valid = read_files(data_path, "valid.list");

  std::cout << "Number of games where I won: " << valid.second.ge(0).sum() << '\n';

  std::cout << "Training...\n";
  auto model = fonf::LogReg(FEAT_SIZE, N_HID, 1);
  model.train(train.first, train.second, EPOCH, LR, WEIGHT_DECAY, NTHREADS);

  std::cout << "Predicting...\n";
  auto pred = model.predict(valid.first);

  std::cout << "Saving...\n";
  std::ofstream fout;
  fout.open("save.bin", std::ios::binary | std::ios::out);
  fout << model;
  fout.close();

  std::ifstream fin;
  fin.open("save.bin", std::ios::binary | std::ios::in);
  auto model2 = fonf::LogReg();
  fin >> model2;
  fin.close();

  auto pred2 = model2.predict(valid.first);
  std::cout << "Saved model and actual model disagrees on " << pred.ne(pred2).sum() << " values\n";
  std::cout << pred.eq(valid.second).sum() << " correct / " << pred.numel() << " total, and " << pred.eq(valid.second).sum().to<float>() / pred.numel() <<"%\n";

  return 0;
}
*/

#endif // WITH_ATEN
