from settings import *

# the class we will use to store the map, and make calls to path finding
class Grid:
    # set up all the default values for the frid and read in the map from a given file
    def __init__(self, filename):
        self.__values  = []         # rewards[row][col] = current value estimate for state (row,col)
        self.__rewards = []         # rewards[row][col] = reward obtained when state (row,col) reached
        self.__grid    = []         # grid[row][col]: 0 = WALKABLE, 1 = BLOCKED, 2 = TERMINAL
        self.__policy  = []         # policy[row][col][action] = probability of taking LEGAL_ACTIONS[action] at state (row,col)
        self.__rows    = 0          # number of rows in the grid
        self.__cols    = 0          # number of columns in the grid
        self.__load_data(filename)  # load the grid data from a given file
        self.__set_initial_values() # set the initial value estimate for the state
        self.__set_initial_policy() # set the initial policy estimate for the state

    def rows(self):                 return self.__rows
    def cols(self):                 return self.__cols
    def get_values(self):           return self.__values
    def get_state(self, r, c):      return self.__grid[r][c]
    def get_value(self, r, c):      return self.__values[r][c]
    def get_reward(self, r, c):     return self.__rewards[r][c]
    def get_policy(self, r, c):     return self.__policy[r][c]
    def get_min_value(self):        return min([min(col) for col in self.__values])
    def get_max_value(self):        return max([max(col) for col in self.__values])

    # loads the grid data from a given file name
    def __load_data(self, filename):
        # turn each line in the map file into a list of integers
        f = open(filename, 'r')
        for line in f:
            self.__grid.append([])
            self.__rewards.append([])
            l = line.strip().split(",")
            for c in l:
                c = c.strip()
                if c == 'X':    
                    self.__grid[-1].append(STATE_BLOCKED)
                    self.__rewards[-1].append(0)
                elif c == 'T':  
                    self.__grid[-1].append(STATE_TERMINAL)
                    self.__rewards[-1].append(0)
                else:       
                    self.__grid[-1].append(STATE_WALKABLE)    
                    self.__rewards[-1].append(0 if 'X' in c else float(c))
        # set the number of rows and columns of the file
        self.__rows, self.__cols = len(self.__grid), len(self.__grid[0])
        
    # sets the initial value estimate of the state so that each state has value 0
    def __set_initial_values(self):
        self.__values = [[0]*self.cols() for i in range(self.rows())]

    # sets the initial equiprobable policy for all states in the grid
    # you can use this as a template for how you implement the update_policy function below
    def __set_initial_policy(self):
        # our policy is a 3D array indexed by [row][col][action_index] where action_index is the index of LEGAL_ACTIONS
        initial_policy = [[[]]*self.cols() for i in range(self.rows())]
        # iterate through every row, col in the grid, setting the policy for that state
        for r in range(self.rows()):
            for c in range(self.cols()):
                # we have a null policy for the goal state because it's a terminal state, we can't move from it
                if self.get_state(r,c) != STATE_WALKABLE:
                    initial_policy[r][c] = [0]*len(LEGAL_ACTIONS)
                    continue
                # for every non-terminal state, set an equiprobable policy of moving in a legal direction
                # here, 'legal' will be an array of 1s and 0s, 1 indicating that LEGAL_ACTIONS[i] is legal
                legal = [1 if self.__is_legal_action(r, c, action) else 0 for action in LEGAL_ACTIONS]
                # we can sum the binary array to get the number of legal actions at this state
                num_legal = sum(legal)
                # so now the equiprobable policy is just dividing each element of the binary array by the number of actions
                state_policy = [i/num_legal for i in legal]
                # set the current policy
                initial_policy[r][c] = state_policy
        # set the class policy to this initial policy we just created
        self.__policy = initial_policy

    # check whether we can make a action from a given state
    def __is_legal_action(self, row, col, action):
        # check if the action will place us out of bounds
        new_row, new_col = row + action[0], col + action[1]
        # return false if the new row, col is off of the grid
        if new_row < 0 or new_col < 0 or new_row >= self.rows() or new_col >= self.cols(): return False
        # it's a legal action if the resulting state is 0 (not blocked)
        return self.get_state(new_row, new_col) != STATE_BLOCKED

    # returns a binary array of length len(LEGAL_ACTIONS)
    # array[i] == 1 if LEGAL_ACTIONS[i] is legal, and 0 if it is not
    # see above 
    def __get_action_legality(self, row, col):
        return [1 if self.__is_legal_action(row, col, action) else 0 for action in LEGAL_ACTIONS]

    # update the value estimated via dynamic programming
    def update_values(self):
        # create a new array initialized to zero to hold the new estimates
        new_values = [[0]*self.cols() for i in range(self.rows())]
        # calculate a new value estimate for each state
        for r in range(self.rows()):
            for c in range(self.cols()):
                # never update the value of the goal node (reward = 1.0 in our case)
                if self.get_state(r, c) != STATE_WALKABLE: continue
                # determine the legal actions from this state
                legal = self.__get_action_legality(r, c)
                for a in range(len(legal)):
                    # if this isn't a legal action, the new value doesn't matter
                    if legal[a] == 0: continue
                    # calculate the next state based on this action
                    next_row, next_col = r + LEGAL_ACTIONS[a][0], c + LEGAL_ACTIONS[a][1]
                    # the new value of this state is calculated by the Bellman equation
                    probability = self.get_policy(r, c)[a]
                    reward = self.get_reward(r, c)
                    #if a >= 4: reward *= 1.41
                    new_values[r][c] += probability * (reward + RL_GAMMA * self.get_value(next_row, next_col))
        # set the current value estimate to the one we just calculated
        self.__values = new_values

    # update the current policy via greedy selection
    def update_policy(self):
        # calculate a new policy for every state
        for r in range(self.rows()):
            for c in range(self.cols()):
                # if this is a goal state (reward>=1) then we can't have a policy since it's terminal
                if self.get_state(r, c) != STATE_WALKABLE: continue
                max_value, num_max = -10000000, 1
                values = [0] * len(LEGAL_ACTIONS)
                # determine the legal actions from this state
                legal = self.__get_action_legality(r, c)
                for a in range(len(legal)):
                    # if this action isn't legal, set the new policy to zero
                    if legal[a] == 0: continue
                    # the state we will arrive at if we do this action
                    next_row, next_col = r + LEGAL_ACTIONS[a][0], c + LEGAL_ACTIONS[a][1]
                    values[a] = self.get_value(next_row, next_col)
                    if len(values) == 0 or values[a] > max_value:
                        max_value = values[a]
                        num_max = 1
                    elif values[a] == max_value:
                        num_max += 1
                # for each possible action, greedily choose the maximum values and update the policy
                for a in range(len(legal)):
                    self.__policy[r][c][a] = (1.0/num_max) if (legal[a] and values[a] == max_value) else 0.0
