/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include <glog/logging.h>

#include "state.h"
#include "utils.h"

#include "modules/gatherer.h"

#include <functional>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<GathererModule>("GathererModule")(
      metadata("type", rttr::type::get<GathererModule>()))
      .constructor();
}

namespace {

template <typename T>
std::string unitsToString(T&& units) {
  std::ostringstream oss;
  for (auto it = units.begin(); it != units.end(); ++it) {
    if (it != units.begin()) {
      oss << ", ";
    }
    oss << utils::unitString(*it);
  }
  return oss.str();
}

struct UpcWithSource {
  std::shared_ptr<UPCTuple> upc;
  int sourceId;

  UpcWithSource(std::shared_ptr<UPCTuple> upc, int sourceId)
      : upc(std::move(upc)), sourceId(sourceId) {}
};

/**
 * Shared data for gatherer tasks.
 *
 * The gatherer module stores all necessary data in GatherTaskData. For each
 * Gather UPCTUple it receives, a task will be spawned with the corresponding
 * non-zero probability worker units.
 * However, there's only one GathererData object for all tasks which is used
 * to assign workers to resources and to resource depots. Alternatively, we
 * could go with a single GathererTask object but this would break the one
 * UPCTuple -> one Task relationship.
 * GathererTask objects never succeed; they only fail if all their units have
 * died or have been reallocated.
 */
struct GathererData {
  struct Gatherer;

  struct Resource {
    int lastUpdate = 0;
    Unit* unit = nullptr;
    std::vector<Gatherer*> gatherers;
    Unit* depot = nullptr;
    float depotDistance = 0.0f;
  };

  struct Gatherer {
    int lastUpdate = 0;
    Unit* unit = nullptr;
    Resource* target = nullptr;
    int lastIssueGather = 0;
    UpcId upcId = -1; // UPC ID of the task that manages this unit
  };

  int lastUpdate = 0;

  std::unordered_map<Unit*, Resource> resourceMap;
  std::unordered_map<Unit*, Gatherer> gathererMap;
  int gasGatherers = 0;
  int maxGasGatherers = 0;

  void updateGatherer(
      State* state,
      Gatherer& g,
      std::vector<UpcWithSource>& newUpcs);
  void assignGatherer(Gatherer& g, Resource& r);
  void unassignGatherer(Gatherer& g);
  void updateResource(State* state, Resource& g);

  void stepUpdate(State* state, std::vector<UpcWithSource>& newUpcs);
};

class GathererTask : public Task {
 public:
  using Task::Task;

  void updateUnits(State* state);
  void update(State* state) override;

  std::shared_ptr<GathererData> data;
};

using utils::distance;
using utils::getBestScore;
using utils::getBestScoreCopy;

void GathererTask::updateUnits(State* state) {
  auto board = state->board();

  // Remove dead and re-assigned units
  for (auto it = units().begin(); it != units().end();) {
    Unit* unit = *it;
    if (unit->dead || board->taskWithUnit(unit).get() != this ||
        !unit->isMine || !unit->type->isWorker) {
      VLOG(1) << "Gatherer " << utils::unitString(unit)
              << (unit->dead ? " dead" : " re-assigned")
              << ", removing it from task " << utils::upcString(upcId());
      auto mit = data->gathererMap.find(unit);
      data->unassignGatherer(mit->second);
      data->gathererMap.erase(mit);
      units().erase(it++);
    } else {
      ++it;
    }
  }
}

void GathererTask::update(State* state) {
  updateUnits(state);
  if (units().empty()) {
    setStatus(TaskStatus::Failure);
  }
}

void GathererData::updateGatherer(
    State* state,
    Gatherer& g,
    std::vector<UpcWithSource>& newUpcs) {
  g.lastUpdate = state->currentFrame();

  if (g.target) {
    if (!g.target->depot) {
      unassignGatherer(g);
    }
  }

  // Pick a resource to mine from. Gas is preferred, resources closer to their
  // depot with less workers already mining from them are preferred. Switching
  // targets if we already have one is discouraged.
  // With normal bases, it should distribute the workers evenly to all mineral
  // patches and not assign more than 3 workers per patch. If there are too
  // many workers then it might long distance mine.
  // TODO: Workers may be assigned to unreachable resources!

  auto scoreFunc = [&](Resource& r) {
    if (!r.depot || !r.unit->completed()) {
      return std::numeric_limits<double>::infinity();
    }

    int n = r.gatherers.size();
    if (&r == g.target) {
      --n;
    }
    double score = 1 + (n * 3 /
                        (2 + std::max(r.depotDistance, 4.0f * 4) / (4.0f * 4)));
    if (r.unit->type->isGas) {
      if (!r.unit->isMine || (&r != g.target && (r.gatherers.size() >= 3 || gasGatherers >= maxGasGatherers)) || (&r == g.target && gasGatherers > maxGasGatherers)) {
        return std::numeric_limits<double>::infinity();
      }
      score -= 300;
    }
    if (g.target && g.target->gatherers.size() > 2) {
      score += distance(g.unit, r.unit) / (4 * 4 * 15);
    } else {
      if (distance(g.unit, r.unit) > 4.0f * 8) {
        score += 10;
      }
    }
    score += std::max(r.depotDistance, 4.0f * 4) / (4 * 2 * 15);
    if (r.depotDistance > 4.0f * 10) {
      score += 10000.0;
    }
    if (&r == g.target && r.gatherers.size() <= (r.unit->type->isGas ? 3 : 2)) {
      score -= 100;
    }
    return score;
  };

  auto newTarget = getBestScore(
      resourceMap,
      [&](auto& v) { return scoreFunc(v.second); },
      std::numeric_limits<double>::infinity());

  if (newTarget != resourceMap.end() && &newTarget->second != g.target) {
    assignGatherer(g, newTarget->second);
  }

  if (g.target) {
    bool issueGather = false;

    if (!g.unit->unit.orders.empty()) {
      auto& o = g.unit->unit.orders.front();
      if (g.unit->carryingMinerals() || g.unit->carryingGas()) {
        issueGather = (o.type != tc::BW::Order::ReturnMinerals &&
                       o.type != tc::BW::Order::ReturnGas) ||
            o.targetId != g.target->depot->id;
      } else {
        if (o.type != tc::BW::Order::MiningMinerals &&
            o.type != tc::BW::Order::HarvestGas) {
          if (g.target->unit->type->isGas) {
            issueGather = (o.type != tc::BW::Order::MoveToGas &&
                           o.type != tc::BW::Order::WaitForGas) ||
                o.targetId != g.target->unit->id;
          } else {
            issueGather = (o.type != tc::BW::Order::MoveToMinerals &&
                           o.type != tc::BW::Order::WaitForMinerals) ||
                o.targetId != g.target->unit->id;
          }
        }
      }
    }

    if (issueGather && (g.lastIssueGather == 0 ||
                        state->currentFrame() - g.lastIssueGather >= 15)) {
      g.lastIssueGather = state->currentFrame();

      // Post UPC for mineral gathering
      auto upc = std::make_shared<UPCTuple>();
      upc->unit[g.unit] = 1.0f;
      if (g.unit->carryingMinerals() || g.unit->carryingGas()) {
        upc->positionU[g.target->depot] = 1;
        upc->command[Command::Gather] = 1;
      } else {
        upc->positionU[g.target->unit] = 1;
        if (!g.target->unit->visible) {
          upc->command[Command::Move] = 1;
        } else {
          upc->command[Command::Gather] = 1;
        }
      }
      newUpcs.emplace_back(std::move(upc), g.upcId);
    }
  }
}

void GathererData::assignGatherer(Gatherer& g, Resource& r) {
  if (g.target) {
    unassignGatherer(g);
  }
  g.target = &r;
  r.gatherers.push_back(&g);
  VLOG(4) << "Assign gatherer " << utils::unitString(g.unit) << " to "
          << utils::unitString(g.target->unit);
  if (r.unit->type->isGas) {
    ++gasGatherers;
  }
}

void GathererData::unassignGatherer(Gatherer& g) {
  if (g.target) {
    if (g.target->unit->type->isGas) {
      --gasGatherers;
    }
    VLOG(4) << "Unassign gatherer " << utils::unitString(g.unit) << " from "
            << utils::unitString(g.target->unit);
    auto it =
        std::find(g.target->gatherers.begin(), g.target->gatherers.end(), &g);
    if (it != g.target->gatherers.end()) {
      g.target->gatherers.erase(it);
    } else {
      LOG(ERROR) << "Gatherer not present in expected list of target resource";
    }
    g.target = nullptr;
  }
}

void GathererData::updateResource(State* state, Resource& r) {
  r.lastUpdate = state->currentFrame();

  // Assign closest depot for this resource
  r.depot = getBestScoreCopy(
      state->unitsInfo().myResourceDepots(),
      [&](Unit* u) {
        if (!u->completed() && u->type != buildtypes::Zerg_Lair &&
            u->type != buildtypes::Zerg_Hive) {
          return std::numeric_limits<float>::infinity();
        }
        return distance(u, r.unit);
      },
      std::numeric_limits<float>::infinity());
  if (r.depot) {
    r.depotDistance = distance(r.unit, r.depot);
  }
}

void GathererData::stepUpdate(
    State* state,
    std::vector<UpcWithSource>& newUpcs) {
  int frame = state->currentFrame();
  lastUpdate = frame;

  bool hasEnoughGas = false;
  auto tcs = state->tcstate();
  auto res = tcs->frame->resources[tcs->player_id];
  if ((res.gas > 900 && res.gas > res.ore) || state->unitsInfo().myWorkers().size() < 30) {
    double maxGasCost = 0.0;
    for (const BuildType* t : buildtypes::allUnitTypes) {
      if (t->builder && !state->unitsInfo().myUnitsOfType(t->builder).empty()) {
        if (t->gasCost > maxGasCost) maxGasCost = t->gasCost;
      }
    }
    for (const BuildType* t : buildtypes::allUpgradeTypes) {
      if (t->builder && !state->unitsInfo().myUnitsOfType(t->builder).empty()) {
        if (t->gasCost > maxGasCost) maxGasCost = t->gasCost;
      }
    }
    for (const BuildType* t : buildtypes::allTechTypes) {
      if (t->builder && !state->unitsInfo().myUnitsOfType(t->builder).empty()) {
        if (t->gasCost > maxGasCost) maxGasCost = t->gasCost;
      }
    }
    if (res.gas >= maxGasCost) hasEnoughGas = true;
  }

  maxGasGatherers = 0;
  for (Unit* u : state->unitsInfo().myBuildings()) {
    if (u->type->isRefinery) {
      maxGasGatherers += hasEnoughGas ? 0 : 3;
    }
  }
  maxGasGatherers =
      std::min(maxGasGatherers, (int)state->unitsInfo().myWorkers().size() / 4);

  if (state->board()->hasKey("GathererMinGasGatherers")) {
    int n = state->board()->get<int>("GathererMinGasGatherers");
    if (maxGasGatherers < n) {
      maxGasGatherers = n;
    }
  }
  if (state->board()->hasKey("GathererMaxGasGatherers")) {
    int n = state->board()->get<int>("GathererMaxGasGatherers");
    if (maxGasGatherers > n) {
      maxGasGatherers = n;
    }
  }

  for (auto& it : gathererMap) {
    auto& g = it.second;
    if (frame - g.lastUpdate >= 8) {
      updateGatherer(state, g, newUpcs);
    }
  }

  for (Unit* u : state->unitsInfo().resourceUnits()) {
    auto i = resourceMap.emplace(
        std::piecewise_construct, std::make_tuple(u), std::make_tuple());
    auto& r = i.first->second;
    if (i.second) {
      r.unit = u;
      updateResource(state, r);
    } else if (frame - r.lastUpdate >= 30) {
      updateResource(state, r);
    }
  }

  gasGatherers = 0;
  for (auto i = gathererMap.begin(); i != gathererMap.end(); ++i) {
    auto& g = i->second;
    if (g.target && g.target->unit->type->isGas) {
      ++gasGatherers;
    }
  }

  for (auto i = resourceMap.begin(); i != resourceMap.end();) {
    auto& r = i->second;
    if (r.unit->dead || r.unit->gone) {
      for (auto* g : r.gatherers) {
        g->target = nullptr;
      }
      i = resourceMap.erase(i);
    } else {
      ++i;
    }
  }
}

} // namespace

void GathererModule::step(State* state) {
  auto board = state->board();

  // Everything is managed via a single GathererData object. Find it.
  std::shared_ptr<GathererData> gdata = nullptr;
  auto tasks = board->tasksOfModule(this);
  if (!tasks.empty()) {
    gdata = std::static_pointer_cast<GathererTask>(tasks.front())->data;
  }

  // Consume all Gather UPCs on the blackboard. For each UPC, create a new
  // GathererTask with the units in question and pointing to the central Data
  // object.
  std::unordered_map<Unit*, UpcId> newUnits;
  for (auto& v : state->board()->upcsWithSharpCommand(Command::Gather)) {
    UpcId upcId = v.first;
    auto& upc = v.second;
    std::unordered_set<Unit*> units;

    for (auto uit : upc->unit) {
      // XXX We check for owner again since it could happen that the unit was
      // assigned to another task in the same frame...
      auto owner = board->taskDataWithUnit(uit.first).owner;
      if (uit.second > 0 && owner == nullptr) {
        units.insert(uit.first);
        newUnits[uit.first] = upcId;
      }
    }

    if (!units.empty()) {
      if (gdata == nullptr) {
        gdata = std::make_shared<GathererData>();
      }
      auto task = std::make_shared<GathererTask>(upcId, units);
      task->data = gdata;

      VLOG(1) << "New Gatherer task " << upcId << " with units "
              << unitsToString(units);
      board->consumeUPC(upcId, this);
      board->postTask(task, this, true);
      tasks.push_back(task);
    }
  }

  if (gdata == nullptr) {
    // Nothing to do
    return;
  }

  // Add new resource containers
  for (Unit* u : state->unitsInfo().getNewUnits()) {
    if (u->type->isResourceContainer) {
      auto& r = gdata->resourceMap[u];
      r.unit = u;
      gdata->updateResource(state, r);
    }
  }

  // Create UPCs for new units
  std::vector<UpcWithSource> newUpcs;
  for (auto& it : newUnits) {
    auto& g = gdata->gathererMap[it.first];
    g.unit = it.first;
    g.upcId = it.second;
    gdata->updateGatherer(state, g, newUpcs);
  }

  if (gdata->lastUpdate == 0 ||
      state->currentFrame() - gdata->lastUpdate >= 8) {
    // In the current step, other tasks might have been created and gatherers
    // might have been re-assigned.
    for (auto task : tasks) {
      std::static_pointer_cast<GathererTask>(task)->updateUnits(state);
    }
    gdata->stepUpdate(state, newUpcs);
  }

  for (auto& upc : newUpcs) {
    board->postUPC(upc.upc, upc.sourceId, this);
  }
  Module::step(state);
}

} // namespace fairrsh
