forked from Bartopt/code4MRPL
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_fig.py
More file actions
40 lines (33 loc) · 1.24 KB
/
Copy pathplot_fig.py
File metadata and controls
40 lines (33 loc) · 1.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import matplotlib.pyplot as plt
import pickle
import os
fontSize = 20
def load_ret(expdir, ret_file):
with open(os.path.join(expdir, '{}.pkl'.format(ret_file)), 'rb') as f:
data = pickle.load(f)
return data
def PlotStd(x, data, color, label, xOffset=0, alpha=0.1):
# todo: using seaborn instead
m = np.mean(data, axis=0)
# std = np.std(data, axis=0)
# r1 = m + std
# r2 = m - std
# plt.plot(x + xOffset, m, color=color, linewidth=4, label=label)
# plt.fill_between(x + xOffset, r1, r2, color=color, alpha=alpha)
plt.plot(x + xOffset, m, color=color, linewidth=4, label=label)
# edit path
expdir1 = 'output/pih-meta/MRPL-2/eval_trajectories/' # directory to load data from
y1 = load_ret(expdir1, 'ret_demo1_acctextTrue_meanFalse').squeeze()
y2 = load_ret(expdir1, 'ret_demo0_acctextTrue_meanFalse').squeeze()
plt.rcParams['figure.figsize'] = (15, 10)
plt.rcParams.update({"font.size": fontSize})
plt.xlabel("Test-time rollout numbers", fontsize=fontSize+10)
plt.ylabel("Return", fontsize=fontSize+10)
x = np.arange(y1.shape[1])
PlotStd(x, y1, 'limegreen', 'MRPL', 1, alpha=0.1)
PlotStd(x, y2, 'red', 'MRPL-NoDemo', 1, alpha=0.1)
plt.legend()
plt.show()
# f = plt.gcf()
# f.savefig('tmp.png')