/*	-----------------------------------------------------------------------------
	M A A S C R A F T

	StarCraft: Brood War - Bot

	Author: Dennis Soemers
	Maastricht University
	-----------------------------------------------------------------------------
*/

#pragma once

/*
	Implementation of UCT.h
*/

#include "CommonIncludes.h"

#ifdef ARMY_MANAGER_UCT

#include "BaseManager.h"
#include "Distances.h"
#include "MathConstants.h"
#include "OpponentTracker.h"
#include "RNG.h"
#include "Squad.h"
#include "Timer.hpp"
#include "UCT.h"
#include "UnitTracker.h"

using namespace BWAPI;
using namespace std;

#pragma warning( push )
#pragma warning( disable:4018 )

UCT::UCT() : 
	trueGameState(), gameStateCopy(), mastMap(), rootNode(nullptr), startFrame(-1000)
#ifdef MAASCRAFT_DEBUG
	, verbose(false)
#endif
{}

UCT::~UCT()
{}

const int UCT::getStartFrame() const
{
	return startFrame;
}

void UCT::init(const vector<Unitset>& airSquads, const vector<Unitset>& groundSquads)
{
	startFrame = Broodwar->getFrameCount();
	trueGameState.reset(airSquads, groundSquads);
	mastMap.clear();

	if(rootNode)
	{
		delete rootNode;
	}

	rootNode = new UCTNode(trueGameState, Action::NULL_ACTION, nullptr);
}

void UCT::fillSquads(std::vector<Squad*>& airSquads, std::vector<Squad*>& groundSquads)
{
#ifdef MAASCRAFT_DEBUG
	if(!airSquads.empty() || !groundSquads.empty())
	{
		LOG_WARNING("UCT::fillSquads() called with vectors that were not empty")
		airSquads.clear();
		groundSquads.clear();
	}
#endif

	UnitTracker* unitTracker = UnitTracker::Instance();
	UCTNode* node = rootNode;

	const vector<SimSquad>& selfSquads = trueGameState.getSelfSquads();
	while(node && node->getPlayerToMove() == Broodwar->self())
	{
		UCTNode* child = node->bestChild_maxVisitCount();

		if(!child)
		{
			return;
		}
		
		int squadIndex = child->getParentAction().squadIndex();
		const SimSquad& simSquad = selfSquads.at(squadIndex);

		if(!simSquad.isBase())
		{
			Squad* squad = new Squad;
			const Unitset& units = simSquad.getUnits();

			for(auto it = units.begin(); it != units.end(); ++it)
			{
				Unit unit = *it;

				if(UnitUtils::isUnitValid(unit))
					unitTracker->getUnitOwner(unit)->transferOwnershipTo(unit, squad);
			}

			squad->setTask(child->getParentAction());

			if(simSquad.flyersOnly())
				airSquads.push_back(squad);
			else
				groundSquads.push_back(squad);
		}

		node = child;
	}
}

void UCT::onFrame(const int allowedMilliSeconds)
{
	if(trueGameState.getSelfSquads().size() == 0)
		return;		// no point in running UCT if our only squads are dummy squads for bases

	const vector<SimSquad>& enemySquads = trueGameState.getEnemySquads();
	const vector<SimSquad>& enemyBases = trueGameState.getEnemyBases();

	Timer timer;
	timer.start();

	while(timer.getElapsedTimeInMilliSec() < allowedMilliSeconds)
	{
		gameStateCopy = GameState(trueGameState);		// make new copy of root state

		UCTNode* node = treePolicy(rootNode);
		
#ifdef MAASCRAFT_DEBUG
		if(node == rootNode)
		{
			LOG_MESSAGE("Node from treePolicy equals rootNode!")
		}
#endif

		const double reward = defaultPolicy(gameStateCopy);
		backPropagate(node, reward);
	}

	UCTNode* node = rootNode;

	node = rootNode;

#ifdef MAASCRAFT_DEBUG
	/*
		Visualize rewards of moves
	*/
	/*
	while(node && node->getPlayerToMove() == Broodwar->self())
	{
		const vector<UCTNode*>& children = node->getChildren();
		for(auto it = children.begin(); it != children.end(); ++it)
		{
			UCTNode* child = *it;

			int squadIndex = child->getParentAction().squadIndex();
			bool allied = true;
			if(squadIndex >= selfSquads.size())
			{
				squadIndex -= selfSquads.size();
				allied = false;
			}

			const SimSquad& simSquad = allied ? selfSquads.at(squadIndex) : enemySquads.at(squadIndex);

			Action action = child->getParentAction();
			double score = child->getRewardSum() / child->getNumVisits();
			Color color = Colors::Blue;

			if(score <= -2.0)
				color = Colors::Black;
			else if(score >= 2.0)
				color = Colors::White;
			else if(score >= 1.0)
				color = Colors::Green;
			else if(score <= -1.0)
				color = Colors::Red;
			else if(score > 0.0)
				color = Colors::Teal;
			else if(score < 0.0)
				color = Colors::Orange;
			else if(score == 0.0)
				color = Colors::Yellow;

			if(action.destination() == simSquad.getGraphNode())
			{
				Broodwar->drawCircleMap(simSquad.getUnits().getPosition(), 10, color, true);
				Broodwar->drawTextMap(simSquad.getUnits().getPosition() - Position(0, 15), "%0.5f", score);
			}
			else
			{
				Broodwar->drawLineMap(simSquad.getUnits().getPosition(), action.destination()->getPosition(), color);
				Broodwar->drawTextMap(action.destination()->getPosition(), "%0.5f", score);
			}
		}
		
		node = node->bestChild_maxVisitCount("UCT::onFrame() visualization");
	}
	*/
	if(verbose)
	{
		LOG_MESSAGE("STARTING UCT DUMP!")
		node = rootNode;
		GameState state(trueGameState);

		state.dumpInfo();

		while(node)
		{
			node = node->bestChild_maxVisitCount("UCT::onFrame() verbose");

			if(!node)
			{
				LOG_MESSAGE("No more child")
				break;
			}

			LOG_MESSAGE(StringBuilder() << "Avg node reward = " << node->getRewardSum() / node->getNumVisits())
			Action action = node->getParentAction();
			state.applyAction(node->getParentAction());

			if(action.type() == RESOLVE_ACTIONS)
			{
				state.dumpInfo();
			}

			if(state.isTerminal())
			{
				LOG_MESSAGE("Terminal state!")
				break;
			}
		}
	}

	verbose = false;
#endif
}

void UCT::backPropagate(UCTNode* node, const double reward)
{
	while(node->getParentNode())
	{
		node->incrementVisitCount();

		Action playedAction = node->getParentAction();

		if(node->getParentNode()->getPlayerToMove() == Broodwar->enemy())
		{
			node->addReward(-reward);

			if(mastMap.count(playedAction) == 0)
			{
				mastMap.insert(std::pair<Action, Pair<double, int>>(playedAction, Pair<double, int>(-reward, 1)));
			}
			else
			{
				Pair<double, int>& pair = mastMap.at(playedAction);
				pair.first -= reward;
				pair.second += 1;
			}
		}
		else
		{
			node->addReward(reward);

			if(mastMap.count(playedAction) == 0)
			{
				mastMap.insert(std::pair<Action, Pair<double, int>>(playedAction, Pair<double, int>(reward, 1)));
			}
			else
			{
				Pair<double, int>& pair = mastMap.at(playedAction);
				pair.first += reward;
				pair.second += 1;
			}
		}

		node = node->getParentNode();
	}

	node->addReward(reward);
	node->incrementVisitCount();
}

const double UCT::defaultPolicy(GameState& gameState)
{
	static const double MAST_EPSILON = 0.7;

	GameStateHitPoints startStateHP = trueGameState.computeHitPoints();

	while(!gameState.isTerminal())
	{
		/*	NOT USING MAST
		gameState.applyAction(gameState.getRandomAction());
		*/

		/*
			USING MAST
		*/
		if(RNG::randomInt(0, 10) / 10.0 <= MAST_EPSILON)	// epsilon probability of playing random
		{
			gameState.applyAction(gameState.getRandomAction());
		}
		else		// (1 - epsilon) probablity of playing best action
		{
			vector<Action> actions = gameState.getActions();
			double highestAverage = MathConstants::MIN_DOUBLE;
			Action bestAction = Action::NULL_ACTION;

			for(auto it = actions.begin(); it != actions.end(); ++it)
			{
				Action action = *it;

				if(mastMap.count(action) == 0)
				{
					bestAction = action;
					break;
				}
				else
				{
					Pair<double, int>& pair = mastMap.at(action);
					double average = pair.first / pair.second;

					if(average > highestAverage)
					{
						highestAverage = average;
						bestAction = action;
					}
				}
			}

			gameState.applyAction(bestAction);
		}
	}

	return gameState.getReward(trueGameState, startStateHP);
}

UCTNode* UCT::expand(UCTNode* node)
{
	Action action = node->nextUntriedAction();
	return node->generateChild(gameStateCopy, action);
}

UCTNode* UCT::treePolicy(UCTNode* node)
{
	while(!gameStateCopy.isTerminal())
	{
		if(node->hasUntriedActions())
		{
			return expand(node);
		}
		else
		{
			node = node->bestChild_UCB1();
			gameStateCopy.applyAction(node->getParentAction());
		}
	}
	
	return node;
}

#ifdef MAASCRAFT_DEBUG
void UCT::logRandomTraversal()
{
	LOG_MESSAGE("STARTING RANDOM TRAVERSAL!")
	UCTNode* node = rootNode;
	GameState state(trueGameState);

	state.dumpInfo();

	while(node->getChildren().size())
	{
		node = node->getRandomChild();

		Action action = node->getParentAction();
		state.applyAction(node->getParentAction());

		if(action.type() == RESOLVE_ACTIONS)
			state.dumpInfo();

		if(state.isTerminal())
		{
			LOG_MESSAGE("Terminal state!")
			break;
		}
	}

	while(!state.isTerminal())
	{
		Action action = state.getRandomAction();
		state.applyAction(action);

		if(action.type() == RESOLVE_ACTIONS)
			state.dumpInfo();
	}
	LOG_MESSAGE("Terminal state!")
}

void UCT::setVerbose()
{
	verbose = true;
}
#endif

#pragma warning( pop )

#endif // ARMY_MANAGER_UCT