Skip to content

attMIL-regression breaks when adding cont_labels #15

@tsorz

Description

@tsorz

Hi!

When I try to add a list of cont_labels to marugoto.mil crossval, the program breaks.

python -m marugoto.mil crossval \
    --clini-excel $cliniPath \
    --slide-csv $slidePath  \
    --feature-dir$featurePath \
    --target-label $target \
    --output-path $outPath \
    --n_splits 3 \
    --cont_labels "PLT,INR,aPTT,BILIRUBIN,gGT,CHE"

The error I get is: "AttributeError: 'tuple' object has no attribute 'shape'" (full error message below)
The point at which the program breaks is in marugoto/mil/_mil.py, line 146:
batch = train_dl.one_batch()

Some additional behaviors I noticed:

  • Adding --cont_labels to the main marugoto branch works fine.
  • The ds variable created in marugoto/mil/data.py, line 139 is differently formatted between main and attMIL branch.
def _make_multi_input_dataset(
    *,
    bags: Sequence[Iterable[Path]],
    targets: Tuple[FunctionTransformer, Sequence[Any]],
    add_features: Iterable[Tuple[Any, Sequence[Any]]],
    bag_size: Optional[int] = None
) -> MapDataset:
    target_enc, targs = targets
    assert len(bags) == len(targs), \
        'number of bags and ground truths does not match!'
    for i, (_, vals) in enumerate(add_features):
        assert len(vals) == len(targs), \
            f'number of additional attributes #{i} and ground truths does not match!'

    bag_ds = BagDataset(bags, bag_size=bag_size)

    add_ds = MapDataset(
        _splat_concat,
        *[
            EncodedDataset(enc, vals)
            for enc, vals in add_features
        ])

    targ_ds = EncodedDataset(target_enc, targs)
    ############ !!! Different behavior spotted here !!!
    ds = MapDataset(
        _attach_add_to_bag_and_zip_with_targ,
        bag_ds,
        add_ds,
        targ_ds,
    ) 
    ############
    return ds
  • Indexing ds[0] in attMIL gives [(features_tensor, int_ninstances), [[np.array]], [tensor_target]], where the np.array contains raw values of the cont_labels features
    image

  • Indexing ds[0] in main gives (tensor_features, int_n_instances, tensor_target). Cont_labels are transformed and concatenated to tensor_features, resulting in shape n_instances * (n_features + n_cont_labels)
    image

Error Message:

Traceback (most recent call last):
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 179, in create_batch
try: return (fa_collate,fa_convert)self.prebatched
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in fa_collate
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in fa_collate
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in fa_collate
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 52, in
else type(t[0])([fa_collate(s) for s in zip(*t)]) if isinstance(b, Sequence)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 51, in fa_collate
return (default_collate(t) if isinstance(b, _collate_types)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
return collate(batch, collate_fn_map=default_collate_fn_map)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 120, in collate
return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py", line 170, in collate_numpy_array_fn
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found object

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/main.py", line 5, in
Fire({
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fire/core.py", line 466, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fire/core.py", line 681, in CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/helpers.py", line 359, in categorical_crossval

learn = _crossval_train(
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/helpers.py", line 420, in _crossval_train
learn = train(
File "/mnt/bulk/tsorznechay/LiverHVPG/Python_Modules/marugoto-regression/marugoto/mil/_mil.py", line 147, in train
batch = train_dl.one_batch()
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 187, in one_batch
with self.fake_l.no_multiproc(): res = first(self)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastcore/basics.py", line 660, in first
return next(x, None)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 127, in iter
for b in _loadersself.fake_l.num_workers==0:
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 671, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 43, in fetch
data = next(self.dataset_iter)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 138, in create_batches
yield from map(self.do_batch, self.chunkify(res))
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 183, in do_batch
def do_batch(self, b): return self.retain(self.create_batch(self.before_batch(b)), b)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 181, in create_batch
if not self.prebatched: collate_error(e,b)
File "/mnt/bulk/tsorznechay/LiverHVPG/y/envs/MaruReg/lib/python3.10/site-packages/fastai/data/load.py", line 75, in collate_error
if i == 0: shape_a, type_a = item[idx].shape, item[idx].class.name
AttributeError: 'tuple' object has no attribute 'shape'

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions