package org.bwapi.proxy;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.util.HashSet;
import java.util.Set;

import org.bwapi.proxy.messages.BasicTypes.BoxedBoolean;
import org.bwapi.proxy.messages.GameMessages.StaticGameData;
import org.bwapi.proxy.messages.GameMessages.UnitId;
import org.bwapi.proxy.messages.Messages.FrameCommands;
import org.bwapi.proxy.messages.Messages.FrameMessage;
import org.bwapi.proxy.messages.Messages.PlayerText;
import org.bwapi.proxy.model.Bwta;
import org.bwapi.proxy.model.Game;
import org.bwapi.proxy.model.Player;
import org.bwapi.proxy.model.Position;
import org.bwapi.proxy.model.ROUnit;
import org.bwapi.proxy.model.Unit;
import org.bwapi.proxy.model.UnitType;
import org.bwapi.proxy.util.Pair;

import com.google.protobuf.Message;
import com.google.protobuf.AbstractMessage.Builder;

public class ProxyServer implements Runnable {
	/** port to start the server socket on */
	public static final int PORT_BASE = 12345;
	private final ProxyBotFactory factory;
	private final int port;
	private final String heartbeatFilename;


	public ProxyServer(ProxyBotFactory factory,int port, String heartbeatFilename) {
		this.port = port;
		this.factory = factory;
		this.heartbeatFilename = heartbeatFilename;
	}

	public static void main(String[] args) {
		ProxyBotFactory factory = getFactory(args);
		int port = extractPort(args.length > 0 ? args[0] : null);
		String heartbeatFilename = args.length > 1 ? args[1] : null;
		new ProxyServer(factory,port,heartbeatFilename).run();

	}

	public static int extractPort(String arg) {
	  int port = PORT_BASE;
		if(arg != null) {
			int extra = Integer.parseInt(arg);
			if(extra < 16) port += extra;
			else port = extra;
		}
	  return port;
  }

	private static ProxyBotFactory getFactory(String[] args) {
		if (args.length == 0) {
			throw new RuntimeException("Need to specify a ProxyBotFactory class as a command-line argument!");
		}
		try {
			return (ProxyBotFactory) Class.forName(args[0]).newInstance();
		} catch (Exception e) {
			System.out.println("Could not instantiate an object of class " + args[0]);
			e.printStackTrace();
			throw new RuntimeException(e);
		} 
	}

	/**
	 * Starts the ProxyBot.
	 * 
	 * A server socket is opened and waits for client connections.
	 */
	public void run() {
		ServerSocket serverSocket = null;
		try {
			serverSocket = new ServerSocket(port);
			serverSocket.setSoTimeout(2000);
			while (!Thread.interrupted()) {
				System.out.println("Waiting for client connection");
				while (!Thread.interrupted()) {
					try {
						Socket clientSocket = serverSocket.accept();
						System.out.println(clientSocket.isConnected());

						System.out.println("Client connected");
						runGame(clientSocket);
						return;
						//break;
					} catch (SocketTimeoutException ste) {
						continue;
					}
				}
			}
		} catch (Exception e) {
			e.printStackTrace();
		} finally {
			if (serverSocket != null)
				try {
					serverSocket.close();
				} catch (IOException e) {
					e.printStackTrace();
				}
		}
	}
	

	@SuppressWarnings("unchecked")
    private <T> T readMessage(InputStream input, Builder<?> b) throws IOException {
		b.mergeDelimitedFrom(input);
		return (T) b.build();
	}

	byte[] intToBytes(int i) {
		ByteBuffer bb = ByteBuffer.allocate(4);
		bb.putInt(i);
		return bb.array();
	}

	private void writeMessage(OutputStream output, Message m) throws IOException {
		output.write(intToBytes(m.getSerializedSize()));
		m.writeTo(output);
		output.flush();
	}

	/**
	 * Manages communication with StarCraft.
	 */
	@SuppressWarnings("unused")
  private void runGame(Socket socket) {
		long frameRead = 0l;
		long frameProcess = 0;
		long frameWrite = 0;
		long between = 0;
		double frameSize = 0.0;
		double responseSize = 0.0;
		long onFrameTime = 0;
		ProxyBot bot = null;
		try {
			long start = System.currentTimeMillis();
			// 1. get the initial game information
			InputStream input = new BufferedInputStream(socket.getInputStream(), 50 * 1024);
			OutputStream output = new BufferedOutputStream(socket.getOutputStream(), 50 * 1024);
			StaticGameData data = readMessage(input, StaticGameData.newBuilder());
			
			myPreviousVisibleUnitIds.clear();
			myCurrentVisibleUnitIds.clear();
			myCurrentFrameNum = 0;
			start = System.currentTimeMillis();
			Game g = Game.getInstance();
			g.init();
			g.readStaticGameData(data);
			start = System.currentTimeMillis();

			// send an ACK.
			writeMessage(output, BoxedBoolean.newBuilder().setIsWinner(false).build());
			start = System.currentTimeMillis();
			
			String botName = null;
		  String botCode = null;

			// 4. game updates
			boolean won = false;
			while (!Thread.interrupted()) {
				between += (System.currentTimeMillis() - start);
				start = System.currentTimeMillis();
				FrameMessage frame = readMessage(input, FrameMessage.newBuilder());
				frameSize += frame.getSerializedSize() / 1024.;
				frameRead += (System.currentTimeMillis() - start);
				start = System.currentTimeMillis();

				if (frame.hasTerrainInfo()) {
					Bwta.getInstance().setTerrainData(frame.getTerrainInfo());
				}
				g.readFrameMessage(frame);
				frameProcess += (System.currentTimeMillis() - start);
				start = System.currentTimeMillis();
				

				
				if(bot == null) {
					bot = factory.getBot(g);
					Pair<String,String> pair = inferBotName(bot.getClass());
					if(pair != null) {
						botName = pair.getFirst();
						botCode = pair.getSecond();
					}
					bot.onStart();
				}
			
				processCallbacks(frame, bot, g);

				if (frame.getGameover()) {
					won = frame.hasIsWinner() && frame.getIsWinner();
					
					// TODO: Remove debugging.
					//System.out.println(g.self().getName() + " has " + (won? "won.":"lost."));
					
					bot.onEnd(won);
					break;
				}

				
				bot.onFrame();
				if (heartbeatFilename != null && myCurrentFrameNum % 100 == 0) {
					PrintWriter pw = null;
					try {
						pw = new PrintWriter(heartbeatFilename);
						Player p = g.self();
						pw.println("Frame number: " + myCurrentFrameNum);
						pw.println("Supply: " + p.supplyUsed() + "/" + p.supplyTotal());
						Set<UnitType> types = new HashSet<UnitType>();
						for (ROUnit u : p.getUnits()) {
							UnitType type = u.getType();
							if (!types.contains(type)) {
								types.add(type);
								pw.println(type.getName() + ": " + p.allUnitCount(type));
							}
						}
					} catch (Exception e) {
						System.out.println("Couldn't write game status to file " + heartbeatFilename);
					} finally {
						if (pw != null) pw.close();
					}
				}
				onFrameTime += (System.currentTimeMillis() - start);

				start = System.currentTimeMillis();

				FrameCommands.Builder cb = g.flushCommands();
				if(botName != null) {
					cb.setBotName(botName);
					cb.setBotCode(botCode);
				}
				FrameCommands commands = cb.build();
				writeMessage(output, commands);
				responseSize += commands.getSerializedSize() / 1024.;
				frameWrite += (System.currentTimeMillis() - start);
				output.flush();
				
				myCurrentFrameNum++;
				
				start = System.currentTimeMillis();
				
			//	if(g.getFrameCount() % 200 == 19) {
		//			System.out.println("Frame:");
		//		  System.out.println("frameRead: " + frameRead * 1.0 /g.getFrameCount());
		//		  System.out.println("frameProcess: " + frameProcess * 1.0 /g.getFrameCount());
		//		  System.out.println("frameWrite: " + frameWrite * 1.0 /g.getFrameCount());
		//		  System.out.println("frameSize: " + frameSize * 1.0 /g.getFrameCount());
		//		  System.out.println("responseSize: " + responseSize * 1.0 /g.getFrameCount());
	//			  System.out.println("onFrameTime: " + onFrameTime * 1.0 /g.getFrameCount());
		//		  System.out.println();
		//		}


			}

			if (Thread.interrupted()) {
				System.out.println("Thread Interrupted.");
				return;
			}

			// wait for game to terminate

			bot = null;
			System.out.println("Game Ended");
		} catch (SocketException e) {
			System.out.println("StarCraft has disconnected");
			if (bot != null)
				bot.onDroppedConnection();
		} catch (Exception e) {
			e.printStackTrace();
		} finally {
		}
	}
	
	private Pair<String,String> inferBotName(Class<?> clss) {
	  InputStream strm = clss.getResourceAsStream("/proxy-info.txt");
	  if(strm ==null) return null;
	  try {
	  	BufferedReader reader = new BufferedReader(new InputStreamReader(strm));
			String name = reader.readLine().trim();
			String code = reader.readLine().trim();
	  	strm.close();
	  	return Pair.makePair(name, code);
	  } catch(Exception e) {
	  	return null;
	  }
	  
  }

	private void processCallbacks(FrameMessage frame, ProxyBot bot, Game g) {
		for (String sentText : frame.getSentTextList()) {
			bot.onSendText(sentText);
		}
		for (PlayerText pt : frame.getReceivedTextsList()) {
			bot.onReceiveText(g.getPlayer(pt.getPlayer()), pt.getText());
		}
		for (int id : frame.getLeftplayersList()) {	
			bot.onPlayerLeft(g.getPlayer(id));
		}
		
		processUnitCallbacksByProxy(frame, bot, g);

		if (frame.hasNukeDetect()) {
			bot.onNukeDetect(new Position(frame.getNukeDetect().getX(), frame.getNukeDetect().getY()));
		}
	}
	
	private void processUnitCallbacksByProxy(FrameMessage frame, ProxyBot bot, Game game) {
		// Creation occurs before showing.
		// This is unreliable.
		for (UnitId u : frame.getCreatedUnitsList()) {
			bot.onUnitCreate(new Unit(u.getId()));
		}
		for (UnitId u : frame.getShownUnitsList()) {
			bot.onUnitShow(new Unit(u.getId()));
		}
		// Renegades occur before morphs (for Infested Command Centers).
		for (UnitId u : frame.getRenegadedUnitsList()) {
			bot.onUnitRenegade(new Unit(u.getId()));
		}
		for (UnitId u : frame.getMorphedUnitsList()) {
			bot.onUnitMorph(new Unit(u.getId()));
		}
		// When a unit gets killed, hiding comes before being destroyed.
		for (UnitId u : frame.getHiddenUnitsList()) {
			bot.onUnitHide(new Unit(u.getId()));
		}
		// This is unreliable.
		for (UnitId u : frame.getDestroyedUnitsList()) {
			bot.onUnitDestroy(new Unit(u.getId()));
		}
	}
	
	private int myCurrentFrameNum = 0;
	private final Set<Integer> myPreviousVisibleUnitIds = new HashSet<Integer>();
	private final Set<Integer> myCurrentVisibleUnitIds = new HashSet<Integer>();
	

}
