# Reward function
R = [-0.1, -0.1, 1, -0.1, -0.1, -0.05]

termStates = [2]

# Discount factor
gamma = 0.01

# Some constants
NORTH = 0
EAST = 1
SOUTH = 2
WEST = 3
DONOTHING = 4
actions = [NORTH, EAST, SOUTH, WEST, DONOTHING]
def actStr(x):
    if x == -1:
        return "-"
    actStr = ["North", "East", "South", "West", "Nothing"]
    return actStr[x]

states = range(6)

directions = [NORTH, EAST, SOUTH, WEST]

leftTurn = {NORTH: WEST, EAST: NORTH, SOUTH:EAST, WEST:SOUTH}
rightTurn = {NORTH: EAST, EAST: SOUTH, SOUTH:WEST, WEST:NORTH}

adjacency = [{NORTH:3, EAST:1}, {WEST:0, NORTH:4, EAST:2},
             {WEST:1, NORTH:5}, {SOUTH:0, EAST:4},
             {WEST:3, SOUTH:1, EAST:5}, {WEST:4, SOUTH:2}]

# Fill in adjacency with self-loops
for x, a in enumerate(adjacency):
    for dir in directions:
        if dir not in a:
            a[dir] = x

# Calculate transition probability:
def T(initialState, action, nextState):
    r = 0
    if action == DONOTHING:
        if nextState == initialState:
            return 1
        else:
            return 0
    else:
        if nextState == adjacency[initialState][action]:
            r += 0.9
        if nextState == adjacency[initialState][leftTurn[action]]:
            r += 0.05
        if nextState == adjacency[initialState][rightTurn[action]]:
            r += 0.05
        return r

def expectedUtility(t, initialState, action):
    eu = sum([T(initialState, action, nextState) * U[nextState][t]
              for nextState in states])
    EU[initialState][action].append(eu)
    return eu

def maxExpectedUtility(t, initialState):
    if initialState in termStates:
        MEU[initialState].append(R[initialState])
        BA[initialState].append(-1)
        for a in actions:
            EU[initialState][a].append("-")
        return R[initialState]
    else:
        utils = zip([expectedUtility(t, initialState, action)
                     for action in actions],
                    actions)
        MEU[initialState].append(max(utils)[0])
        BA[initialState].append(max(utils)[1])
        return max(utils)[0]

    
        
# Initialize utilities to zero
U = [[0] for x in states] 
EU = [ [ [] for x in actions] for y in states ]
BA = [ [] for y in states]
MEU = [ [] for y in states]

def iterStep(t):
    for s in states:
        if s in termStates:
            maxExpectedUtility(t, s)
            U[s].append(R[s])
        else:
            U[s].append( R[s] + gamma * maxExpectedUtility(t, s))

def fs(x):
    if isinstance(x,str):
        return x
    else:
        return "%f" % x

def printStep(t):
    for s in states:
#        print "\hline"
#        print s+1, "&", EU[s][NORTH][t], "&",  EU[s][EAST][t], "&", EU[s][SOUTH][t], "&", EU[s][WEST][t], "&", EU[s][DONOTHING][t], "&", actStr(BA[s][t]), "&", MEU[s][t], "&", U[s][t+1], "\\\\"
        print "%d & %d & %s & %s & %s & %s & %s & %s & %f & %f \\\\" % (t,
                                                                        s+1,
                                                                        fs(EU[s][NORTH][t]),
                                                                        fs(EU[s][EAST][t]),
                                                                        fs(EU[s][SOUTH][t]),
                                                                        fs(EU[s][WEST][t]),
                                                                        fs(EU[s][DONOTHING][t]),
                                                                        actStr(BA[s][t]),
                                                                        MEU[s][t],
                                                                        U[s][t+1])
    print "\hline"
        
1
EPSILON = 0.0000001
DELTAMAX = EPSILON * (1.0 - gamma) / (gamma)
print DELTAMAX
delta = 1e6
maxT = 1000
for t in range(maxT):
    if delta < DELTAMAX:
        break
    iterStep(t)
    printStep(t)
    delta = max([abs(U[s][t+1] - U[s][t]) for s in states])


#
#    return delta

