#include "GameState.h"
#include "InformationManager.h"

using namespace BWAPI;

GameState::GameState()
{
}

void GameState::UpdateGameState()
{
	PlayerSet players = Broodwar->getPlayers();
	std::vector<unitState_t>* listToSave;
	Unit* unit;
	Position targetPosition;

	for (PlayerSet::iterator it = players.begin(); it != players.end(); it++) {
		Player* player = *it;
		UnitSet units = player->getUnits();
		
		if ( player == Broodwar->self() ) listToSave = &friendlyUnits;
		else if (Broodwar->self()->isEnemy(player)) listToSave = &enemyUnits;
		else continue;
		
		listToSave->clear();

		for (UnitSet::iterator it2 = units.begin(); it2 != units.end(); it2++) {
			unit = *it2;
			// ignore buildings
			if (unit->getType().isBuilding()) continue;
			// ignore units training
			if (!unit->isCompleted()) continue;

			unitState_t unitState;
			unitState.unitTypeId = unit->getType().getID();
			unitState.numUnits = 1;
			unitState.regionId = informationManager->_regionIdMap[unit->getTilePosition().x()][unit->getTilePosition().y()];
			targetPosition = unit->getTargetPosition();
			if (targetPosition != Positions::None) {
				TilePosition tile = (TilePosition)targetPosition;
				unitState.targetRegionId = informationManager->_regionIdMap[tile.x()][tile.y()];
			} else {
				unitState.targetRegionId = 0;
			}

			// unitState.orderId = unit->getOrder().getID();
			unitState.orderId = getAbstractOrderID(unit->getOrder().getID(), unitState.regionId, unitState.targetRegionId);

			addUnit(unitState, *listToSave);
		}
	}

	GameState::expectedEndFrame();
}

void GameState::addUnit(unitState_t unit, std::vector<unitState_t>& vectorState)
{
	// search if unit is already in the vector
	for (unsigned int i=0; i < vectorState.size(); i++) {
		if (unit.unitTypeId == vectorState[i].unitTypeId && 
			unit.regionId == vectorState[i].regionId &&
			unit.orderId == vectorState[i].orderId &&
			unit.targetRegionId == vectorState[i].targetRegionId) {
			vectorState[i].numUnits++;
			return;
		}
	}
	// if not, add it to the end
	vectorState.push_back(unit);
}

// Calculate the expected end frame for each order
void GameState::expectedEndFrame()
{
	combatList.clear();
	regionsInCombat.clear();
	std::vector<unitState_t>* vectorState;
	for (int i=0; i<2; i++) {
		if (i==0) vectorState = &friendlyUnits;
		else vectorState = &enemyUnits;
		for (unsigned int j=0; j < vectorState->size(); j++) {
			vectorState->at(j).endFrame = -1;
			if (vectorState->at(j).orderId == abstractOrder::Unknown) {
				vectorState->at(j).endFrame = -2;
			} else if (vectorState->at(j).orderId == abstractOrder::Move) {
				vectorState->at(j).endFrame = GameState::getMoveTime(vectorState->at(j).unitTypeId, vectorState->at(j).regionId, vectorState->at(j).targetRegionId);
			} else if (vectorState->at(j).orderId == abstractOrder::Attack) {
				// keep track of the regions with units in attack state
				regionsInCombat.insert(vectorState->at(j).regionId);
			}
			// add all units to the combat list since some units can attack units that they aren't in attack state
			if (i==0) combatList[vectorState->at(j).regionId].first.insert(& vectorState->at(j));
			else combatList[vectorState->at(j).regionId].second.insert(& vectorState->at(j));
		}
	}
	// TODO some units aren't attacking in the same region!!!
	if (combatList.size() > 0) {
		GameState::getCombatTime();
	}
}

int GameState::getMoveTime(int unitTypeId, int regionId, int targetRegionId)
{
	if (regionId == targetRegionId) 
		return BWAPI::Broodwar->getFrameCount();
	BWAPI::UnitType unitType(unitTypeId);
	double speed = unitType.topSpeed(); //non-upgraded top speed in pixels per frame

	// Get the center of the region or chokepoint
	BWAPI::Position pos1 = getCenterRegionId(regionId);
	BWAPI::Position pos2 = getCenterRegionId(targetRegionId);
	if (pos1 == BWAPI::Positions::None || pos2 == BWAPI::Positions::None) {
		return -3;
	}
	int distance = BWTA::getGroundDistance2(BWAPI::TilePosition(pos1), BWAPI::TilePosition(pos2)); // distance in pixels

	return BWAPI::Broodwar->getFrameCount() + (int)((double)distance/speed);
}

// TODO move this to BWTA
BWAPI::Position GameState::getCenterRegionId(int regionId)
{
	BWTA::Region* region = informationManager->_regionFromID[regionId];
	if (region != NULL) {
		return region->getCenter();
	} else {
		BWTA::Chokepoint* cp = informationManager->_chokePointFromID[regionId];
		if (cp != NULL) {
			return cp->getCenter();
		} else {
			return BWAPI::Positions::None;
		}
	}
}

void GameState::getCombatTime()
{
	//for (combatList_t::iterator it=combatList.begin(); it!=combatList.end(); ++it) {
	for (std::set<int>::iterator it=regionsInCombat.begin(); it!=regionsInCombat.end(); ++it) {
		// get friendly stats
		combatStats_t friendStats;
		friendStats.airDPF = friendStats.airHP = friendStats.groundDPF = friendStats.groundHP = 0;
		//std::set<unitState_t*> friendSet = (*it).second.first;
		std::set<unitState_t*> friendSet = combatList[(*it)].first;
		for (std::set<unitState_t*>::iterator it2=friendSet.begin(); it2!=friendSet.end(); ++it2) {
			BWAPI::UnitType unitType( (*it2)->unitTypeId );
			uint8_t numUnits = (*it2)->numUnits;
			GameState::getCombatStats(friendStats, unitType, numUnits);
		}
		// get enemy stats
		combatStats_t enemyStats;
		enemyStats.airDPF = enemyStats.airHP = enemyStats.groundDPF = enemyStats.groundHP = 0;
		//std::set<unitState_t*> enemySet = (*it).second.second;
		std::set<unitState_t*> enemySet = combatList[(*it)].second;
		for (std::set<unitState_t*>::iterator it2=enemySet.begin(); it2!=enemySet.end(); ++it2) {
			BWAPI::UnitType unitType( (*it2)->unitTypeId );
			uint8_t numUnits = (*it2)->numUnits;
			GameState::getCombatStats(enemyStats, unitType, numUnits);
		}
		// calculate end combat time
		double timeToKillEnemyAir = (enemyStats.airHP>0)? (friendStats.airDPF == 0)? 99999 : enemyStats.airHP/friendStats.airDPF : 0;
		double timeToKillEnemyGround = (enemyStats.groundHP>0)? (friendStats.groundDPF == 0)? 99999 : enemyStats.groundHP/friendStats.groundDPF : 0;
		double timeToKillFriendAir = (friendStats.airHP>0)? (enemyStats.airDPF == 0 )? 99999 : friendStats.airHP/enemyStats.airDPF : 0;
		double timeToKillFriendGround = (friendStats.groundHP>0)? (enemyStats.groundDPF == 0)? 99999 : friendStats.groundHP/enemyStats.groundDPF : 0;

		double timeToKillEnemy = (std::max)(timeToKillEnemyAir,timeToKillEnemyGround);
		double timeToKillFriend = (std::max)(timeToKillFriendAir,timeToKillFriendGround);
		int combatEnd = (int)(std::min)(timeToKillEnemy,timeToKillFriend) + BWAPI::Broodwar->getFrameCount();
		//combatEnd = BWAPI::Broodwar->getFrameCount() + timeToKillEnemyGround;

		// update end combat time
		for (std::set<unitState_t*>::iterator it2=friendSet.begin(); it2!=friendSet.end(); ++it2) {
			if ((*it2)->orderId == abstractOrder::Attack) (*it2)->endFrame = combatEnd;
		}
		for (std::set<unitState_t*>::iterator it2=enemySet.begin(); it2!=enemySet.end(); ++it2) {
			if ((*it2)->orderId == abstractOrder::Attack) (*it2)->endFrame = combatEnd;
		}
	}
}

void GameState::getCombatStats(combatStats_t & combatStats, BWAPI::UnitType unitType, uint8_t numUnits )
{
	if (unitType.airWeapon().damageAmount() > 0 )
		combatStats.airDPF += numUnits * ((double)unitType.airWeapon().damageAmount() / unitType.airWeapon().damageCooldown());
	if (unitType.groundWeapon().damageAmount() > 0 )
		combatStats.groundDPF += numUnits * ((double)unitType.groundWeapon().damageAmount() / unitType.groundWeapon().damageCooldown());
	// In the case of Firebats and Zealots, the damage returned by BWAPI is not right, since they have two weapons:
	if (unitType == UnitTypes::Terran_Firebat || unitType == UnitTypes::Protoss_Zealot)
		combatStats.groundDPF += numUnits * ((double)unitType.groundWeapon().damageAmount() / unitType.groundWeapon().damageCooldown());
	if (unitType.isFlyer())
		combatStats.airHP += numUnits * ((double)(unitType.maxShields() + unitType.maxHitPoints()));
	else
		combatStats.groundHP += numUnits * ((double)(unitType.maxShields() + unitType.maxHitPoints()));
}

int GameState::getAbstractOrderID(int orderId, int currentRegion, int targetRegion)
{
 	if ( orderId == BWAPI::Orders::MoveToMinerals ||
		 orderId == BWAPI::Orders::WaitForMinerals ||
		 orderId == BWAPI::Orders::MiningMinerals ||
		 orderId == BWAPI::Orders::Harvest3 ||
		 orderId == BWAPI::Orders::Harvest4 ||
 		 orderId == BWAPI::Orders::ReturnMinerals )
		//return abstractOrder::Mineral;
		return abstractOrder::Unknown;
 	else if ( orderId == BWAPI::Orders::MoveToGas ||
			  orderId == BWAPI::Orders::Harvest1 ||
			  orderId == BWAPI::Orders::Harvest2 ||
			  orderId == BWAPI::Orders::WaitForGas ||
			  orderId == BWAPI::Orders::HarvestGas ||
 			  orderId == BWAPI::Orders::ReturnGas )
 		//return abstractOrder::Gas;
		return abstractOrder::Unknown;
	else if ( orderId == BWAPI::Orders::Move ||
			  orderId == BWAPI::Orders::Follow ||
			  orderId == BWAPI::Orders::ComputerReturn ) 
		return abstractOrder::Move;
	else if ( orderId == BWAPI::Orders::AttackUnit ||
			  orderId == BWAPI::Orders::AttackMove ||
			  orderId == BWAPI::Orders::AttackTile)
		  if (currentRegion == targetRegion) return abstractOrder::Attack;
		  else return abstractOrder::Move;
	else if ( orderId == BWAPI::Orders::Repair || 
			  orderId == BWAPI::Orders::MedicHeal1 )
			return abstractOrder::Heal;
 	else
		return abstractOrder::Unknown;
}

const std::string GameState::getAbstractOrderName(BWAPI::Order order, int currentRegion, int targetRegion)
{
	int abstractId = getAbstractOrderID(order.getID(), currentRegion, targetRegion);
	return abstractOrder::name[abstractId];
}

const std::string GameState::getAbstractOrderName(int abstractId)
{
	return abstractOrder::name[abstractId];
}