-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathValueIteration.py
More file actions
67 lines (58 loc) · 2.39 KB
/
Copy pathValueIteration.py
File metadata and controls
67 lines (58 loc) · 2.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import numpy as np
import matplotlib.pyplot as plt
class ValueIteration:
def __init__(self, reward_function, transition_model, gamma):
self.num_states = transition_model.shape[0]
self.num_actions = transition_model.shape[1]
self.reward_function = np.nan_to_num(reward_function)
self.transition_model = transition_model
self.gamma = gamma
self.values = np.zeros(self.num_states)
self.policy = None
def one_iteration(self):
delta = 0
for s in range(self.num_states):
temp = self.values[s]
v_list = np.zeros(self.num_actions)
for a in range(self.num_actions):
p = self.transition_model[s, a]
v_list[a] = self.reward_function[s] + self.gamma * np.sum(p * self.values)
self.values[s] = max(v_list)
delta = max(delta, abs(temp - self.values[s]))
return delta
def get_policy(self):
pi = np.ones(self.num_states) * -1
for s in range(self.num_states):
v_list = np.zeros(self.num_actions)
for a in range(self.num_actions):
p = self.transition_model[s, a]
v_list[a] = self.reward_function[s] + self.gamma * np.sum(p * self.values)
max_index = []
max_val = np.max(v_list)
for a in range(self.num_actions):
if v_list[a] == max_val:
max_index.append(a)
pi[s] = np.random.choice(max_index)
return pi.astype(int)
def train(self, tol=1e-3, plot=True):
epoch = 0
delta = self.one_iteration()
delta_history = [delta]
while delta > tol:
epoch += 1
delta = self.one_iteration()
delta_history.append(delta)
if delta < tol:
break
self.policy = self.get_policy()
# print(f'# iterations of policy improvement: {len(delta_history)}')
# print(f'delta = {delta_history}')
if plot is True:
fig, ax = plt.subplots(1, 1, figsize=(3, 2), dpi=200)
ax.plot(np.arange(len(delta_history)) + 1, delta_history, marker='o', markersize=4,
alpha=0.7, color='#2ca02c', label=r'$\gamma= $' + f'{self.gamma}')
ax.set_xlabel('Iteration')
ax.set_ylabel('Delta')
ax.legend()
plt.tight_layout()
plt.show()