/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#pragma once

#include "circularbuffer.h"
#include "fairrsh.h"
#include "module.h"
#include "task.h"
#include "unitsinfo.h"
#include "upc.h"
#include "upcfilter.h"

#include <chrono>
#include <list>
#include <map>
#include <memory>

#include <gflags/gflags.h>
#include <glog/logging.h>
#include <mapbox/variant.hpp>
#include <torchcraft/client.h>

DECLARE_string(umm_type);

namespace fairrsh {

namespace model {
class UnitsMixer;
class ZOUnitsMixer;
class PGUnitsMixer;
class OpeningBandit;
} // namespace model

/// UPCTuple and associated origin
struct UPCData {
  RTTR_ENABLE()

 public:
  std::shared_ptr<UPCTuple> upc = nullptr;
  UpcId source = kInvalidUpcId;
  Module* origin = nullptr;

  UPCData() {}
  UPCData(std::shared_ptr<UPCTuple> upc, UpcId source, Module* origin)
      : upc(upc), source(source), origin(origin) {}
};

/// Task and associated owner
struct TaskData {
  RTTR_ENABLE()

 public:
  std::shared_ptr<Task> task = nullptr;
  Module* owner = nullptr;
  bool autoRemove = true;

  TaskData() {}
  TaskData(std::shared_ptr<Task> task, Module* owner, bool autoRemove)
      : task(std::move(task)), owner(owner), autoRemove(autoRemove) {}
};

/**
 * Base class for data attached to the posting of an UPC.
 *
 * When backpropagating through the bot, this data will be provided for
 * computing gradients for the respective posted UPC. A common use case
 * would be to store the output of a featurizer here.
 */
struct UpcPostData {
  // Base class is empty
};

/**
 * Stores information about UPCs that have been posted to the board.
 */
struct UpcPost {
  RTTR_ENABLE()

 public:
  /// Game
  FrameNum frame = -1;
  /// Identifier of source UPC
  UpcId sourceId = kInvalidUpcId;
  /// Identifier of posted UPC
  UpcId upcId = kInvalidUpcId;
  /// The module performing the transaction
  Module* module = nullptr;
  /// Data attached to this transaction
  std::shared_ptr<UpcPostData> data = nullptr;

  UpcPost() {}
  UpcPost(
      FrameNum frame,
      UpcId sourceId,
      UpcId upcId,
      Module* module,
      std::shared_ptr<UpcPostData> data = nullptr)
      : frame(frame),
        sourceId(sourceId),
        upcId(upcId),
        module(module),
        data(std::move(data)) {}
};

/**
 * An access-aware blackboard.
 *
 * The blackboard provides a means for modules to exchange UPCTuples while
 * keeping track of producers and consumers.
 *
 * Furthermore, there is some rudimentary functionality for holding global state
 * via post(), hasKey() and get().
 */
class Blackboard {
  RTTR_ENABLE()

 public:
  typedef mapbox::util::variant<bool, int, std::string, Position> Data;
  typedef std::map<UpcId, std::shared_ptr<UPCTuple>> UPCMap;

  using TaskTimeStats =
      std::tuple<UpcId, std::string, std::chrono::milliseconds>;

  // A few commonly used keys for post() and get()
  static char const* kMyLocationKey;
  static char const* kEnemyLocationKey;
  static char const* kEnemyRaceKey;
  static char const* kEnemyNameKey;
  static char const* kBuilderScoutingPolicyKey;

  Blackboard(State* state);

  virtual ~Blackboard();

  void init();

  // Updates internal mappings after the torchcraft state has been updated.
  void update();

  void post(std::string const& key, Data const& data) {
    map_[key] = data;
  }
  bool hasKey(std::string const& key) {
    return map_.find(key) != map_.end();
  }
  Data const& get(std::string const& key) const {
    return map_.at(key);
  }
  template <typename T>
  T const& get(std::string const& key) const {
    return map_.at(key).get<T>();
  }
  template <typename T>
  T const& get(std::string const& key, T const& defaultValue) const {
    auto it = map_.find(key);
    if (it == map_.end()) {
      return defaultValue;
    }
    return it->second.get<T>();
  }
  void remove(std::string const& key) {
    map_.erase(key);
  }

  bool isTracked(UnitId uid) const;
  void track(UnitId uid);
  void untrack(UnitId uid);

  // UPC post/query/consume
  UpcId postUPC(
      std::shared_ptr<UPCTuple> upc,
      UpcId sourceId,
      Module* origin,
      std::shared_ptr<UpcPostData> data = nullptr);
  void consumeUPCs(std::vector<UpcId> const& ids, Module* consumer);
  void consumeUPC(UpcId id, Module* consumer) {
    consumeUPCs({id}, consumer);
  }
  void removeUPCs(std::vector<UpcId> const& ids);
  UPCMap upcs() const;
  UPCMap upcsFrom(Module* origin) const;
  UPCMap upcsFrom(std::shared_ptr<Module> origin) const {
    return upcsFrom(origin.get());
  }
  UPCMap upcsWithSharpCommand(Command cmd) const;
  UPCMap upcsWithCommand(Command cmd, float minProb) const;

  // UPC filters
  void addUPCFilter(std::shared_ptr<UPCFilter> filter);
  void removeUPCFilter(std::shared_ptr<UPCFilter> filter);

  // Task post/query
  void
  postTask(std::shared_ptr<Task> task, Module* owner, bool autoRemove = false);
  std::shared_ptr<Task> taskForId(UpcId id) const;
  std::vector<std::shared_ptr<Task>> tasksOfModule(Module* module) const;
  std::shared_ptr<Task> taskWithUnit(Unit* unit) const;
  TaskData taskDataWithUnit(Unit* unit) const;
  std::shared_ptr<Task> taskWithUnitOfModule(Unit* unit, Module* module) const;
  void markTaskForRemoval(UpcId upcId);
  void markTaskForRemoval(std::shared_ptr<Task> task) {
    markTaskForRemoval(task->upcId());
  }
  void updateUnitAccessCounts(tc::Client::Command const& command);

  // Game commands
  void postCommands(std::vector<tc::Client::Command> const& commands) {
    // Add stats for each command
    for (auto command : commands) {
      updateUnitAccessCounts(command);
    }

    auto& current = commands_.at(0);
    current.insert(current.end(), commands.begin(), commands.end());
  }
  void postCommand(tc::Client::Command const& command) {

    updateUnitAccessCounts(command);
    commands_.at(0).push_back(command);
  }
  std::vector<tc::Client::Command> const& commands(int stepsBack = 0) const {
    return commands_.at(-std::min(stepsBack, int(commands_.size()) - 1));
  }
  size_t pastCommandsAvailable() const {
    return commands_.size();
  }

  // Updates the taskByUnit mapping, should be called after setUnits on a task
  void updateTasksByUnit(Task* task);

  // Updates the tasks_ and tasksByModule_ mapping
  void changeTaskOwnership(Task* task, Module* previousOwner, Module* newOwner);

  // UPC consistency checks
  // Calling this only makes sense in the player's postStep function, once
  // all the UPCs have been converted into commands to be posted to the game.
  void checkPostStep();

  std::vector<TaskTimeStats> getTaskTimeStats() const {
    return taskTimeStats_;
  }

  std::shared_ptr<model::UnitsMixer> unitsMixer;
  // ^ TODO remove, see #321 [this sucks, but better is more work]
  // same below
  std::vector<std::pair<size_t, float>> unitsMixerCurrentInput;
  std::vector<float> unitsMixerCurrentOutput;

  std::shared_ptr<model::OpeningBandit> openingBandit;
  void setCollectTimers(bool collect);

 private:
  State* state_;
  std::unordered_map<std::string, Data> map_;
  CircularBuffer<std::vector<tc::Client::Command>> commands_;
  std::map<UpcId, UPCData> upcs_;
  std::atomic_long upcCount_;
  std::unordered_map<UpcId, Module*> consumedUPCs_;
  std::list<UpcPost> posts_;
  std::list<std::shared_ptr<UPCFilter>> upcFilters_;
  std::set<UnitId> tracked_;

  std::list<TaskData> tasks_;
  std::map<UpcId, std::list<TaskData>::iterator> tasksById_;
  std::unordered_multimap<Module*, std::list<TaskData>::iterator>
      tasksByModule_;
  std::unordered_map<Unit*, std::list<TaskData>::iterator> tasksByUnit_;
  std::vector<UpcId> tasksToBeRemoved_;
  std::map<UnitId, size_t> unitAccessCounts_;
  std::vector<TaskTimeStats> taskTimeStats_;

  bool collectTimers_ = false;
};

} // namespace fairrsh
