#include "DpsLearner.h"
#include "comutil.h" // for _bstr_t (string conversions)

#define MAX_UNIT_TYPE BWAPI::UnitTypes::Enum::MAX

DpsLearner::DpsLearner():
armyDestroyed(0),
armyReinforcement(0),
armyPeace(0),
armyGameEnd(0),
moreThanTwoArmies(0),
armyKilledWithoutKills(0)
{
	// first we need to populate the structure to get the UnitType enum from string
	for (auto unitType : BWAPI::UnitTypes::allUnitTypes()) {
		getUnitTypeID[unitType.c_str()] = unitType;
	}

	// init unitTypeDPF
	unitTypeDPF.resize(MAX_UNIT_TYPE);
	for (int i = 0; i < MAX_UNIT_TYPE; ++i) {
		unitTypeDPF[i].resize(MAX_UNIT_TYPE);
	}
	// init priority
	typePriority.resize(MAX_UNIT_TYPE);
	bordaCount.resize(MAX_UNIT_TYPE);

	// cache for unitTypeDPF from BWAPI
	double DPF;
	unitTypeDPFbwapi.resize(MAX_UNIT_TYPE);
	for (int i = 0; i < MAX_UNIT_TYPE; ++i) {
		unitTypeDPFbwapi[i].resize(MAX_UNIT_TYPE);
		for (int j = 0; j < MAX_UNIT_TYPE; ++j) {
			BWAPI::UnitType unitType(i);
			BWAPI::UnitType enemyType(j);
			DPF = 0.0;

			if (!enemyType.isFlyer() && unitType.groundWeapon().damageAmount() > 0) {
				DPF = (double)unitType.groundWeapon().damageAmount();
				if (unitType.groundWeapon().damageType() == BWAPI::DamageTypes::Concussive) {
					if (enemyType.size() == BWAPI::UnitSizeTypes::Large) DPF *= 0.25;
					else if (enemyType.size() == BWAPI::UnitSizeTypes::Medium) DPF *= 0.5;
				} else if (unitType.groundWeapon().damageType() == BWAPI::DamageTypes::Explosive) {
					if (enemyType.size() == BWAPI::UnitSizeTypes::Small) DPF *= 0.25;
					else if (enemyType.size() == BWAPI::UnitSizeTypes::Medium) DPF *= 0.5;
				}
				// In the case of Firebats and Zealots, the damage returned by BWAPI is not right, since they have two weapons:
				if (unitType == BWAPI::UnitTypes::Terran_Firebat || unitType == BWAPI::UnitTypes::Protoss_Zealot)
					DPF *= 2;
				DPF = (DPF / unitType.groundWeapon().damageCooldown());
			} else if (enemyType.isFlyer() && unitType.airWeapon().damageAmount() > 0) {
				DPF = (double)unitType.airWeapon().damageAmount();
				if (unitType.airWeapon().damageType() == BWAPI::DamageTypes::Concussive) {
					if (enemyType.size() == BWAPI::UnitSizeTypes::Large) DPF *= 0.25;
					else if (enemyType.size() == BWAPI::UnitSizeTypes::Medium) DPF *= 0.5;
				} else if (unitType.airWeapon().damageType() == BWAPI::DamageTypes::Explosive) {
					if (enemyType.size() == BWAPI::UnitSizeTypes::Small) DPF *= 0.25;
					else if (enemyType.size() == BWAPI::UnitSizeTypes::Medium) DPF *= 0.5;
				}
				DPF = DPF / unitType.airWeapon().damageCooldown();
			}
			unitTypeDPFbwapi[i][j].DPF = DPF;
		}
	}

	// cache for unit DPF and max HP from BWAPI
	unitDPF.resize(MAX_UNIT_TYPE);
	unitDPFbwapi.resize(MAX_UNIT_TYPE);
	unitHPbwapi.resize(MAX_UNIT_TYPE);
	for (int i = 0; i < MAX_UNIT_TYPE; ++i) {
		BWAPI::UnitType unitType(i);

		if (unitType.airWeapon().damageAmount() > 0)
			unitDPFbwapi[i].air = (double)unitType.airWeapon().damageAmount() / (double)unitType.airWeapon().damageCooldown();
		if (unitType.groundWeapon().damageAmount() > 0)
			unitDPFbwapi[i].ground = (double)unitType.groundWeapon().damageAmount() / (double)unitType.groundWeapon().damageCooldown();
		// In the case of Firebats and Zealots, the damage returned by BWAPI is not right, since they have two weapons:
		if (unitType == BWAPI::UnitTypes::Terran_Firebat || unitType == BWAPI::UnitTypes::Protoss_Zealot)
			unitDPFbwapi[i].ground *= 2;
		
		if (unitType.airWeapon().damageAmount() > 0 && unitType.groundWeapon().damageAmount() > 0)
			unitDPFbwapi[i].both = std::min(unitDPFbwapi[i].ground, unitDPFbwapi[i].air);

		if (unitType.isFlyer())
			unitHPbwapi[i].air = unitType.maxShields() + unitType.maxHitPoints();
		else
			unitHPbwapi[i].ground = unitType.maxShields() + unitType.maxHitPoints();
	}
}


void DpsLearner::parseReplayData(std::string replayFolder)
{
	WIN32_FIND_DATA FindFileData;
	std::ifstream currentFile;
	std::string line, fileName;
	std::stringstream stream;
	std::streampos pos;

	combatInfo newCombat;
	int parserState = ParserState::COMBAT_START;
	std::string dummy;
	bool army1;
	bool combatStart;
	int lineNumber;

	// get path to the folder
	unsigned found = replayFolder.find_last_of("/\\");
	std::string path = replayFolder.substr(0, found);

	HANDLE hFind = FindFirstFile((_bstr_t)replayFolder.c_str(), &FindFileData);
	while (hFind != INVALID_HANDLE_VALUE) {

		fileName = path + FindFileData.cFileName;
// 		LOG("Analyzing file " << fileName);

		currentFile.open(fileName.c_str());
		lineNumber = 0;
		while (std::getline(currentFile, line)) {
			lineNumber++;
// 			LOG(line.c_str());
			if (parserState == ParserState::COMBAT_START) {
				if (line.find("ARMY_DESTROYED") != std::string::npos) {
					armyDestroyed++;
					// clear newCombat
					newCombat.armySize1.clear();
					newCombat.armySize2.clear();
					newCombat.armyUnits1.clear();
					newCombat.armyUnitsEnd1.clear();
					newCombat.armyUnits2.clear();
					newCombat.armyUnitsEnd2.clear();
					newCombat.kills.clear();

// 					LOG(line.c_str());
					// get start and end frames
					stream << line;
					std::getline(stream, dummy, ','); // NEW_COMBAT
					std::getline(stream, dummy, ','); newCombat.startFrame = atoi(dummy.c_str());
					std::getline(stream, dummy, ','); newCombat.endFrame = atoi(dummy.c_str());
					// LOG("start: " << startFrame << " end: " << endFrame);
					stream.str(std::string()); // reset the stream
					stream.clear();
					parserState = ParserState::ARMIES;
					army1 = true;
					combatStart = true;
// 					LOG("--------- Army 1 START");
				} else if (line.find("REINFORCEMENT") != std::string::npos) {
					armyReinforcement++;
				} else if (line.find("PEACE") != std::string::npos) {
					armyPeace++;
				} else if (line.find("GAME_END") != std::string::npos) {
// 					LOG("Army end found it at line " << lineNumber);
					armyGameEnd++;
				}
			} else if (parserState == ParserState::ARMIES) {
				if (line.find("ARMY_START") != std::string::npos) {
					while (std::getline(currentFile, line)) {
						lineNumber++;
						if (line.find("ARMY") == std::string::npos
							&& line.find("KILLS") == std::string::npos
							&& line.find("NEW_COMBAT") == std::string::npos) {
							// collect unit data
// 							LOG(line.c_str());
							stream << line;
							unitInfo unit;
							std::getline(stream, dummy, ','); unit.ID = atoi(dummy.c_str());
							std::getline(stream, unit.typeName, ',');
							std::getline(stream, dummy, ','); unit.x = atoi(dummy.c_str());
							std::getline(stream, dummy, ','); unit.y = atoi(dummy.c_str());
							std::getline(stream, dummy, ','); unit.HP = atoi(dummy.c_str());
							std::getline(stream, dummy, ','); unit.shield = atoi(dummy.c_str());
							std::getline(stream, dummy, ','); // energy
							stream.str(std::string()); // reset the stream
							stream.clear();
							unit.typeID = getUnitTypeID[unit.typeName];
							if (unit.typeID > BWAPI::UnitTypes::AllUnits) {
								LOG("Unkwon unit type: " << unit.typeID << " " << unit.typeName);
							}
							if (army1) {
								if (combatStart) {
									newCombat.armyUnits1[unit.ID] = unit;
									newCombat.armySize1[unit.typeID]++;
								} else newCombat.armyUnitsEnd1[unit.ID] = unit;
							} else {
								if (combatStart) {
									newCombat.armyUnits2[unit.ID] = unit;
									newCombat.armySize2[unit.typeID]++;
								} else newCombat.armyUnitsEnd2[unit.ID] = unit;
							}
						} else { // new army
							if (army1) {
								army1 = false;
// 								LOG("--------- Army 2");
							} else if (combatStart) {
								army1 = true;
								combatStart = false;
// 								LOG("--------- Army 1 END");
							} else {
								// sometimes we do not have kills (Scanner sweep)
								if (line.find("NEW_COMBAT") != std::string::npos){
// 									LOG("--------- END COMBAT");
// 									allCombats.push_back(newCombat);
									armyKilledWithoutKills++;
									parserState = ParserState::COMBAT_START;
								// sometimes we have more than two armies
								} else if (line.find("ARMY") != std::string::npos){
// 									LOG("--------- FOUND A THIRD ARMY");
									moreThanTwoArmies++;
									// look until the next NEW_COMBAT
									while (std::getline(currentFile, line)) {
										lineNumber++;
										if (line.find("NEW_COMBAT") != std::string::npos) break;
									}
									parserState = ParserState::COMBAT_START;
								} else {
// 									LOG("--------- KILLS");
									parserState = ParserState::KILLS;
								}
								// move line pointer to previous line
								currentFile.seekg(pos, std::ios::beg);
								lineNumber--;
								break;
							}
						}
						pos = currentFile.tellg();
					}
				}
			} else if (parserState == ParserState::KILLS) {
				while (std::getline(currentFile, line)) {
					lineNumber++;
					if (line.find("NEW_COMBAT") == std::string::npos){
// 						LOG(line.c_str());
						unitKilledInfo kill;
						stream << line;
						std::getline(stream, dummy, ','); kill.unitID = atoi(dummy.c_str());
						std::getline(stream, dummy, ','); kill.frame = atoi(dummy.c_str());
						stream.str(std::string()); // reset the stream
						stream.clear();
						newCombat.kills.push_back(kill);
					} else {
// 						LOG("--------- END COMBAT");
						allCombats.push_back(newCombat);
						parserState = ParserState::COMBAT_START;
						// move line pointer to previous line
						currentFile.seekg(pos, std::ios::beg);
						lineNumber--;
						break;
					}
					pos = currentFile.tellg();
				}
			}
		}
		currentFile.close();



		if (!FindNextFile(hFind, &FindFileData)) {
			FindClose(hFind);
			hFind = INVALID_HANDLE_VALUE;
		}
	}
}

void DpsLearner::calculateDPS(bool skipTransports, bool onlyOneType)
{
	LOG("Start calculating DPS from " << allCombats.size() << " combats, skip transports: " << skipTransports << " onleyOneType: " << onlyOneType );
	// feed stats from parser data
	int lastFrameArmy1, lastFrameArmy2;
	int sizeArmy1canAttackGround, sizeArmy1canAttackAir, sizeArmy1canAttackBoth;
	int sizeArmy2canAttackGround, sizeArmy2canAttackAir, sizeArmy2canAttackBoth;

	combatsProcessed = 0;
	bool skipCombatFlag = false; // needed for nested loop

	for (auto combat : allCombats) {
// 		LOG("NEW_COMBAT");
		lastFrameArmy1 = combat.startFrame;
		lastFrameArmy2 = combat.startFrame;
		// count the total units in the army depends on target type
		sizeArmy1canAttackGround = sizeArmy1canAttackAir = sizeArmy1canAttackBoth = 0;
		sizeArmy2canAttackGround = sizeArmy2canAttackAir = sizeArmy2canAttackBoth = 0;
		if (onlyOneType) {
			// we only process the combat of 1 type vs 1 type
			if (combat.armySize1.size() > 1) continue;
			if (combat.armySize2.size() > 1) continue;
		}
		for (auto army : combat.armySize1) {
			if (skipTransports && army.first.spaceProvided() > 0) {
				skipCombatFlag = true;
				break;
			}
			if (canAttackAirUnits(army.first)) {
				if (canAttackGroundUnits(army.first)) sizeArmy1canAttackBoth += army.second;
				else sizeArmy1canAttackAir += army.second;
			} else if (canAttackGroundUnits(army.first)) sizeArmy1canAttackGround += army.second;
			// if the unit cannot attack air or ground (bunker, spell caster, ...) consider attack type both
			else sizeArmy1canAttackBoth += army.second;
		}
		for (auto army : combat.armySize2) {
			if (skipTransports && army.first.spaceProvided() > 0) {
				skipCombatFlag = true;
				break;
			}
			if (canAttackAirUnits(army.first)) {
				if (canAttackGroundUnits(army.first)) sizeArmy2canAttackBoth += army.second;
				else sizeArmy2canAttackAir += army.second;
			} else if (canAttackGroundUnits(army.first)) sizeArmy2canAttackGround += army.second;
			// if the unit cannot attack air or ground (bunker, spell caster, ...) consider attack type both
			else sizeArmy2canAttackBoth += army.second;
		}
		if (skipCombatFlag) {
			skipCombatFlag = false;
			continue;
		}

		// TODO killed units don't count their DPS (i.e. the survivor's DPS are overestimated)
		for (auto killed : combat.kills) {
			if (combat.armyUnits1.find(killed.unitID) != combat.armyUnits1.end()) {
				// army1 unit killed
				unitKilled(killed, lastFrameArmy2, combat.armyUnits1, 
							sizeArmy2canAttackGround, sizeArmy2canAttackAir, sizeArmy2canAttackBoth, 
							combat.armySize2, combat.armySize1,
							sizeArmy1canAttackGround, sizeArmy1canAttackAir, sizeArmy1canAttackBoth);
			} else {
				// army2 unit killed
				unitKilled(killed, lastFrameArmy1, combat.armyUnits2,
					sizeArmy1canAttackGround, sizeArmy1canAttackAir, sizeArmy1canAttackBoth,
					combat.armySize1, combat.armySize2,
					sizeArmy2canAttackGround, sizeArmy2canAttackAir, sizeArmy2canAttackBoth);
			}
		}
		++combatsProcessed;
	} // end for all combats

	// Add spider mines damage to Vultures
// 	LOG("Add spider damage");
//	for (auto unitType : BWAPI::UnitTypes::allUnitTypes()) {
//		unitTypeDPF[BWAPI::UnitTypes::Terran_Vulture][unitType].totalDamage += unitTypeDPF[BWAPI::UnitTypes::Terran_Vulture_Spider_Mine][unitType].totalDamage;
//		unitTypeDPF[BWAPI::UnitTypes::Terran_Vulture][unitType].totalTime += unitTypeDPF[BWAPI::UnitTypes::Terran_Vulture_Spider_Mine][unitType].totalTime;
//	}
		
	// Calculate DPF

	// cleaning stats
	DPFcases.clear();
	for (int matchupInt = TVT; matchupInt != NONE; matchupInt++) {
		DPFcases.insert(std::pair<Matchups, DPFstats>(static_cast<Matchups>(matchupInt), DPFstats()));
	}

	for (auto unitType1 : BWAPI::UnitTypes::allUnitTypes()) {
		if (unitType1.getID() > 163) break; // no interesting units farther this point
		if ((unitType1.canAttack() || unitType1.isSpellcaster()) && !unitType1.isHero()) {
			DPFtype dpfByOneType;
			for (auto unitType2 : BWAPI::UnitTypes::allUnitTypes()) {
				if (unitType2.getID() > 163) break; // no interesting units farther this point
				if ((unitType2.canAttack() || unitType2.isSpellcaster()) && !unitType2.isHero()) {
					// identify the matchup
					Matchups vsType = getMatchupType(unitType1, unitType2);

					// update the DPF
					if (unitTypeDPF[unitType1][unitType2].totalTime == 0) {
						if (unitTypeDPFbwapi[unitType1][unitType2].DPF > 0) {
							DPFcases[vsType].size++;
							DPFcases[vsType].noCases++;
							unitTypeDPF[unitType1][unitType2].DPF = unitTypeDPFbwapi[unitType1][unitType2].DPF;
// 							if (vsType == TVT) LOG("Missing " << unitType1.c_str() << " vs " << unitType2.c_str());
						}
					} else {
						DPFcases[vsType].size++;
						unitTypeDPF[unitType1][unitType2].DPF = unitTypeDPF[unitType1][unitType2].totalDamage / (double)unitTypeDPF[unitType1][unitType2].totalTime;
						// DPF* = ( DPF+DPFl*num combats ) / num combats +1
						//unitTypeDPF[unitType1][unitType2].DPF = (unitTypeDPFbwapi[unitType1][unitType2].DPF + unitTypeDPF[unitType1][unitType2].DPF * (double)unitTypeDPF[unitType1][unitType2].numCombats) / ((double)unitTypeDPF[unitType1][unitType2].numCombats + 1);

						// always get the max for dpfByOneType
						if (unitType1.airWeapon().damageAmount() > 0) {
							if (dpfByOneType.air == 0) dpfByOneType.air = unitTypeDPF[unitType1][unitType2].DPF;
							else dpfByOneType.air = std::max(dpfByOneType.air, unitTypeDPF[unitType1][unitType2].DPF);
						}
						if (unitType1.groundWeapon().damageAmount() > 0) {
							if (dpfByOneType.ground == 0) dpfByOneType.ground = unitTypeDPF[unitType1][unitType2].DPF;
							else dpfByOneType.ground = std::max(dpfByOneType.ground, unitTypeDPF[unitType1][unitType2].DPF);
						}
						if (unitType1.airWeapon().damageAmount() > 0 && unitType1.groundWeapon().damageAmount() > 0) {
							dpfByOneType.both = std::min(dpfByOneType.ground, dpfByOneType.air);
						}
					}
				}
			}
			unitDPF[unitType1] = dpfByOneType;
		}
	}
}

void DpsLearner::calculateDPS(std::vector<combatInfo> combats, int* indices, int indexToSkip)
{
	// be sure the learn stuff is clear
	clear();

	// feed stats from parser data
	int lastFrameArmy1, lastFrameArmy2;
	int sizeArmy1canAttackGround, sizeArmy1canAttackAir, sizeArmy1canAttackBoth;
	int sizeArmy2canAttackGround, sizeArmy2canAttackAir, sizeArmy2canAttackBoth;

	combatsProcessed = 0;

	for (std::size_t i = 0; i < combats.size(); ++i) {
		if (indices[i] == indexToSkip) continue;

		lastFrameArmy1 = combats[i].startFrame;
		lastFrameArmy2 = combats[i].startFrame;
		// count the total units in the army depends on target type
		sizeArmy1canAttackGround = sizeArmy1canAttackAir = sizeArmy1canAttackBoth = 0;
		sizeArmy2canAttackGround = sizeArmy2canAttackAir = sizeArmy2canAttackBoth = 0;

		for (auto army : combats[i].armySize1) {
			if (canAttackAirUnits(army.first)) {
				if (canAttackGroundUnits(army.first)) sizeArmy1canAttackBoth += army.second;
				else sizeArmy1canAttackAir += army.second;
			} else if (canAttackGroundUnits(army.first)) sizeArmy1canAttackGround += army.second;
			// if the unit cannot attack air or ground (bunker, spell caster, ...) consider attack type both
			else sizeArmy1canAttackBoth += army.second;
		}
		for (auto army : combats[i].armySize2) {
			if (canAttackAirUnits(army.first)) {
				if (canAttackGroundUnits(army.first)) sizeArmy2canAttackBoth += army.second;
				else sizeArmy2canAttackAir += army.second;
			} else if (canAttackGroundUnits(army.first)) sizeArmy2canAttackGround += army.second;
			// if the unit cannot attack air or ground (bunker, spell caster, ...) consider attack type both
			else sizeArmy2canAttackBoth += army.second;
		}

		// TODO killed units don't count their DPS (i.e. the survivor's DPS are overestimated)
		for (auto killed : combats[i].kills) {
			if (combats[i].armyUnits1.find(killed.unitID) != combats[i].armyUnits1.end()) {
				// army1 unit killed
				unitKilled(killed, lastFrameArmy2, combats[i].armyUnits1,
					sizeArmy2canAttackGround, sizeArmy2canAttackAir, sizeArmy2canAttackBoth,
					combats[i].armySize2, combats[i].armySize1,
					sizeArmy1canAttackGround, sizeArmy1canAttackAir, sizeArmy1canAttackBoth);
			} else {
				// army2 unit killed
				unitKilled(killed, lastFrameArmy1, combats[i].armyUnits2,
					sizeArmy1canAttackGround, sizeArmy1canAttackAir, sizeArmy1canAttackBoth,
					combats[i].armySize1, combats[i].armySize2,
					sizeArmy2canAttackGround, sizeArmy2canAttackAir, sizeArmy2canAttackBoth);
			}
		}
		++combatsProcessed;
	} // end for all combats

	// Calculate DPF

	// cleaning stats
	DPFcases.clear();
	for (int matchupInt = TVT; matchupInt != NONE; matchupInt++) {
		DPFcases.insert(std::pair<Matchups, DPFstats>(static_cast<Matchups>(matchupInt), DPFstats()));
	}

	for (auto unitType1 : BWAPI::UnitTypes::allUnitTypes()) {
		if (unitType1.getID() > 163) break; // no interesting units farther this point
		if ((unitType1.canAttack() || unitType1.isSpellcaster()) && !unitType1.isHero()) {
			DPFtype dpfByOneType;
			for (auto unitType2 : BWAPI::UnitTypes::allUnitTypes()) {
				if (unitType2.getID() > 163) break; // no interesting units farther this point
				if ((unitType2.canAttack() || unitType2.isSpellcaster()) && !unitType2.isHero()) {
					// identify the matchup
					Matchups vsType = getMatchupType(unitType1, unitType2);

					// update the DPF
					if (unitTypeDPF[unitType1][unitType2].totalTime == 0) {
						if (unitTypeDPFbwapi[unitType1][unitType2].DPF > 0) {
							DPFcases[vsType].size++;
							DPFcases[vsType].noCases++;
							unitTypeDPF[unitType1][unitType2].DPF = unitTypeDPFbwapi[unitType1][unitType2].DPF;
							// 							if (vsType == TVT) LOG("Missing " << unitType1.c_str() << " vs " << unitType2.c_str());
						}
					} else {
						DPFcases[vsType].size++;
						unitTypeDPF[unitType1][unitType2].DPF = unitTypeDPF[unitType1][unitType2].totalDamage / (double)unitTypeDPF[unitType1][unitType2].totalTime;
						// DPF* = ( DPF+DPFl*num combats ) / num combats +1
						//unitTypeDPF[unitType1][unitType2].DPF = (unitTypeDPFbwapi[unitType1][unitType2].DPF + unitTypeDPF[unitType1][unitType2].DPF * (double)unitTypeDPF[unitType1][unitType2].numCombats) / ((double)unitTypeDPF[unitType1][unitType2].numCombats + 1);

						// always get the max for dpfByOneType
						if (unitType1.airWeapon().damageAmount() > 0) {
							if (dpfByOneType.air == 0) dpfByOneType.air = unitTypeDPF[unitType1][unitType2].DPF;
							else dpfByOneType.air = std::max(dpfByOneType.air, unitTypeDPF[unitType1][unitType2].DPF);
						}
						if (unitType1.groundWeapon().damageAmount() > 0) {
							if (dpfByOneType.ground == 0) dpfByOneType.ground = unitTypeDPF[unitType1][unitType2].DPF;
							else dpfByOneType.ground = std::max(dpfByOneType.ground, unitTypeDPF[unitType1][unitType2].DPF);
						}
						if (unitType1.airWeapon().damageAmount() > 0 && unitType1.groundWeapon().damageAmount() > 0) {
							dpfByOneType.both = std::min(dpfByOneType.ground, dpfByOneType.air);
						}
					}
				}
			}
			unitDPF[unitType1] = dpfByOneType;
		}
	}
}

DpsLearner::Matchups DpsLearner::getMatchupType(BWAPI::UnitType type1, BWAPI::UnitType type2)
{
	if (type1.getRace() == BWAPI::Races::Terran && type2.getRace() == BWAPI::Races::Terran) {
		return TVT;
	}
	if ((type1.getRace() == BWAPI::Races::Terran && type2.getRace() == BWAPI::Races::Protoss) ||
		(type2.getRace() == BWAPI::Races::Terran && type1.getRace() == BWAPI::Races::Protoss)) {
		return TVP;
	}
	if ((type1.getRace() == BWAPI::Races::Terran && type2.getRace() == BWAPI::Races::Zerg) ||
		(type2.getRace() == BWAPI::Races::Terran && type1.getRace() == BWAPI::Races::Zerg)) {
		return TVZ;
	}
	if (type1.getRace() == BWAPI::Races::Protoss && type2.getRace() == BWAPI::Races::Protoss) {
		return PVP;
	}
	if ((type1.getRace() == BWAPI::Races::Protoss && type2.getRace() == BWAPI::Races::Zerg) ||
		(type2.getRace() == BWAPI::Races::Protoss && type1.getRace() == BWAPI::Races::Zerg)) {
		return PVZ;
	}
	if (type1.getRace() == BWAPI::Races::Zerg && type2.getRace() == BWAPI::Races::Zerg) {
		return PVP;
	}
	return DpsLearner::Matchups::NONE;
}

std::string DpsLearner::getMatchupName(Matchups matchId)
{
	switch (matchId) {
	case TVT: return "TVT"; break;
	case TVP: return "TVP"; break;
	case TVZ: return "TVZ"; break;
	case PVP: return "PVP"; break;
	case PVZ: return "PVZ"; break;
	case ZVZ: return "ZVZ"; break;
	default: return "NONE"; break;
	}
}

void DpsLearner::damageUpperBound()
{
	// cleaning stats
	DPFbounded.clear();
	for (int matchupInt = TVT; matchupInt != NONE; matchupInt++) {
		DPFbounded.insert(std::pair<Matchups, int>(static_cast<Matchups>(matchupInt), 0));
	}

	for (auto unitType1 : BWAPI::UnitTypes::allUnitTypes()) {
		for (auto unitType2 : BWAPI::UnitTypes::allUnitTypes()) {
			if (unitTypeDPFbwapi[unitType1][unitType2].DPF > 0 &&
				unitTypeDPFbwapi[unitType1][unitType2].DPF < unitTypeDPF[unitType1][unitType2].DPF) {
				unitTypeDPF[unitType1][unitType2].DPF = unitTypeDPFbwapi[unitType1][unitType2].DPF;

				Matchups vsType = getMatchupType(unitType1, unitType2);
				DPFbounded[vsType]++;
			}
		}
	}
}

void DpsLearner::unitKilled(unitKilledInfo killed, int &lastFrame, std::map<int, unitInfo> armyUnits, 
							int sizeArmyCanAttackGround, int sizeArmyCanAttackAir, int sizeArmyCanAttackBoth, 
							std::map<BWAPI::UnitType, int> armySizeAttacking, std::map<BWAPI::UnitType, int> &armySizeDefending,
							int &sizeArmyCanAttackGround2, int &sizeArmyCanAttackAir2, int &sizeArmyCanAttackBoth2)
{
	int framesToKill = killed.frame - lastFrame;
	int damageDone = armyUnits[killed.unitID].HP + armyUnits[killed.unitID].shield;
	int unitTypeIdKilled = armyUnits[killed.unitID].typeID;
	BWAPI::UnitType unitTypeKilled(unitTypeIdKilled);
	int totalUnitsAttacking = sizeArmyCanAttackBoth;
	if (unitTypeKilled.isFlyer()) totalUnitsAttacking += sizeArmyCanAttackAir;
	else totalUnitsAttacking += sizeArmyCanAttackGround;
	if (totalUnitsAttacking == 0) {
		//DEBUG("Damage done without units");
		return;
	}
	double damageSplit = (double)damageDone / (double)totalUnitsAttacking;
	
	for (auto typeSize : armySizeAttacking) {
		if (typeSize.second <= 0) continue; // we have less than 0 units of this type
		if (typeSize.first > BWAPI::UnitTypes::AllUnits) continue; // unknown unit
		// check if we can attack the target to apply damage
		BWAPI::UnitType unitType(typeSize.first);
		if (   (unitTypeKilled.isFlyer() && canAttackAirUnits(unitType))
			|| (!unitTypeKilled.isFlyer() && canAttackGroundUnits(unitType))
			|| (!canAttackAirUnits(unitType) && !canAttackGroundUnits(unitType)) ) { // bunker, spell caster, transporter, ...
			double damageIncrease = damageSplit * (double)typeSize.second;
			// TODO change all numeric_limits and use their constant values
			if (damageIncrease > DBL_MAX - damageIncrease) {
				DEBUG("Damage Overflow! unitKilledID " << killed.unitID << " frame: " << killed.frame);
			}
			if (framesToKill > INT_MAX - framesToKill) {
				DEBUG("Frames Overflow! unitKilledID " << killed.unitID << " frame: " << killed.frame);
			}
			if (unitTypeDPF[typeSize.first][unitTypeIdKilled].totalTime + framesToKill < 0) {
				DEBUG("Frames Negative! unitKilledID " << killed.unitID << " frame: " << killed.frame);
			}
			unitTypeDPF[typeSize.first][unitTypeIdKilled].totalDamage += damageIncrease;
			unitTypeDPF[typeSize.first][unitTypeIdKilled].totalTime += framesToKill;
			unitTypeDPF[typeSize.first][unitTypeIdKilled].numCombats++;
			//unitTypeDPF[typeSize.first][unitTypeIdKilled].DPF += (damageIncrease - unitTypeDPF[typeSize.first][unitTypeIdKilled].DPF) / framesToKill;
			//this->currentAverage += (value - this->currentAverage) / ++count;
		}
	}
	// remove lost
	armySizeDefending[killed.unitID]--;
	if (canAttackAirUnits(unitTypeKilled)) {
		if (canAttackGroundUnits(unitTypeKilled)) --sizeArmyCanAttackBoth2;
		else --sizeArmyCanAttackAir2;
	} else if (canAttackGroundUnits(unitTypeKilled)) --sizeArmyCanAttackGround2;
	lastFrame = killed.frame;
}

// computes type priority using Borda Count http://en.wikipedia.org/wiki/Borda_count
// we give points when first unit type is killed
void DpsLearner::calculatePriorityFirstBorda()
{
	int unitTypeID;
	bool checkingDecisionsArmy1, checkingDecisionsArmy2;
	std::map<BWAPI::UnitType, int> typesToDestroy1, typesToDestroy2;
	bool skipCombatFlag = false;

	for (auto combat : allCombats) {
		typesToDestroy1.clear();
		typesToDestroy2.clear();

		// first check if we can decide targets
		if (combat.armySize1.size() < 2) {
			checkingDecisionsArmy2 = false;
		} else {
			checkingDecisionsArmy2 = true;
			typesToDestroy1 = combat.armySize1;
		}
		if (combat.armySize2.size() < 2) {
			checkingDecisionsArmy1 = false;
		} else {
			checkingDecisionsArmy1 = true;
			typesToDestroy2 = combat.armySize2;
		}

		// if target is clear go to next combat
		if (!checkingDecisionsArmy1 && !checkingDecisionsArmy1) continue;

		// skip combats with mines or transports
		for (auto army : combat.armySize1) {
			if (army.first.spaceProvided() > 0 || army.first == BWAPI::UnitTypes::Terran_Vulture_Spider_Mine) {
				skipCombatFlag = true;
				break;
			}
		}
		for (auto army : combat.armySize2) {
			if (army.first.spaceProvided() > 0 || army.first == BWAPI::UnitTypes::Terran_Vulture_Spider_Mine) {
				skipCombatFlag = true;
				break;
			}
		}
		if (skipCombatFlag) {
			skipCombatFlag = false;
			continue;
		}


		// give points when a new type is destroyed
		std::pair<std::set<int>::iterator, bool> ret;
		for (auto killedInfo : combat.kills) {
			if (combat.armyUnits1.find(killedInfo.unitID) != combat.armyUnits1.end()) {
				// unit killed is from army1
				if (checkingDecisionsArmy2) {
					unitTypeID = combat.armyUnits1[killedInfo.unitID].typeID;
					BWAPI::UnitType unitType(unitTypeID);
					// look if the type is still there to kill it
					if (typesToDestroy1.find(unitType) != typesToDestroy1.end()) {
						typesToDestroy1.erase(unitType);
// 						typePriority[unitTypeID] += typesToDestroy1.size();
						bordaCount[unitTypeID].score += typesToDestroy1.size();
						++bordaCount[unitTypeID].frequency;
						if (typesToDestroy1.size() < 2) checkingDecisionsArmy2 = false;
					}
				}
			} else {
				// unit killed is from army2
				if (checkingDecisionsArmy1) {
					unitTypeID = combat.armyUnits2[killedInfo.unitID].typeID;
					BWAPI::UnitType unitType(unitTypeID);
					// look if the type is still there to kill it
					if (typesToDestroy2.find(unitType) != typesToDestroy2.end()) {
						typesToDestroy2.erase(unitType);
// 						typePriority[unitTypeID] += typesToDestroy2.size();
						bordaCount[unitTypeID].score += typesToDestroy2.size();
						++bordaCount[unitTypeID].frequency;
						if (typesToDestroy2.size() < 2) checkingDecisionsArmy1 = false;
					}
				}
			}
			
			// need keep giving points?
			if (!checkingDecisionsArmy1 && !checkingDecisionsArmy1) break;
		}
	}

	// calculate avg. Borda count
	for (auto unitType : BWAPI::UnitTypes::allUnitTypes()) {
		if (bordaCount[unitType].frequency == 0) continue; // skip if no data
		typePriority[unitType] = (float)bordaCount[unitType].score / (float)bordaCount[unitType].frequency;
	}
}

// computes type priority using Borda Count http://en.wikipedia.org/wiki/Borda_count
// we give points when first unit type is killed
void DpsLearner::calculatePriorityLastBorda()
{
	int unitTypeID;
	bool checkingDecisionsArmy1, checkingDecisionsArmy2;
	std::map<BWAPI::UnitType, int> typesToDestroy1, typesToDestroy2;

	for (auto combat : allCombats) {
		typesToDestroy1.clear();
		typesToDestroy2.clear();

		// first check if we can decide targets
		if (combat.armySize1.size() < 2) {
			checkingDecisionsArmy2 = false;
		} else {
			checkingDecisionsArmy2 = true;
			typesToDestroy1 = combat.armySize1;
		}
		if (combat.armySize2.size() < 2) {
			checkingDecisionsArmy1 = false;
		} else {
			checkingDecisionsArmy1 = true;
			typesToDestroy2 = combat.armySize2;
		}

		// if target is clear go to next combat
		if (!checkingDecisionsArmy1 && !checkingDecisionsArmy1) continue;


		// give points when a new type is destroyed
		std::pair<std::set<int>::iterator, bool> ret;
		for (auto killedInfo : combat.kills) {
			if (combat.armyUnits1.find(killedInfo.unitID) != combat.armyUnits1.end()) {
				// unit killed is from army1
				if (checkingDecisionsArmy2) {
					unitTypeID = combat.armyUnits1[killedInfo.unitID].typeID;
					BWAPI::UnitType unitType(unitTypeID);
					// look if the type is still there to kill it
					if (typesToDestroy1.find(unitType) != typesToDestroy1.end()) {
						typesToDestroy1[unitType]--;
						if (typesToDestroy1[unitType] <= 0) {
							typesToDestroy1.erase(unitType);
// 							typePriority[unitTypeID] += typesToDestroy1.size();
							bordaCount[unitTypeID].score += typesToDestroy1.size();
							++bordaCount[unitTypeID].frequency;
							if (typesToDestroy1.size() < 2) checkingDecisionsArmy2 = false;
						}
					}
				}
			} else {
				// unit killed is from army2
				if (checkingDecisionsArmy1) {
					unitTypeID = combat.armyUnits2[killedInfo.unitID].typeID;
					BWAPI::UnitType unitType(unitTypeID);
					// look if the type is still there to kill it
					if (typesToDestroy2.find(unitType) != typesToDestroy2.end()) {
						typesToDestroy2[unitType]--;
						if (typesToDestroy2[unitType] <= 0) {
							typesToDestroy2.erase(unitType);
// 							typePriority[unitTypeID] += typesToDestroy2.size();
							bordaCount[unitTypeID].score += typesToDestroy2.size();
							++bordaCount[unitTypeID].frequency;
							if (typesToDestroy2.size() < 2) checkingDecisionsArmy1 = false;
						}
					}
				}
			}

			// need keep giving points?
			if (!checkingDecisionsArmy1 && !checkingDecisionsArmy1) break;
		}
	}

	// calculate avg. Borda count
	for (auto unitType : BWAPI::UnitTypes::allUnitTypes()) {
		if (bordaCount[unitType].frequency == 0) continue; // skip if no data
		typePriority[unitType] = (float)bordaCount[unitType].score / (float)bordaCount[unitType].frequency;
	}
}

int DpsLearner::getUnitTypefromID(combatInfo combat, int unitID)
{
	if (combat.armyUnits1.find(unitID) == combat.armyUnits1.end()) {
		// not fond in army1, should be in army 2
		return combat.armyUnits2[unitID].typeID;
	} else {
		// found in army1
		return combat.armyUnits1[unitID].typeID;
	}
}

void DpsLearner::clear()
{
	for (int i = 0; i < MAX_UNIT_TYPE; ++i) {
		unitTypeDPF[i].clear();
		unitTypeDPF[i].resize(MAX_UNIT_TYPE);
	}
	typePriority.clear();
	typePriority.resize(MAX_UNIT_TYPE);
	bordaCount.clear();
	bordaCount.resize(MAX_UNIT_TYPE);

	armyDestroyed = 0;
	armyReinforcement = 0;
	armyPeace = 0;
	armyGameEnd = 0;
}