#include "BananaBrain.h"

static double calculate_potential(Unit unit,Position position,std::function<void(UnitPotential&)>& potential_function)
{
	UnitPotential unit_potential(unit, position);
	potential_function(unit_potential);
	return unit_potential.value();
}

bool unit_potential(Unit unit,std::function<void(UnitPotential&)> potential_function)
{
	Position initial_position = predict_position(unit);
	if (!check_collision(unit, initial_position)) initial_position = unit->getPosition();
	Position current_position = initial_position;
	double current_potential = calculate_potential(unit, current_position, potential_function);
	
	int step_size = unit->isFlying() ? 8 : 4;
	
	for (int step = 0; step < 16; step++) {
		Position best_position = current_position;
		double best_potential = current_potential;
		bool better_position_found = false;
		
		for (Position delta_position : { Position(-step_size, 0), Position(step_size, 0), Position(0, -step_size), Position(0, step_size),
										 Position(-step_size, -step_size), Position(step_size, -step_size), Position(step_size, step_size), Position(-step_size, step_size) }) {
			Position new_position = current_position + delta_position;
			if (new_position.isValid() && check_collision(unit, new_position)) {
				double new_potential = calculate_potential(unit, new_position, potential_function);
				if (new_potential > best_potential) {
					best_position = new_position;
					best_potential = new_potential;
					better_position_found = true;
				}
			}
		}
		
		if (!better_position_found) break;
		current_position = best_position;
		current_potential = best_potential;
	}
	
	current_position.makeValid();
	
	if (unit->isFlying() && unit->getType().haltDistance() > 1) {
		int distance = initial_position.getApproxDistance(current_position);
		if (distance > unit->getPlayer()->topSpeed(unit->getType()) * Broodwar->getRemainingLatencyFrames() &&
			distance * 256 < unit->getType().haltDistance()) {
			Position delta = current_position - initial_position;
			double d_halt = 1.02 * unit->getType().haltDistance() / 256.0;
			double dx = d_halt * delta.x / distance;
			double dy = d_halt * delta.y / distance;
			Position new_position = initial_position + Position(int(dx + 0.5), int(dy + 0.5));
			if (new_position.isValid()) current_position = new_position;
		}
	}
	
	if (current_position != initial_position) {
		unit_move(unit, current_position);
		return true;
	} else {
		return false;
	}
}

void UnitPotential::add_potential(UnitType type,Position position,double potential,int max_distance)
{
	Position delta = edge_to_edge_delta(unit_->getType(), position_, type, position);
	double distance = distance_to_bwcircle(delta, max_distance);
	if (distance < 0.0) {
		value_ += (distance * potential);
	}
}

void UnitPotential::add_potential(Unit unit,double potential,int max_distance)
{
	add_potential(unit->getType(), predict_position(unit), potential, max_distance);
}

void UnitPotential::add_potential(Position position,double potential,int max_distance)
{
	Position delta = edge_to_point_delta(unit_->getType(), position_, position);
	double distance = distance_to_bwcircle(delta, max_distance);
	if (distance < 0.0) {
		value_ += (distance * potential);
	}
}

void UnitPotential::add_potential(Position position,double potential)
{
	Position delta = edge_to_point_delta(unit_->getType(), position_, position);
	double distance = std::sqrt(delta.x * delta.x + delta.y * delta.y);
	value_ += (distance * potential);
}

void UnitPotential::repel_units(const std::vector<Unit>& units,Unit skip_unit)
{
	for (Unit other_unit : units) {
		if (other_unit == skip_unit || !other_unit->isCompleted() || is_disabled(other_unit) || !other_unit->isPowered()) continue;
		
		int max_distance = weapon_max_range(other_unit, unit_->isFlying());
		if (max_distance >= 0) add_potential(other_unit, 1.0, max_distance + 24);
	}
}

void UnitPotential::repel_storms()
{
	double largest_storm_evasion_distance = 0.0;
	std::vector<Position> positions = micro_manager.list_existing_storm_positions();
	
	for (auto& position : positions) {
		Position delta = edge_to_point_delta(unit_->getType(), position_, position);
		double distance = distance_to_bwcircle(delta, WeaponTypes::Psionic_Storm.outerSplashRadius() + 32);
		largest_storm_evasion_distance = std::max(largest_storm_evasion_distance, -distance);
	}
	
	value_ += (-largest_storm_evasion_distance);
}

void UnitPotential::repel_detectors(double potential)
{
	double largest_detection_evasion_distance = 0.0;
	
	for (auto& enemy_unit : information_manager.enemy_units()) {
		int range = enemy_unit->detection_range();
		if (range >= 0 &&
			(!enemy_unit->unit->exists() || enemy_unit->unit->isPowered())) {
			Position delta = edge_to_edge_delta(unit_->getType(), position_, enemy_unit->type, enemy_unit->position);
			double distance = distance_to_bwcircle(delta, range + 32);
			largest_detection_evasion_distance = std::max(largest_detection_evasion_distance, -distance);
		}
	}
	
	value_ += ((-largest_detection_evasion_distance) * potential);
}

void UnitPotential::repel_emps()
{
	double largest_emp_evasion_distance = 0.0;
	
	for (auto& bullet : Broodwar->getBullets()) {
		if (bullet->getType() == BulletTypes::EMP_Missile && bullet->getTargetPosition().isValid()) {
			Position delta = edge_to_point_delta(unit_->getType(), position_, bullet->getTargetPosition());
			double distance = distance_to_square(delta, WeaponTypes::EMP_Shockwave.outerSplashRadius() + 32);
			largest_emp_evasion_distance = std::max(largest_emp_evasion_distance, -distance);
		}
	}
	
	value_ += (-largest_emp_evasion_distance);
}

void UnitPotential::repel_buildings()
{
	if (unit_->isFlying()) return;
	
	for (auto& entry : information_manager.all_units()) {
		const InformationUnit& information_unit = entry.second;
		if (information_unit.type.isBuilding() && !information_unit.flying) {
			add_potential(information_unit.type, information_unit.position, 1.0, 32);
		}
	}
}

void UnitPotential::repel_terrain()
{
	if (unit_->isFlying()) return;
	int distance = distance_to_terrain(position_) - 32;
	if (distance < 0) {
		value_ += (distance * 1.0);
	}
}

void UnitPotential::kite_units(const std::vector<Unit>& units,Unit skip_unit)
{
	for (Unit other_unit : units) {
		if (other_unit == skip_unit || !other_unit->isCompleted() || is_disabled(other_unit) || !other_unit->isPowered()) continue;
		
		int max_distance = weapon_max_range(other_unit, unit_->isFlying());
		if (max_distance >= 0) {
			int attack_distance = weapon_max_range(unit_, other_unit->isFlying());
			
			if (attack_distance < 0) {
				add_potential(other_unit, 1.0, max_distance + 24);
			} else {
				int min_distance = weapon_min_range(other_unit);
				if (min_distance > 0) {
					int distance = unit_->getDistance(other_unit);
					if (distance - min_distance < max_distance - distance) {
						add_potential(other_unit->getPosition(), -1.0);
					} else {
						add_potential(other_unit, 1.0, attack_distance - 4);
					}
				} else {
					add_potential(other_unit, 1.0, attack_distance - 4);
				}
			}
		}
	}
}
