/*	-----------------------------------------------------------------------------
	M A A S C R A F T

	StarCraft: Brood War - Bot

	Author: Dennis Soemers
	Maastricht University
	-----------------------------------------------------------------------------
*/

/*
	Implementation of UCTNode.h
*/

#include "CommonIncludes.h"

#ifdef ARMY_MANAGER_UCT

using namespace BWAPI;
using namespace std;

#include <cmath>

#include "MathConstants.h"
#include "RNG.h"
#include "UCTNode.h"

UCTNode::UCTNode(GameState& gameState, Action parentAction, UCTNode* parentNode) : 
	untriedActions(gameState.getActions()), children(), parentAction(parentAction), 
	playerToMove(gameState.getPlayerToMove()), rewardSum(0.0), numVisits(0), parentNode(parentNode)
{}

UCTNode::~UCTNode()
{
	if(children.size() > 0)
	{
		for(auto it = children.begin(); it != children.end(); ++it)
		{
			UCTNode* child = *it;
			if(child)
			{
				delete child;
				child = nullptr;
			}
		}
		children.clear();
	}
}

UCTNode* UCTNode::bestChild_maxAvgReward() const
{
#ifdef MAASCRAFT_DEBUG
	if(children.size() == 0)
	{
		LOG_WARNING("UCTNode::bestChild_maxAvgReward() children.size() == 0")
		LOG_WARNING(StringBuilder() << "	Num visits = " << numVisits)
		LOG_WARNING(StringBuilder() << "	Score = " << rewardSum)
		return nullptr;
	}
#endif

	if(children.size() == 1)
		return children.front();

	UCTNode* bestChild = nullptr;
	double maxAvgReward = MathConstants::MIN_DOUBLE;

	for(auto it = children.begin(); it != children.end(); ++it)
	{
		UCTNode* child = *it;

		const double avgReward = child->getRewardSum() / child->getNumVisits();

		if(avgReward > maxAvgReward)
		{
			maxAvgReward = avgReward;
			bestChild = child;
		}
	}

	return bestChild;
}

UCTNode* UCTNode::bestChild_maxVisitCount() const
{
#ifdef MAASCRAFT_DEBUG
	if(children.size() == 0)
	{
		LOG_WARNING("UCTNode::bestChild_maxVisitCount() children.size() == 0")
		LOG_WARNING(StringBuilder() << "	Num visits = " << numVisits)
		LOG_WARNING(StringBuilder() << "	Score = " << rewardSum)
		return nullptr;
	}
#endif

	if(children.size() == 1)
		return children.front();

	UCTNode* bestChild = nullptr;
	int maxVisitCount = MathConstants::MIN_INT;

	for(auto it = children.begin(); it != children.end(); ++it)
	{
		UCTNode* child = *it;

		const int visitCount = child->getNumVisits();

		if(visitCount > maxVisitCount)
		{
			maxVisitCount = visitCount;
			bestChild = child;
		}
	}

	return bestChild;
}

UCTNode* UCTNode::bestChild_maxVisitCount(string caller) const
{
#ifdef MAASCRAFT_DEBUG
	if(children.size() == 0)
	{
		LOG_WARNING("UCTNode::bestChild_maxVisitCount() children.size() == 0")
		LOG_WARNING(StringBuilder() << "Caller = " << caller)
		LOG_WARNING(StringBuilder() << "	Num visits = " << numVisits)
		LOG_WARNING(StringBuilder() << "	Score = " << rewardSum)
		return nullptr;
	}
#endif

	if(children.size() == 1)
		return children.front();

	UCTNode* bestChild = nullptr;
	int maxVisitCount = MathConstants::MIN_INT;

	for(auto it = children.begin(); it != children.end(); ++it)
	{
		UCTNode* child = *it;

		const int visitCount = child->getNumVisits();

		if(visitCount > maxVisitCount)
		{
			maxVisitCount = visitCount;
			bestChild = child;
		}
	}

	return bestChild;
}

UCTNode* UCTNode::bestChild_UCB1() const
{
	static const double EXPLORATION_CONSTANT = 3.0 * sqrt(2.0);

#ifdef MAASCRAFT_DEBUG
	if(children.size() == 0)
	{
		LOG_WARNING("UCTNode::bestChild_UCB1() children.size() == 0")
		return nullptr;
	}
#endif

	if(children.size() == 1)
		return children.front();

	UCTNode* bestChild = nullptr;
	double bestUCB = MathConstants::MIN_DOUBLE;

	double logTerm = log(double(numVisits));

	for(auto it = children.begin(); it != children.end(); ++it)
	{
		UCTNode* child = *it;

		const unsigned int childVisits = child->getNumVisits();
		const double one_over_visits = 1.0 / childVisits;
		const double ucb = (child->getRewardSum() * one_over_visits) + (EXPLORATION_CONSTANT * sqrt(logTerm * one_over_visits));

		if(ucb > bestUCB)
		{
			bestUCB = ucb;
			bestChild = child;
		}
	}

	return bestChild;
}

void UCTNode::addReward(const double reward)
{
	rewardSum += reward;
}

UCTNode* UCTNode::generateChild(GameState& gameState, Action actionToApply)
{
	gameState.applyAction(actionToApply);
	UCTNode* child = new UCTNode(gameState, actionToApply, this);
	children.push_back(child);
	return child;
}

const vector<UCTNode*>& UCTNode::getChildren() const
{
	return children;
}

const int UCTNode::getNumVisits() const
{
	return numVisits;
}

Action UCTNode::getParentAction() const
{
	return parentAction;
}

UCTNode* UCTNode::getParentNode() const
{
	return parentNode;
}

Player UCTNode::getPlayerToMove() const
{
	return playerToMove;
}

UCTNode* UCTNode::getRandomChild() const
{
	return children.at(RNG::randomInt(0, children.size() - 1));
}

const double UCTNode::getRewardSum() const
{
	return rewardSum;
}

const bool UCTNode::hasUntriedActions() const
{
	return !(untriedActions.empty());
}

void UCTNode::incrementVisitCount()
{
	numVisits++;
}

Action UCTNode::nextUntriedAction()
{
#ifdef MAASCRAFT_DEBUG
	if(untriedActions.empty())
	{
		LOG_WARNING("UCTNode::nextUntriedAction() called for empty untriedActions vector")
		return Action::NULL_ACTION;
	}
#endif

	Action action = Action(untriedActions.back());
	untriedActions.pop_back();

	return action;
}

#endif // ARMY_MANAGER_UCT