diff --git a/src/mouse_tracking/utils/writers.py b/src/mouse_tracking/utils/writers.py index 5f3102c..e48dd6a 100644 --- a/src/mouse_tracking/utils/writers.py +++ b/src/mouse_tracking/utils/writers.py @@ -594,6 +594,8 @@ def write_pose_clip( for key, attrs in all_attrs.items(): for cur_attr, data in attrs.items(): out_f[key].attrs.create(cur_attr, data) + if len(adjusted_clip_idxs) > 0: + out_f["poseest"].attrs.create("clip_start_frame", clip_idxs[0]) def downgrade_pose_file(pose_h5_path, disable_id: bool = False): diff --git a/tests/utils/writers/test_write_pose_clip.py b/tests/utils/writers/test_write_pose_clip.py index 74868e4..9c8ff28 100644 --- a/tests/utils/writers/test_write_pose_clip.py +++ b/tests/utils/writers/test_write_pose_clip.py @@ -865,3 +865,28 @@ def test_comprehensive_pose_file_structure(): for file_path in [in_pose_file, out_pose_file]: if os.path.exists(file_path): os.unlink(file_path) + + +def test_writes_clip_start_frame_attribute(): + """Test that clip_start_frame attribute is written to the output pose file.""" + with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp_in: + in_pose_file = tmp_in.name + with tempfile.NamedTemporaryFile(suffix=".h5", delete=False) as tmp_out: + out_pose_file = tmp_out.name + + try: + with h5py.File(in_pose_file, "w") as f: + poseest = f.create_group("poseest") + poseest.create_dataset( + "points", data=np.random.rand(20, 1, 12, 2).astype(np.float32) + ) + poseest.attrs["version"] = [3, 0] + + write_pose_clip(in_pose_file, out_pose_file, range(5, 15)) + + with h5py.File(out_pose_file, "r") as f: + assert f["poseest"].attrs["clip_start_frame"] == 5 + finally: + for path in [in_pose_file, out_pose_file]: + if os.path.exists(path): + os.unlink(path)