/*	-----------------------------------------------------------------------------
	M A A S C R A F T

	StarCraft: Brood War - Bot

	Author: Dennis Soemers
	Maastricht University
	-----------------------------------------------------------------------------
*/

#include <algorithm>

#include "CommonIncludes.h"

#include "ArrayUtils.h"
#include "Clustering.h"
#include "Distances.h"
#include "MapAnalyser.h"
#include "OpponentTracker.h"
#include "UnitUtils.h"
#include "VectorUtils.h"

#pragma warning( push )
#pragma warning( disable : 4018 )
#pragma warning( disable : 4244 )

using namespace BWAPI;
using namespace std;

/*
	Utility functions, only used in this file. Therefore not declared in Cluster.h!
*/

// Returns a pointer to an array of bools indicating whether or not certain Units in a given search space are within a given
// radius of each other.
//
// The resulting array can be used together with the indices at which 2 Units occur in the searchSpace vector to test whether 2 specific Units
// are within range of each other afterwards by passing the array and the indices to the isWithinDistance() function
bool* constructDistanceArray(const Unitset& searchSpace, const double radius, bool sameRegionOnly = false)
{
	OpponentTracker* opponentTracker = OpponentTracker::Instance();
	MapAnalyser* mapAnalyser = MapAnalyser::Instance();
	const int n = searchSpace.size();
	const double radiusSquared = radius*radius;

	// comparing first Unit to all others will require n elements. Then, second Unit to all others requires (n - 1) elements. Etc.
	// Last element compares last Unit to itself, requiring only 1 element.
	//
	// So we need n + (n - 1) + (n - 2) + ... + (1) elements = the (n)th Triangular Number ( http://en.wikipedia.org/wiki/Triangular_number )
	// = n * (n + 1) / 2		(guarantueed to be an integer result, since n * (n + 1) is multiplication of an even with an uneven number, resulting in even

	const int arr_size = n * (n + 1) / 2;

	bool* results = new bool[arr_size]();

	int high_index;
	int index = 0;
	for(int low_index = 0; low_index < n; ++low_index)
	{
		results[index] = true;			// unit should always trivially be within range of itself
		++index;

		Position pos1 = searchSpace[low_index]->getPosition();
		if(pos1 == Positions::Unknown)
			pos1 = opponentTracker->getLastPosition(searchSpace[low_index]);

		if(sameRegionOnly)
		{
			for(high_index = low_index + 1; high_index < n; ++high_index)
			{
				Position pos2 = searchSpace[high_index]->getPosition();
				if(pos2 == Positions::Unknown)
					pos2 = opponentTracker->getLastPosition(searchSpace[high_index]);

				results[index] = (mapAnalyser->positionsInSameRegion(pos1, pos2) &&
									Distances::withinSquaredDistanceInt(pos1.x, pos1.y, pos2.x, pos2.y, radiusSquared));
				++index;
			}
		}
		else
		{
			for(high_index = low_index + 1; high_index < n; ++high_index)
			{
				Position pos2 = searchSpace[high_index]->getPosition();
				if(pos2 == Positions::Unknown)
					pos2 = opponentTracker->getLastPosition(searchSpace[high_index]);

				results[index] = Distances::withinSquaredDistanceInt(pos1.x, pos1.y, pos2.x, pos2.y, radiusSquared);
				++index;
			}
		}
	}

	return results;
}

// Returns whether or not, given an array constructed using the above function with n data points,
// the data point with index index_1 was within range of the data point at index index_2
bool isWithinDistance(const bool arr[], int index_1, int index_2, const int n)
{
	if(index_1 > index_2)
		std::swap(index_1, index_2);

	// now we know for sure that index_1 <= index_2

	// find the index where comparisons using Unit at index_1 starts according to algorithm which seems to work out on paper:
	int index = 0;
	int start;
	int subtract_term = 1;

	while(index <= index_1)
	{
		subtract_term += (index - 1);
		start = (index * n) - subtract_term;

		++index;
	}

	return arr[start + index_2 - index_1];
}

// Returns a vector of all indices of Units in searchSpace which are within radius distance of the Unit with index unit_index according to distanceArray
// WILL include the unit with the given unit_index itself
vector<int> const getUnitsWithinDistance(const int unit_index, const Unitset& searchSpace, bool distanceArray[])
{
	const int searchSize = searchSpace.size();

	vector<int> results = vector<int>();

	for(int i = 0; i < searchSize; ++i)
	{
		if(isWithinDistance(distanceArray, unit_index, i, searchSize))
			results.push_back(i);
	}

	return results;
}

/*
	Implementation of Cluster.h

	Algorithm used is Density-Based Spatial Clustering of Applications with Noise (DBCSCAN)
	as described here: http://en.wikipedia.org/wiki/DBSCAN 
*/

Unitset Clustering::getClusterForUnit(const Unit root, const Unitset& searchSpace, 
									const double radius, const size_t clusterThreshold)
{
	const double squaredRadius = radius * radius;
	const unsigned searchSize = searchSpace.size();

	bool* distanceArray = constructDistanceArray(searchSpace, radius);

	// Possible statuses for Units:
	//	0 = default/unvisited
	//	1 = visited
	//	2 = in cluster (and therefore also visited)
	int* status = new int[searchSize];
	ArrayUtils::initializeIntArray(status, searchSize, 0);

	std::vector<Unitset> clusters = std::vector<Unitset>();

	// find index of root in searchSpace
	int rootIndex = -1;
	for(int i = 0; i < searchSpace.size(); ++i)
	{
		if(root == searchSpace[i])
		{
			rootIndex = i;
			break;
		}
	}

#ifdef MAASCRAFT_DEBUG
	if(rootIndex == -1)
	{
		LOG_WARNING("Clustering::getClusterForUnit() called with root not in searchSpace")
		delete[] distanceArray;
		delete[] status;
		return Unitset();
	}
#endif

	Unitset cluster;
	vector<int> neighbourPts = getUnitsWithinDistance(rootIndex, searchSpace, distanceArray);

	if(neighbourPts.size() >= clusterThreshold)
	{
		// expand cluster
		cluster.push_back(searchSpace[rootIndex]);
		status[rootIndex] = 2;

		for(size_t j = 0; j < neighbourPts.size(); ++j)
		{
			int neighbourIndex = neighbourPts[j];

			if(status[neighbourIndex] == 0)		// not visited
			{
				status[neighbourIndex] = 1;
				const vector<int> neighbourPtsPrime = getUnitsWithinDistance(neighbourIndex, searchSpace, distanceArray);
				if(neighbourPtsPrime.size() >= clusterThreshold)
					VectorUtils::append(neighbourPts, neighbourPtsPrime);
			}

			if(status[neighbourIndex] != 2)		// not in any cluster yet
			{
				status[neighbourIndex] = 2;
				cluster.push_back(searchSpace[neighbourIndex]);
			}
		}
	}

	delete[] distanceArray;
	delete[] status;

	return cluster;
}

vector<Unitset> Clustering::getClustersForUnits(const Unitset& searchSpace, const double radius, const size_t clusterThreshold)
{
	const double squaredRadius = radius * radius;
	const unsigned searchSize = searchSpace.size();

	bool* distanceArray = constructDistanceArray(searchSpace, radius);

	// Possible statuses for Units:
	//	0 = default/unvisited
	//	1 = visited
	//	2 = in cluster (and therefore also visited)
	int* status = new int[searchSize];
	ArrayUtils::initializeIntArray(status, searchSize, 0);

	std::vector<Unitset> clusters = std::vector<Unitset>();

	for(size_t i = 0; i < searchSize; ++i)
	{
		if(status[i] == 0)		// not visited
		{
			status[i] = 1;
			vector<int> neighbourPts = getUnitsWithinDistance(i, searchSpace, distanceArray);

			if(neighbourPts.size() >= clusterThreshold)
			{
				Unitset C;

				// expand cluster
				C.push_back(searchSpace[i]);
				status[i] = 2;

				for(size_t j = 0; j < neighbourPts.size(); ++j)
				{
					int neighbourIndex = neighbourPts[j];

					if(status[neighbourIndex] == 0)		// not visited
					{
						status[neighbourIndex] = 1;
						const vector<int> neighbourPtsPrime = getUnitsWithinDistance(neighbourIndex, searchSpace, distanceArray);
						if(neighbourPtsPrime.size() >= clusterThreshold)
							VectorUtils::append(neighbourPts, neighbourPtsPrime);
					}

					if(status[neighbourIndex] != 2)		// not in any cluster yet
					{
						status[neighbourIndex] = 2;
						C.push_back(searchSpace[neighbourIndex]);
					}
				}

				clusters.push_back(move(C));
			}
		}
	}

	delete[] distanceArray;
	delete[] status;

	return clusters;
}

vector<Unitset> Clustering::getClustersForUnitsSameRegion(const Unitset& searchSpace, const double radius, const size_t clusterThreshold)
{
	const double squaredRadius = radius * radius;
	const unsigned searchSize = searchSpace.size();

	bool* distanceArray = constructDistanceArray(searchSpace, radius, true);

	// Possible statuses for Units:
	//	0 = default/unvisited
	//	1 = visited
	//	2 = in cluster (and therefore also visited)
	int* status = new int[searchSize];
	ArrayUtils::initializeIntArray(status, searchSize, 0);

	std::vector<Unitset> clusters = std::vector<Unitset>();

	for(size_t i = 0; i < searchSize; ++i)
	{
		if(status[i] == 0)		// not visited
		{
			status[i] = 1;
			vector<int> neighbourPts = getUnitsWithinDistance(i, searchSpace, distanceArray);

			if(neighbourPts.size() >= clusterThreshold)
			{
				Unitset C;

				// expand cluster
				C.push_back(searchSpace[i]);
				status[i] = 2;

				for(size_t j = 0; j < neighbourPts.size(); ++j)
				{
					int neighbourIndex = neighbourPts[j];

					if(status[neighbourIndex] == 0)		// not visited
					{
						status[neighbourIndex] = 1;
						const vector<int> neighbourPtsPrime = getUnitsWithinDistance(neighbourIndex, searchSpace, distanceArray);
						if(neighbourPtsPrime.size() >= clusterThreshold)
							VectorUtils::append(neighbourPts, neighbourPtsPrime);
					}

					if(status[neighbourIndex] != 2)		// not in any cluster yet
					{
						status[neighbourIndex] = 2;
						C.push_back(searchSpace[neighbourIndex]);
					}
				}

				clusters.push_back(move(C));
			}
		}
	}

	delete[] distanceArray;
	delete[] status;

	return clusters;
}

#pragma warning( pop )