Skip to content

Reproducing the validation accuracy vs learning rates curve on ResNet #67

@liulei277

Description

@liulei277

Hello!
We tried to reproduce the experiment in your paper (Figure 16, ResNet on CIFAR-10 for different widths (compared to a base network).
We made some modifications to examples/ResNet/main.py:

for width_mult in [0.5, 1.0, 2.0, ]:
        for log2lr in np.linspace(-3, 0, 7): 
             net = getattr(resnet, args.arch)(wm=width_mult)
             ...
             if args.optimizer == 'musgd':
                 optimizer = MuSGD(net.parameters(), lr=2**log2lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
            ...

And we ran the following commands:

# mup
python main.py --load_base_shapes resnet18.bsh

Then we got the following picture:
image

Is there anything wrong in our implementation? Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions