Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cosmos_framework/data/vfm/sequence_packing/packers.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def pack_input_sequence(
)

vision_split_len = 0
# Per-item split lengths for multi-control attention routing.
# Only tracked when control_weights are present (inference-only);
# skipped during training to avoid unnecessary side effects.
track_item_split_lens = gen_data_clean.control_weights is not None
sample_item_split_lens: list[int] = []
# Controlnet-style transfer: when set, all vision items share the same
# temporal mRoPE grid. We snapshot the offset before the loop and
# rewind to it before each item, so every item produces identical
Expand Down Expand Up @@ -359,6 +364,10 @@ def pack_input_sequence(
vision_temporal_positions=vision_temporal_positions,
)
vision_split_len += item_split_len
if track_item_split_lens:
sample_item_split_lens.append(item_split_len)
if track_item_split_lens:
packed_seq.vision_item_split_lens.append(sample_item_split_lens)
sample_len += vision_split_len

else:
Expand Down
15 changes: 15 additions & 0 deletions cosmos_framework/data/vfm/sequence_packing/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@ class PackedSequence:
action: ModalityData | None = None
sound: ModalityData | None = None

# Multi-control transfer: per-sample list of per-vision-item token counts.
# For a multi-control transfer sample with N controls + 1 noisy target,
# vision_item_split_lens[i] = [L_ctrl0, L_ctrl1, ..., L_ctrlN-1, L_noisy].
# Used by cosmos3_vfm_network.py to derive gen-relative control/noisy ranges
# for multi_control_two_way_attention.
vision_item_split_lens: list[list[int]] = field(default_factory=list)

# Per-sample per-control weights for multi-control weighted V-scaling.
# Parallel to vision_item_split_lens[i][:-1] (excludes noisy item).
# None for non-transfer or standard single-control samples.
control_weights: list[list[float]] | None = None

def finalize(
self,
gen_data_clean: GenerationDataClean,
Expand Down Expand Up @@ -262,6 +274,9 @@ def finalize(
# Temporal causal
null_action_supertokens=self.null_action_supertokens,
num_action_tokens_per_supertoken=self.num_action_tokens_per_supertoken,
# Multi-control transfer
vision_item_split_lens=list(self.vision_item_split_lens),
control_weights=gen_data_clean.control_weights,
)

def to_cuda(self) -> None:
Expand Down
13 changes: 13 additions & 0 deletions cosmos_framework/inference/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ class TransferArgs(ArgsBase):
"""Resolved transfer inference arguments for a single control hint."""

control_path: ResolvedFilePathOrUrl | None = None
weight: float = 1.0
"""Strength of this control signal in the weighted multi-control attention aggregation."""

@pydantic.field_validator("weight", mode="before")
@classmethod
def _coerce_none_weight(cls, v: float | None) -> float:
# ``TransferOverrides.weight`` defaults to None ("unset") and is resolved into
# this required float via ``_build`` (which does not strip None). An unset weight
# resolves to the neutral 1.0 → equal weighting / single-control parity.
return 1.0 if v is None else v


class EdgeTransferArgs(TransferArgs):
Expand All @@ -262,6 +272,9 @@ class TransferOverrides(OverridesBase):
control_path: ResolvedFilePathOrUrl | None = None
"""Path or URL to pre-computed control input."""

weight: float | None = None
"""Override the control weight for multi-control weighted attention aggregation."""

def download(self, output_dir: Path):
if self.control_path is not None:
self.control_path = download_file(self.control_path, output_dir, "transfer_control")
Expand Down
29 changes: 28 additions & 1 deletion cosmos_framework/inference/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,21 @@ def build_transfer_batch(
prompt: str,
negative_prompt: str | None,
share_vision_temporal_positions: bool,
control_weights: list[float] | None = None,
) -> dict[str, object]:
"""Build the ``[ctrl_1, ..., ctrl_N, target]`` batch for transfer inference."""
"""Build the ``[ctrl_1, ..., ctrl_N, target]`` batch for transfer inference.

``control_weights`` is a per-control scalar (default 1.0 each). Weights are
normalised to sum to 1 before use. In ``multi_control_two_way_attention`` N
independent maskless SDPA passes are computed (one per control), each with
KV = [text | ctrl_i | noisy]. The final noisy output is the weighted sum:

noisy_out = w_1 * noisy_out_1 + ... + w_N * noisy_out_N

All SDPA calls are maskless so Flash Attention is always active.
When ``N=1`` the single weight normalises to 1.0, reproducing the original
``two_way_attention`` behaviour exactly.
"""
control_5ds = [cv.unsqueeze(0).cuda().to(dtype=torch.bfloat16) for cv in control_videos]
target_5d = target_video.unsqueeze(0).cuda().to(dtype=torch.bfloat16)
num_vision_items = len(control_5ds) + 1
Expand All @@ -165,6 +178,16 @@ def build_transfer_batch(
else:
condition_frame_indexes = []

if control_weights is None:
control_weights = [1.0] * len(control_5ds)
assert len(control_weights) == len(control_5ds), (
f"control_weights length {len(control_weights)} must match number of controls {len(control_5ds)}"
)
assert all(w >= 0 for w in control_weights), f"control_weights must all be non-negative, got {control_weights}"
total = sum(control_weights)
assert total > 0, f"control_weights must have a positive sum, got {control_weights}"
control_weights = [w / total for w in control_weights]

size = torch.tensor([[height, width, height, width]], dtype=torch.float32).cuda()
batch: dict[str, object] = {
"dataset_name": "video_transfer",
Expand All @@ -174,6 +197,9 @@ def build_transfer_batch(
"padding_mask": torch.zeros(1, 1, height, width).cuda(),
"num_frames": torch.tensor([num_frames]).cuda(),
"num_vision_items_per_sample": [num_vision_items],
# Per-control weights for multi-control weighted attention aggregation.
# Shape: [num_samples], each element is a list of floats (one per control).
"control_weights": [control_weights],
"is_preprocessed": True,
# share_vision_temporal_positions must match the trained checkpoint's
# SequencePlan regime; mismatched flag → frame-drift between control and
Expand Down Expand Up @@ -512,6 +538,7 @@ def generate_transfer_sample(
prompt=prompt,
negative_prompt=negative_prompt,
share_vision_temporal_positions=share_temporal,
control_weights=[h.weight for h in hints.values()],
)
outputs = model.generate_samples_from_batch(
data_batch,
Expand Down
150 changes: 149 additions & 1 deletion cosmos_framework/model/vfm/mot/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ def __init__(
self.num_action_tokens_per_supertoken = num_action_tokens_per_supertoken
self.null_action_supertokens = null_action_supertokens

# Multi-control transfer fields (set post-construction in cosmos3_vfm_network.py).
# Gen-relative token ranges for each control stream, one tuple (start, end) per control.
self.control_stream_token_ranges: list[tuple[int, int]] | None = None
# Gen-relative token range (start, end) for the noisy target tokens.
self.noisy_token_range: tuple[int, int] | None = None
# Per-control scalar weights; parallel to control_stream_token_ranges.
self.control_weights: list[float] | None = None


AttentionMaskType = SplitInfo

Expand Down Expand Up @@ -233,6 +241,139 @@ def three_way_attention(
return out_all


def multi_control_two_way_attention(
packed_query_states: SequencePack,
packed_key_states: SequencePack,
packed_value_states: SequencePack,
split_info: SplitInfo,
) -> SequencePack:
"""Two-way attention for multi-control transfer inference.

N independent single-control attention passes; noisy output = weighted sum.

Layout of the "full/gen" segment (mirrors the packed batch built by ``build_transfer_batch``):

full = [ctrl_1 | ctrl_2 | ... | ctrl_N | noisy]

For each control i, one independent maskless SDPA is computed:

ctrl_i and noisy both attend to KV = [text | ctrl_i | noisy]

The final outputs are:
- ctrl_i output: from pass i only
- noisy output: w_1 * noisy_out_1 + ... + w_N * noisy_out_N (weighted sum)

All SDPA calls are maskless → Flash Attention is always active.
N=1, w=1.0 → identical to ``two_way_attention``.

Padding safety:
Both ``get_causal_seq`` and ``get_full_only_seq`` can return padded rows.
We unpad to valid token counts before each SDPA so that padded rows
never enter the softmax denominator.

Args:
packed_query/key/value_states: SequencePack for a single sample.
split_info: SplitInfo carrying ``control_stream_token_ranges``,
``noisy_token_range``, and ``control_weights`` (all must be non-None).
"""
assert split_info.control_stream_token_ranges is not None
assert split_info.noisy_token_range is not None
assert split_info.control_weights is not None

ctrl_ranges = split_info.control_stream_token_ranges
noisy_s, noisy_e = split_info.noisy_token_range
weights = split_info.control_weights

# ── 1. Text self-attention (causal, unchanged) ───────────────────────────
causal_q, causal_q_offsets = get_causal_seq(packed_query_states)
causal_k, causal_k_offsets = get_causal_seq(packed_key_states)
causal_v, _ = get_causal_seq(packed_value_states)

use_dont_care_mask = causal_q_offsets is causal_k_offsets
causal_res = attention(
causal_q.unsqueeze(0),
causal_k.unsqueeze(0),
causal_v.unsqueeze(0),
cumulative_seqlen_Q=causal_q_offsets,
cumulative_seqlen_KV=causal_k_offsets,
max_seqlen_Q=packed_query_states["max_causal_len"],
max_seqlen_KV=packed_query_states["max_causal_len"],
is_causal=True,
causal_type=CausalType.DontCare if use_dont_care_mask else CausalType.TopLeft,
)
causal_out = causal_res.squeeze(0).flatten(-2, -1) # [N_text, Hq*D]

# ── 2. Extract unpadded full/gen tokens ──────────────────────────────────
full_q, full_q_offsets = get_full_only_seq(packed_query_states)
full_k, _ = get_full_only_seq(packed_key_states)
full_v, _ = get_full_only_seq(packed_value_states)

n_text = int(causal_k_offsets[-1])
n_full = int(full_q_offsets[-1])

# Unpad to avoid padded rows entering the softmax denominator.
causal_k_v = causal_k[:n_text] # [N_text, Hkv, D]
causal_v_v = causal_v[:n_text] # [N_text, Hkv, D]
full_q_v = full_q[:n_full] # [N_full, Hq, D]
full_k_v = full_k[:n_full] # [N_full, Hkv, D]
full_v_v = full_v[:n_full] # [N_full, Hkv, D]

noisy_q = full_q_v[noisy_s:noisy_e] # [N_noisy, Hq, D]
noisy_k = full_k_v[noisy_s:noisy_e] # [N_noisy, Hkv, D]
noisy_v = full_v_v[noisy_s:noisy_e] # [N_noisy, Hkv, D]

def _sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
"""Maskless attention using cosmos_framework attention() → [N_q, Hq*D]."""
n_q, n_kv = q.shape[0], k.shape[0]
seqlens_q = torch.tensor([n_q], dtype=torch.int32, device=q.device)
seqlens_kv = torch.tensor([n_kv], dtype=torch.int32, device=k.device)
res = attention(
q.unsqueeze(0), # [1, N_q, Hq, D]
k.unsqueeze(0), # [1, N_kv, Hkv, D]
v.unsqueeze(0), # [1, N_kv, Hkv, D]
seqlens_Q=seqlens_q,
seqlens_KV=seqlens_kv,
max_seqlen_Q=n_q,
max_seqlen_KV=n_kv,
) # [1, N_q, Hq, D]
return res.squeeze(0).flatten(-2, -1) # [N_q, Hq*D]

# ── 3. N independent single-control passes ────────────────────────────────
# For each control i: KV = [text | ctrl_i | noisy] — maskless SDPA.
# ctrl_i attends to [text, ctrl_i, noisy] → stored directly in full_out.
# noisy attends to [text, ctrl_i, noisy] → accumulated as weighted sum.
full_out_v = full_q_v.new_zeros(n_full, causal_out.shape[-1])
noisy_out_acc: torch.Tensor | None = None

for i, (cs, ce) in enumerate(ctrl_ranges):
ctrl_k_i = full_k_v[cs:ce]
ctrl_v_i = full_v_v[cs:ce]
ctrl_q_i = full_q_v[cs:ce]

# KV context for this pass: [text | ctrl_i | noisy]
kv_k_i = torch.cat([causal_k_v, ctrl_k_i, noisy_k], dim=0)
kv_v_i = torch.cat([causal_v_v, ctrl_v_i, noisy_v], dim=0)

# ctrl_i output — stored directly
full_out_v[cs:ce] = _sdpa(ctrl_q_i, kv_k_i, kv_v_i)

# noisy output for pass i — accumulate weighted sum
noisy_out_i = _sdpa(noisy_q, kv_k_i, kv_v_i)
if noisy_out_acc is None:
noisy_out_acc = weights[i] * noisy_out_i
else:
noisy_out_acc = noisy_out_acc + weights[i] * noisy_out_i

assert noisy_out_acc is not None
full_out_v[noisy_s:noisy_e] = noisy_out_acc

# Re-pad to original shape so downstream layers see consistent tensor sizes.
full_out = full_q.new_zeros(full_q.shape[0], full_out_v.shape[-1])
full_out[:n_full] = full_out_v

return from_mode_splits(causal_out, full_out, packed_query_states)


def dispatch_attention(
packed_query_states: SequencePack,
packed_key_states: SequencePack,
Expand All @@ -242,7 +383,14 @@ def dispatch_attention(
memory_value: MemoryValue | None = None,
) -> tuple[SequencePack, KVToStore | None]:
assert memory_value is None, "Base dispatch_attention does not handle MemoryValue"
if isinstance(attention_mask, SplitInfo) and attention_mask.is_three_way:
if isinstance(attention_mask, SplitInfo) and attention_mask.control_stream_token_ranges is not None:
output = multi_control_two_way_attention(
packed_query_states,
packed_key_states,
packed_value_states,
attention_mask,
)
elif isinstance(attention_mask, SplitInfo) and attention_mask.is_three_way:
output = three_way_attention(
packed_query_states,
packed_key_states,
Expand Down
51 changes: 50 additions & 1 deletion cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_utils import PreTrainedModel

from cosmos_framework.model.vfm.mot.attention import build_packed_sequence
from cosmos_framework.model.vfm.mot.attention import SplitInfo, build_packed_sequence
from cosmos_framework.model.vfm.mot.context_parallel_utils import (
get_context_parallel_last_hidden_state,
get_context_parallel_sharded_sequence,
Expand Down Expand Up @@ -1085,6 +1085,55 @@ def forward(
pad_for_cuda_graphs=self.pad_for_cuda_graphs,
)

# ── Multi-control transfer: annotate SplitInfo with per-item ranges ──────
# Activated only when packed_seq carries control_weights, i.e. the caller
# has set up a multi-control batch via build_transfer_batch.
#
# multi_control_two_way_attention runs N independent maskless SDPA passes,
# one per control. For each pass i, KV = [text | ctrl_i | noisy].
# The final noisy output is the weighted sum of the N pass outputs:
# noisy_out = w_1 * noisy_out_1 + ... + w_N * noisy_out_N
# All SDPA calls are maskless → Flash Attention always active.
# N=1, w=1.0 → identical to two_way_attention.
#
# CP compatibility: control_stream_token_ranges are gen-relative global
# offsets computed here, before CP sharding. Ulysses CP restores the full
# sequence on every rank (via all-to-all) before calling dispatch_attention,
# so the global ranges are valid indices inside multi_control_two_way_attention.
if (
isinstance(attention_meta, SplitInfo)
and packed_seq.control_weights is not None
and packed_seq.vision_item_split_lens
):
# For multi-control, each sample must have N controls + 1 noisy item
# (items 0..N-2 are controls, item N-1 is the noisy target).
# Only batch_size=1 is supported; assert to catch misuse early.
assert len(packed_seq.vision_item_split_lens) == 1, (
f"Multi-control transfer requires batch_size=1, got {len(packed_seq.vision_item_split_lens)} samples."
)
item_lens = packed_seq.vision_item_split_lens[0] # [L_ctrl0, L_ctrl1, ..., L_noisy]
weights = packed_seq.control_weights[0] # [w_ctrl0, w_ctrl1, ...]
assert len(item_lens) > 1, (
f"Multi-control requires at least 1 control + 1 noisy item; got vision_item_split_lens={item_lens}."
)
assert len(weights) == len(item_lens) - 1, (
f"control_weights length ({len(weights)}) must equal number of control items ({len(item_lens) - 1})."
)
ctrl_ranges: list[tuple[int, int]] = []
cursor = 0
for lens in item_lens[:-1]: # all but last = control streams
ctrl_ranges.append((cursor, cursor + lens))
cursor += lens
noisy_range = (cursor, cursor + item_lens[-1])
n_gen = int(vision_sequence_indexes.shape[0]) if vision_sequence_indexes is not None else 0
assert noisy_range[1] == n_gen, (
f"vision_item_split_lens sums to {noisy_range[1]} gen tokens but packed tensor has "
f"{n_gen}; packing inconsistency detected."
)
attention_meta.control_stream_token_ranges = ctrl_ranges
attention_meta.noisy_token_range = noisy_range
attention_meta.control_weights = weights

input_pack, packed_position_ids = get_context_parallel_sharded_sequence(
attn_implementation=self.config.joint_attn_implementation,
input_pack=input_pack,
Expand Down
2 changes: 2 additions & 0 deletions cosmos_framework/model/vfm/omni_mot_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2915,6 +2915,7 @@ def get_data_and_condition(self, data_batch: dict[str, torch.Tensor], iteration:
},
step=iteration,
)
control_weights: list[list[float]] | None = data_batch.get("control_weights", None)
return GenerationDataClean(
batch_size=batch_size,
is_image_batch=is_image_batch,
Expand All @@ -2931,6 +2932,7 @@ def get_data_and_condition(self, data_batch: dict[str, torch.Tensor], iteration:
action_domain_id=action_domain_id,
num_vision_items_per_sample=num_vision_items_per_sample,
raw_action_dim=raw_action_dim,
control_weights=control_weights,
)

def _normalize_video_databatch_inplace(
Expand Down
Loading