/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "modules/creategatherattack.h"

#include "player.h"
#include "state.h"
#include "utils.h"

#include <glog/logging.h>

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::class_<CreateGatherAttackModule>(
      "CreateGatherAttackModule")(
      metadata("type", rttr::type::get<CreateGatherAttackModule>()))
      .constructor();
}

void CreateGatherAttackModule::step(State* state) {
  auto board = state->board();

  // Is this the top module? Otherwise, consume the top module's UPC.
  // Check for UPC from top module
  int topUpcId = -1;
  if (player_->getTopModule().get() != this) {
    for (auto& upcs : board->upcsFrom(player_->getTopModule())) {
      topUpcId = upcs.first;
      break;
    }
    if (topUpcId < 0) {
      LOG(WARNING) << "Could not find UPC tuple from top module";
      return;
    }
  }

  if (!create_ || !gather_ || !attack_) {
    create_ = std::make_shared<UPCTuple>();
    create_->command[Command::Create] = 1;
    gather_ = std::make_shared<UPCTuple>();
    gather_->command[Command::Gather] = 1;
    attack_ = std::make_shared<UPCTuple>();
    attack_->command[Command::Delete] = 0.5;
    attack_->command[Command::Move] = 0.5;

    // Uniform position
#ifdef WITH_ATEN
    create_->scale = 64;
    create_->position = UPCTuple::zeroPosition(state, create_->scale);
    create_->position.fill_(1.0f / create_->position.numel());

    // Share tensors between UPCs
    gather_->scale = create_->scale;
    attack_->scale = create_->scale;
    gather_->position = create_->position;
    attack_->position = create_->position;
#endif // WITH_ATEN
    // Otherwise, undefined position equals to uniform
  }

  // Gather UPC contains workers, the other UPCs are left with an empty unit
  // map signalling that we don't specify any unit.

  // Gather UPC contains workers only. To avoid spamming of UPC filters,
  // we'll just include workers that aren't included in a task right now.
  gather_->unit.clear();
  auto& workers = state->unitsInfo().myWorkers();
  for (Unit* worker : workers) {
    if (board->taskWithUnit(worker) == nullptr) {
      gather_->unit[worker] = 1;
    }
  }

  // Repost UPC instances that aren't on the blackboard any more
  std::vector<std::shared_ptr<UPCTuple>> myUpcs;
  for (auto& it : board->upcsFrom(this)) {
    myUpcs.emplace_back(it.second);
  }

  bool consumed = false;
  for (auto& upc : {create_, gather_, attack_}) {
    if (std::find(myUpcs.begin(), myUpcs.end(), upc) == myUpcs.end()) {
      if (!consumed && topUpcId >= 0) {
        board->consumeUPC(topUpcId, this);
        consumed = true;
      }
      board->postUPC(upc, topUpcId, this);
    }
  }
}

} // namespace fairrsh
