/*
 * Copyright (c) 2017-present, Facebook, Inc.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "buildingplacer.h"
#include "builderhelper.h"

#include "state.h"
#include "task.h"
#include "utils.h"

#include <glog/logging.h>

//DEFINE_string(bp_model, "", "Path to building placer model");

namespace cherrypi {

REGISTER_SUBCLASS_0(Module, BuildingPlacerModule);

namespace {

/// Proxies the task to create a building and re-tries at different locations if
/// necessary.
class BuildingPlacerTask : public ProxyTask {
 private:
  bool reserved_ = false;

 public:
  std::shared_ptr<UPCTuple> sourceUpc;
  BuildType const* type;
  Position pos;
  /// Need to send another UPC for this building?
  bool respawn = false;

  BuildingPlacerTask(
      UpcId targetUpcId,
      UpcId upcId,
      std::shared_ptr<UPCTuple> sourceUpc,
      BuildType const* type,
      Position pos)
      : ProxyTask(targetUpcId, upcId),
        sourceUpc(std::move(sourceUpc)),
        type(type),
        pos(std::move(pos)) {}
  virtual ~BuildingPlacerTask() = default;

  UpcId targetUpcId() {
    return targetUpcId_;
  }

  void setTarget(UpcId targetUpcId) {
    targetUpcId_ = targetUpcId;
  }

  void setPosition(Position p) {
    pos = std::move(p);
  }

  void reserveLocation(State* state) {
    if (!reserved_) {
      VLOG(0) << "Reserve for " << utils::upcString(upcId()) << " ("
              << utils::buildTypeString(type) << " at " << pos << ")";
      try {
        builderhelpers::fullReserve(state->tilesInfo(), type, pos);
      } catch (const std::exception&) {
        cancel(state);
        return;
      }

    }
    reserved_ = true;
  }

  void unreserveLocation(State* state) {
    if (reserved_) {
      VLOG(0) << "Unreserve for " << utils::upcString(upcId()) << " ("
              << utils::buildTypeString(type) << " at " << pos << ")";
      try {
        builderhelpers::fullUnreserve(state->tilesInfo(), type, pos);
      } catch (const std::exception&) {}
    }
    reserved_ = false;
  }

  virtual void update(State* state) override {
    ProxyTask::update(state);

    if (finished()) {
      VLOG(2) << "Proxied building task for " << utils::upcString(upcId())
              << " (" << utils::buildTypeString(type) << " at " << pos
              << ") finished";
      unreserveLocation(state);
//      if (status() == TaskStatus::Failure) {
//        VLOG(2) << "Proxied building task for " << utils::upcString(upcId())
//                << " (" << utils::buildTypeString(type) << " at " << pos
//                << ") failed; scheduling retry";
//        respawn = true;
//        setStatus(TaskStatus::Unknown);
//        target_ = nullptr;
//        targetUpcId_ = kInvalidUpcId;
//      } else {
//        unreserveLocation(state);
//      }
    }
  }

  virtual void cancel(State* state) override {
    ProxyTask::cancel(state);
    if (reserved_) {
      unreserveLocation(state);
    }
  }
};

} // namespace

void BuildingPlacerModule::step(State* state) {
  auto board = state->board();

  // Cache BWEM base locations
  if (baseLocations_.empty()) {
    for (auto& area : state->areaInfo().areas()) {
      baseLocations_.insert(
          area.baseLocations.begin(), area.baseLocations.end());
    }
  }

  // Fully initialize the model by doing a dummy forward pass in the first
  // frame; we'll have enough time there then.
  if (firstStep_) {
    firstStep_ = false;

    // We'll also initialize the static map features now
    //staticData_ = std::make_shared<BuildingPlacerSample::StaticData>(state);

    auto upc = UPCTuple();
    upc.command[Command::Create] = 1;
    upc.state = UPCTuple::BuildTypeMap{{buildtypes::Zerg_Hatchery, 1}};
    upcWithPositionForBuilding(state, upc, buildtypes::Zerg_Hatchery);
  }

  for (auto const& upct : board->upcsWithSharpCommand(Command::Create)) {
    auto upcId = upct.first;
    auto& upc = *upct.second;

    // Do we know what we want?
    auto ctMax = upc.createTypeArgMax();
    if (ctMax.second < 0.99f) {
      VLOG(4) << "Not sure what we want? argmax over build types ="
              << ctMax.second;
      continue;
    }
    BuildType const* type = ctMax.first;

    std::shared_ptr<UPCTuple> newUpc;
    if (type->isBuilding && type->builder->isWorker) {
      newUpc = upcWithPositionForBuilding(state, upc, type);
//      if (!newUpc && type->requiresPsi) {
//        type = buildtypes::Protoss_Pylon;
//        newUpc = upcWithPositionForBuilding(state, upc, buildtypes::Protoss_Pylon);
//      }
    }

    // Ignore the UPC if we can't determine a position
    if (newUpc == nullptr) {
      continue;
    }

    // Post new UPC along with a ProxyTask
    auto pos = newUpc->position.get<Position>();
    auto newUpcId = board->postUPC(std::move(newUpc), upcId, this);
    if (newUpcId >= 0) {
      board->consumeUPC(upcId, this);
      if (builderhelpers::canBuildAt(state, type, pos)) {
        auto task = std::make_shared<BuildingPlacerTask>(
            newUpcId, upcId, upct.second, type, pos);
        task->reserveLocation(state);
        board->postTask(std::move(task), this, true);
      } else {
        board->consumeUPC(newUpcId, this);
      }
    }
  }

  // We need to update the upc id of any SetCreatePriority commands whose
  // Create task we are proxying.
  for (auto const& upct :
       board->upcsWithSharpCommand(Command::SetCreatePriority)) {
    auto upcId = upct.first;
    auto& upc = *upct.second;
    if (upc.state.is<UPCTuple::SetCreatePriorityState>()) {
      auto st = upc.state.get_unchecked<UPCTuple::SetCreatePriorityState>();
      for (auto& task : board->tasksOfModule(this)) {
        if (task->upcId() == std::get<0>(st)) {
          auto bptask = std::static_pointer_cast<BuildingPlacerTask>(task);
          std::shared_ptr<UPCTuple> newUpc = std::make_shared<UPCTuple>(upc);
          std::get<0>(st) = bptask->targetUpcId();
          newUpc->state = st;
          auto newUpcId = board->postUPC(std::move(newUpc), upcId, this);
          if (newUpcId >= 0) {
            board->consumeUPC(upcId, this);
          }
          break;
        }
      }
    }
  }

  // Any scheduled retries?
  for (auto& task : board->tasksOfModule(this)) {
    auto bptask = std::static_pointer_cast<BuildingPlacerTask>(task);
    if (!bptask->respawn) {
      continue;
    }

    bptask->unreserveLocation(state);

    std::shared_ptr<UPCTuple> newUpc;
    if (bptask->type->isBuilding && bptask->type->builder->isWorker) {
      newUpc =
          upcWithPositionForBuilding(state, *bptask->sourceUpc, bptask->type);
    }

    if (newUpc == nullptr) {
      continue;
    }

    auto pos = newUpc->position.get<Position>();
    auto newUpcId = board->postUPC(std::move(newUpc), bptask->upcId(), this);
    if (newUpcId >= 0) {
      bptask->respawn = false;
      bptask->setTarget(newUpcId);
      bptask->setPosition(pos);
      bptask->reserveLocation(state);
    }
  }
}

void BuildingPlacerModule::onGameStart(State* state) {
  firstStep_ = true;
  baseLocations_.clear();
}

std::shared_ptr<UPCTuple> BuildingPlacerModule::upcWithPositionForBuilding(
    State* state,
    UPCTuple const& upc,
    BuildType const* type) {
  // Perform placement with rules so we a) have a fallback and b) a candidate
  // area for the model.
  std::shared_ptr<UPCTuple> seedUpc =
      builderhelpers::upcWithPositionForBuilding(state, upc, type);
  return seedUpc;
}

} // namespace cherrypi
