/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "blackboard.h"

#include "models/bandit.h"
#include "models/units_mixer.h"
#include "state.h"
#include "utils.h"

#include <glog/logging.h>

#include <chrono>
#include <iomanip>
#include <sstream>

DEFINE_string(umm_type, "zo", "UMM type");
DECLARE_string(umm_path);
DECLARE_string(opening);
DECLARE_bool(bandit);

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<UPCData>("UPCData")(
      metadata("type", rttr::type::get<UPCData>()))
      .property_readonly("upcs", &UPCData::upc)
      .property_readonly("origin", &UPCData::origin);

  rttr::registration::class_<Blackboard>("Blackboard")(
      metadata("type", rttr::type::get<Blackboard>()))
      // No constructor, please use via State
      .method("isTracked", &Blackboard::isTracked)
      .method("track", &Blackboard::track)
      .method("untracked", &Blackboard::untrack)
      .method("postUPC", &Blackboard::postUPC)
      .method("upcs", &Blackboard::upcs)
      .method(
          "upcsFrom",
          rttr::select_overload<Blackboard::UPCMap(std::shared_ptr<Module>)
                                    const>(&Blackboard::upcsFrom))
      .method("upcsWithSharpCommand", &Blackboard::upcsWithSharpCommand)
      .method("upcsWithCommand", &Blackboard::upcsWithCommand);
}

namespace {

template <typename UnaryPredicate>
Blackboard::UPCMap selectUpcs(
    std::map<int, UPCData> const& upcs,
    UnaryPredicate pred) {
  Blackboard::UPCMap result;
  auto it = upcs.begin();
  while (true) {
    it = std::find_if(it, upcs.end(), pred);
    if (it == upcs.end()) {
      break;
    }
    result.emplace(it->first, it->second.upc);
    ++it;
  }
  return result;
}

} // namespace

char const* Blackboard::kMyLocationKey = "my_location";
char const* Blackboard::kEnemyLocationKey = "enemy_location";
char const* Blackboard::kEnemyRaceKey = "enemy_race";
char const* Blackboard::kEnemyNameKey = "enemy_name";
char const* Blackboard::kBuilderScoutingPolicyKey = "builder_send_scoutUPC";

Blackboard::Blackboard(State* state)
    : state_(state), commands_(16), upcCount_(0) {
  // Note: state may not be fully constructed yet -- don't do anything fancy
  // with it.
}

Blackboard::~Blackboard() {}

void Blackboard::init() {
  // Initialize global units mixer instance
  if (FLAGS_umm_path != "") {
    try {
      unitsMixer = model::UnitsMixer::load();
    } catch (std::exception& ex) {
      LOG(INFO)
          << "No existing units mixer model found, initializing a new one";
      unitsMixer = model::UnitsMixer::make(FLAGS_umm_type);
    }
  }
  // Initialize global opening bandit instance
  //if (FLAGS_bandit) {
  if (true) { // not taking any risk
    std::string eName = get<std::string>(Blackboard::kEnemyNameKey);
    try {
      openingBandit = model::OpeningBandit::load(utils::stringToLower(eName));
    } catch (std::exception& ex) {
      tc::BW::Race eRace = tc::BW::Race::Unknown;
      auto eRaceBB = tc::BW::Race::_from_integral_nothrow(
          get<int>(Blackboard::kEnemyRaceKey));
      if (eRaceBB) {
        eRace = *eRaceBB;
      }
      LOG(WARNING) << "No existing openings bandit found for name: " << eName
                   << ", initializing a new one from race "
                   << eRace._to_string();
      openingBandit = model::OpeningBandit::make(
          eRace, state_->tcstate()->map_name, utils::stringToLower(eName));
    }
  }
}

bool Blackboard::isTracked(UnitId uid) const {
  return tracked_.find(uid) != tracked_.end();
}
void Blackboard::track(UnitId uid) {
  tracked_.insert(uid);
}
void Blackboard::untrack(UnitId uid) {
  tracked_.erase(uid);
}

void Blackboard::setCollectTimers(bool collect) {
  collectTimers_ = collect;
}

UpcId Blackboard::postUPC(
    std::shared_ptr<UPCTuple> upc,
    UpcId sourceId,
    Module* origin,
    std::shared_ptr<UpcPostData> data) {
  auto id = ++upcCount_;
  for (auto filter : upcFilters_) {
    upc = filter->filter(state_, upc, origin);
    if (upc == nullptr) {
      // Note that we return the ID to the caller, i.e. it will assume that the
      // UPC has been posted. This is done in order to not impose a need to
      // handle posting failures. Module code should be robust enough to handle
      // situations where UPCTuples are not consumed anyway.
      VLOG(1) << utils::upcString(upc, id) << " from " << origin->name()
              << " has been filtered out. Not posting.";
      return kFilteredUpcId;
    }
  }

  upcs_[id] = UPCData(upc, sourceId, origin);
  posts_.emplace_back(
      state_->currentFrame(), sourceId, id, origin, std::move(data));
  VLOG(1) << "<- " << utils::upcString(upc, id) << " from " << origin->name()
          << " with source " << utils::upcString(sourceId);
  return id;
}

void Blackboard::consumeUPCs(std::vector<int> const& ids, Module* consumer) {
  for (auto id : ids) {
    VLOG(1) << "-> " << utils::upcString(upcs_[id].upc, id) << " to "
            << consumer->name();
    upcs_.erase(id);
    consumedUPCs_[id] = consumer;
  }
}

void Blackboard::removeUPCs(std::vector<int> const& ids) {
  for (auto id : ids) {
    if (upcs_.erase(id) > 0) {
      VLOG(1) << "-> upc " << id << " removed ";
    }
  }
}

Blackboard::UPCMap Blackboard::upcs() const {
  return selectUpcs(upcs_, [](std::pair<int, UPCData> const&) { return true; });
}

Blackboard::UPCMap Blackboard::upcsFrom(Module* origin) const {
  return selectUpcs(upcs_, [origin](std::pair<int, UPCData> const& d) {
    return d.second.origin == origin;
  });
}

Blackboard::UPCMap Blackboard::upcsWithSharpCommand(Command cmd) const {
  return selectUpcs(upcs_, [cmd](std::pair<int, UPCData> const& d) {
    return d.second.upc->command[cmd] == 1.0f;
  });
}

Blackboard::UPCMap Blackboard::upcsWithCommand(Command cmd, float minProb)
    const {
  return selectUpcs(upcs_, [cmd, minProb](std::pair<int, UPCData> const& d) {
    return d.second.upc->command[cmd] >= minProb;
  });
}

void Blackboard::addUPCFilter(std::shared_ptr<UPCFilter> filter) {
  upcFilters_.push_back(filter);
}

void Blackboard::removeUPCFilter(std::shared_ptr<UPCFilter> filter) {
  upcFilters_.remove(filter);
}

void Blackboard::postTask(
    std::shared_ptr<Task> task,
    Module* owner,
    bool autoRemove) {
  if (tasksById_.find(task->upcId()) != tasksById_.end()) {
    throw std::runtime_error(
        "Existing task found for " + utils::upcString(task->upcId()));
  }
  auto it = tasks_.emplace(tasks_.end(), task, owner, autoRemove);
  tasksById_.emplace(task->upcId(), it);
  tasksByModule_.emplace(owner, it);
  for (Unit* u : const_cast<const Task*>(task.get())->units()) {
    tasksByUnit_[u] = it;
  }
}

std::shared_ptr<Task> Blackboard::taskForId(int id) const {
  auto it = tasksById_.find(id);
  if (it == tasksById_.end()) {
    return nullptr;
  }
  return it->second->task;
}

std::vector<std::shared_ptr<Task>> Blackboard::tasksOfModule(
    Module* module) const {
  std::vector<std::shared_ptr<Task>> result;
  auto range = tasksByModule_.equal_range(module);
  for (auto it = range.first; it != range.second; ++it) {
    result.emplace_back(it->second->task);
  }
  return result;
}

std::shared_ptr<Task> Blackboard::taskWithUnit(Unit* unit) const {
  auto it = tasksByUnit_.find(unit);
  if (it == tasksByUnit_.end()) {
    return nullptr;
  }
  return it->second->task;
}

TaskData Blackboard::taskDataWithUnit(Unit* unit) const {
  auto it = tasksByUnit_.find(unit);
  if (it == tasksByUnit_.end()) {
    return TaskData();
  }
  return *(it->second);
}

std::shared_ptr<Task> Blackboard::taskWithUnitOfModule(
    Unit* unit,
    Module* module) const {
  auto it = tasksByUnit_.find(unit);
  if (it == tasksByUnit_.end()) {
    return nullptr;
  }
  if (it->second->owner != module) {
    return nullptr;
  }
  return it->second->task;
}

void Blackboard::markTaskForRemoval(int upcId) {
  // Mark for removal, but keep it around until the next update
  tasksToBeRemoved_.push_back(upcId);
}

void Blackboard::updateUnitAccessCounts(tc::Client::Command const& command) {
  // Make sure we got a command to a unit and the unit exists
  if (command.code == tc::BW::Command::CommandUnit && command.args.size() > 0) {
    const int unitId = command.args[0];

    // Not sure if the default constructor of an int initializes it
    // to 0, hence being careful
    unitAccessCounts_[unitId] =
        (unitAccessCounts_.find(unitId) == unitAccessCounts_.end())
        ? 0
        : unitAccessCounts_[unitId] + 1;
  }
}

void Blackboard::updateTasksByUnit(Task* task) {
  auto it = std::find_if(tasks_.begin(), tasks_.end(), [&](auto& v) {
    return v.task.get() == task;
  });
  for (auto i = tasksByUnit_.begin(); i != tasksByUnit_.end();) {
    if (i->second == it)
      i = tasksByUnit_.erase(i);
    else
      ++i;
  }
  for (Unit* u : const_cast<const Task*>(task)->units()) {
    tasksByUnit_[u] = it;
  }
}

void Blackboard::changeTaskOwnership(
    Task* task,
    Module* previousOwner,
    Module* newOwner) {
  auto it = std::find_if(tasks_.begin(), tasks_.end(), [&](auto& v) {
    return v.task.get() == task;
  });
  it->owner = newOwner;
  auto range = tasksByModule_.equal_range(previousOwner);
  for (auto itM = range.first; itM != range.second; ++itM) {
    if (itM->second->task.get() == task) {
      tasksByModule_.erase(itM);
      break;
    }
  }
  tasksByModule_.emplace(newOwner, it);
}

void Blackboard::update() {
  commands_.push();

  // Remove tasks pending for removal
  for (auto id : tasksToBeRemoved_) {
    auto it = tasksById_.find(id);
    if (it == tasksById_.end()) {
      LOG(WARNING) << "Task " << id << " to be removed but does not exist";
      continue;
    }

    // Update mappings; this is quite painful for the multimap
    auto owner = it->second->owner;
    VLOG(2) << "Removing task with id " << id << " from " << owner->name();
    auto range = tasksByModule_.equal_range(owner);
    for (auto itM = range.first; itM != range.second; ++itM) {
      if (itM->second->task->upcId() == id) {
        tasksByModule_.erase(itM);
        break;
      }
    }
    // Remove all references in tasksByUnit_
    for (auto i = tasksByUnit_.begin(); i != tasksByUnit_.end();) {
      if (i->second == it->second)
        i = tasksByUnit_.erase(i);
      else
        ++i;
    }
    tasks_.erase(it->second);
    tasksById_.erase(it);
  }
  tasksToBeRemoved_.clear();

  // Remove any dead units from the tasksByUnit_ mapping. Do this before
  // updating the tasks so that they can handle deaths and re-assignment the
  // same way.
  for (auto* u : state_->unitsInfo().getDestroyUnits()) {
    auto uit = tasksByUnit_.find(u);
    if (uit != tasksByUnit_.end()) {
      tasksByUnit_.erase(u);
    }
  }

  // Clear before adding stats for this round of updates
  taskTimeStats_.clear();

  // Update tasks in reverse order of their UPC Id, effectively running more
  // recently created tasks first.
  for (auto it = tasksById_.rbegin(); it != tasksById_.rend(); ++it) {
    auto task = it->second->task;

    if (task->status() != TaskStatus::Cancelled) {
      std::chrono::time_point<hires_clock> start;
      if (collectTimers_) {
        start = hires_clock::now();
      }
      task->update(state_);
      if (collectTimers_) {
        auto duration = hires_clock::now() - start;
        auto ms =
            std::chrono::duration_cast<std::chrono::milliseconds>(duration);
        // Only log tasks that longer than threshold of 0 ms
        if (ms.count() > 0) {
          taskTimeStats_.push_back(
              std::make_tuple(it->first, it->second->owner->name(), ms));
        }
      }
    }

    // Update tasksByUnit_. This is not efficient, and would be best done
    // in Task::setUnits
    for (auto i = tasksByUnit_.begin(); i != tasksByUnit_.end();) {
      if (i->second == it->second)
        i = tasksByUnit_.erase(i);
      else
        ++i;
    }
    for (Unit* u : const_cast<const Task*>(task.get())->units()) {
      tasksByUnit_[u] = it->second;
    }

    // If auto-removal is turned on, schedule finished tasks for removal
    if (it->second->autoRemove && task->finished()) {
      VLOG(3) << "Blackboard: removing task " << task->upcId() << "with status "
              << (int)task->status();
      tasksToBeRemoved_.push_back(it->first);
    }
  }

  // Check if any of the tasks that are not marked for removal have units
  // in common
  std::set<UpcId> tasksToBeRemoved_t(
      tasksToBeRemoved_.begin(), tasksToBeRemoved_.end());
  std::map<Unit*, UpcId> duplicateUnits;

  // Check if two tasks have overlapping units
  for (auto it = tasksById_.rbegin(); it != tasksById_.rend(); ++it) {
    // Create a new shared pointer to get a const pointer to access
    // the public const function units() instead of the protected one
    auto task = std::const_pointer_cast<const Task>(it->second->task);

    // Only check tasks that are not scheduled to be removed
    if (tasksToBeRemoved_t.find(task->upcId()) == tasksToBeRemoved_t.end()) {
      for (auto const unit : task->units()) {
        auto unitIt = duplicateUnits.find(unit);
        if (unitIt != duplicateUnits.end()) {
          LOG(WARNING) << "Blackboard: Error: Task " << task->upcId()
                       << " has units in common with task " << unitIt->second;
        } else {
          duplicateUnits[unit] = task->upcId();
        }
      }
    }
  }
}

void Blackboard::checkPostStep() {
  std::map<Unit*, UpcId> duplicateUnits;

  for (auto upcTuple : upcs_) {
    auto upc = upcTuple.second.upc;

    for (auto unitProb : upc->unit) {
      if (unitProb.second == 1.0) {
        auto unitIt = duplicateUnits.find(unitProb.first);
        if (unitIt != duplicateUnits.end()) {
          LOG(WARNING) << "Blackboard: Error: Upc " << upcTuple.first
                       << " has units in common with Upc " << unitIt->second;
        } else {
          duplicateUnits[unitProb.first] = upcTuple.first;
        }
      }
    }
  }
}

} // namespace fairrsh
