/*	-----------------------------------------------------------------------------
	M A A S C R A F T

	StarCraft: Brood War - Bot

	Author: Dennis Soemers
	Maastricht University
	-----------------------------------------------------------------------------
*/

/*
	Implementation of UnitUtils.h
*/

#include "CommonIncludes.h"

#include "Distances.h"
#include "MathConstants.h"
#include "OpponentTracker.h"
#include "UnitUtils.h"

using namespace BWAPI;
using namespace Distances;

// Implementations of all radius / distance related functions largely based on BWAPI implementations, but
// made more efficient (hopefully) by replacing distance computations by squared-distance comparisons
//
// BWAPI implementations: https://github.com/bwapi/bwapi/blob/master/bwapi/BWAPILIB/Source/Unit.cpp

const bool UnitUtils::canAttack(const Unit unit)
{
	return (unit->getGroundWeaponCooldown() == 0 && unit->getAirWeaponCooldown() == 0 && !unit->isStartingAttack());
}

const bool UnitUtils::canMove(const Unit unit)
{
	return (canAttack(unit) || !unit->isAttackFrame());
}

const bool UnitUtils::isInCombat(const Unit unit)
{
	UnitType type = unit->getType();

	if(type.isFlyer())
	{
		return (OpponentTracker::Instance()->getAirThreat(unit->getPosition()) > 0.0f											||
				!(UnitUtils::getUnitsInWeaponRange(unit, type.airWeapon(), Filter::IsEnemy && Filter::IsDetected).empty())		||
				!(UnitUtils::getUnitsInWeaponRange(unit, type.groundWeapon(), Filter::IsEnemy && Filter::IsDetected).empty())	||
				!(Broodwar->getUnitsInRadius(unit->getPosition(), 400, Filter::IsEnemy && Filter::IsBuilding).empty())				);
	}
	else
	{
		return (OpponentTracker::Instance()->getGroundThreat(unit->getPosition()) > 0.0f										||
				!(UnitUtils::getUnitsInWeaponRange(unit, type.airWeapon(), Filter::IsEnemy && Filter::IsDetected).empty())		||
				!(UnitUtils::getUnitsInWeaponRange(unit, type.groundWeapon(), Filter::IsEnemy && Filter::IsDetected).empty())	||
				!(Broodwar->getUnitsInRadius(unit->getPosition(), 400, Filter::IsEnemy && Filter::IsBuilding).empty())				);
	}
}

const float UnitUtils::dpsAir(const Unit unit)
{
	UnitType type = unit->getType();

	if(type == UnitTypes::Protoss_High_Templar)
	{
		return 112.0f / 3.0f;
	}
	else if(type == UnitTypes::Protoss_Dark_Archon)
	{
		return 50.0f;
	}
	else if(type == UnitTypes::Protoss_Carrier)
	{
		return 8 * 6 * (24.0f / 37);
	}

	WeaponType weapon = type.airWeapon();

	if(weapon == WeaponTypes::None)
		return -1.0f;
	else
	{
		Player player = unit->getPlayer();
		return (player->damage(weapon) * (24.0f / weapon.damageCooldown()));
	}
}

const float UnitUtils::dpsAir(const UnitData& data)
{
	UnitType type = data.unitType;

	if(type == UnitTypes::Protoss_High_Templar)
	{
		return 112.0f / 3.0f;
	}
	else if(type == UnitTypes::Protoss_Dark_Archon)
	{
		return 50.0f;
	}
	else if(type == UnitTypes::Protoss_Carrier)
	{
		return 8 * 6 * (24.0f / 37);
	}

	WeaponType weapon = type.airWeapon();

	if(weapon == WeaponTypes::None)
		return -1.0f;
	else
	{
		Player player = data.unitPtr->getPlayer();
		return (player->damage(weapon) * (24.0f / weapon.damageCooldown()));
	}
}

const float UnitUtils::dpsGround(const Unit unit)
{
	UnitType type = unit->getType();

	if(type == UnitTypes::Terran_Medic)
	{
		return 1.0f;
	}
	else if(type == UnitTypes::Protoss_High_Templar)
	{
		return 112.0f / 3.0f;
	}
	else if(type == UnitTypes::Protoss_Dark_Archon)
	{
		return 50.0f;
	}
	else if(type == UnitTypes::Protoss_Reaver)
	{
		return 100 * (24.0f / 60);
	}
	else if(type == UnitTypes::Protoss_Carrier)
	{
		return 8 * 6 * (24.0f / 37);
	}

	WeaponType weapon = type.groundWeapon();

	if(weapon == WeaponTypes::None)
		return -1.0f;
	else
	{
		Player player = unit->getPlayer();

		float groundDPS = (player->damage(weapon) * (24.0f / player->weaponDamageCooldown(type)));

		if(type == UnitTypes::Terran_Firebat || type == UnitTypes::Protoss_Zealot)
			groundDPS = 2.0f * groundDPS;

		return groundDPS;
	}
}

const float UnitUtils::dpsGround(const UnitData& data)
{
	UnitType type = data.unitType;

	if(type == UnitTypes::Terran_Medic)
	{
		return 1.0f;
	}
	else if(type == UnitTypes::Protoss_High_Templar)
	{
		return 112.0f / 3.0f;
	}
	else if(type == UnitTypes::Protoss_Dark_Archon)
	{
		return 50.0f;
	}
	else if(type == UnitTypes::Protoss_Reaver)
	{
		return 100 * (24.0f / 60);
	}
	else if(type == UnitTypes::Protoss_Carrier)
	{
		return 8 * 6 * (24.0f / 37);
	}
	else if(type == UnitTypes::Terran_Vulture_Spider_Mine)
	{
		return 50.0f;
	}

	WeaponType weapon = type.groundWeapon();

	if(weapon == WeaponTypes::None)
		return -1.0f;
	else
	{
		Player player = data.unitPtr->getPlayer();

		float groundDPS = (player->damage(weapon) * (24.0f / player->weaponDamageCooldown(type)));

		if(type == UnitTypes::Terran_Firebat || type == UnitTypes::Protoss_Zealot)
			groundDPS = 2.0f * groundDPS;

		return groundDPS;
	}
}

const float UnitUtils::dpsToTarget(const Unit attacker, const Unit target)
{
	UnitType attackerType = attacker->getType();
	UnitType targetType = target->getType();

	float dps = 0.0f;
	WeaponType weapon = WeaponTypes::None;

	if(targetType.isFlyer() || target->isLifted())
	{
		dps = dpsAir(attacker);
		weapon = attackerType.airWeapon();
	}
	else
	{
		dps = dpsGround(attacker);
		weapon = attackerType.groundWeapon();
	}

	if(dps > 0.0f)
	{
		if(weapon.damageType() == DamageTypes::Explosive)
		{
			if(targetType.size() == UnitSizeTypes::Small)
				dps *= 0.5f;
			else if(targetType.size() == UnitSizeTypes::Medium)
				dps *= 0.75f;
		}
		else if(weapon.damageType() == DamageTypes::Concussive)
		{
			if(targetType.size() == UnitSizeTypes::Large)
				dps *= 0.25f;
			else if(targetType.size() == UnitSizeTypes::Medium)
				dps *= 0.5f;
		}
	}

	if(attackerType == UnitTypes::Terran_Bunker && (attacker->isCompleted() || !attacker->isVisible()))
		dps += 38.4f;					// assuming 4 marines inside bunker

	if(attacker->isVisible() && !attacker->isCompleted())
		return dps *= 0.1f;

	return dps;
}

const float UnitUtils::damageToTarget(const Unit attacker, const Unit target)
{
	UnitType attackerType = attacker->getType();
	UnitType targetType = target->getType();

	float damage = 0.0f;
	WeaponType weapon = WeaponTypes::None;

	if(targetType.isFlyer() || target->isLifted())
	{
		weapon = attackerType.airWeapon();

		if(weapon == WeaponTypes::None)
		{
			damage = -1.0f;
		}
		else
		{
			Player player = attacker->getPlayer();

			damage = (player->damage(weapon));
		}
	}
	else
	{
		weapon = attackerType.groundWeapon();

		if(weapon == WeaponTypes::None)
		{
			damage = -1.0f;
		}
		else
		{
			Player player = attacker->getPlayer();

			damage = (player->damage(weapon));

			if(attackerType == UnitTypes::Terran_Firebat || attackerType == UnitTypes::Protoss_Zealot)
				damage = 2.0f * damage;
		}
	}

	if(damage > 0.0f)
	{
		if(weapon.damageType() == DamageTypes::Explosive)
		{
			if(targetType.size() == UnitSizeTypes::Small)
				damage *= 0.5f;
			else if(targetType.size() == UnitSizeTypes::Medium)
				damage *= 0.75f;
		}
		else if(weapon.damageType() == DamageTypes::Concussive)
		{
			if(targetType.size() == UnitSizeTypes::Large)
				damage *= 0.25f;
			else if(targetType.size() == UnitSizeTypes::Medium)
				damage *= 0.5f;
		}
	}

	return damage;
}

const int UnitUtils::hp(const Unit unit)
{
	return (unit->getHitPoints() + unit->getShields());
}

const int UnitUtils::hp(const UnitData& data)
{
	return (data.lastHP + data.unitPtr->getShields());
}

const int UnitUtils::rangeAir(const Unit unit)
{
	UnitType type = unit->getType();
	WeaponType weapon = type.airWeapon();

	if(weapon == WeaponTypes::None)
	{
		if(type == UnitTypes::Protoss_High_Templar)
			return 288;
		else if(type == UnitTypes::Protoss_Dark_Archon)
			return 320;
		else if(type == UnitTypes::Protoss_Carrier)
			return 384;
		else
			return 0;
	}
	else
	{
		Player player = unit->getPlayer();
		return (player->weaponMaxRange(weapon));
	}
}

const int UnitUtils::rangeAir(const UnitData& data)
{
	UnitType type = data.unitType;
	WeaponType weapon = type.airWeapon();

	if(weapon == WeaponTypes::None)
	{
		if(type == UnitTypes::Protoss_High_Templar)
			return 288;
		else if(type == UnitTypes::Protoss_Dark_Archon)
			return 320;
		else if(type == UnitTypes::Protoss_Carrier)
			return 384;
		else
			return 0;
	}
	else
	{
		Player player = data.unitPtr->getPlayer();
		return (player->weaponMaxRange(weapon));
	}
}

const int UnitUtils::rangeGround(const Unit unit)
{
	UnitType type = unit->getType();
	WeaponType weapon = type.groundWeapon();

	if(weapon == WeaponTypes::None)
	{
		if(type == UnitTypes::Protoss_High_Templar)
			return 288;
		else if(type == UnitTypes::Protoss_Dark_Archon)
			return 320;
		else if(type == UnitTypes::Protoss_Reaver)
			return 256;
		else if(type == UnitTypes::Protoss_Carrier)
			return 384;
		else
			return 0;
	}
	else
	{
		Player player = unit->getPlayer();
		return (player->weaponMaxRange(weapon));
	}
}

const int UnitUtils::rangeGround(const UnitData& data)
{
	UnitType type = data.unitType;
	WeaponType weapon = type.groundWeapon();

	if(weapon == WeaponTypes::None)
	{
		if(type == UnitTypes::Protoss_High_Templar)
			return 288;
		else if(type == UnitTypes::Protoss_Dark_Archon)
			return 320;
		else if(type == UnitTypes::Protoss_Reaver)
			return 256;
		else if(type == UnitTypes::Protoss_Carrier)
			return 384;
		else
			return 0;
	}
	else if(type == UnitTypes::Terran_Siege_Tank_Siege_Mode)
	{
		return 12 * 32;
	}
	else
	{
		Player player = data.unitPtr->getPlayer();
		return (player->weaponMaxRange(weapon));
	}
}

const bool UnitUtils::isUnitValid(const Unit unit)
{
	return (unit != nullptr											&&
			(unit->isCompleted() || unit->getType().isBuilding())	&&
			unit->getHitPoints() > 0								&& 
			unit->exists()											&& 
			unit->getType() != UnitTypes::Unknown					&&
			unit->getPosition() != Positions::Invalid				&&
			unit->getPosition() != Positions::Unknown				&&
			unit->getPosition() != Positions::None						);
}

const Unit UnitUtils::getClosestUnit(const Position& center, const Unitset& units)
{
	int minDistSq = MathConstants::MAX_INT;
	Unit closestUnit = nullptr;

	auto it = units.begin();
	auto it_end = units.end();

	for(/**/; it != it_end; ++it)
	{
		Unit u = *it;
		const int distSq = getSquaredDistance(center, u);

		if(distSq < minDistSq)
		{
			minDistSq = distSq;
			closestUnit = u;
		}
	}

	return closestUnit;
}

const Unit UnitUtils::getClosestUnit(const Unit unit, const Unitset& units)
{
	int minDistSq = MathConstants::MAX_INT;
	Unit closestUnit = nullptr;

	auto it = units.begin();
	auto it_end = units.end();

	for(/**/; it != it_end; ++it)
	{
		Unit u = *it;
		const int distSq = getSquaredDistance(unit, u);

		if(distSq < minDistSq)
		{
			minDistSq = distSq;
			closestUnit = u;
		}
	}

	return closestUnit;
}

const Unit UnitUtils::getClosestUnitInRadius(const Position& center, const int radius, const UnitFilter& pred)
{
	return getClosestUnit(center, getUnitsInRadius(center, radius, pred));
}

const Unit UnitUtils::getClosestUnitInRadius(const Unit unit, const int radius, const UnitFilter& pred)
{
	return getClosestUnit(unit, getUnitsInRadius(unit, radius, pred));
}

const Unit UnitUtils::getClosestUnitInRectangle(const Position& center, const int left, const int top, 
												const int right, const int bottom, const UnitFilter& pred)
{
	Unitset unitsInRectangle = Broodwar->getUnitsInRectangle(left, top, right, bottom, pred);

	int minDistSq = MathConstants::MAX_INT;
	Unit closestUnit = nullptr;

	auto it = unitsInRectangle.begin();
	auto it_end = unitsInRectangle.end();

	for(/**/; it != it_end; ++it)
	{
		Unit u = *it;
		const int distSq = getSquaredDistance(center, u->getPosition());

		if(distSq < minDistSq)
		{
			minDistSq = distSq;
			closestUnit = u;
		}
	}

	return closestUnit;
}

// radius in pixels!
const Unitset UnitUtils::getUnitsInRadius(const Position& center, const int radius, const UnitFilter& pred)
{
	const int squaredRadius = radius*radius;

	return Broodwar->getUnitsInRectangle(center.x - radius,
                                         center.y - radius,
                                         center.x + radius,
                                         center.y + radius,
                                         [&](Unit u){ return (withinSquaredDistance(center, u->getPosition(), squaredRadius) && (!pred.isValid() || pred(u))); });
}

// radius in pixels!
const Unitset UnitUtils::getUnitsInRadius(const Unit unit, const int radius, const UnitFilter& pred)
{
	if(!unit->exists())
		return Unitset::none;

	const int squaredRadius = radius*radius;

    return Broodwar->getUnitsInRectangle(unit->getLeft() - radius,
                                         unit->getTop() - radius,
                                         unit->getRight() + radius,
                                         unit->getBottom() + radius,
                                         [&](Unit u){ return (unit != u && withinSquaredDistance(unit->getPosition(), u->getPosition(), squaredRadius) && (!pred.isValid() || pred(u))); });
}

const Unitset UnitUtils::getUnitsInWeaponRange(const Unit unit, const WeaponType weapon, const UnitFilter& pred)
{
	if(weapon == WeaponTypes::None)
		return Unitset();

	Player player = unit->getPlayer();
	const int max = player->weaponMaxRange(weapon);
	const int maxSq = max*max;
	const int min = weapon.minRange();
	const int minSq = min*min;

    return Broodwar->getUnitsInRectangle(unit->getLeft() - max,
                                         unit->getTop() - max,
                                         unit->getRight() + max,
                                         unit->getBottom() + max,
                                         [&](Unit u)->bool
                                         {
											// Unit check and unit status
											if ( u == unit || u->isInvincible() )
												return false;

											// Weapon distance check
											const int distSq = getSquaredDistance(unit, u);
											if ( (min && distSq < minSq) || distSq > maxSq )
												return false;

											if(!weapon.targetsAir() && (u->isFlying() || u->isLifted()))
												return false;

											// Weapon behavioural checks
											/*
											UnitType ut = u->getType();
											if ( ( weapon.targetsOwn() && u->getPlayer() != player )		||
													( !weapon.targetsAir() && u->isFlying() )				||
													( !weapon.targetsGround() && !u->isFlying() )			||
													( !weapon.targetsMechanical() && ut.isMechanical() )	||
													( !weapon.targetsOrganic() && ut.isOrganic() )			||
													( weapon.targetsNonBuilding() && ut.isBuilding() )		||
													( weapon.targetsNonRobotic() && ut.isRobotic() )		||
													( weapon.targetsOrgOrMech() && !(ut.isOrganic() || ut.isMechanical())) )
												return false;*/

											return (!pred.isValid() || pred(u));
                                        });
}

const Unitset UnitUtils::getUnitsInWeaponRange(const Unit unit, const WeaponType weapon, const BWAPI::Unitset& units)
{
	Unitset unitsInRange;

	if(weapon == WeaponTypes::None)
		return unitsInRange;

	Player player = unit->getPlayer();
	const int max = player->weaponMaxRange(weapon);
	const int maxSq = max*max;
	const int min = weapon.minRange();
	const int minSq = min*min;

	auto it = units.begin();
	auto it_end = units.end();

	for(/**/; it != it_end; ++it)
	{
		Unit u = *it;

		if(u == unit || u->isInvincible())
			continue;

		const int distSq = getSquaredDistance(unit, u);
		if((min && distSq < minSq) || distSq > maxSq)
			continue;
		/*
		UnitType ut = u->getType();
		if ( (( weapon.targetsOwn() && u->getPlayer() != player )	||
			( !weapon.targetsAir() && !u->isFlying() )				||
			( !weapon.targetsGround() && u->isFlying() )			||
			( weapon.targetsMechanical() && ut.isMechanical() )		||
			( weapon.targetsOrganic() && ut.isOrganic() )			||
			( weapon.targetsNonBuilding() && !ut.isBuilding() )		||
			( weapon.targetsNonRobotic() && !ut.isRobotic() )		||
			( weapon.targetsOrgOrMech() && (ut.isOrganic() || ut.isMechanical()) )) )
			continue;*/

		unitsInRange.push_back(u);
	}

	return unitsInRange;
}