Skip to content

Aligning logits with labels through two shifts? #6

@Xlun

Description

@Xlun

在 main.py中数据准备时:

def collate_fn(examples, device):
    token_ids = torch.tensor(
        [example['token_ids'] for example in examples], device=device)
    return **{'input_ids': token_ids[:, :-1], 'labels': token_ids[:, 1:]}**

def train_chunk(.......):
..........
batch = collate_fn(
            examples=examples[i:i+per_device_batch_size], device=fabric.device)
input_ids, labels = batch['input_ids'], batch['labels']

在 modeling_llama.py 中loss计算时:

class LlamaForCausalLM(LlamaPreTrainedModel):
....................
        if labels is not None:
            # Shift so that tokens < n predict n
            **shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()**

为什么在模型数据sample输入时进行了预测和真实值之间的位移对齐,在模型中loss计算时还进行了一次位移对齐?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions