#include "CombatSimulatorBasic.h"
#include "InformationManager.h"

CombatSimulatorBasic::CombatSimulatorBasic()
	: unitDPF(nullptr),
	unitHP(nullptr)
{
}

CombatSimulatorBasic::CombatSimulatorBasic(std::vector<DpsLearner::DPFtype> *newUnitDPF, 
	std::vector<DpsLearner::HPtype> *newUnitHP)
	: unitDPF(newUnitDPF),
	unitHP(newUnitHP)
{
}

CombatSimulatorBasic::combatStats_t CombatSimulatorBasic::getCombatStats(const UnitStateVector &army)
{
	combatStats_t combatStats;
	if (unitDPF == nullptr) {
		for (auto unitGroup : army) {
			BWAPI::UnitType unitType(unitGroup->unitTypeId);
			if (unitType.airWeapon().damageAmount() > 0 && unitType.groundWeapon().damageAmount() > 0) {
				double minDPF = std::min((double)unitType.airWeapon().damageAmount() / (double)unitType.airWeapon().damageCooldown(),
					(double)unitType.groundWeapon().damageAmount() / (double)unitType.groundWeapon().damageCooldown());
				combatStats.bothDPF += unitGroup->numUnits * minDPF;
			} else {
				if (unitType.airWeapon().damageAmount() > 0)
					combatStats.airDPF += (unitGroup->numUnits * ((double)unitType.airWeapon().damageAmount() / (double)unitType.airWeapon().damageCooldown()));
				if (unitType.groundWeapon().damageAmount() > 0)
					combatStats.groundDPF += (unitGroup->numUnits * ((double)unitType.groundWeapon().damageAmount() / (double)unitType.groundWeapon().damageCooldown()));
				// In the case of Firebats and Zealots, the damage returned by BWAPI is not right, since they have two weapons:
				if (unitType == BWAPI::UnitTypes::Terran_Firebat || unitType == BWAPI::UnitTypes::Protoss_Zealot)
					combatStats.groundDPF += unitGroup->numUnits * ((double)unitType.groundWeapon().damageAmount() / (double)unitType.groundWeapon().damageCooldown());
			}

			if (unitType.isFlyer())
				combatStats.airHP += unitGroup->numUnits * ((double)(unitType.maxShields() + unitType.maxHitPoints()));
			else
				combatStats.groundHP += unitGroup->numUnits * ((double)(unitType.maxShields() + unitType.maxHitPoints()));
		}
	} else {
		for (auto unitGroup : army) {
			if ((*unitDPF)[unitGroup->unitTypeId].both > 0) {
				combatStats.bothDPF += unitGroup->numUnits * (*unitDPF)[unitGroup->unitTypeId].both;
			} else {
				combatStats.airDPF += unitGroup->numUnits * (*unitDPF)[unitGroup->unitTypeId].air;
				combatStats.groundDPF += unitGroup->numUnits * (*unitDPF)[unitGroup->unitTypeId].ground;
			}

			combatStats.airHP += unitGroup->numUnits * (*unitHP)[unitGroup->unitTypeId].air;
			combatStats.groundHP += unitGroup->numUnits * (*unitHP)[unitGroup->unitTypeId].ground;
		}
	}

	return combatStats;
}


int CombatSimulatorBasic::getCombatLength(combatStats_t friendStats, combatStats_t enemyStats, int &timeToKillEnemy)
{
	// method 1 (units that can attack ground and air damage they attack twice)
	// 	double timeToKillEnemyAir = (enemyStats.airHP > 0) ? (friendStats.airDPF == 0) ? INT_MAX : enemyStats.airHP / friendStats.airDPF : 0;
	// 	double timeToKillEnemyGround = (enemyStats.groundHP > 0) ? (friendStats.groundDPF == 0) ? INT_MAX : enemyStats.groundHP / friendStats.groundDPF : 0;
	// 	double timeToKillFriendAir = (friendStats.airHP > 0) ? (enemyStats.airDPF == 0) ? INT_MAX : friendStats.airHP / enemyStats.airDPF : 0;
	// 	double timeToKillFriendGround = (friendStats.groundHP > 0) ? (enemyStats.groundDPF == 0) ? INT_MAX : friendStats.groundHP / enemyStats.groundDPF : 0;
	// 
	// 	timeToKillEnemy = (int)(std::max)(timeToKillEnemyAir, timeToKillEnemyGround);
	// 	int timeToKillFriend = (int)(std::max)(timeToKillFriendAir, timeToKillFriendGround);
	// 	return (std::min)(timeToKillEnemy, timeToKillFriend);

	// method 2 (units that can attack ground and air help to the bigger group)
	double timeToKillEnemyAir = (enemyStats.airHP > 0) ? (friendStats.airDPF == 0) ? INT_MAX : enemyStats.airHP / friendStats.airDPF : 0;
	double timeToKillEnemyGround = (enemyStats.groundHP > 0) ? (friendStats.groundDPF == 0) ? INT_MAX : enemyStats.groundHP / friendStats.groundDPF : 0;
	if (friendStats.bothDPF > 0) {
		if (timeToKillEnemyAir > timeToKillEnemyGround) {
			double combinetDPF = friendStats.airDPF + friendStats.bothDPF;
			timeToKillEnemyAir = (enemyStats.airHP > 0) ? (combinetDPF == 0) ? INT_MAX : enemyStats.airHP / combinetDPF : 0;
		} else {
			double combinetDPF = friendStats.groundDPF + friendStats.bothDPF;
			timeToKillEnemyGround = (enemyStats.groundHP > 0) ? (combinetDPF == 0) ? INT_MAX : enemyStats.groundHP / combinetDPF : 0;
		}
	}

	double timeToKillFriendAir = (friendStats.airHP > 0) ? (enemyStats.airDPF == 0) ? INT_MAX : friendStats.airHP / enemyStats.airDPF : 0;
	double timeToKillFriendGround = (friendStats.groundHP > 0) ? (enemyStats.groundDPF == 0) ? INT_MAX : friendStats.groundHP / enemyStats.groundDPF : 0;
	if (enemyStats.bothDPF > 0) {
		if (timeToKillFriendAir > timeToKillEnemyGround) {
			double combinetDPF = enemyStats.airDPF + enemyStats.bothDPF;
			timeToKillFriendAir = (friendStats.airHP > 0) ? (combinetDPF == 0) ? INT_MAX : friendStats.airHP / combinetDPF : 0;
		} else {
			double combinetDPF = enemyStats.groundDPF + enemyStats.bothDPF;
			timeToKillFriendGround = (friendStats.groundHP > 0) ? (combinetDPF == 0) ? INT_MAX : friendStats.groundHP / combinetDPF : 0;
		}
	}

	if (timeToKillEnemyAir == INT_MAX && timeToKillEnemyGround > 0) timeToKillEnemy = (int)timeToKillEnemyGround;
	else if (timeToKillEnemyGround == INT_MAX && timeToKillEnemyAir > 0) timeToKillEnemy = (int)timeToKillEnemyAir;
	else timeToKillEnemy = (int)(std::max)(timeToKillEnemyAir, timeToKillEnemyGround);
	
	int timeToKillFriend;
	if (timeToKillFriendAir == INT_MAX && timeToKillFriendGround > 0) timeToKillFriend = (int)timeToKillFriendGround;
	else if (timeToKillFriendGround == INT_MAX && timeToKillFriendAir > 0) timeToKillFriend = (int)timeToKillFriendAir;
	else timeToKillFriend = (int)(std::max)(timeToKillFriendAir, timeToKillFriendGround);
	return (std::min)(timeToKillEnemy, timeToKillFriend);
}



int CombatSimulatorBasic::getCombatLength(GameState::army_t* army)
{
	combatStats_t friendStats = getCombatStats(army->friendly);
	combatStats_t enemyStats = getCombatStats(army->enemy);
	int timeToKillEnemy;
	return getCombatLength(friendStats, enemyStats, timeToKillEnemy);
}

void CombatSimulatorBasic::simulateCombat(GameState::army_t* armyInCombat, GameState::army_t* army, int frames)
{
	if (!canSimulate(armyInCombat, army)) return;

	// get stats
	combatStats_t friendStats = getCombatStats(armyInCombat->friendly);
	combatStats_t enemyStats = getCombatStats(armyInCombat->enemy);

	// calculate end combat time
	int timeToKillEnemy;
	int combatLength = getCombatLength(friendStats, enemyStats, timeToKillEnemy);

	if (frames == 0 || frames > combatLength) { // one army is destroyed
		UnitStateVector* loserUnits = &army->friendly;
		UnitStateVector* winnerUnits = &army->enemy;
		UnitStateVector* loserInCombat = &armyInCombat->friendly;
		UnitStateVector* winnerInCombat = &armyInCombat->enemy;
		combatStats_t* loserStats = &friendStats;
		if (combatLength == timeToKillEnemy) {
			loserUnits = &army->enemy;
			winnerUnits = &army->friendly;
			loserInCombat = &armyInCombat->enemy;
			winnerInCombat = &armyInCombat->friendly;
			loserStats = &enemyStats;
		}

		removeAllUnitsFromArmy(loserInCombat, loserUnits);
		removeSomeUnitsFromArmy(winnerInCombat, winnerUnits, loserStats, (double)combatLength);
	} else { // both army still alive
		removeSomeUnitsFromArmy(&armyInCombat->friendly, &army->friendly, &enemyStats, (double)frames);
		removeSomeUnitsFromArmy(&armyInCombat->enemy, &army->enemy, &friendStats, (double)frames);
	}
}

void CombatSimulatorBasic::removeAllUnitsFromArmy(UnitStateVector* loserInCombat, UnitStateVector* loserUnits)
{
	// Remove ALL loser units from list
	for (UnitStateVector::iterator it = loserInCombat->begin(); it != loserInCombat->end(); ++it) {
		unitGroup_t* groupToDelete = *it;
		UnitStateVector::iterator unitFound = std::find(loserUnits->begin(), loserUnits->end(), groupToDelete);
		if (unitFound != loserUnits->end()) {
			delete *unitFound;
			loserUnits->erase(unitFound);
		} else LOG("[ERROR] Group unit not found");
	}
}

void CombatSimulatorBasic::removeSomeUnitsFromArmy(UnitStateVector* winnerInCombat, UnitStateVector* winnerUnits, combatStats_t* loserStats, double frames)
{
	// Calculate winner losses
	int totalAirDamage = (int)(loserStats->airDPF * frames);
	int totalGroundDamage = (int)(loserStats->groundDPF * frames);
	int totalBothDamage = (int)(loserStats->bothDPF * frames);
	int* damageToDeal;
	bool usingBothDamage;
	int unitsToRemove;

	for (UnitStateVector::iterator it = winnerInCombat->begin(); it != winnerInCombat->end(); ++it) {
		// stop if we don't have more damage to receive
		if (totalAirDamage <= 0 && totalGroundDamage <= 0 && totalBothDamage <= 0) break;

		unitGroup_t* groupDamaged = *it;
		BWAPI::UnitType unitType(groupDamaged->unitTypeId);
		int groupHP = groupDamaged->numUnits * (unitType.maxShields() + unitType.maxHitPoints());
		// define type of damage to receive
		damageToDeal = &totalBothDamage;
		usingBothDamage = true;
		if (unitType.isFlyer() && totalAirDamage > 0) {
			damageToDeal = &totalAirDamage;
			usingBothDamage = false;
		} else if (!unitType.isFlyer() && totalGroundDamage > 0) {
			damageToDeal = &totalGroundDamage;
			usingBothDamage = false;
		}
		// compute losses
		if (*damageToDeal >= groupHP) {
			// kill all
			UnitStateVector::iterator unitFound = std::find(winnerUnits->begin(), winnerUnits->end(), groupDamaged);
			if (unitFound != winnerUnits->end()) {
				delete *unitFound;
				winnerUnits->erase(unitFound);
			} else LOG("[ERROR] Group unit not found");
			// adjust damage left
			*damageToDeal -= groupHP;
		} else if (usingBothDamage) {
			// kill some using both damage
			unitsToRemove = (int)trunc((double)*damageToDeal / (double)(unitType.maxShields() + unitType.maxHitPoints()));
			groupDamaged->numUnits -= unitsToRemove;
			// adjust damage left
			*damageToDeal -= unitsToRemove * (unitType.maxShields() + unitType.maxHitPoints());
		} else {
			// combining specific and both damage
			totalBothDamage += *damageToDeal; // notice that this is safe since we are going to use all the added damage
			*damageToDeal = 0;
			// kill all
			if (totalBothDamage >= groupHP) {
				UnitStateVector::iterator unitFound = std::find(winnerUnits->begin(), winnerUnits->end(), groupDamaged);
				if (unitFound != winnerUnits->end()) {
					delete *unitFound;
					winnerUnits->erase(unitFound);
				} else LOG("[ERROR] Group unit not found");
				// adjust damage left
				totalBothDamage -= groupHP;
			} else {
				// kill some
				unitsToRemove = (int)trunc((double)totalBothDamage / (double)(unitType.maxShields() + unitType.maxHitPoints()));
				groupDamaged->numUnits -= unitsToRemove;
				// adjust damage left
				totalBothDamage -= unitsToRemove * (unitType.maxShields() + unitType.maxHitPoints());
			}
		}
	}
}
