/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#pragma once

#include <algorithm>
#include <cfloat>
#include <sstream>
#include <fstream>
#include <vector>
#include <math.h>

#ifdef WITH_ATEN
#include <ATen/ATen.h>
#endif // WITH_ATEN
#include <torchcraft/client.h>

#include "fairrsh.h"
#include "state.h"
#include "unitsinfo.h"
#include "upc.h"

namespace fairrsh {
namespace utils {

std::string buildTypeString(BuildType const* buildType);

template <typename Units, typename UnaryPredicate>
inline auto filterUnits(Units&& units, UnaryPredicate pred) {
  std::vector<typename std::decay<Units>::type::value_type> result;
  std::copy_if(units.begin(), units.end(), std::back_inserter(result), pred);
  return result;
}

template <typename Units, typename UnaryPredicate>
inline auto countUnits(Units&& units, UnaryPredicate pred) {
  return std::count_if(units.begin(), units.end(), pred);
}

inline std::vector<tc::Unit> filterUnitsByType(
    std::vector<tc::Unit> const& units,
    tc::BW::UnitType type) {
  return filterUnits(
      units, [type](tc::Unit const& u) { return u.type == type; });
}

inline std::vector<tc::Unit> filterUnitsByTypes(
    std::vector<tc::Unit> const& units,
    std::vector<tc::BW::UnitType> const& types) {
  return filterUnits(units, [types](tc::Unit const& u) {
    for (auto type : types) {
      if (u.type == type) {
        return true;
      }
    }
    return false;
  });
}

inline std::vector<tc::Unit> filterUnitsByType(
    std::vector<tc::Unit> const& units,
    std::function<bool(tc::BW::UnitType)> const& pred) {
  return filterUnits(units, [pred](tc::Unit const& u) {
    auto ut = tc::BW::UnitType::_from_integral_nothrow(u.type);
    if (ut) {
      return pred(*ut);
    }
    return false;
  });
}

inline bool prerequisitesReady(State* state, const BuildType* buildType) {
  auto& unitsInfo = state->unitsInfo();
  for (auto* prereq : buildType->prerequisites) {
    if (prereq->isUnit()) {
      bool hasPrereq = !unitsInfo.myCompletedUnitsOfType(prereq).empty();
      if (!hasPrereq) {
        if (prereq == buildtypes::Zerg_Spire) {
          hasPrereq =
              !unitsInfo.myUnitsOfType(buildtypes::Zerg_Greater_Spire).empty();
        } else if (prereq == buildtypes::Zerg_Hatchery) {
          hasPrereq =
              !unitsInfo.myCompletedUnitsOfType(buildtypes::Zerg_Lair).empty();
          if (!hasPrereq) {
            hasPrereq = !unitsInfo.myUnitsOfType(buildtypes::Zerg_Hive).empty();
          }
        } else if (prereq == buildtypes::Zerg_Lair) {
          hasPrereq = !unitsInfo.myUnitsOfType(buildtypes::Zerg_Hive).empty();
        }
      }
      if (!hasPrereq) {
        return false;
      }
    } else if (prereq->isUpgrade()) {
      if (state->getUpgradeLevel(prereq) < prereq->level) {
        return false;
      }
    } else if (prereq->isTech()) {
      if (!state->hasResearched(prereq)) {
        return false;
      }
    } else {
      VLOG(2) << "Unknown prerequisite " << buildTypeString(prereq) << " for "
              << buildTypeString(buildType);
      return false;
    }
  }
  return true;
}

// Determine the closest unit to a given position
template <typename It>
inline It getClosest(int x, int y, It first, It last) {
  It closest = last;
  float mind = FLT_MAX;
  while (first != last) {
    float d = float(x - first->x) * (x - first->x) +
        float(y - first->y) * (y - first->y);
    if (d < mind) {
      closest = first;
      mind = d;
    }
    ++first;
  }
  return closest;
}

// Check whether a unit's current orders include the given command
inline bool isExecutingCommand(
    tc::Unit const& unit,
    tc::BW::UnitCommandType command) {
  auto orders = tc::BW::commandToOrders(command);
  auto res = std::find_first_of(
      unit.orders.begin(),
      unit.orders.end(),
      orders.begin(),
      orders.end(),
      [](tc::Order const& o1, tc::BW::Order o2) { return o1.type == o2; });
  return res != unit.orders.end();
}

inline bool isExecutingCommand(
    Unit const* unit,
    tc::BW::UnitCommandType command) {
  return isExecutingCommand(unit->unit, std::move(command));
}

// Approximation of euclidian distance (as used by StarCraft)
inline unsigned int disthelper(unsigned int dx, unsigned int dy) {
  // Helper takes and returns pixels
  if (dx < dy) {
    std::swap(dx, dy);
  }
  if (dx / 4u < dy) {
    dx = dx - dx / 16u + dy * 3u / 8u - dx / 64u + dy * 3u / 256u;
  }
  return dx;
}

inline unsigned int pxdistance(int px1, int py1, int px2, int py2) {
  unsigned int dx = std::abs(px1 - px2);
  unsigned int dy = std::abs(py1 - py2);
  return disthelper(dx, dy);
}

// Walktile distance
inline float distance(int x1, int y1, int x2, int y2) {
  unsigned int dx = std::abs(x1 - x2) * unsigned(tc::BW::XYPixelsPerWalktile);
  unsigned int dy = std::abs(y1 - y2) * unsigned(tc::BW::XYPixelsPerWalktile);
  return float(disthelper(dx, dy)) / tc::BW::XYPixelsPerWalktile;
}

// Walktile distance
inline float distance(Unit const* a, Unit const* b) {
  return distance(a->x, a->y, b->x, b->y);
}

// Walktile distance
inline float distance(Position const& a, Position const& b) {
  return distance(a.x, a.y, b.x, b.y);
}

inline int pxDistanceBB(
    int xminA,
    int yminA,
    int xmaxA,
    int ymaxA,
    int xminB,
    int yminB,
    int xmaxB,
    int ymaxB) {
  if (xmaxB < xminA) { // To the left
    if (ymaxB < yminA) { // Fully above
      return pxdistance(xmaxB, ymaxB, xminA, yminA);
    } else if (yminB > ymaxA) { // Fully below
      return pxdistance(xmaxB, yminB, xminA, ymaxA);
    } else { // Adjecent
      return xminA - xmaxB;
    }
  } else if (xminB > xmaxA) { // To the right
    if (ymaxB < yminA) { // Fully above
      return pxdistance(xminB, ymaxB, xmaxA, yminA);
    } else if (yminB > ymaxA) { // Fully below
      return pxdistance(xminB, yminB, xmaxA, ymaxA);
    } else { // Adjecent
      return xminB - xmaxA;
    }
  } else if (ymaxB < yminA) { // Above
    return yminA - ymaxB;
  } else if (yminB > ymaxA) { // Below
    return yminB - ymaxA;
  }

  return 0;
}

inline int pxDistanceBB(Unit const* a, Unit const* b) {
  return pxDistanceBB(
      a->unit.pixel_x - a->type->dimensionLeft,
      a->unit.pixel_y - a->type->dimensionUp,
      a->unit.pixel_x + a->type->dimensionRight,
      a->unit.pixel_y + a->type->dimensionDown,
      b->unit.pixel_x - b->type->dimensionLeft,
      b->unit.pixel_y - b->type->dimensionUp,
      b->unit.pixel_x + b->type->dimensionRight,
      b->unit.pixel_y + b->type->dimensionDown);
}

inline bool isWithinRadius(Unit* unit, int32_t x, int32_t y, float radius) {
  return distance(unit->x, unit->y, x, y) <= radius;
}

inline float distanceBB(Unit const* a, Unit const* b) {
  return float(pxDistanceBB(a, b)) / tc::BW::XYPixelsPerWalktile;
}

// Bounding box distance given that unit a is in position a and unit b is in
// position b.
inline float
distanceBB(Unit const* a, Position pa, Unit const* b, Position pb) {
  return float(pxDistanceBB(
             pa.x * tc::BW::XYPixelsPerWalktile - a->type->dimensionLeft,
             pa.y * tc::BW::XYPixelsPerWalktile - a->type->dimensionUp,
             pa.x * tc::BW::XYPixelsPerWalktile + a->type->dimensionRight,
             pa.y * tc::BW::XYPixelsPerWalktile + a->type->dimensionDown,
             pb.x * tc::BW::XYPixelsPerWalktile - b->type->dimensionLeft,
             pb.y * tc::BW::XYPixelsPerWalktile - b->type->dimensionUp,
             pb.x * tc::BW::XYPixelsPerWalktile + b->type->dimensionRight,
             pb.y * tc::BW::XYPixelsPerWalktile + b->type->dimensionDown)) /
      tc::BW::XYPixelsPerWalktile;
}

template <typename Units>
inline std::vector<Unit*>
filterUnitsByDistance(Units&& units, int32_t x, int32_t y, float radius) {
  return filterUnits(
      units, [=](Unit* u) { return isWithinRadius(u, x, y, radius); });
}

template <class T, class Compare>
constexpr const T& clamp(const T& v, const T& lo, const T& hi, Compare comp) {
  return assert(!comp(hi, lo)), comp(v, lo) ? lo : comp(hi, v) ? hi : v;
}

template <class T>
constexpr const T& clamp(const T& v, const T& lo, const T& hi) {
  return clamp(v, lo, hi, std::less<>());
}

// Get movement towards position p, rotated by angle in degrees.
// If not exact, we click past it so we maintain flyer acceleration
// Positive angle rotates from the top right to the bottom left corner,
// since the y axis points down.
inline Position getMovePosHelper(
    int ux,
    int uy,
    int px,
    int py,
    int mx,
    int my,
    double angle,
    bool exact) {
  auto fdirX = px - ux;
  auto fdirY = py - uy;
  if (fdirX == 0 && fdirY == 0) {
    return Position(px, py);
  }
  auto rad = angle * DEG_PER_RAD;
  auto c = std::cos(rad);
  auto s = std::sin(rad);
  auto dirX = fdirX * c - fdirY * s;
  auto dirY = fdirX * s + fdirY * c;
  if (!exact && dirX * dirX + dirY * dirY < 10) {
    // Approximate, I don't want to compute the magnitude
    // Clicks at least 10 walktiles ahead
    auto div = std::abs(dirX == 0 ? dirY : dirX);
    dirX = dirX / div * 10;
    dirY = dirY / div * 10;
  }
  return Position(
      clamp(ux + (int)dirX, 0, mx - 1), clamp(uy + (int)dirY, 0, my - 1));
}
inline Position getMovePos(
    State* state,
    Unit* u,
    Position p,
    double angle = 0,
    bool exact = true) {
  return getMovePosHelper(
      u->x,
      u->y,
      p.x,
      p.y,
      state->mapWidth(),
      state->mapHeight(),
      angle,
      exact);
}
inline Position getMovePos(
    State* state,
    Unit* u,
    Unit* p,
    double angle = 0,
    bool exact = true) {
  return getMovePosHelper(
      u->x,
      u->y,
      p->x,
      p->y,
      state->mapWidth(),
      state->mapHeight(),
      angle,
      exact);
}

inline Position clampPositionToMap(
    State* state,
    int const x,
    int const y,
    bool strict = false) {
  auto cx = utils::clamp(x, 1, state->mapWidth() - 1);
  auto cy = utils::clamp(y, 1, state->mapHeight() - 1);
  if (strict && (cx != x || cy != y)) {
    return Position(-1, -1);
  }
  return Position(cx, cy);
}

inline Position
clampPositionToMap(State* state, Position const& pos, bool strict = false) {
  return clampPositionToMap(state, pos.x, pos.y, strict);
}

inline bool isWorker(tc::Unit const& unit) {
  auto ut = tc::BW::UnitType::_from_integral_nothrow(unit.type);
  if (ut) {
    return tc::BW::isWorker(*ut);
  }
  return false;
}

inline bool isBuilding(tc::Unit const& unit) {
  auto ut = tc::BW::UnitType::_from_integral_nothrow(unit.type);
  if (ut) {
    return tc::BW::isBuilding(*ut);
  }
  return false;
}

inline std::vector<tc::Unit> getWorkers(tc::State* state) {
  return utils::filterUnitsByType(
      state->units[state->player_id],
      static_cast<bool (*)(tc::BW::UnitType)>(&tc::BW::isWorker));
}

inline std::vector<tc::Unit> getMineralFields(tc::State* state) {
  return utils::filterUnitsByType(
      state->units[state->neutral_id], tc::BW::isMineralField);
}

// x,y in walktiles
inline bool isBuildable(tc::State* state, int x, int y) {
  if (x < 0 || y < 0 || x >= state->map_size[0] || y >= state->map_size[1]) {
    return false;
  }
  return state->buildable_data[y * state->map_size[0] + x];
}

#ifdef WITH_ATEN
// Returns argmax (x,y) and value in walktiles
inline std::tuple<int, int, float> argmax(at::Tensor const& pos, int scale) {
  if (!pos.defined() || pos.dim() != 2) {
    throw std::runtime_error("Two-dimensional tensor expected");
  }
  // ATen needs a const accessor...
  auto acc = const_cast<at::Tensor&>(pos).accessor<float, 2>();
  int xmax = 0;
  int ymax = 0;
  float max = std::numeric_limits<float>::lowest();
  for (int x = 0; x < acc.size(0); x++) {
    for (int y = 0; y < acc.size(1); y++) {
      auto el = acc[x][y];
      if (el > max) {
        max = el;
        xmax = x;
        ymax = y;
      }
    }
  }

  return std::make_tuple(xmax * scale, ymax * scale, max);
}
#else // WITH_ATEN
inline std::tuple<int, int, float> argmax(UPCTuple::UndefTensor const& pos, int scale) {
  throw std::runtime_error("Argmax of undefined tensor requested");
}
#endif // WITH_ATEN

inline UnitId commandUnitId(tc::Client::Command const& cmd) {
  if ((cmd.code == tc::BW::Command::CommandUnit ||
       cmd.code == tc::BW::Command::CommandUnitProtected) &&
      cmd.args.size() > 0) {
    return cmd.args[0];
  }
  return -1;
}

inline tc::BW::UnitCommandType commandUnitType(tc::Client::Command const& cmd) {
  if ((cmd.code == tc::BW::Command::CommandUnit ||
       cmd.code == tc::BW::Command::CommandUnitProtected) &&
      cmd.args.size() > 1) {
    auto uct = tc::BW::UnitCommandType::_from_integral_nothrow(cmd.args[1]);
    if (uct) {
      return *uct;
    }
  }
  return tc::BW::UnitCommandType::MAX;
}

inline tc::BW::UnitType buildCommandUnitType(tc::Client::Command const& cmd) {
  if ((cmd.code == tc::BW::Command::CommandUnit ||
       cmd.code == tc::BW::Command::CommandUnitProtected) &&
      cmd.args.size() > 5 && cmd.args[1] == tc::BW::UnitCommandType::Build) {
    auto ut = tc::BW::UnitType::_from_integral_nothrow(cmd.args[5]);
    if (ut) {
      return *ut;
    }
  }
  return tc::BW::UnitType::MAX;
}

inline tc::BW::UnitType trainCommandUnitType(tc::Client::Command const& cmd) {
  if ((cmd.code == tc::BW::Command::CommandUnit ||
       cmd.code == tc::BW::Command::CommandUnitProtected) &&
      cmd.args.size() > 2 && cmd.args[1] == tc::BW::UnitCommandType::Train) {
    if (cmd.args[2] < 0) {
      // Unit type specified in 'extra' field
      if (cmd.args.size() > 5) {
        auto ut = tc::BW::UnitType::_from_integral_nothrow(cmd.args[5]);
        if (ut) {
          return *ut;
        }
      } else {
        return tc::BW::UnitType::MAX;
      }
    } else {
      auto ut = tc::BW::UnitType::_from_integral_nothrow(cmd.args[2]);
      if (ut) {
        return *ut;
      }
    }
  }
  return tc::BW::UnitType::MAX;
}

inline Position buildCommandPosition(tc::Client::Command const& cmd) {
  if ((cmd.code == tc::BW::Command::CommandUnit ||
       cmd.code == tc::BW::Command::CommandUnitProtected) &&
      cmd.args.size() > 4 && cmd.args[1] == tc::BW::UnitCommandType::Build) {
    return Position(cmd.args[3], cmd.args[4]);
  }
  return Position(-1, -1);
}

template <typename Units>
inline Position centerOfUnits(Units&& units) {
  Position p;
  if (units.size() != 0) {
    for (Unit const* unit : units) {
      p += Position(unit);
    }
    p /= units.size();
  } else {
    VLOG(2) << "Center of no units is (0, 0)";
    return Position(0, 0);
  }
  return p;
}

template <typename InputIterator>
inline Position centerOfUnits(InputIterator start, InputIterator end) {
  Position p(0, 0);
  auto size = 0U;
  if (start == end) {
    VLOG(2) << "Center of no units is (0, 0)";
    return Position(0, 0);
  }
  for (; start != end; start++, size++) {
    p += Position(*start);
  }
  p /= size;
  return p;
}

template <typename Enumeration>
auto enumAsInt(Enumeration const value) ->
    typename std::underlying_type<Enumeration>::type {
  return static_cast<typename std::underlying_type<Enumeration>::type>(value);
}

struct NoValue {};
static constexpr NoValue noValue{};
template <typename A, typename B>
constexpr bool isEqualButNotNoValue(A&& a, B&& b) {
  return a == b;
}
template <typename A>
constexpr bool isEqualButNotNoValue(A&& a, NoValue) {
  return false;
}
template <typename B>
constexpr bool isEqualButNotNoValue(NoValue, B&& b) {
  return false;
}

/// This function iterates from begin to end, passing each value to the
/// provided score function.
/// It returns the corresponding iterator for the value whose score function
/// returned the lowest value (using operator <).
/// If invalidScore is provided, then any score which compares equal to it will
/// never be returned.
/// If bestPossibleScore is provided, then any score which compares equal to it
/// (and also is the best score thus far) will cause an immediate return,
/// without iterating to the end.
/// If the range is empty or no value can be returned (due to invalidScore),
/// then the end iterator is returned.
template <typename Iterator,
          typename Score,
          typename InvalidScore = NoValue,
          typename BestPossibleScore = NoValue>
auto getBestScore(
    Iterator begin,
    Iterator end,
    Score&& score,
    InvalidScore&& invalidScore = InvalidScore(),
    BestPossibleScore&& bestPossibleScore = BestPossibleScore()) {
  if (begin == end) {
    return end;
  }
  auto i = begin;
  auto best = i;
  auto bestScore = score(*i);
  ++i;
  if (isEqualButNotNoValue(bestScore, invalidScore)) {
    best = end;
    for (; i != end; ++i) {
      auto s = score(*i);
      if (isEqualButNotNoValue(s, invalidScore)) {
        continue;
      }
      best = i;
      bestScore = s;
      if (isEqualButNotNoValue(s, bestPossibleScore)) {
        return best;
      }
      break;
    }
  }
  for (; i != end; ++i) {
    auto s = score(*i);
    if (isEqualButNotNoValue(s, invalidScore)) {
      continue;
    }
    if (s < bestScore) {
      best = i;
      bestScore = s;
      if (isEqualButNotNoValue(s, bestPossibleScore)) {
        break;
      }
    }
  }
  return best;
}

/// This function is equivalent to getBestScore, but it can be passed a range
/// or container instead of two iterators.
/// The return value is still an iterator.
template <typename Range,
          typename Score,
          typename InvalidScore = NoValue,
          typename BestPossibleScore = NoValue>
auto getBestScore(
    Range&& range,
    Score&& score,
    InvalidScore&& invalidScore = InvalidScore(),
    BestPossibleScore&& bestPossibleScore = BestPossibleScore()) {
  return getBestScore(
      range.begin(),
      range.end(),
      std::forward<Score>(score),
      std::forward<InvalidScore>(invalidScore),
      std::forward<BestPossibleScore>(bestPossibleScore));
}

/// This function is the same as getBestScore, but it returns a copy of the
/// value retrieved by dereferencing the returned iterator (using auto type
/// semantics; it's a copy, not a reference).
/// If the end iterator would be returned, a value initialized object is
/// returned as if by T{}.
template <typename Range,
          typename Score,
          typename InvalidScore = NoValue,
          typename BestPossibleScore = NoValue>
auto getBestScoreCopy(
    Range&& range,
    Score&& score,
    InvalidScore&& invalidScore = InvalidScore(),
    BestPossibleScore&& bestPossibleScore = BestPossibleScore()) {
  auto i = getBestScore(
      range.begin(),
      range.end(),
      std::forward<Score>(score),
      std::forward<InvalidScore>(invalidScore),
      std::forward<BestPossibleScore>(bestPossibleScore));
  if (i == range.end()) {
    return typename std::remove_reference<decltype(*i)>::type{};
  }
  return *i;
}

/// This function is the same as getBestScore, but it returns a pointer
/// to the value of the dereferenced result iterator, or nullptr if the
/// end iterator would be returned.
template <typename Range,
          typename Score,
          typename InvalidScore = NoValue,
          typename BestPossibleScore = NoValue>
auto getBestScorePointer(
    Range&& range,
    Score&& score,
    InvalidScore&& invalidScore = InvalidScore(),
    BestPossibleScore&& bestPossibleScore = BestPossibleScore()) {
  auto i = getBestScore(
      range.begin(),
      range.end(),
      std::forward<Score>(score),
      std::forward<InvalidScore>(invalidScore),
      std::forward<BestPossibleScore>(bestPossibleScore));
  if (i == range.end()) {
    return (decltype(&*i)) nullptr;
  }
  return &*i;
}

template <typename T>
std::string stringToLower(T&& str) {
  std::string lowered;
  lowered.resize(str.size());
  std::transform(str.begin(), str.end(), lowered.begin(), tolower);
  return lowered;
}

template <typename T>
std::vector<std::string> stringSplit(T&& str, char sep) {
  std::vector<std::string> result;
  std::stringstream ss(str);
  while (ss.good()) {
    std::string t;
    std::getline(ss, t, sep);
    result.emplace_back(std::move(t));
  }
  return result;
}

// TODO just implement operator<< for Unit to be honest...
inline std::string unitString(Unit const* unit) {
  if (!unit)
    return "nullptr";
  std::ostringstream oss;
  // Log units with 'i' prefix so that we'll be able to use 'u' for UPC tuples
  oss << "i" << unit->id << " (" << unit->type->name << ")";
  return oss.str();
}

template <typename Units>
inline std::string unitsString(Units&& units) {
  std::ostringstream oss;
  oss << "[";
  for (auto unit : units) {
    oss << unitString(unit) << ",";
  }
  oss << "]";
  return oss.str();
}

inline std::string buildTypeString(BuildType const* buildType) {
  return (buildType ? buildType->name : "null");
}

template <typename T>
inline std::string resourcesString(T&& resources) {
  std::ostringstream oss;
  oss << resources.ore << " ore, " << resources.gas << " gas, "
      << resources.used_psi / 2 << "/" << resources.total_psi / 2 << " psi/2";
  return oss.str();
}

inline std::string upcString(UpcId id) {
  std::ostringstream oss;
  oss << "u" << id;
  return oss.str();
}

inline std::string upcString(std::shared_ptr<UPCTuple> const& upc, UpcId id) {
  std::ostringstream oss;
  oss << "u" << id << " (";
  if (upc->command[Command::Create] == 1) {
    oss << "C";
  } else if (upc->command[Command::Move] == 1) {
    oss << "M";
  } else if (upc->command[Command::Delete] == 1) {
    oss << "D";
  } else if (upc->command[Command::Gather] == 1) {
    oss << "G";
  } else {
    oss << "?";
  }

  // TODO Add more relevant information for specific UPcs?

  oss << ")";
  return oss.str();
}


inline auto makeSharpUPC(Unit* u, Command c) {
  auto upc = std::make_shared<UPCTuple>();
  upc->unit[u] = 1;
  upc->command[c] = 1;
  return upc;
}
inline auto makeSharpUPC(Unit* u, Position p, Command c) {
  auto upc = std::make_shared<UPCTuple>();
  upc->unit[u] = 1;
  upc->positionS = p;
  upc->command[c] = 1;
  return upc;
}
inline auto makeSharpUPC(Unit* u, Unit* p, Command c) {
  auto upc = std::make_shared<UPCTuple>();
  upc->unit[u] = 1;
  upc->positionU[p] = 1;
  upc->command[c] = 1;
  return upc;
}
inline auto makeSharpUPC(UPCTuple& other_upc, Unit* u, Command c) {
  auto upc = std::make_shared<UPCTuple>(other_upc);
  upc->unit[u] = 1;
  upc->command[c] = 1;
  return upc;
}

template <typename Units>
std::unordered_set<Unit*> findNearbyEnemyUnits(State* state, Units&& units) {
  auto& enemyUnits = state->unitsInfo().enemyUnits();
  std::unordered_set<Unit*> nearby;
  for (auto unit : units) {
    // from UAlbertaBot
    auto wRange = 75;
    for (auto enemy :
         filterUnitsByDistance(enemyUnits, unit->x, unit->y, wRange)) {
      // XXX What if it's gone??
      if (!enemy->gone) {
        nearby.insert(enemy);
      }
    }
  }
  return nearby;
}

inline void
drawLine(State* state, Position const& a, Position const& b, int color = 255) {
  state->board()->postCommand(
      {tc::BW::Command::DrawLine,
       a.x * tc::BW::XYPixelsPerWalktile,
       a.y * tc::BW::XYPixelsPerWalktile,
       b.x * tc::BW::XYPixelsPerWalktile,
       b.y * tc::BW::XYPixelsPerWalktile,
       color});
}

inline void
drawCircle(State* state, Position const& a, int radius, int color = 255) {
  state->board()->postCommand(
      {tc::BW::Command::DrawCircle,
       a.x * tc::BW::XYPixelsPerWalktile,
       a.y * tc::BW::XYPixelsPerWalktile,
       radius, // in pixels
       color});
}

// From https://gist.github.com/mrts/5890888, which is based on Alex Andrescu's
// implementation at http://bit.ly/2wfJnWn.
template <class Function>
class ScopeGuard {
 public:
  ScopeGuard(Function f) : guardFunction_(std::move(f)), active_(true) {}

  ~ScopeGuard() {
    if (active_) {
      guardFunction_();
    }
  }

  ScopeGuard(ScopeGuard&& rhs)
      : guardFunction_(std::move(rhs.guardFunction_)), active_(rhs.active_) {
    rhs.dismiss();
  }

  void dismiss() {
    active_ = false;
  }

 private:
  Function guardFunction_;
  bool active_;

  ScopeGuard() = delete;
  ScopeGuard(const ScopeGuard&) = delete;
  ScopeGuard& operator=(const ScopeGuard&) = delete;
};

template <class Function>
ScopeGuard<Function> makeGuard(Function f) {
  return ScopeGuard<Function>(std::move(f));
}

inline bool tcOrderIsAttack(int orderId) {
  auto order = tc::BW::Order::_from_integral_nothrow(orderId);
  if (!order) {
    return false;
  }
  switch (*order) {
    // case tc::BW::Order::Die:
    // case tc::BW::Order::Stop:
    case tc::BW::Order::Guard:
    case tc::BW::Order::PlayerGuard:
    case tc::BW::Order::TurretGuard:
    case tc::BW::Order::BunkerGuard:
    // case tc::BW::Order::Move:
    // case tc::BW::Order::ReaverStop:
    case tc::BW::Order::Attack1:
    case tc::BW::Order::Attack2:
    case tc::BW::Order::AttackUnit:
    case tc::BW::Order::AttackFixedRange:
    case tc::BW::Order::AttackTile:
    // case tc::BW::Order::Hover:
    case tc::BW::Order::AttackMove:
    // case tc::BW::Order::InfestedCommandCenter:
    // case tc::BW::Order::UnusedNothing:
    // case tc::BW::Order::UnusedPowerup:
    case tc::BW::Order::TowerGuard:
    case tc::BW::Order::TowerAttack:
    case tc::BW::Order::VultureMine:
    case tc::BW::Order::StayInRange:
    case tc::BW::Order::TurretAttack:
    // case tc::BW::Order::Nothing:
    // case tc::BW::Order::Unused_24:
    // case tc::BW::Order::DroneStartBuild:
    // case tc::BW::Order::DroneBuild:
    case tc::BW::Order::CastInfestation:
    case tc::BW::Order::MoveToInfest:
    case tc::BW::Order::InfestingCommandCenter:
    // case tc::BW::Order::PlaceBuilding:
    // case tc::BW::Order::PlaceProtossBuilding:
    // case tc::BW::Order::CreateProtossBuilding:
    // case tc::BW::Order::ConstructingBuilding:
    // case tc::BW::Order::Repair:
    // case tc::BW::Order::MoveToRepair:
    // case tc::BW::Order::PlaceAddon:
    // case tc::BW::Order::BuildAddon:
    // case tc::BW::Order::Train:
    // case tc::BW::Order::RallyPointUnit:
    // case tc::BW::Order::RallyPointTile:
    // case tc::BW::Order::ZergBirth:
    // case tc::BW::Order::ZergUnitMorph:
    // case tc::BW::Order::ZergBuildingMorph:
    // case tc::BW::Order::IncompleteBuilding:
    // case tc::BW::Order::IncompleteMorphing:
    // case tc::BW::Order::BuildNydusExit:
    // case tc::BW::Order::EnterNydusCanal:
    // case tc::BW::Order::IncompleteWarping:
    // case tc::BW::Order::Follow:
    // case tc::BW::Order::Carrier:
    // case tc::BW::Order::ReaverCarrierMove:
    // case tc::BW::Order::CarrierStop:
    case tc::BW::Order::CarrierAttack:
    case tc::BW::Order::CarrierMoveToAttack:
    // case tc::BW::Order::CarrierIgnore2:
    case tc::BW::Order::CarrierFight:
    case tc::BW::Order::CarrierHoldPosition:
    // case tc::BW::Order::Reaver:
    case tc::BW::Order::ReaverAttack:
    case tc::BW::Order::ReaverMoveToAttack:
    case tc::BW::Order::ReaverFight:
    case tc::BW::Order::ReaverHoldPosition:
    // case tc::BW::Order::TrainFighter:
    case tc::BW::Order::InterceptorAttack:
    case tc::BW::Order::ScarabAttack:
    // case tc::BW::Order::RechargeShieldsUnit:
    // case tc::BW::Order::RechargeShieldsBattery:
    // case tc::BW::Order::ShieldBattery:
    // case tc::BW::Order::InterceptorReturn:
    // case tc::BW::Order::DroneLand:
    // case tc::BW::Order::BuildingLand:
    // case tc::BW::Order::BuildingLiftOff:
    // case tc::BW::Order::DroneLiftOff:
    // case tc::BW::Order::LiftingOff:
    // case tc::BW::Order::ResearchTech:
    // case tc::BW::Order::Upgrade:
    // case tc::BW::Order::Larva:
    // case tc::BW::Order::SpawningLarva:
    // case tc::BW::Order::Harvest1:
    // case tc::BW::Order::Harvest2:
    // case tc::BW::Order::MoveToGas:
    // case tc::BW::Order::WaitForGas:
    // case tc::BW::Order::HarvestGas:
    // case tc::BW::Order::ReturnGas:
    // case tc::BW::Order::MoveToMinerals:
    // case tc::BW::Order::WaitForMinerals:
    // case tc::BW::Order::MiningMinerals:
    // case tc::BW::Order::Harvest3:
    // case tc::BW::Order::Harvest4:
    // case tc::BW::Order::ReturnMinerals:
    // case tc::BW::Order::Interrupted:
    // case tc::BW::Order::EnterTransport:
    // case tc::BW::Order::PickupIdle:
    // case tc::BW::Order::PickupTransport:
    // case tc::BW::Order::PickupBunker:
    // case tc::BW::Order::Pickup4:
    // case tc::BW::Order::PowerupIdle:
    // case tc::BW::Order::Sieging:
    // case tc::BW::Order::Unsieging:
    // case tc::BW::Order::WatchTarget:
    // case tc::BW::Order::InitCreepGrowth:
    // case tc::BW::Order::SpreadCreep:
    // case tc::BW::Order::StoppingCreepGrowth:
    // case tc::BW::Order::GuardianAspect:
    // case tc::BW::Order::ArchonWarp:
    // case tc::BW::Order::CompletingArchonSummon:
    case tc::BW::Order::HoldPosition:
    // case tc::BW::Order::QueenHoldPosition:
    // case tc::BW::Order::Cloak:
    // case tc::BW::Order::Decloak:
    // case tc::BW::Order::Unload:
    // case tc::BW::Order::MoveUnload:
    case tc::BW::Order::FireYamatoGun:
    case tc::BW::Order::MoveToFireYamatoGun:
    case tc::BW::Order::CastLockdown:
    // case tc::BW::Order::Burrowing:
    // case tc::BW::Order::Burrowed:
    // case tc::BW::Order::Unburrowing:
    // case tc::BW::Order::CastDarkSwarm:
    case tc::BW::Order::CastParasite:
    case tc::BW::Order::CastSpawnBroodlings:
    case tc::BW::Order::CastEMPShockwave:
    // case tc::BW::Order::NukeWait:
    // case tc::BW::Order::NukeTrain:
    // case tc::BW::Order::NukeLaunch:
    // case tc::BW::Order::NukePaint:
    case tc::BW::Order::NukeUnit:
    case tc::BW::Order::CastNuclearStrike:
    // case tc::BW::Order::NukeTrack:
    // case tc::BW::Order::InitializeArbiter:
    // case tc::BW::Order::CloakNearbyUnits:
    // case tc::BW::Order::PlaceMine:
    // case tc::BW::Order::RightClickAction:
    case tc::BW::Order::SuicideUnit:
    // case tc::BW::Order::SuicideLocation:
    case tc::BW::Order::SuicideHoldPosition:
    // case tc::BW::Order::CastRecall:
    // case tc::BW::Order::Teleport:
    // case tc::BW::Order::CastScannerSweep:
    // case tc::BW::Order::Scanner:
    // case tc::BW::Order::CastDefensiveMatrix:
    // case tc::BW::Order::CastPsionicStorm:
    case tc::BW::Order::CastIrradiate:
    // case tc::BW::Order::CastPlague:
    // case tc::BW::Order::CastConsume:
    // case tc::BW::Order::CastEnsnare:
    // case tc::BW::Order::CastStasisField:
    // case tc::BW::Order::CastHallucination:
    // case tc::BW::Order::Hallucination2:
    // case tc::BW::Order::ResetCollision:
    // case tc::BW::Order::ResetHarvestCollision:
    case tc::BW::Order::Patrol:
    // case tc::BW::Order::CTFCOPInit:
    // case tc::BW::Order::CTFCOPStarted:
    // case tc::BW::Order::CTFCOP2:
    // case tc::BW::Order::ComputerAI:
    case tc::BW::Order::AtkMoveEP:
    case tc::BW::Order::HarassMove:
    case tc::BW::Order::AIPatrol:
    // case tc::BW::Order::GuardPost:
    // case tc::BW::Order::RescuePassive:
    // case tc::BW::Order::Neutral:
    // case tc::BW::Order::ComputerReturn:
    // case tc::BW::Order::InitializePsiProvider:
    // case tc::BW::Order::SelfDestructing:
    // case tc::BW::Order::Critter:
    // case tc::BW::Order::HiddenGun:
    // case tc::BW::Order::OpenDoor:
    // case tc::BW::Order::CloseDoor:
    // case tc::BW::Order::HideTrap:
    // case tc::BW::Order::RevealTrap:
    // case tc::BW::Order::EnableDoodad:
    // case tc::BW::Order::DisableDoodad:
    // case tc::BW::Order::WarpIn:
    // case tc::BW::Order::Medic:
    // case tc::BW::Order::MedicHeal:
    // case tc::BW::Order::HealMove:
    // case tc::BW::Order::MedicHoldPosition:
    // case tc::BW::Order::MedicHealToIdle:
    // case tc::BW::Order::CastRestoration:
    // case tc::BW::Order::CastDisruptionWeb:
    case tc::BW::Order::CastMindControl:
    // case tc::BW::Order::DarkArchonMeld:
    case tc::BW::Order::CastFeedback:
    case tc::BW::Order::CastOpticalFlare:
      // case tc::BW::Order::CastMaelstrom:
      // case tc::BW::Order::JunkYardDog:
      // case tc::BW::Order::Fatal:
      // case tc::BW::Order::None:
      return true;
    default:
      return false;
  }
}

inline bool file_exists (const std::string& name) {
  std::ifstream f(name.c_str());
  return f.good();
}

template <typename T>
inline void inplace_flat_vector_add(std::vector<T>& in, const std::vector<T>& add) {
  for (size_t i = 0; i < in.size(); ++i) {
    in[i] += add[i];
  }
}

template <typename T>
inline void inplace_flat_vector_addcmul(std::vector<T>& in,
					const std::vector<T>& mul1,
					const std::vector<T>& mul2) {
  for (size_t i = 0; i < in.size(); ++i) {
    in[i] += mul1[i] * mul2[i];
  }
}
template <typename T>
inline void inplace_flat_vector_addcmul(std::vector<T>& in,
					const std::vector<T>& mul1,
					T mul2) {
  for (size_t i = 0; i < in.size(); ++i) {
    in[i] += mul1[i] * mul2;
  }
}

template <typename T>
inline void inplace_flat_vector_div(std::vector<T>& in, T div) {
  for (size_t i = 0; i < in.size(); ++i) {
    in[i] /= div;
  }
}

template <typename T>
inline T l2_norm_vector(const std::vector<T>& v) {
  T s2 = 0;
  for (auto& e : v)
    s2 += pow(e, 2);
  return sqrt(s2);
}

template <typename T>
inline size_t argmax(const std::vector<T>& v) {
  return std::distance(v.begin(), std::max_element(v.begin(), v.end()));
}

} // namespace utils
} // namespace fairrsh
