Skip to content
177 changes: 177 additions & 0 deletions comfy_extras/nodes_custom_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,142 @@
from comfy_api.latest import ComfyExtension, io
import re

def video_latent_composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False):
# destination/source shape: [B, C, F, H, W]
source = source.to(destination.device)

if resize_source:
target_size = (source.shape[2], destination.shape[3], destination.shape[4])
source = torch.nn.functional.interpolate(
source,
size=target_size,
mode="trilinear",
align_corners=False
)

x_latent = x // multiplier
y_latent = y // multiplier

if mask is None:
mask = torch.ones_like(source)
else:
mask = mask.to(destination.device, copy=True)
mask = mask.unsqueeze(0).unsqueeze(0)
mask_target_size = (mask.shape[2], source.shape[3], source.shape[4])
mask = torch.nn.functional.interpolate(
mask,
size=mask_target_size,
mode="trilinear",
align_corners=False
)

dst_h, dst_w = destination.shape[3], destination.shape[4]
src_h, src_w = source.shape[3], source.shape[4]

visible_h = max(0, min(y_latent + src_h, dst_h) - max(0, y_latent))
visible_w = max(0, min(x_latent + src_w, dst_w) - max(0, x_latent))

if visible_h <= 0 or visible_w <= 0:
return destination

src_top = max(0, -y_latent)
src_left = max(0, -x_latent)
dst_top = max(0, y_latent)
dst_left = max(0, x_latent)

m = mask[:, :, :, src_top:src_top+visible_h, src_left:src_left+visible_w]
s = source[:, :, :, src_top:src_top+visible_h, src_left:src_left+visible_w]
d = destination[:, :, :, dst_top:dst_top+visible_h, dst_left:dst_left+visible_w]

destination[:, :, :, dst_top:dst_top+visible_h, dst_left:dst_left+visible_w] = (m * s) + ((1.0 - m) * d)

return destination

def time_to_move_sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, latent_mask, denoise=1.0, start_step=None, time_to_move_last_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None):

sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
model_sampling = model.get_model_object("model_sampling")
process_latent_out = model.get_model_object("process_latent_out")
process_latent_in = model.get_model_object("process_latent_in")

reference_latent_image = latent_image.clone()

reference_sigmas = sampler.sigmas
reference_noise = noise.clone()

if last_step == None or last_step > steps:
last_step = steps

if time_to_move_last_step == None or time_to_move_last_step > last_step:
time_to_move_last_step = last_step

if start_step == None:
start_step = 0

total_iterations = min(last_step, steps) - start_step
if total_iterations <= 0:
return latent_image.to(
device=comfy.model_management.intermediate_device(),
dtype=comfy.model_management.intermediate_dtype(),
)

for i in range(total_iterations):
if i > 0:
#don't add new noise to samples after first step taken
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")

temp_start = start_step + i

if temp_start < last_step - 1:
temp_force_full_denoise = False
else:
temp_force_full_denoise = force_full_denoise

samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=temp_start, last_step=temp_start + 1, force_full_denoise=temp_force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)

if temp_start < time_to_move_last_step:
scale = reference_sigmas[temp_start + 1].to(noise.device)

if torch.count_nonzero(reference_latent_image) > 0: #Don't shift the empty latent image.
noisy = model_sampling.noise_scaling(scale, reference_noise, process_latent_in(reference_latent_image))
noisy = model_sampling.inverse_noise_scaling(scale, noisy)
noisy = process_latent_out(noisy)
else:
noisy = reference_latent_image

noisy.to(samples.device)

samples = video_latent_composite(samples, noisy, 0, 0, latent_mask, multiplier=1, resize_source=True)

latent_image = samples

samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples


def time_to_move_common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, latent_mask, denoise=1.0, disable_noise=False, start_step=None, time_to_move_last_step = None, last_step=None, force_full_denoise=False):
latent_image = latent["samples"]
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))

if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = comfy.sample.prepare_noise(latent_image, seed, batch_inds)

noise_mask = None
if "noise_mask" in latent:
noise_mask = latent["noise_mask"]

callback = latent_preview.prepare_callback(model, steps)
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = time_to_move_sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, latent_mask,
denoise=denoise, start_step=start_step, time_to_move_last_step = time_to_move_last_step, last_step=last_step,
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
out = latent.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
return (out, )

class BasicScheduler(io.ComfyNode):
@classmethod
Expand Down Expand Up @@ -978,6 +1114,46 @@ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput:
return io.NodeOutput(out, out_denoised)

sample = execute

class TimeToMoveKSamplerAdvanced(io.ComfyNode):
Comment thread
Pizzawookiee marked this conversation as resolved.
@classmethod

def define_schema(cls):
return io.Schema(
node_id="TimeToMoveKSamplerAdvanced",
category="sampling/time_to_move",
inputs=[
io.Model.Input("model"),
io.Combo.Input("add_noise", options=["enable", "disable"], advanced=True),
io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01),
io.Combo.Input("sampler_name", options = comfy.samplers.KSampler.SAMPLERS),
io.Combo.Input("scheduler", options = comfy.samplers.KSampler.SCHEDULERS),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent_image"),
io.Mask.Input("latent_mask", tooltip = "Make sure mask is the same length as the latents rather than the original video."),
io.Int.Input("start_at_step", default = 0, min = 0, max = 10000, advanced = True, tooltip = "Generally should set at a step greater than 0."),
io.Int.Input("time_to_move_end_at_step", default = 0, min = 0, max = 10000, advanced = True, tooltip = "Generally should set at a step greater than 0 and less than total number of steps."),
io.Int.Input("end_at_step", default = 10000, min = 0, max = 10000, advanced = True, tooltip = "Use just like typical end_at_step with normal KSamplerAdvanced"),
io.Combo.Input("return_with_leftover_noise", options=["disable", "enable"], advanced = True),
],
outputs=[
io.Latent.Output(display_name="latent"),
]
)

@classmethod
def execute(cls, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, latent_mask, start_at_step, time_to_move_end_at_step, end_at_step, return_with_leftover_noise, denoise=1.0) -> io.NodeOutput:
force_full_denoise = True
if return_with_leftover_noise == "enable":
force_full_denoise = False
disable_noise = False
if add_noise == "disable":
disable_noise = True

return time_to_move_common_ksampler(model, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, latent_mask, denoise=denoise, disable_noise=disable_noise, start_step=start_at_step, time_to_move_last_step = time_to_move_end_at_step, last_step=end_at_step, force_full_denoise=force_full_denoise)

class AddNoise(io.ComfyNode):
@classmethod
Expand Down Expand Up @@ -1087,6 +1263,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
DisableNoise,
AddNoise,
SamplerCustomAdvanced,
TimeToMoveKSamplerAdvanced,
ManualSigmas,
]

Expand Down
62 changes: 60 additions & 2 deletions comfy_extras/nodes_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,42 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination

def convert_rgb_mask_to_latent_mask(
mask: torch.Tensor,
k: int,
spatial_downsample_h: int,
spatial_downsample_w: int
) -> torch.Tensor:
"""
Converts [T, H, W] mask to [T_latent, H_latent, W_latent].
Handles non-square spatial downsampling.
"""
# 1. Temporal Sampling
# Select first frame and every k-th frame thereafter
mask0 = mask[0:1]
mask1 = mask[1::k]
sampled = torch.cat([mask0, mask1], dim=0) # [T_latent, H, W]
Comment thread
Pizzawookiee marked this conversation as resolved.

# 2. Prepare for Spatial Interpolation
# Shape: [Batch=1, Channels=1, Depth=T_latent, Height=H, Width=W]
sampled = sampled.unsqueeze(0).unsqueeze(0)

# 3. Calculate New Spatial Dimensions
h_latent = sampled.shape[-2] // spatial_downsample_h
w_latent = sampled.shape[-1] // spatial_downsample_w

# 4. Interpolate
# We maintain the temporal count (sampled.shape[2])
# but resize H and W independently
pooled = torch.nn.functional.interpolate(
sampled,
size=(sampled.shape[2], h_latent, w_latent),
mode="nearest"
)

# 5. Return to [T_latent, H_latent, W_latent]
return pooled.squeeze(0).squeeze(0)

class LatentCompositeMasked(IO.ComfyNode):
@classmethod
def define_schema(cls):
Expand Down Expand Up @@ -73,8 +109,7 @@ def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.No
return IO.NodeOutput(output)

composite = execute # TODO: remove



class ImageCompositeMasked(IO.ComfyNode):
@classmethod
def define_schema(cls):
Expand Down Expand Up @@ -398,6 +433,28 @@ def execute(cls, mask, value) -> IO.NodeOutput:

image_to_mask = execute # TODO: remove

class RGBMaskToLatentMask(IO.ComfyNode):
Comment thread
Pizzawookiee marked this conversation as resolved.
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RGBMasktoLatentMask",
search_aliases=["rgb mask to latent mask", "rgb mask", "latent mask"],
description="Converts an RGB mask to a latent-space mask for use with causal Video VAEs (e.g., Wan).",
category="latent",
inputs=[
IO.Mask.Input("mask", optional=False),
IO.Vae.Input("vae", optional=False),
],
outputs=[IO.Mask.Output()],
)

@classmethod
def execute(cls, mask, vae) -> IO.NodeOutput:
# Ensure we work on a copy of the mask to remain non-destructive
mask_copy = mask.clone()
downscale_ratio = vae.downscale_ratio
k = (mask.shape[0] - 1) // (downscale_ratio[0](mask.shape[0]) - 1) if (downscale_ratio[0](mask.shape[0]) - 1) > 1 else 1
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

k guard condition > 1 silently misbehaves when T_latent == 2.

The condition (downscale_ratio[0](mask.shape[0]) - 1) > 1 evaluates to False when T_latent - 1 == 1 (i.e., T_latent == 2), so k is hard-clamped to 1 instead of the correct k = T - 1. A mask with 81 input frames would be returned with all 81 temporal frames rather than the expected 2, silently producing a shape mismatch downstream. The guard only needs to avoid division by zero, so the threshold should be > 0.

🐛 Proposed fix
-        k = (mask.shape[0] - 1) // (downscale_ratio[0](mask.shape[0]) - 1) if (downscale_ratio[0](mask.shape[0]) - 1) > 1 else 1
+        t_latent = downscale_ratio[0](mask.shape[0])
+        k = (mask.shape[0] - 1) // (t_latent - 1) if (t_latent - 1) > 0 else 1

The refactor also avoids calling the downscale_ratio[0] callable twice.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@comfy_extras/nodes_mask.py` at line 458, The computation for k incorrectly
uses the guard (downscale_ratio[0](mask.shape[0]) - 1) > 1 which forces k=1 when
T_latent==2; change it to only guard against division by zero by checking > 0
and call the downscale callable once: store ratio =
downscale_ratio[0](mask.shape[0]) then compute k = (mask.shape[0] - 1) // (ratio
- 1) if (ratio - 1) > 0 else 1 so T_latent==2 yields the correct k = T - 1 and
avoids the double call to downscale_ratio[0].

return IO.NodeOutput(convert_rgb_mask_to_latent_mask(mask_copy, k, spatial_downsample_h = downscale_ratio[1], spatial_downsample_w = downscale_ratio[2]))

# Mask Preview - original implement from
# https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81
Expand Down Expand Up @@ -439,6 +496,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
FeatherMask,
GrowMask,
ThresholdMask,
RGBMaskToLatentMask,
MaskPreview,
]

Expand Down