diff --git a/maia3/models.py b/maia3/models.py index 3b23ed5..23e62c8 100644 --- a/maia3/models.py +++ b/maia3/models.py @@ -333,6 +333,8 @@ def __init__(self, cfg): self.promo_bias_proj = nn.Linear(cfg.head_hid_dim, 4, bias=False) # 4 promotion types: q, r, b, n + self.rank7_indices = [chess.square(file, 6) for file in range(8)] # squares 48-55 + self.rank8_indices = [chess.square(file, 7) for file in range(8)] # squares 56-63 def interpolate_elo(self, elos): @@ -375,26 +377,16 @@ def forward(self, tokens, self_elos, oppo_elos): scores_base = torch.einsum("bid,bjd->bij", sq_from, sq_to) / math.sqrt(self.cfg.head_hid_dim) scores_flat = scores_base.reshape(x.size(0), 64 * 64) # (B, 4096) - rank7_indices = [chess.square(file, 6) for file in range(8)] # squares 48-55 - rank8_indices = [chess.square(file, 7) for file in range(8)] # squares 56-63 - - rank8_features = sq_to[:, rank8_indices, :] # (B, 8, head_hid_dim) + rank8_features = sq_to[:, self.rank8_indices, :] # (B, 8, head_hid_dim) promo_biases = self.promo_bias_proj(rank8_features) * math.sqrt(self.cfg.head_hid_dim) # (B, 8, 4) for q,r,b,n - promotion_logits = [] - for from_file in range(8): # source file (a-h) - from_sq = rank7_indices[from_file] - for to_file in range(8): # target file (a-h) - to_sq = rank8_indices[to_file] - base_score = scores_base[:, from_sq, to_sq] # (B,) - for piece_idx in range(4): # q=0, r=1, b=2, n=3 - bias = promo_biases[:, to_file, piece_idx] # (B,) - promotion_logits.append((base_score + bias).unsqueeze(1)) - promotion_logits = torch.cat(promotion_logits, dim=1) # (B, 256) + base = scores_base[:, self.rank7_indices][:, :, self.rank8_indices] # (B, 8, 8) + promotion_logits = (base.unsqueeze(-1) + promo_biases.unsqueeze(1)).reshape(x.size(0), 256) # (B, 256) + logits_move = torch.cat([scores_flat, promotion_logits], dim=1) # (B, 4352) x = self.last_ln(x.mean(dim=1)) logits_value = self.fc_value(F.relu(self.fc_value_hid(x))) # (B, 3) logits_ponder = self.fc_ponder(F.relu(self.fc_ponder_hid(x))) # (B, 1) - return logits_move, logits_value, logits_ponder.squeeze(1) # (B, 4352), (B, 3), (B,) \ No newline at end of file + return logits_move, logits_value, logits_ponder.squeeze(1) # (B, 4352), (B, 3), (B,)