package undermind;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.bwapi.proxy.model.Color;
import org.bwapi.proxy.model.Game;
import org.bwapi.proxy.model.Position;
import org.bwapi.proxy.model.ROUnit;
import org.bwapi.proxy.model.TilePosition;
import org.bwapi.proxy.model.Unit;

import edu.berkeley.nlp.starcraft.util.FastPriorityQueue;


public class AStarSearch {
	public List<TilePosition> chokeBuffer;
	long time1 = 0;
	Game mGame;
	int mapheight;
	int mapwidth;
	static final double edgePenalty = 10.0;
	public Set<TilePosition> grid;
	public int walkableTolerance = 4;
	
	Map<TilePosition, Integer> ttdm; // Tile to distance (from nearest unwalkable tile) map
	
	public AStarSearch() {
		mGame = Game.getInstance();
		mapheight = mGame.getMapHeight();
		mapwidth = mGame.getMapWidth();
		ttdm = new HashMap<TilePosition, Integer>();
		preprocessMap();
	}
	
	class Node {
		Node parent;
		TilePosition state;
		double backCost;
		double heuristic;
		
		public Node(Node parent, TilePosition state, double backCost, double heuristic) {
			super();
			this.parent = parent;
			this.state = state;
			this.backCost = backCost;
			this.heuristic = heuristic;
		}
	}
	
	
	public void preprocessMap() {
		Set<TilePosition> seenTile = new HashSet<TilePosition>();
		LinkedList<TilePosition> fringe = new LinkedList<TilePosition>();
		
		//Initialize fringe with all unwalkable tiles. 
		for (int x = 0; x < mapwidth; x++) {
			for (int y = 0; y < mapheight; y++) {
				TilePosition newpos = new TilePosition(x,y);
				if (!isWalkable(newpos)) {
					fringe.add(newpos);
					seenTile.add(newpos);
				}
			}
		}
		
		//initialize distances
		for (TilePosition tp : fringe){
			ttdm.put(tp, 0);
		}
		
		
		// This is a little rough around the edges --  not super correct because of diagonal successors, but should do the trick.
		while (!fringe.isEmpty()) {
			TilePosition tp = fringe.removeFirst();
			TilePosition[] successors = fastSuccessors(tp);
			for (TilePosition succ : successors) {
				if (succ == null || seenTile.contains(succ)) continue;
				seenTile.add(succ);
				fringe.add(succ);
				ttdm.put(succ, 1+ttdm.get(tp));
			}
		}
	}
	
	
	public List<TilePosition> getPath(TilePosition start, TilePosition end) {
		time1 = System.currentTimeMillis();
		Node node  = search(start, end);
		if (node == null) return null;
		List<TilePosition> path = new ArrayList<TilePosition>();
		while (node != null) {
			path.add(node.state);
			node = node.parent;
		}
		Collections.reverse(path);
		return path;
	}
	
	
	
	
	public List<TilePosition> getPath2(TilePosition start, TilePosition end) {
		time1 = System.currentTimeMillis();
		Node node  = search2(start, end);
		if (node == null) return null;
		List<TilePosition> path = new ArrayList<TilePosition>();
		while (node != null) {
			path.add(node.state);
			node = node.parent;
		}
		Collections.reverse(path);
		return path;
	}
	
	private Node search(TilePosition start, TilePosition end) {
		FastPriorityQueue<Node> fringe = new FastPriorityQueue<Node>();
		byte[][] gridArray = new byte[Game.getInstance().getMapWidth()][Game.getInstance().getMapHeight()];
		Node startNode = new Node(null, start, 0, computeHeuristic(start, end));
		enqueueNode(startNode, fringe);
		
		int endx = end.x();
		int endy = end.y();
		
		
		while (! fringe.isEmpty()) {
			long time2 = System.currentTimeMillis();
			if (time2 - time1 > 2000) {
				return null;
			}
			Node node = fringe.next();
			TilePosition state = node.state;
			int x = state.x();
			int y = state.y();
			
			if (gridArray[x][y] != 0) continue;
			if (x == endx && y == endy) {
				return node;
			}
			gridArray[x][y] = 1;
			TilePosition[] successors = fastSuccessors(state);
			for (TilePosition pos : successors) {
				if (pos == null) continue;
				int x2 = pos.x();
				int y2 = pos.y();
				if (gridArray[x2][y2] != 0) continue;
				
				Node childNode = new Node(node, pos, node.backCost + transitionCost(state, pos), computeHeuristic(pos, end));
				fringe.setPriority(childNode, -(childNode.backCost*1 + childNode.heuristic));
			}
		}
		return null;
	}
	

	private Node search2(TilePosition start, TilePosition end) {
		FastPriorityQueue<Node> fringe = new FastPriorityQueue<Node>();
		byte[][] gridArray = new byte[Game.getInstance().getMapWidth()][Game.getInstance().getMapHeight()];
		Node startNode = new Node(null, start, 0, computeHeuristic(start, end));
		enqueueNode(startNode, fringe);
		
		int endx = end.x();
		int endy = end.y();
		
		
		while (! fringe.isEmpty()) {
			long time2 = System.currentTimeMillis();
			if (time2 - time1 > 2000) {
				return null;
			}
			Node node = fringe.next();
			TilePosition state = node.state;
			int x = state.x();
			int y = state.y();
			
			if (gridArray[x][y] != 0) continue;
			if (x == endx && y == endy) {
				return node;
			}
			gridArray[x][y] = 1;
			TilePosition[] successors = fastSuccessors2(state);
			for (TilePosition pos : successors) {
				if (pos == null) continue;
				int x2 = pos.x();
				int y2 = pos.y();
				if (gridArray[x2][y2] != 0) continue;
				
				Node childNode = new Node(node, pos, node.backCost + transitionCost(state, pos), computeHeuristic(pos, end));
				fringe.setPriority(childNode, -(childNode.backCost*1 + childNode.heuristic));
			}
		}
		return null;
	}
	
	
	private void enqueueNode(Node node, FastPriorityQueue<Node> fringe) {
		//System.out.println("PRIORITY SET " + (node.backCost + node.heuristic) + " " + node.state);
		fringe.setPriority(node, -(node.backCost*1 + node.heuristic));
	}
	
	private double computeHeuristic(TilePosition state, TilePosition end) {
		
		int xd = state.x() - end.x();
		if (xd < 0) xd = -xd;
		int yd = state.y() - end.y();
		if (yd < 0) yd = -yd;

		double r = (xd+yd)*0.75;
		return r;
	}
	
	

	public double transitionCost(TilePosition state, TilePosition pos) {
		double cost = 0.0;
		if (state.x() == pos.x()) {
			if (state.y() == pos.y()) {
				
			}
			else cost += 1;
		}
		else {
			if (state.y() == pos.y()) {
				cost += 1;
			}
			else cost += 1.5;
		}
		
		/* Penalty for being close to edges */
		cost += edgePenalty/ttdm.get(pos);
		
		return cost;
	}
	
	public TilePosition[] fastSuccessors(TilePosition state) {
		
		TilePosition[] res = new TilePosition[8];

		
		int k = 0;
		int x = state.x();
		int y = state.y();
		
		
		for (int dx = -1; dx <= 1; ++dx) {
			for (int dy = -1; dy <= 1; ++dy) {
				if (dx == 0 && dy == 0) continue;
				/*
				if (Game.getInstance().unitsOnTile(new TilePosition(x+dx, y+dy)).size() != 0) {
					Set s = Game.getInstance().unitsOnTile(new TilePosition(x+dx, y+dy));
					continue;
				} */
				TilePosition newpos = new TilePosition(x+dx,y+dy);
				if (isWalkable(newpos)) {
					res[k] = newpos;
				}
				k++;
			}
		}
		return res;
		
	}
	
	public TilePosition[] fastSuccessors2(TilePosition state) {
		
		TilePosition[] res = new TilePosition[8];

		
		int k = 0;
		int x = state.x();
		int y = state.y();
		
		
		for (int dx = -1; dx <= 1; ++dx) {
			for (int dy = -1; dy <= 1; ++dy) {
				if (dx == 0 && dy == 0) continue;

				/*
				if (Game.getInstance().unitsOnTile(new TilePosition(x+dx, y+dy)).size() != 0) {
					Set s = Game.getInstance().unitsOnTile(new TilePosition(x+dx, y+dy));
					continue;
				} */
				TilePosition newpos = new TilePosition(x+dx,y+dy);
				int count = 0;
				for (ROUnit u : mGame.unitsOnTile(newpos)) {
					count++;
					if (u.getType().isBuilding())
						count += 1000;
				}
				if (count > 2)
					continue;
				if (isWalkable(newpos)) {
					res[k] = newpos;
				}
				k++;
			}
		}
		return res;
		
	}
	
	
	public double getGroundDistance(TilePosition s, TilePosition g) {
		return getPath(s,g).size();
	}
	
	public void drawPath(List<TilePosition> path) {
		for (TilePosition tp : path) {
			mGame.drawCircleMap(Position.centerOfTile(tp), 3, Color.WHITE, true);
		}
	}
	
	public void drawttdm() {
		for (Map.Entry<TilePosition, Integer> ti : ttdm.entrySet()) {
			mGame.drawTextMap(Position.centerOfTile(ti.getKey()), ti.getValue().toString());
		}
	}
	
	public boolean isWalkable(TilePosition tp) {
		for (int i = 0; i < 4; i++) {
			for (int j = 0; j < 4; j++) {
				if (!Game.getInstance().isWalkable((tp.x())*4+i, (tp.y())*4+j)) {
					return false;
				}
			}
		}
		if (tp.x() < 0 || tp.x() > mapwidth-1) return false;
		if (tp.y() < 0 || tp.y() > mapheight-1) return false;
		
		return true;
	}
	
	
	
	
	
	
	
	
	
	public List<TilePosition> getPath(TilePosition start, TilePosition end, List<TilePosition> banned) {
		time1 = System.currentTimeMillis();
		Node node  = search(start, end, banned);
		if (node == null) return null;
		List<TilePosition> path = new ArrayList<TilePosition>();
		while (node != null) {
			path.add(node.state);
			node = node.parent;
		}
		Collections.reverse(path);
		return path;
	}
	
	private Node search(TilePosition start, TilePosition end, List<TilePosition> banned) {
		FastPriorityQueue<Node> fringe = new FastPriorityQueue<Node>();
		byte[][] gridArray = new byte[Game.getInstance().getMapWidth()][Game.getInstance().getMapHeight()];
		Node startNode = new Node(null, start, 0, computeHeuristic(start, end));
		enqueueNode(startNode, fringe);
		
		int endx = end.x();
		int endy = end.y();
		
		
		while (! fringe.isEmpty()) {
			long time2 = System.currentTimeMillis();
			Node node = fringe.next();
			TilePosition state = node.state;
			int x = state.x();
			int y = state.y();
			
			if (gridArray[x][y] != 0) continue;
			if (x == endx && y == endy) {
				return node;
			}
			gridArray[x][y] = 1;
			TilePosition[] successors = fastSuccessors(state, banned);
			for (TilePosition pos : successors) {
				if (pos == null) continue;
				int x2 = pos.x();
				int y2 = pos.y();
				if (gridArray[x2][y2] != 0) continue;
				
				Node childNode = new Node(node, pos, node.backCost + transitionCost(state, pos), computeHeuristic(pos, end));
				fringe.setPriority(childNode, -(childNode.backCost*1 + childNode.heuristic));
			}
		}
		return null;
	}
	
	
	
	
	
	
	
	
	
	
	// WARNING: CURRENTLY SET UP ONLY FOR USE WITH WALLIN (uses fields of the class)
	public TilePosition[] fastSuccessors(TilePosition state, List<TilePosition> banned) {
		
		TilePosition[] res = new TilePosition[4];

		
		int k = 0;
		int x = state.x();
		int y = state.y();
		
		
		for (int dx = -1; dx <= 1; ++dx) {
			for (int dy = -1; dy <= 1; ++dy) {
				if (dx == 0 && dy == 0) continue;
				if ((dx != 0 && dy != 0)) continue;
				TilePosition newpos = new TilePosition(x+dx,y+dy);
				if (banned.contains(newpos) || !grid.contains(newpos)) continue;
				/*
				if (Game.getInstance().unitsOnTile(new TilePosition(x+dx, y+dy)).size() != 0) {
					Set s = Game.getInstance().unitsOnTile(new TilePosition(x+dx, y+dy));
					continue;
				} */
				if (isSuperWalkable(newpos)) {
					res[k] = newpos;
				}
				k++;
			}
		}
		return res;
		
	}
	
	public boolean isSuperWalkable(TilePosition tp) {
		if (chokeBuffer.contains(tp)) return true;
		for (int i = -walkableTolerance; i < walkableTolerance+4; i++) {
			for (int j = -walkableTolerance; j < walkableTolerance+4; j++) {
				if (!Game.getInstance().isWalkable((tp.x())*4+i, (tp.y())*4+j)) {
					return false;
				}
			}
		}
		if (tp.x() < 0 || tp.x() > mapwidth-1) return false;
		if (tp.y() < 0 || tp.y() > mapheight-1) return false;
		
		return true;
	}
	
	public void makeUnitPath(Unit u, TilePosition goal) {
		List<TilePosition> path = getPath2(u.getTilePosition(),goal);
		if (path != null) {
			if (path.size() > 3) u.move(path.get(3));
			else if (path.size() > 1) u.move(path.get(1));
			drawPath(path);
		}
		else u.move(goal);
	}

}
