-
Notifications
You must be signed in to change notification settings - Fork 13k
feat: Time-to-Move sampling #13707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
feat: Time-to-Move sampling #13707
Changes from 9 commits
800bf84
c3cd2a4
0b7d560
b3a0665
f3aebfa
ae54d7a
b715186
8dd41ef
d56a093
de97192
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,42 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou | |
| destination[..., top:bottom, left:right] = source_portion + destination_portion | ||
| return destination | ||
|
|
||
| def convert_rgb_mask_to_latent_mask( | ||
| mask: torch.Tensor, | ||
| k: int, | ||
| spatial_downsample_h: int, | ||
| spatial_downsample_w: int | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Converts [T, H, W] mask to [T_latent, H_latent, W_latent]. | ||
| Handles non-square spatial downsampling. | ||
| """ | ||
| # 1. Temporal Sampling | ||
| # Select first frame and every k-th frame thereafter | ||
| mask0 = mask[0:1] | ||
| mask1 = mask[1::k] | ||
| sampled = torch.cat([mask0, mask1], dim=0) # [T_latent, H, W] | ||
|
Pizzawookiee marked this conversation as resolved.
|
||
|
|
||
| # 2. Prepare for Spatial Interpolation | ||
| # Shape: [Batch=1, Channels=1, Depth=T_latent, Height=H, Width=W] | ||
| sampled = sampled.unsqueeze(0).unsqueeze(0) | ||
|
|
||
| # 3. Calculate New Spatial Dimensions | ||
| h_latent = sampled.shape[-2] // spatial_downsample_h | ||
| w_latent = sampled.shape[-1] // spatial_downsample_w | ||
|
|
||
| # 4. Interpolate | ||
| # We maintain the temporal count (sampled.shape[2]) | ||
| # but resize H and W independently | ||
| pooled = torch.nn.functional.interpolate( | ||
| sampled, | ||
| size=(sampled.shape[2], h_latent, w_latent), | ||
| mode="nearest" | ||
| ) | ||
|
|
||
| # 5. Return to [T_latent, H_latent, W_latent] | ||
| return pooled.squeeze(0).squeeze(0) | ||
|
|
||
| class LatentCompositeMasked(IO.ComfyNode): | ||
| @classmethod | ||
| def define_schema(cls): | ||
|
|
@@ -73,8 +109,7 @@ def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.No | |
| return IO.NodeOutput(output) | ||
|
|
||
| composite = execute # TODO: remove | ||
|
|
||
|
|
||
|
|
||
| class ImageCompositeMasked(IO.ComfyNode): | ||
| @classmethod | ||
| def define_schema(cls): | ||
|
|
@@ -398,6 +433,28 @@ def execute(cls, mask, value) -> IO.NodeOutput: | |
|
|
||
| image_to_mask = execute # TODO: remove | ||
|
|
||
| class RGBMaskToLatentMask(IO.ComfyNode): | ||
|
Pizzawookiee marked this conversation as resolved.
|
||
| @classmethod | ||
| def define_schema(cls): | ||
| return IO.Schema( | ||
| node_id="RGBMasktoLatentMask", | ||
| search_aliases=["rgb mask to latent mask", "rgb mask", "latent mask"], | ||
| description="Converts an RGB mask to a latent-space mask for use with causal Video VAEs (e.g., Wan).", | ||
| category="latent", | ||
| inputs=[ | ||
| IO.Mask.Input("mask", optional=False), | ||
| IO.Vae.Input("vae", optional=False), | ||
| ], | ||
| outputs=[IO.Mask.Output()], | ||
| ) | ||
|
|
||
| @classmethod | ||
| def execute(cls, mask, vae) -> IO.NodeOutput: | ||
| # Ensure we work on a copy of the mask to remain non-destructive | ||
| mask_copy = mask.clone() | ||
| downscale_ratio = vae.downscale_ratio | ||
| k = (mask.shape[0] - 1) // (downscale_ratio[0](mask.shape[0]) - 1) if (downscale_ratio[0](mask.shape[0]) - 1) > 1 else 1 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The condition 🐛 Proposed fix- k = (mask.shape[0] - 1) // (downscale_ratio[0](mask.shape[0]) - 1) if (downscale_ratio[0](mask.shape[0]) - 1) > 1 else 1
+ t_latent = downscale_ratio[0](mask.shape[0])
+ k = (mask.shape[0] - 1) // (t_latent - 1) if (t_latent - 1) > 0 else 1The refactor also avoids calling the 🤖 Prompt for AI Agents |
||
| return IO.NodeOutput(convert_rgb_mask_to_latent_mask(mask_copy, k, spatial_downsample_h = downscale_ratio[1], spatial_downsample_w = downscale_ratio[2])) | ||
|
|
||
| # Mask Preview - original implement from | ||
| # https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81 | ||
|
|
@@ -439,6 +496,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]: | |
| FeatherMask, | ||
| GrowMask, | ||
| ThresholdMask, | ||
| RGBMaskToLatentMask, | ||
| MaskPreview, | ||
| ] | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.