Skip to content

topKgate loss issues #178

@powermano

Description

@powermano

We have calculated the loss of the gate, but does this have any effect on training? Where is this Loss used?

 logits = self.wg(input) #dim: [bxs, num_experts]
        if self.k == 1:
            self.loss, self.gate_log, gates1_s, dispatch_mask, retval = top1gating(
                    logits,
                    self.capacity_factor if self.training else self.eval_capacity_factor,
                    is_expert_slicing=self.is_expert_slicing,
                    fp16_mode=self.fp16_mode,
                    nonpadding=nonpadding,
                    logits_gumbel=self.logits_gumbel if self.training else 0,
                    token_drop_type=self.token_drop_type,
                    straight_through=self.straight_through,
                    straight_through_temperature=self.straight_through_temperature,
                    balance_ratio=self.balance_ratio,
                    gate_log_req=self.gate_log_req,
                    lid=lid,
                    tutel_cumsum_sub_one=self.tutel_cumsum_sub_one,
                )
            return gates1_s, dispatch_mask, retval

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions