/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "upc.h"

#include "state.h"

namespace fairrsh {

RTTR_REGISTRATION {
  rttr::registration::enumeration<Command>("Command")(
      rttr::value("Create", Command::Create),
      rttr::value("Move", Command::Move),
      rttr::value("Delete", Command::Delete),
      rttr::value("Gather", Command::Gather),
      rttr::value("MAX", Command::MAX));

  rttr::registration::class_<UPCTuple>("UPCTuple")(
      metadata("type", rttr::type::get<UPCTuple>()))
      .constructor()
      .property("unit", &UPCTuple::unit)
#ifdef WITH_ATEN
      .property("position", &UPCTuple::position)
      .property("state", &UPCTuple::state)
#endif // WITH_ATEN
      .property("command", &UPCTuple::command)
      .property("createType", &UPCTuple::createType)
      .property("scale", &UPCTuple::scale);
}

float UPCTuple::positionProb(int x, int y) const {
#ifdef WITH_ATEN
  if (position.defined() && position.dim() == 2) {
    int sx = x * scale;
    int sy = y * scale;
    if (sx < 0 || sy < 0 || sx >= position.size(0) || sy >= position.size(1)) {
      return 0.0f;
    }
    // ATen needs more const stuff
    auto t = const_cast<at::Tensor&>(position);
    return *(t[sx][sy].data<float>());
  }
#endif // WITH_ATEN

  if (!positionU.empty()) {
    for (auto& pair : positionU) {
      if (pair.first->x == x && pair.first->y == y) {
        return pair.second;
      }
    }
    return 0.0f;
  }

  if (positionS.x > -1 && positionS.y > -1) {
    if (scale == 1) {
      if (positionS.x == x && positionS.y == y) {
        return 1.0f;
      } else {
        return 0.0f;
      }
    } else {
      if (positionS.x == x / scale && positionS.y == y / scale) {
        return 1.0f;
      } else {
        return 0.0f;
      }
    }
  }

  if (positionA != nullptr) {
    auto area = positionA->areaInfo->tryGetArea({x, y});
    if (area == positionA) {
      return 1.0f;
    } else {
      return 0.0f;
    }
  }

  // Unspecified position: we assume that any position is good.
  return 1.0f;
}

#ifdef WITH_ATEN
at::Tensor UPCTuple::zeroPosition(State* state, int scale) {
  return at::CPU(at::kFloat)
      .zeros({state->mapWidth() / scale, state->mapHeight() / scale});
}
#endif // WITH_ATEN

} // namespace fairrsh
