/**
 * Copyright (c) 2017-present, Facebook, Inc.
 * All rights reserved.
 */

#pragma once

#include <condition_variable>
#include <mutex>
#include <thread>
#include <vector>

#include "fairrsh.h"
#include "module.h"
#include "state.h"

namespace fairrsh {

/**
 * The main bot object.
 *
 * This class is used to play StarCraft Broodwar (TM) via the TorchCraft bridge.
 * The behavior and actions of the player are determined by a user-supplied list
 * of bot modules.
 */
class Player {
  RTTR_ENABLE()

  using ClientCommands = std::vector<tc::Client::Command>;

 public:
  Player(std::shared_ptr<tc::Client> client);
  virtual ~Player();
  Player(const Player&) = delete;
  Player& operator=(const Player&) = delete;

  State* state() {
    return state_;
  }

  std::shared_ptr<Module> getTopModule() const;
  void addModule(std::shared_ptr<Module> module);
  void addModules(std::vector<std::shared_ptr<Module>> const& modules);

  /// UI update frequency of Broodwar instance. Set this before calling init().
  void setFrameskip(int n);
  /// Combine n server-side frames before taking any action.
  /// Set this before calling init().
  void setCombineFrames(int n);

  /// Log a warning if step() exceeds a maximum duration.
  /// Defaults to true.
  void setWarnIfSlow(bool warn);
  /// Run bot step in separate thread to prevent blocking game execution.
  /// Defaults to false.
  void setNonBlocking(bool nonBlocking);
  /// Delay step() to make the game run in approx. factor*fastest speed.
  void setRealtimeFactor(float factor);

  /// Set whether to gather bot statistics during the game.
  void setCollectStats(bool collect);

  /// Set whether to gather timing statistics during the game.
  void setCollectTimers(bool collect);

  virtual void stepModule(std::shared_ptr<Module> module);
  void stepModules();
  void step();
  void init();
  void run();
  size_t steps() const {
    return steps_;
  }

 protected:
  using commandStartEndFrame =
      std::pair<tc::BW::UnitCommandType, std::pair<FrameNum, FrameNum>>;
  virtual void preStep();
  virtual void postStep();
  void logFailedCommands();

  std::shared_ptr<tc::Client> client_;
  int frameskip_ = 1;
  int combineFrames_ = 3;
  bool warnIfSlow_ = true;
  bool nonBlocking_ = false;
  bool collectStats_ = true;
  bool collectTimers_ = false;
  float realtimeFactor_ = -1.0f;
  std::vector<std::shared_ptr<Module>> modules_;
  State* state_;
  std::shared_ptr<Module> top_;
  std::unordered_map<std::shared_ptr<Module>, Duration> moduleTimeSpent_;
  std::unordered_map<std::shared_ptr<Module>, Duration> moduleTimeSpentAgg_;
  Duration stateUpdateTimeSpent_;
  Duration stateUpdateTimeSpentAgg_;
  size_t steps_ = 0;
  bool initialized_ = false;
  hires_clock::time_point lastStep_;
  std::map<int, std::vector<commandStartEndFrame>> commandStats_;
  std::map<int, std::vector<std::vector<commandStartEndFrame>::iterator>>
      commandStatsHelper_;
  // We collect unit deaths ourselves so that we don't miss any deaths that
  // occured during frames that we skipped.
  std::vector<int> unitDeaths_;

 private:
  void updateCommandStats(ClientCommands commands);

  void doStepThread();
  ClientCommands doStep();

  /// Lock client_ for step() and doStep().
  std::mutex stepClientMutex_;
  ClientCommands commands_;
  bool startStep_ = false;
  bool stepDone_ = false;
  bool exitStepThread_ = false;
  /// Mutex for stepCondition_. Also ensures you can read/write commands_,
  /// startStep_, exitStepThread_
  std::mutex stepMutex_;
  std::condition_variable stepStartCondition_;
  std::condition_variable stepDoneCondition_;
  std::unique_ptr<std::thread> stepThread_;
};

} // namespace fairrsh
