-
Notifications
You must be signed in to change notification settings - Fork 59
Expand file tree
/
Copy pathutils.py
More file actions
152 lines (125 loc) · 4.99 KB
/
Copy pathutils.py
File metadata and controls
152 lines (125 loc) · 4.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import numpy as np
import torch
import matplotlib.pyplot as plt
def one_hot(y, num_class):
return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1)
def DBindex(cl_data_file):
class_list = cl_data_file.keys()
cl_num = len(class_list)
cl_means = []
stds = []
DBs = []
for cl in class_list:
cl_means.append(np.mean(cl_data_file[cl], axis=0))
stds.append(np.sqrt(np.mean(np.sum(np.square(cl_data_file[cl] - cl_means[-1]), axis=1))))
mu_i = np.tile(np.expand_dims(np.array(cl_means), axis=0), (len(class_list), 1, 1))
mu_j = np.transpose(mu_i, (1, 0, 2))
mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis=2))
for i in range(cl_num):
DBs.append(np.max([(stds[i] + stds[j]) / mdists[i, j] for j in range(cl_num) if j != i]))
return np.mean(DBs)
def sparsity(cl_data_file):
class_list = cl_data_file.keys()
cl_sparsity = []
for cl in class_list:
cl_sparsity.append(np.mean([np.sum(x != 0) for x in cl_data_file[cl]]))
return np.mean(cl_sparsity)
"""
Files for plot figs of adaptation difficulty
"""
def read_attr_dists(trainloader, dataset):
if dataset == 'SUN':
print("attribute distance for SUN!")
attr_dists = trainloader.dataset.meta['attr_labels']
attr_dists_array = np.array(attr_dists).astype('float32')
attr_dists = torch.from_numpy(attr_dists_array)
base_labels = trainloader.dataset.cl_list
base_ind = np.unique(base_labels).tolist()
elif dataset == 'CUB':
print("attribute distance for CUB!")
filename = 'filelists/CUB/CUB_200_2011/masked_class_attribute_labels.txt'
attr_dists = []
with open(filename, 'r') as f:
for line in f.readlines():
line_split = line.strip().split(' ')
float_line = []
for str_num in line_split:
float_line.append(float(str_num))
attr_dists.append(float_line)
attr_dists_array = np.array(attr_dists)
attr_dists = torch.from_numpy(attr_dists_array)
base_ind = []
for i in range(200):
if i % 2 == 0:
base_ind.append(i)
elif dataset == 'AWA2':
print("attribute distance for AWA2!")
filename = 'filelists/AWA2/class_attribute_label.txt'
attr_dists = []
with open(filename, 'r') as f:
for line in f.readlines():
line_split = line.strip().split(' ')
float_line = []
for str_num in line_split:
float_line.append(float(str_num))
attr_dists.append(float_line)
attr_dists_array = np.array(attr_dists)
attr_dists = torch.from_numpy(attr_dists_array)
base_labels = trainloader.dataset.cl_list
base_ind = np.unique(base_labels).tolist()
else:
AssertionError("not implement!")
return attr_dists, base_ind
def get_attr_distance(trainloader, dataset):
attr_dists, base_ind = read_attr_dists(trainloader, dataset)
# class-agnostic or task-agnostic
# part_dists = _dists_check(part_dists)
# base_dists = part_dists[base_ind, :].mean(0) # (102,)
#
# all_cls_dists = part_dists
# base_cls_dists = part_dists[base_ind, :] # (100, 102)
# original
import random
base_cls_dists = []
sc_cls_lists = [random.sample(base_ind, 5) for _ in range(10000)]
for sc_cls in sc_cls_lists:
sc_dists = attr_dists[sc_cls, :]
base_cls_dists.append(sc_dists)
base_cls_dists = torch.stack(base_cls_dists, dim=0)
all_cls_dists = attr_dists
base_dists = base_cls_dists.mean(1) # (task_num, 102)
return all_cls_dists, base_dists, base_cls_dists
def interval_avg(acc_all, dist_all):
min_d = np.min(dist_all)
max_d = np.max(dist_all)
inr = (max_d - min_d) / 9
acc_inr = [0 for _ in range(9)]
dis_inr = [0 for _ in range(9)]
cout_inr = [0 for _ in range(9)]
for dis, acc in zip(dist_all, acc_all):
for i in range(9):
min_i = min_d + i * inr
max_i = min_d + (i + 1) * inr
if dis >= min_i and dis <= max_i:
acc_inr[i] += acc
dis_inr[i] += dis
cout_inr[i] += 1
acc_avg_inr, dis_avg_inr = [], []
for acc, dis, num in zip(acc_inr, dis_inr, cout_inr):
if num != 0:
acc_avg_inr.append(1.0 * acc / num)
dis_avg_inr.append(1.0 * dis / num)
return acc_avg_inr, dis_avg_inr
def plot_fig(acc_all, dist_all):
acc_avg_inr, dis_avg_inr = interval_avg(acc_all, dist_all)
print("acc_avg_inr:", acc_avg_inr)
print("dis_avg_inr:", dis_avg_inr)
plt.scatter(dist_all, acc_all)
plt.scatter(dis_avg_inr, acc_avg_inr, s=40, marker='x', c='red')
for x, y in zip(dis_avg_inr, acc_avg_inr):
plt.annotate("%.1f" % (y), xy=(x, y), xytext=(x - 0.005, y + 1.5), color='r', weight='heavy')
plt.plot(dis_avg_inr, acc_avg_inr, c='red')
# plt.show()
plt.savefig('dist_acc.pdf')
plt.close()