Skip to content

More options of input/output types in coord_check #80

@francois-rozet

Description

@francois-rozet

Hello 👋

I am having issues with using mup.get_coord_data because some of my modules return dataclass objects. Currently only, dict, list, tuple and tensors are supported. It would be great, and fairly easy, to also support dataclasses.

I think that the only code to modify would be

mup/mup/coord_check.py

Lines 129 to 148 in 1981497

def get_stat(d, x, fdict):
if isinstance(x, (tuple, list)):
for i, _x in enumerate(x):
_d = copy(d)
_d['module'] += f'[{i}]'
get_stat(_d, _x, fdict)
elif isinstance(x, dict):
for name, _x in x.items():
_d = copy(d)
_d['module'] += f'[{name}]'
get_stat(_d, _x, fdict)
elif isinstance(x, torch.Tensor):
_d = copy(d)
for fname, f in fdict.items():
_d[fname] = f(x).item()
records.append(_d)
elif x is None:
pass
else:
raise NotImplementedError(f'Unexpected output type: {type(x)}')

I can do a PR for that.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions