#include "UCT.h"

#define UCB_C 5   // this is the constant that regulates exploration vs exploitation, it must be tuned for each domain
#define EPSILON 0.2  // e-greedy startegy

UCT::UCT(int maxDepth, EvaluationFunction* ef, int maxSimulations, int maxSimulationTime)
{
    _maxDepth = maxDepth;
    _ef = ef;
    _maxSimulations= maxSimulations;
    _maxSimulationTime = maxSimulationTime;
    _maxDepthReached = 0;
    _maxDepthRolloutReached = 0;
    _totalBranching = 0;
    _numBranching = 0;
    _maxMissplacedUnits = 0;
}

playerActions_t UCT::start(bool player, GameState gs)
{
    timerUCT.start();
    //LOG("Starting UCT... " << player);
    playerActions_t bestAction = startUCT(gs, player, !player, _maxSimulations, gs._time+_maxSimulationTime);
    //LOG("UCT: " << bestMove.evaluation);
    // save stats
    unsigned int numberOfGroups = gs.friendlyUnits.size() + gs.enemyUnits.size();
    double timeUCT = timerUCT.stopAndGetTime();
    stats.groupTime[numberOfGroups] += timeUCT;
	stats.groupFrequency[numberOfGroups]++;
    LOG("Groups: " << numberOfGroups << " seconds: " << timeUCT << " depth: " << _maxDepthReached << " maxRolloutDepth: " << _maxDepthRolloutReached 
        << " avgBranching: " << _totalBranching/_numBranching);
//     LOG("   TreePolicy: " << timerTreePolicy.getElapsedTime());
//     LOG("   RolloutPolicy: " << timerRolloutPolicy.getElapsedTime());
//     LOG("      ActionGenerator: " << timerRolloutPolicyActionGenerator.getElapsedTime());
//     LOG("          constructor: " << timerAGconstructor.getElapsedTime());
//     LOG("          conditions: " << timerAGconditions.getElapsedTime());
//     LOG("      RandomAction: " << timerRolloutPolicyRandomAction.getElapsedTime());
//     LOG("      Execute: " << timerRolloutPolicyExecute.getElapsedTime());
//     LOG("      MoveForward: " << timerRolloutPolicyMoveForward.getElapsedTime());
//     LOG("      Evaluate: " << timerRolloutPolicyEvaluate.getElapsedTime());
//     LOG("   BackupPolicy: " << timerBackupPolicy.getElapsedTime());

    return bestAction;
}

playerActions_t UCT::startUCT(GameState gs, bool maxplayer, bool minplayer, int T, int cutOffTime)
{
    // create root node
    gameNode_t* tree = newGameNode(maxplayer, gs);

    // while withing computational budget
    for(int i = 0;i<T;i++) {
        //LOG("MCTSCD " << i);
        // tree policy, get best child
        //timerTreePolicy.start();
        gameNode_t* leaf = bestChild(tree, maxplayer);
        //timerTreePolicy.stop();
        
        if (!leaf->isEmpty) {
            // default policy, run simulation
            //timerRolloutPolicy.start();
            GameState gs2 = leaf->gs;
            simulate(gs2,cutOffTime,leaf->nextPlayerInSimultaneousNode);
            
            //timerRolloutPolicyEvaluate.start();
            int time = gs2._time - gs._time;
            double evaluation = _ef->evaluate(maxplayer, minplayer, gs2, 0)*pow(0.99,time/10.0);
            //timerRolloutPolicyEvaluate.stop();
            //timerRolloutPolicy.stop();
        
            // backup
            //timerBackupPolicy.start();
            while(leaf!=NULL) {
                leaf->totalEvaluation += evaluation;
                leaf->totalVisits++;
                leaf = leaf->parent;
            }
            //timerBackupPolicy.stop();
        }
    }
    
    // return best child
    int mostVisitedIdx = -1;
    gameNode_t* mostVisited = NULL;
    for(unsigned int i = 0;i<tree->children.size();i++) {
        gameNode_t* child = tree->children[i];
        //LOG("Visited: " << child->totalVisits << " reward: " << child->totalEvaluation/child->totalVisits);
        if (mostVisited == NULL || child->totalVisits > mostVisited->totalVisits) {
            mostVisited = child;
            mostVisitedIdx = i;
        }
    }

    playerActions_t bestActions;
    if (mostVisitedIdx!=-1) {
        bestActions = tree->actions[mostVisitedIdx];
    }

    // free memory
    deleteNode(tree);
    
    return bestActions;
}

void UCT::simulate(GameState gs, int time, __int8 nextSimultaneous)
{
    __int8 nextPlayerInSimultaneousNode = nextSimultaneous;
    int depth = 0;
    while(!gs.gameover() && gs._time < time) {
        //LOG("Simulation time " << gs._time << " of " << time);
        //LOG(gs.toString());
        // look next player to move
        //timerRolloutPolicyActionGenerator.start();
        ActionGenerator moveGenerator;
        int nextPlayer = -1;
        if (gs.canExecuteAnyAction(true)) {
            if (gs.canExecuteAnyAction(false)) {
                // if both can move: alternate
                nextPlayer = (int)nextPlayerInSimultaneousNode;
                nextPlayerInSimultaneousNode = 1 - nextPlayerInSimultaneousNode;
            } else {
                nextPlayer = 1;
            }
        } else {
            if (gs.canExecuteAnyAction(false)) nextPlayer = 0;
        }
        if (nextPlayer != -1) {
            moveGenerator = ActionGenerator(gs, nextPlayer==1);
            _totalBranching += (long)moveGenerator._size;
            _numBranching++;
        } else {
            DEBUG("New game node: This should not have happened...");
        }
        //timerRolloutPolicyActionGenerator.stop();

        // chose random action
        //timerRolloutPolicyRandomAction.start();
        playerActions_t unitsAction = moveGenerator.getRandomAction();
        //timerRolloutPolicyRandomAction.stop();
        //LOG(moveGenerator.toString());

        // execute action
        //timerRolloutPolicyExecute.start();
        gs.execute(unitsAction, moveGenerator._player);
        //timerRolloutPolicyExecute.stop();
        //timerRolloutPolicyMoveForward.start();
	    gs.moveForward();
        //timerRolloutPolicyMoveForward.stop();
        //LOG(gs.toString());
        depth++;
    }
    //LOG("End sim");
    if (_maxDepthRolloutReached < depth) _maxDepthRolloutReached = depth;
}

UCT::gameNode_t* UCT::newGameNode(bool maxplayer, GameState gs, gameNode_t* parent)
{
    gameNode_t* newGameNode = new gameNode_t;
    newGameNode->parent = parent;
    newGameNode->gs = gs;
    newGameNode->isEmpty = false;
    newGameNode->totalEvaluation = 0;
    newGameNode->totalVisits = 0;

    if (parent==NULL) {
        newGameNode->depth = 0;
        newGameNode->nextPlayerInSimultaneousNode = 0; // max player
    } else { 
        newGameNode->depth = parent->depth+1;
        newGameNode->nextPlayerInSimultaneousNode = parent->nextPlayerInSimultaneousNode;
    }

    if (_maxDepthReached < newGameNode->depth) _maxDepthReached = (int)newGameNode->depth;

    newGameNode->player = -1;
    if (gs.winner()!=-1 || gs.gameover()) {
        // its a leaf
        return newGameNode; 
    }

    if (gs.canExecuteAnyAction(maxplayer)) {
        if (gs.canExecuteAnyAction(!maxplayer)) {
            // if both can move: alternate
            newGameNode->player = newGameNode->nextPlayerInSimultaneousNode;
            newGameNode->nextPlayerInSimultaneousNode = 1 - newGameNode->nextPlayerInSimultaneousNode;
        } else {
            newGameNode->player = (int)maxplayer;
        }
    } else {
        if (gs.canExecuteAnyAction(!maxplayer)) newGameNode->player = (int)!maxplayer;
    }

    if (newGameNode->player != -1) {
        newGameNode->moveGenerator = ActionGenerator(gs, newGameNode->player==1);
        _totalBranching += (long)newGameNode->moveGenerator._size;
        _numBranching++;
    } else {
        DEBUG("New game node: This should not have happened...");
    }

    return newGameNode; 
}

void UCT::deleteNode(gameNode_t* node)
{
    for (unsigned int i = 0; i<node->children.size(); i++) {
        deleteNode(node->children[i]);
    }
    delete node;
}

UCT::gameNode_t* UCT::bestChild(gameNode_t* currentNode, bool maxplayer)
{
    // Cut the tree policy at a predefined depth
    if ( currentNode->depth >= _maxDepth ) return currentNode;

    // if gameover return this node
    if (currentNode->player == -1) return currentNode;

    // if no children yet, create one
    if (currentNode->children.empty()) {
        if ( !currentNode->moveGenerator.hasMoreActions() ) {
            DEBUG("Error crating first child");
            LOG(currentNode->gs.toString());
            LOG("Player: " << currentNode->player);
            return currentNode;
        }
        playerActions_t action = currentNode->moveGenerator.getNextAction();
        if (!action.empty()) {
            currentNode->actions.push_back(action);
            GameState gs2 = currentNode->gs.cloneIssue(action, currentNode->moveGenerator._player);
            gameNode_t* node = newGameNode(maxplayer, gs2, currentNode);
            currentNode->children.push_back(node);
            return node;                
        } else {
            DEBUG("Error generating action for first child");
            LOG(currentNode->gs.toString());
            LOG("Player: " << currentNode->player);
            return currentNode;
        }
    }
    
    // Bandit policy
    //gameNode_t* best = UCB(currentNode, maxplayer);
    gameNode_t* best = eGreedy(currentNode, maxplayer);
    
    if (best->isEmpty) {
        // No more leafs because this node has no children!
        delete best;
        return currentNode;
    }
    return bestChild(best, maxplayer);
}

UCT::gameNode_t* UCT::eGreedy(gameNode_t* currentNode, bool maxplayer)
{
    float randomNumber = (float)(std::rand() % 100) / (float)100;
    //LOG("Random number: " << randomNumber);
    gameNode_t* best = new gameNode_t;
    best->isEmpty = true;

    if (randomNumber < EPSILON) { // select random
        delete best; // we are going to create or point to one

        unsigned int totalChildren;
        double maxUnsignedInt = std::numeric_limits<unsigned int>::max();
        if ( currentNode->moveGenerator._size >  maxUnsignedInt ) totalChildren = (unsigned int)maxUnsignedInt;
        else totalChildren = (unsigned int)currentNode->moveGenerator._size;
        
        unsigned int createdChildren = currentNode->children.size()-1;

        unsigned int randomChoice = std::rand() % (totalChildren-1);

        if (randomChoice > createdChildren) { // create a new child
            playerActions_t action = currentNode->moveGenerator.getNextAction();
            if (!action.empty()) {
                currentNode->actions.push_back(action);
                GameState gs2 = currentNode->gs.cloneIssue(action, currentNode->moveGenerator._player);
                best = newGameNode(maxplayer, gs2, currentNode);
                currentNode->children.push_back(best);
            } else {
                DEBUG("Error generating action for a child");
                return currentNode;
            }
        } else { // pick one of the created children
            best = currentNode->children.at(randomChoice);
        }
    } else { // select max reward
        double bestScore = 0;
        double tmpScore;
        for (unsigned int i = 0; i<currentNode->children.size(); i++) {
            gameNode_t* child = currentNode->children[i];
            tmpScore = child->totalEvaluation / child->totalVisits;
            if (currentNode->player == 0) { // min node
                tmpScore = -tmpScore;
            }
            if (best->isEmpty || tmpScore>bestScore) {
                if (best->isEmpty) delete best;
                best = child;
                bestScore = tmpScore;
            }
        }
    }
        
    return best;

}

UCT::gameNode_t* UCT::UCB(gameNode_t* currentNode, bool maxplayer)
{
    // WARNING if branching factor too high we will stuck at this depth
    // if non visited children, visit
    if ( currentNode->moveGenerator.hasMoreActions() ) {
        playerActions_t action = currentNode->moveGenerator.getNextAction();
        if (!action.empty()) {
            currentNode->actions.push_back(action);
            GameState gs2 = currentNode->gs.cloneIssue(action, currentNode->moveGenerator._player);
            gameNode_t* node = newGameNode(maxplayer, gs2, currentNode);
            currentNode->children.push_back(node);
            return node;                
        } else {
            DEBUG("Error generating action for a child");
            return currentNode;
        }
    }

    double bestScore = 0;
    double tmpScore;
    gameNode_t* best = new gameNode_t;
    best->isEmpty = true;
    for (unsigned int i = 0; i<currentNode->children.size(); i++) {
        gameNode_t* child = currentNode->children[i];
        tmpScore = nodeValue(child);
        if (best->isEmpty || tmpScore>bestScore) {
            if (best->isEmpty) delete best;
            best = child;
            bestScore = tmpScore;
        }
    }

    return best;
}


double UCT::nodeValue(UCT::gameNode_t* node)
{
    double exploitation = node->totalEvaluation / node->totalVisits;
    double exploration = sqrt(log(node->parent->totalVisits/node->totalVisits));
    if (node->parent->player==1) { // max node:
        //exploitation = (exploitation + evaluation_bound)/(2*evaluation_bound);
    } else {
        //exploitation = - (exploitation - evaluation_bound)/(2*evaluation_bound);
        exploitation = - exploitation;
    }

    double tmp = UCB_C*exploitation + exploration;
    return tmp;
}