Skip to content

cpu load checkpoint fixed  #4

Description

@fatalfeel

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def load_resnet():
resnet = torchmodels.resnet50(pretrained=True)

resnet.conv1 = nn.Conv2d(3, 64, 5, 1, 2, bias=False)
resnet = nn.Sequential(*list(resnet.children())[:-2])

checkpoint = T.load(pretrained_glimpsemodel, map_location=device)
resnet.load_state_dict(checkpoint["model_dict"])
# We fix the parameters of resnet and do not train it
for param in resnet.parameters():
    param.requires_grad = False

return get_cuda(resnet)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions