Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 46 additions & 61 deletions dain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,88 +7,73 @@
class DAIN_Layer(nn.Module):
def __init__(self, mode='adaptive_avg', mean_lr=0.00001, gate_lr=0.001, scale_lr=0.00001, input_dim=144):
super(DAIN_Layer, self).__init__()
assert mode in [None, 'avg', 'adaptive_avg', 'adaptive_scale', 'full'], f'Unsupported mode: {mode}!'\
'Use one of: None, "avg", "adaptive_avg", "adaptive_scale", "full". '
print("Mode = ", mode)

self.mode = mode
self.mean_lr = mean_lr
self.gate_lr = gate_lr
self.scale_lr = scale_lr

# Parameters for adaptive average
# Parameters for adaptive average; aka Dain(1)
self.mean_layer = nn.Linear(input_dim, input_dim, bias=False)
self.mean_layer.weight.data = torch.FloatTensor(data=np.eye(input_dim, input_dim))

# Parameters for adaptive std
self.scaling_layer = nn.Linear(input_dim, input_dim, bias=False)
self.scaling_layer.weight.data = torch.FloatTensor(data=np.eye(input_dim, input_dim))
# Parameters for adaptive scaling; Dain(1+2)
if mode == 'adaptive_scale' or mode == 'full':
self.scaling_layer = nn.Linear(input_dim, input_dim, bias=False)
self.scaling_layer.weight.data = torch.FloatTensor(data=np.eye(input_dim, input_dim))

# Parameters for adaptive scaling
self.gating_layer = nn.Linear(input_dim, input_dim)
# Parameters for adaptive gating; Dain(1+2+3)
if mode == 'full':
self.gating_layer = nn.Linear(input_dim, input_dim)

self.eps = 1e-8

def forward(self, x):
# Expecting (n_samples, dim, n_feature_vectors)

## Other methods:
# Nothing to normalize
if self.mode == None:
pass

return x
# Do simple average normalization
elif self.mode == 'avg':
avg = torch.mean(x, 2)
avg = avg.resize(avg.size(0), avg.size(1), 1)
x = x - avg

# Perform only the first step (adaptive averaging)
elif self.mode == 'adaptive_avg':
avg = torch.mean(x, 2)
adaptive_avg = self.mean_layer(avg)
adaptive_avg = adaptive_avg.resize(adaptive_avg.size(0), adaptive_avg.size(1), 1)
x = x - adaptive_avg

# Perform the first + second step (adaptive averaging + adaptive scaling )
elif self.mode == 'adaptive_scale':

# Step 1:
avg = torch.mean(x, 2)
adaptive_avg = self.mean_layer(avg)
adaptive_avg = adaptive_avg.resize(adaptive_avg.size(0), adaptive_avg.size(1), 1)
x = x - adaptive_avg

# Step 2:
std = torch.mean(x ** 2, 2)
std = torch.sqrt(std + self.eps)
adaptive_std = self.scaling_layer(std)
adaptive_std[adaptive_std <= self.eps] = 1

adaptive_std = adaptive_std.resize(adaptive_std.size(0), adaptive_std.size(1), 1)
x = x / (adaptive_std)

elif self.mode == 'full':

# Step 1:
avg = torch.mean(x, 2)
adaptive_avg = self.mean_layer(avg)
adaptive_avg = adaptive_avg.resize(adaptive_avg.size(0), adaptive_avg.size(1), 1)
x = x - adaptive_avg

# # Step 2:
std = torch.mean(x ** 2, 2)
std = torch.sqrt(std + self.eps)
adaptive_std = self.scaling_layer(std)
adaptive_std[adaptive_std <= self.eps] = 1

adaptive_std = adaptive_std.resize(adaptive_std.size(0), adaptive_std.size(1), 1)
x = x / adaptive_std

# Step 3:
avg = torch.mean(x, 2)
gate = F.sigmoid(self.gating_layer(avg))
gate = gate.resize(gate.size(0), gate.size(1), 1)
x = x * gate

else:
assert False

return x
return x

## DAIN:
# Perform the first step: adaptive averaging; DAIN(1)
# Step 1:
avg = torch.mean(x, 2)
adaptive_avg = self.mean_layer(avg)
adaptive_avg = adaptive_avg.resize(adaptive_avg.size(0), adaptive_avg.size(1), 1)
x = x - adaptive_avg
if self.mode == 'adaptive_avg':
return x

# Perform the second step: adaptive averaging + adaptive scaling; DAIN(1+2)
# Step 2:
std = torch.mean(x ** 2, 2)
std = torch.sqrt(std + self.eps)
adaptive_std = self.scaling_layer(std)
adaptive_std[adaptive_std <= self.eps] = 1

adaptive_std = adaptive_std.resize(adaptive_std.size(0), adaptive_std.size(1), 1)
x = x / adaptive_std
if self.mode == 'adaptive_scale':
return x

# Perform the third step: adaptuve avg + adative scale + gating; DAIN(1+2+3)
# Step 3:
avg = torch.mean(x, 2)
gate = F.sigmoid(self.gating_layer(avg))
gate = gate.resize(gate.size(0), gate.size(1), 1)
x = x * gate
if self.mode == 'full':
return x

assert False, "You fool! Should not reach here."
15 changes: 9 additions & 6 deletions train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
def lob_epoch_trainer(model, loader, lr=0.0001, optimizer=optim.RMSprop):
model.train()

model_optimizer = optimizer([
{'params': model.base.parameters()},
{'params': model.dean.mean_layer.parameters(), 'lr': lr * model.dean.mean_lr},
{'params': model.dean.scaling_layer.parameters(), 'lr': lr * model.dean.scale_lr},
{'params': model.dean.gating_layer.parameters(), 'lr': lr * model.dean.gate_lr},
], lr=lr)
dean_params = [{'params': model.base.parameters()}]
if model.dean.mode in ['adaptive_avg', 'adaptive_scale', 'full']:
dean_params.append({'params': model.dean.mean_layer.parameters(), 'lr': lr * model.dean.mean_lr})
if model.dean.mode in ['adaptive_scale', 'full']:
dean_params.append({'params': model.dean.scaling_layer.parameters(), 'lr': lr * model.dean.scale_lr})
if model.dean.mode == 'full':
dean_params.append({'params': model.dean.gating_layer.parameters(), 'lr': lr * model.dean.gate_lr})

model_optimizer = optimizer(dean_params, lr=lr)

criterion = CrossEntropyLoss()
train_loss, counter = 0, 0
Expand Down