/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#include "upcfilter.h"

#include "state.h"

namespace fairrsh {

std::shared_ptr<UPCTuple> AssignedUnitsFilter::filter(
    State* state,
    std::shared_ptr<UPCTuple> upc,
    Module* origin) {
  auto board = state->board();
  auto fUpc = upc;

  if (upc->command[Command::Gather] == 1) {
    // This is a gather UPC. We'll set any units that are currently assigned to
    // a task to zero probability, provided that the origin of this UPC is not
    // the origin of that task.
    for (auto it : upc->unit) {
      if (it.second <= 0) {
        continue;
      }
      auto td = board->taskDataWithUnit(it.first);
      if (td.task == nullptr || td.owner == origin) {
        continue;
      }

      // Remove unit
      if (fUpc == upc) {
        fUpc = std::make_shared<UPCTuple>(*(upc.get()));
      }
      fUpc->unit[it.first] = 0;
      VLOG(1) << "Removed unit " << it.first->id
              << " from gather UPC since it is already assigned";
    }
  }

  return fUpc;
};

} // namespace fairrsh
