/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#pragma once

#ifdef WITH_ATEN

#include "fairrsh.h"

#include <ATen/ATen.h>
#include <glog/logging.h>

constexpr auto UT_MAX = 234;
constexpr auto OPP_EMB = 100;
constexpr auto FEAT_SIZE = UT_MAX * 2 + OPP_EMB;

namespace fairrsh {
namespace fonf {
namespace detail {

class LogReg {
  public:
    LogReg();
    LogReg(int32_t nin, int32_t nout);
    at::Tensor predict(at::Tensor input) const;
    void train(at::Tensor data, at::Tensor targ, uint32_t epochs, float lr, float weight_decay=0, int threads=1);

    at::Tensor weights, bias;
    at::Tensor mean, std;
  private:
    at::Tensor sigmoid(at::Tensor x) {
      return 1 / (1 + (-x).exp());
    }
};

std::ostream& operator<<(std::ostream& out, const fonf::detail::LogReg& o);
std::istream& operator>>(std::istream& in, fonf::detail::LogReg& o);

extern LogReg* model;
} // namespace detail

void initialize();

template <typename U1, typename U2>
bool fonf(U1&& my_units, U2&& their_units) {
  if (!detail::model) {
    LOG(ERROR) << "No model found for fonf, returning true";
    return true;
  }
  auto inp = at::CPU(at::kFloat).zeros({FEAT_SIZE});
  for (auto unit : my_units) {
    inp[unit->type->unit] += 1.0f;
  }
  for (auto unit : their_units) {
    inp[unit->type->unit+ UT_MAX] += 1.0f;
  }

  auto out = detail::model->predict(inp);
  return at::Scalar(out[0][0]).to<float>() > 0;
}

} // namespace fonf
} // namespace fairrsh

#else // WITH_ATEN

namespace fairrsh {
namespace fonf {

inline void initialize() {}

} // namespace fonf
} // namespace fairrsh

#endif // WITH_ATEN
