diff --git a/python/src/coreai_models/models/macos/mixtral.py b/python/src/coreai_models/models/macos/mixtral.py index 2ff0a77..aa7f1ca 100644 --- a/python/src/coreai_models/models/macos/mixtral.py +++ b/python/src/coreai_models/models/macos/mixtral.py @@ -30,15 +30,12 @@ def __init__(self, dim: int, hidden_dim: int, num_experts: int, top_k: int) -> N self.switch_mlp = SwitchGLU(dim, hidden_dim, num_experts) def forward(self, x: torch.Tensor) -> torch.Tensor: - gates = self.gate(x) - gates = torch.softmax(gates, dim=-1, dtype=torch.float32) + router_logits = self.gate(x).to(torch.float32) - active_experts_scores, active_experts_indices = torch.topk( - gates, self.top_k, dim=-1, largest=True + top_logits, active_experts_indices = torch.topk( + router_logits, self.top_k, dim=-1, largest=True ) - - active_experts_scores /= active_experts_scores.sum(dim=-1, keepdim=True) - active_experts_scores = active_experts_scores.to(x.dtype) + active_experts_scores = torch.softmax(top_logits, dim=-1).to(x.dtype) y_active_experts = self.switch_mlp(x, active_experts_indices) active_experts_scores = active_experts_scores.unsqueeze(-1) diff --git a/python/tests/test_model_units/test_models/test_macos_layer_counts/test_mixtral.py b/python/tests/test_model_units/test_models/test_macos_layer_counts/test_mixtral.py index 89e2d78..fccb5f5 100644 --- a/python/tests/test_model_units/test_models/test_macos_layer_counts/test_mixtral.py +++ b/python/tests/test_model_units/test_models/test_macos_layer_counts/test_mixtral.py @@ -126,7 +126,7 @@ "cos": 1, "decomposable.broadcasting_add": 5, "decomposable.broadcasting_batch_matmul": 7, - "decomposable.broadcasting_divide": 1, + "decomposable.broadcasting_divide": 0, "decomposable.broadcasting_mul": 11, "decomposable.broadcasting_sub": 1, "gather_along_axis": 4, @@ -135,7 +135,7 @@ "name": 21, "output": 6, "reduce_mean": 1, - "reduce_sum": 2, + "reduce_sum": 1, "reshape": 16, "rsqrt": 1, "silu": 1, @@ -168,7 +168,7 @@ "create_token": 1, "decomposable.broadcasting_add": 5, "decomposable.broadcasting_batch_matmul": 8, - "decomposable.broadcasting_divide": 1, + "decomposable.broadcasting_divide": 0, "decomposable.broadcasting_mul": 11, "decomposable.broadcasting_sub": 1, "gather_along_axis": 4, @@ -180,7 +180,7 @@ "output": 6, "read_handle": 4, "reduce_mean": 1, - "reduce_sum": 2, + "reduce_sum": 1, "reshape": 21, "rsqrt": 1, "silu": 1,