#include "AbstractLayer.h"
#include "ABCD.h"
#include "UCT.h"
#include "EvaluationFunctionBasic.h"

AbstractLayer::AbstractLayer()
{
    // clear lists
    informationManager->gameState.friendlyUnits.clear();
    informationManager->gameState.enemyUnits.clear();

    // add only enemy units
    informationManager->gameState.addAllEnemyUnits();

    // add our buildings
    informationManager->gameState.addSelfBuildings();

}

// Add squad to game state and store reference ID
void AbstractLayer::addSquadToGameState(SquadAgent* squad)
{
    //LOG("Best group for squad (" << squad << "):");
    std::map<unsigned int, unsigned int> groupIdFrequency;
    unsigned int abstractGroupID;
    // for each unit on the squad, add it to the game state and save id reference
    for(CombatUnitSet::const_iterator i = squad->_squadUnits.begin(); i!=squad->_squadUnits.end(); ++i) {
        if ((*i)->_unit->getType().isWorker()) continue; // ignore workers
        //LOG("  - " << (*i)->_unit->getType().getName());
		abstractGroupID = informationManager->gameState.addFriendlyUnit((*i)->_unit);
        groupIdFrequency[abstractGroupID]++;
    }

    unsigned int maxFrequency = 0;
    unsigned int bestGroup;
    // assing to the squad the most comon group ID
    for(std::map<unsigned int, unsigned int>::const_iterator i = groupIdFrequency.begin(); i!=groupIdFrequency.end(); ++i) {
        if ( (*i).second > maxFrequency ) {
            bestGroup = (*i).first;
            maxFrequency = (*i).second;
        }
    }

    // one idGroup can have many squads!!
    _idToSquad[bestGroup].insert(squad);
    //LOG("Best group for squad (" << squad << "): " << bestGroup);

}

// Execute search and return best targetPosition for each squad
std::map<SquadAgent*, BWAPI::Position> AbstractLayer::searchBestOrders()
{
    // compare game states
    if (!informationManager->lastGameState.gameover()) {
        int misplacedUnits = 0;
        int totalUnits = 0;
        //int totalBuildings = 0;
        //LOG(informationManager->lastGameState.toString());
        //LOG(informationManager->gameState.toString());
        std::vector<GameState::unitState_t> friendlyUnits1 = informationManager->lastGameState.friendlyUnits;
        std::vector<GameState::unitState_t> friendlyUnits2 = informationManager->gameState.friendlyUnits;
        for (unsigned int i=0; i < friendlyUnits1.size(); i++) {
            //BWAPI::UnitType unitType(friendlyUnits1[i].unitTypeId);
            for (unsigned int j=0; j < friendlyUnits2.size(); j++) {
                if (friendlyUnits1[i].unitTypeId == friendlyUnits2[j].unitTypeId &&
                    friendlyUnits1[i].regionId == friendlyUnits2[j].regionId) {
                    misplacedUnits += abs(friendlyUnits1[i].numUnits-friendlyUnits2[j].numUnits);
                    int maxUnits = std::max(friendlyUnits1[i].numUnits,friendlyUnits2[j].numUnits);
                    totalUnits += maxUnits;
                    //if (unitType.isBuilding()) totalBuildings += maxUnits;
                    break;
                }
                // if we didn't find a match, it's a misplaced unit
                if (j == friendlyUnits2.size()-1) {
                    misplacedUnits += friendlyUnits1[i].numUnits;
                    totalUnits += friendlyUnits1[i].numUnits;
                    //if (unitType.isBuilding()) totalBuildings += friendlyUnits1[i].numUnits;
                }
            }
        }
        for (unsigned int i=0; i < friendlyUnits2.size(); i++) {
            //BWAPI::UnitType unitType(friendlyUnits2[i].unitTypeId);
            for (unsigned int j=0; j < friendlyUnits1.size(); j++) {
                if (friendlyUnits1[j].unitTypeId == friendlyUnits2[i].unitTypeId &&
                    friendlyUnits1[j].regionId == friendlyUnits2[i].regionId) {
                        break;
                }
                // if we didn't find a match, it's a misplaced unit
                if (j == friendlyUnits1.size()-1) {
                    misplacedUnits += friendlyUnits2[i].numUnits;
                    totalUnits += friendlyUnits2[i].numUnits;
                    //if (unitType.isBuilding()) totalBuildings += friendlyUnits2[i].numUnits;
                }
            }
        }
        LOG("Correct Units: " << totalUnits-misplacedUnits << " of " << totalUnits << " jaccard: " << float(totalUnits-misplacedUnits)/float(totalUnits));
    }

    // now that we have all the units in the game state, compute expected end frame
    informationManager->gameState.expectedEndFrame();
    // and forward until next point decision
    // we can move Forward or set our orders to 0 to find the best inmidate action
    //informationManager->gameState.moveForward();
    //LOG(informationManager->gameState.toString());
    informationManager->gameState.resetFriendlyActions();
    //LOG(informationManager->gameState.toString());

    // Search algorithm
    std::string algorithm = LoadConfigString("high_level_search", "algorithm", "ABCD");
    playerActions_t bestActions;
    EvaluationFunctionBasic ef;
    if (algorithm == "ABCD") {
        int depth = LoadConfigInt("ABCD", "depth", 1);
        ABCD searchAlg = ABCD(depth, &ef);
        bestActions = searchAlg.start(true, informationManager->gameState);
    } else if (algorithm == "MCTSCD") {
        int depth = LoadConfigInt("MCTSCD", "depth", 1);
        int iterations = LoadConfigInt("MCTSCD", "iterations");
        int maxSimTime = LoadConfigInt("MCTSCD", "max_simulation_time");
        UCT searchAlg = UCT(depth, &ef, iterations, maxSimTime);
        bestActions = searchAlg.start(true, informationManager->gameState);
    } else {
        // get random actions
        ActionGenerator moveGenerator = ActionGenerator(informationManager->gameState, true);
        bestActions = moveGenerator.getRandomAction();
    }

    // update last gameState
    informationManager->lastGameState = informationManager->gameState;
    informationManager->lastGameState.execute(bestActions, true);
    informationManager->lastGameState.moveForward(HIGH_LEVEL_REFRESH);
    informationManager->lastGameState.mergeGroups();

    //LOG("Best actions: ");
    std::map<SquadAgent*, BWAPI::Position> bestOrders;
    int groupID;
	uint8_t orderId;
	uint8_t targetRegionId;
    BWAPI::Position targetPosition;
    for(playerActions_t::const_iterator i = bestActions.begin(); i!=bestActions.end(); ++i) {
//         groupID = (*i).first;
//         orderId = (*i).second.first;
//         targetRegionId = (*i).second.second;
        // Unpacking playerActions_t (00000000 TTTTTTTT AAAAAAAA GGGGGGGG)
        //LOG("Uncoding Action");
        int playerAction = (*i);
        groupID = playerAction & 0xFF; // preserve the 7 righmost bits
        orderId = (playerAction >> 8) & 0xFF;
        targetRegionId = playerAction >> 16;
        std::set<SquadAgent*> squadSet = _idToSquad[groupID];
//         if (squadSet.empty()) {
//             LOG( groupID << " action: " << (int)orderId << " region: " << (int)targetRegionId << " " << informationManager->gameState.getAbstractOrderName((int)orderId) << " OFF");
//         } else {
//             LOG( groupID << " action: " << (int)orderId << " region: " << (int)targetRegionId << " " << informationManager->gameState.getAbstractOrderName((int)orderId) << " ON");
//         }

        for(std::set<SquadAgent*>::const_iterator squad = squadSet.begin(); squad!=squadSet.end(); ++squad) {
            targetPosition = informationManager->gameState.getCenterRegionId((int)targetRegionId);
            bestOrders[*squad] = targetPosition;
//             LOG("  - Group ID: " << groupID << ", action: " << informationManager->gameState.getAbstractOrderName((int)orderId) << " region: " << (int)targetRegionId 
//                 << "(" << targetPosition.x() << "," << targetPosition.y() << ") squad (" << *squad << ")");
        }
    }

    return bestOrders;
}

bool AbstractLayer::hasFriendlyUnits()
{
    return !(informationManager->gameState.friendlyUnits.empty() || informationManager->gameState.hasOnlyBuildings(informationManager->gameState.friendlyUnits));
}

float AbstractLayer::getEvaluation()
{
    EvaluationFunctionBasic ef;
    return ef.evaluate(true, false, informationManager->gameState, 0);
}

void AbstractLayer::printBranchingStats()
{
   // branching factor start for root node
    ActionGenerator actions = ActionGenerator(informationManager->gameState, true);
    double highLevelFriendly = actions.getHighLevelFriendlyActions();
    double highLevelEnemy = actions.getHighLevelEnemyActions();
    double highLevelTotal = highLevelFriendly * highLevelEnemy;
    double lowLevelFriendly = actions.getLowLevelFriendlyActions();
    double lowLevelEnemy = actions.getLowLevelEnemyActions();
    double lowLevelTotal = lowLevelFriendly * lowLevelEnemy;
    double sparcraftFriendly = actions.getSparcraftFriendlyActions();
    double sparcraftEnemy = actions.getSparcraftEnemyActions();
    double sparcraftTotal = sparcraftFriendly * sparcraftEnemy;
    LOG("LL friend: " << lowLevelFriendly << " LL enemy: " << lowLevelEnemy << " LL total: " << lowLevelTotal <<
        " Spar friend: " << sparcraftFriendly << " Spar enemy: " << sparcraftEnemy << " Spar total: " << sparcraftTotal <<
        " HL friend: " << highLevelFriendly << " HL enemy: " << highLevelEnemy << " HL total: " << highLevelTotal);
}