diff --git a/cosmos_framework/data/vfm/sequence_packing/packers.py b/cosmos_framework/data/vfm/sequence_packing/packers.py index f311f878..a4532fd1 100644 --- a/cosmos_framework/data/vfm/sequence_packing/packers.py +++ b/cosmos_framework/data/vfm/sequence_packing/packers.py @@ -257,6 +257,11 @@ def pack_input_sequence( ) vision_split_len = 0 + # Per-item split lengths for multi-control attention routing. + # Only tracked when control_weights are present (inference-only); + # skipped during training to avoid unnecessary side effects. + track_item_split_lens = gen_data_clean.control_weights is not None + sample_item_split_lens: list[int] = [] # Controlnet-style transfer: when set, all vision items share the same # temporal mRoPE grid. We snapshot the offset before the loop and # rewind to it before each item, so every item produces identical @@ -359,6 +364,10 @@ def pack_input_sequence( vision_temporal_positions=vision_temporal_positions, ) vision_split_len += item_split_len + if track_item_split_lens: + sample_item_split_lens.append(item_split_len) + if track_item_split_lens: + packed_seq.vision_item_split_lens.append(sample_item_split_lens) sample_len += vision_split_len else: diff --git a/cosmos_framework/data/vfm/sequence_packing/types.py b/cosmos_framework/data/vfm/sequence_packing/types.py index ade4a1ba..0fdd4ef3 100644 --- a/cosmos_framework/data/vfm/sequence_packing/types.py +++ b/cosmos_framework/data/vfm/sequence_packing/types.py @@ -155,6 +155,18 @@ class PackedSequence: action: ModalityData | None = None sound: ModalityData | None = None + # Multi-control transfer: per-sample list of per-vision-item token counts. + # For a multi-control transfer sample with N controls + 1 noisy target, + # vision_item_split_lens[i] = [L_ctrl0, L_ctrl1, ..., L_ctrlN-1, L_noisy]. + # Used by cosmos3_vfm_network.py to derive gen-relative control/noisy ranges + # for multi_control_two_way_attention. + vision_item_split_lens: list[list[int]] = field(default_factory=list) + + # Per-sample per-control weights for multi-control weighted V-scaling. + # Parallel to vision_item_split_lens[i][:-1] (excludes noisy item). + # None for non-transfer or standard single-control samples. + control_weights: list[list[float]] | None = None + def finalize( self, gen_data_clean: GenerationDataClean, @@ -262,6 +274,9 @@ def finalize( # Temporal causal null_action_supertokens=self.null_action_supertokens, num_action_tokens_per_supertoken=self.num_action_tokens_per_supertoken, + # Multi-control transfer + vision_item_split_lens=list(self.vision_item_split_lens), + control_weights=gen_data_clean.control_weights, ) def to_cuda(self) -> None: diff --git a/cosmos_framework/inference/args.py b/cosmos_framework/inference/args.py index da8dea22..7095fa7e 100644 --- a/cosmos_framework/inference/args.py +++ b/cosmos_framework/inference/args.py @@ -246,6 +246,16 @@ class TransferArgs(ArgsBase): """Resolved transfer inference arguments for a single control hint.""" control_path: ResolvedFilePathOrUrl | None = None + weight: float = 1.0 + """Strength of this control signal in the weighted multi-control attention aggregation.""" + + @pydantic.field_validator("weight", mode="before") + @classmethod + def _coerce_none_weight(cls, v: float | None) -> float: + # ``TransferOverrides.weight`` defaults to None ("unset") and is resolved into + # this required float via ``_build`` (which does not strip None). An unset weight + # resolves to the neutral 1.0 → equal weighting / single-control parity. + return 1.0 if v is None else v class EdgeTransferArgs(TransferArgs): @@ -262,6 +272,9 @@ class TransferOverrides(OverridesBase): control_path: ResolvedFilePathOrUrl | None = None """Path or URL to pre-computed control input.""" + weight: float | None = None + """Override the control weight for multi-control weighted attention aggregation.""" + def download(self, output_dir: Path): if self.control_path is not None: self.control_path = download_file(self.control_path, output_dir, "transfer_control") diff --git a/cosmos_framework/inference/transfer.py b/cosmos_framework/inference/transfer.py index 57760e7c..61a99b49 100644 --- a/cosmos_framework/inference/transfer.py +++ b/cosmos_framework/inference/transfer.py @@ -155,8 +155,21 @@ def build_transfer_batch( prompt: str, negative_prompt: str | None, share_vision_temporal_positions: bool, + control_weights: list[float] | None = None, ) -> dict[str, object]: - """Build the ``[ctrl_1, ..., ctrl_N, target]`` batch for transfer inference.""" + """Build the ``[ctrl_1, ..., ctrl_N, target]`` batch for transfer inference. + + ``control_weights`` is a per-control scalar (default 1.0 each). Weights are + normalised to sum to 1 before use. In ``multi_control_two_way_attention`` N + independent maskless SDPA passes are computed (one per control), each with + KV = [text | ctrl_i | noisy]. The final noisy output is the weighted sum: + + noisy_out = w_1 * noisy_out_1 + ... + w_N * noisy_out_N + + All SDPA calls are maskless so Flash Attention is always active. + When ``N=1`` the single weight normalises to 1.0, reproducing the original + ``two_way_attention`` behaviour exactly. + """ control_5ds = [cv.unsqueeze(0).cuda().to(dtype=torch.bfloat16) for cv in control_videos] target_5d = target_video.unsqueeze(0).cuda().to(dtype=torch.bfloat16) num_vision_items = len(control_5ds) + 1 @@ -165,6 +178,16 @@ def build_transfer_batch( else: condition_frame_indexes = [] + if control_weights is None: + control_weights = [1.0] * len(control_5ds) + assert len(control_weights) == len(control_5ds), ( + f"control_weights length {len(control_weights)} must match number of controls {len(control_5ds)}" + ) + assert all(w >= 0 for w in control_weights), f"control_weights must all be non-negative, got {control_weights}" + total = sum(control_weights) + assert total > 0, f"control_weights must have a positive sum, got {control_weights}" + control_weights = [w / total for w in control_weights] + size = torch.tensor([[height, width, height, width]], dtype=torch.float32).cuda() batch: dict[str, object] = { "dataset_name": "video_transfer", @@ -174,6 +197,9 @@ def build_transfer_batch( "padding_mask": torch.zeros(1, 1, height, width).cuda(), "num_frames": torch.tensor([num_frames]).cuda(), "num_vision_items_per_sample": [num_vision_items], + # Per-control weights for multi-control weighted attention aggregation. + # Shape: [num_samples], each element is a list of floats (one per control). + "control_weights": [control_weights], "is_preprocessed": True, # share_vision_temporal_positions must match the trained checkpoint's # SequencePlan regime; mismatched flag → frame-drift between control and @@ -512,6 +538,7 @@ def generate_transfer_sample( prompt=prompt, negative_prompt=negative_prompt, share_vision_temporal_positions=share_temporal, + control_weights=[h.weight for h in hints.values()], ) outputs = model.generate_samples_from_batch( data_batch, diff --git a/cosmos_framework/model/vfm/mot/attention.py b/cosmos_framework/model/vfm/mot/attention.py index 9f950b9f..5c55ddfa 100644 --- a/cosmos_framework/model/vfm/mot/attention.py +++ b/cosmos_framework/model/vfm/mot/attention.py @@ -56,6 +56,14 @@ def __init__( self.num_action_tokens_per_supertoken = num_action_tokens_per_supertoken self.null_action_supertokens = null_action_supertokens + # Multi-control transfer fields (set post-construction in cosmos3_vfm_network.py). + # Gen-relative token ranges for each control stream, one tuple (start, end) per control. + self.control_stream_token_ranges: list[tuple[int, int]] | None = None + # Gen-relative token range (start, end) for the noisy target tokens. + self.noisy_token_range: tuple[int, int] | None = None + # Per-control scalar weights; parallel to control_stream_token_ranges. + self.control_weights: list[float] | None = None + AttentionMaskType = SplitInfo @@ -233,6 +241,139 @@ def three_way_attention( return out_all +def multi_control_two_way_attention( + packed_query_states: SequencePack, + packed_key_states: SequencePack, + packed_value_states: SequencePack, + split_info: SplitInfo, +) -> SequencePack: + """Two-way attention for multi-control transfer inference. + + N independent single-control attention passes; noisy output = weighted sum. + + Layout of the "full/gen" segment (mirrors the packed batch built by ``build_transfer_batch``): + + full = [ctrl_1 | ctrl_2 | ... | ctrl_N | noisy] + + For each control i, one independent maskless SDPA is computed: + + ctrl_i and noisy both attend to KV = [text | ctrl_i | noisy] + + The final outputs are: + - ctrl_i output: from pass i only + - noisy output: w_1 * noisy_out_1 + ... + w_N * noisy_out_N (weighted sum) + + All SDPA calls are maskless → Flash Attention is always active. + N=1, w=1.0 → identical to ``two_way_attention``. + + Padding safety: + Both ``get_causal_seq`` and ``get_full_only_seq`` can return padded rows. + We unpad to valid token counts before each SDPA so that padded rows + never enter the softmax denominator. + + Args: + packed_query/key/value_states: SequencePack for a single sample. + split_info: SplitInfo carrying ``control_stream_token_ranges``, + ``noisy_token_range``, and ``control_weights`` (all must be non-None). + """ + assert split_info.control_stream_token_ranges is not None + assert split_info.noisy_token_range is not None + assert split_info.control_weights is not None + + ctrl_ranges = split_info.control_stream_token_ranges + noisy_s, noisy_e = split_info.noisy_token_range + weights = split_info.control_weights + + # ── 1. Text self-attention (causal, unchanged) ─────────────────────────── + causal_q, causal_q_offsets = get_causal_seq(packed_query_states) + causal_k, causal_k_offsets = get_causal_seq(packed_key_states) + causal_v, _ = get_causal_seq(packed_value_states) + + use_dont_care_mask = causal_q_offsets is causal_k_offsets + causal_res = attention( + causal_q.unsqueeze(0), + causal_k.unsqueeze(0), + causal_v.unsqueeze(0), + cumulative_seqlen_Q=causal_q_offsets, + cumulative_seqlen_KV=causal_k_offsets, + max_seqlen_Q=packed_query_states["max_causal_len"], + max_seqlen_KV=packed_query_states["max_causal_len"], + is_causal=True, + causal_type=CausalType.DontCare if use_dont_care_mask else CausalType.TopLeft, + ) + causal_out = causal_res.squeeze(0).flatten(-2, -1) # [N_text, Hq*D] + + # ── 2. Extract unpadded full/gen tokens ────────────────────────────────── + full_q, full_q_offsets = get_full_only_seq(packed_query_states) + full_k, _ = get_full_only_seq(packed_key_states) + full_v, _ = get_full_only_seq(packed_value_states) + + n_text = int(causal_k_offsets[-1]) + n_full = int(full_q_offsets[-1]) + + # Unpad to avoid padded rows entering the softmax denominator. + causal_k_v = causal_k[:n_text] # [N_text, Hkv, D] + causal_v_v = causal_v[:n_text] # [N_text, Hkv, D] + full_q_v = full_q[:n_full] # [N_full, Hq, D] + full_k_v = full_k[:n_full] # [N_full, Hkv, D] + full_v_v = full_v[:n_full] # [N_full, Hkv, D] + + noisy_q = full_q_v[noisy_s:noisy_e] # [N_noisy, Hq, D] + noisy_k = full_k_v[noisy_s:noisy_e] # [N_noisy, Hkv, D] + noisy_v = full_v_v[noisy_s:noisy_e] # [N_noisy, Hkv, D] + + def _sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: + """Maskless attention using cosmos_framework attention() → [N_q, Hq*D].""" + n_q, n_kv = q.shape[0], k.shape[0] + seqlens_q = torch.tensor([n_q], dtype=torch.int32, device=q.device) + seqlens_kv = torch.tensor([n_kv], dtype=torch.int32, device=k.device) + res = attention( + q.unsqueeze(0), # [1, N_q, Hq, D] + k.unsqueeze(0), # [1, N_kv, Hkv, D] + v.unsqueeze(0), # [1, N_kv, Hkv, D] + seqlens_Q=seqlens_q, + seqlens_KV=seqlens_kv, + max_seqlen_Q=n_q, + max_seqlen_KV=n_kv, + ) # [1, N_q, Hq, D] + return res.squeeze(0).flatten(-2, -1) # [N_q, Hq*D] + + # ── 3. N independent single-control passes ──────────────────────────────── + # For each control i: KV = [text | ctrl_i | noisy] — maskless SDPA. + # ctrl_i attends to [text, ctrl_i, noisy] → stored directly in full_out. + # noisy attends to [text, ctrl_i, noisy] → accumulated as weighted sum. + full_out_v = full_q_v.new_zeros(n_full, causal_out.shape[-1]) + noisy_out_acc: torch.Tensor | None = None + + for i, (cs, ce) in enumerate(ctrl_ranges): + ctrl_k_i = full_k_v[cs:ce] + ctrl_v_i = full_v_v[cs:ce] + ctrl_q_i = full_q_v[cs:ce] + + # KV context for this pass: [text | ctrl_i | noisy] + kv_k_i = torch.cat([causal_k_v, ctrl_k_i, noisy_k], dim=0) + kv_v_i = torch.cat([causal_v_v, ctrl_v_i, noisy_v], dim=0) + + # ctrl_i output — stored directly + full_out_v[cs:ce] = _sdpa(ctrl_q_i, kv_k_i, kv_v_i) + + # noisy output for pass i — accumulate weighted sum + noisy_out_i = _sdpa(noisy_q, kv_k_i, kv_v_i) + if noisy_out_acc is None: + noisy_out_acc = weights[i] * noisy_out_i + else: + noisy_out_acc = noisy_out_acc + weights[i] * noisy_out_i + + assert noisy_out_acc is not None + full_out_v[noisy_s:noisy_e] = noisy_out_acc + + # Re-pad to original shape so downstream layers see consistent tensor sizes. + full_out = full_q.new_zeros(full_q.shape[0], full_out_v.shape[-1]) + full_out[:n_full] = full_out_v + + return from_mode_splits(causal_out, full_out, packed_query_states) + + def dispatch_attention( packed_query_states: SequencePack, packed_key_states: SequencePack, @@ -242,7 +383,14 @@ def dispatch_attention( memory_value: MemoryValue | None = None, ) -> tuple[SequencePack, KVToStore | None]: assert memory_value is None, "Base dispatch_attention does not handle MemoryValue" - if isinstance(attention_mask, SplitInfo) and attention_mask.is_three_way: + if isinstance(attention_mask, SplitInfo) and attention_mask.control_stream_token_ranges is not None: + output = multi_control_two_way_attention( + packed_query_states, + packed_key_states, + packed_value_states, + attention_mask, + ) + elif isinstance(attention_mask, SplitInfo) and attention_mask.is_three_way: output = three_way_attention( packed_query_states, packed_key_states, diff --git a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py index 8323a8aa..1592f282 100644 --- a/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py +++ b/cosmos_framework/model/vfm/mot/cosmos3_vfm_network.py @@ -9,7 +9,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from cosmos_framework.model.vfm.mot.attention import build_packed_sequence +from cosmos_framework.model.vfm.mot.attention import SplitInfo, build_packed_sequence from cosmos_framework.model.vfm.mot.context_parallel_utils import ( get_context_parallel_last_hidden_state, get_context_parallel_sharded_sequence, @@ -1085,6 +1085,55 @@ def forward( pad_for_cuda_graphs=self.pad_for_cuda_graphs, ) + # ── Multi-control transfer: annotate SplitInfo with per-item ranges ────── + # Activated only when packed_seq carries control_weights, i.e. the caller + # has set up a multi-control batch via build_transfer_batch. + # + # multi_control_two_way_attention runs N independent maskless SDPA passes, + # one per control. For each pass i, KV = [text | ctrl_i | noisy]. + # The final noisy output is the weighted sum of the N pass outputs: + # noisy_out = w_1 * noisy_out_1 + ... + w_N * noisy_out_N + # All SDPA calls are maskless → Flash Attention always active. + # N=1, w=1.0 → identical to two_way_attention. + # + # CP compatibility: control_stream_token_ranges are gen-relative global + # offsets computed here, before CP sharding. Ulysses CP restores the full + # sequence on every rank (via all-to-all) before calling dispatch_attention, + # so the global ranges are valid indices inside multi_control_two_way_attention. + if ( + isinstance(attention_meta, SplitInfo) + and packed_seq.control_weights is not None + and packed_seq.vision_item_split_lens + ): + # For multi-control, each sample must have N controls + 1 noisy item + # (items 0..N-2 are controls, item N-1 is the noisy target). + # Only batch_size=1 is supported; assert to catch misuse early. + assert len(packed_seq.vision_item_split_lens) == 1, ( + f"Multi-control transfer requires batch_size=1, got {len(packed_seq.vision_item_split_lens)} samples." + ) + item_lens = packed_seq.vision_item_split_lens[0] # [L_ctrl0, L_ctrl1, ..., L_noisy] + weights = packed_seq.control_weights[0] # [w_ctrl0, w_ctrl1, ...] + assert len(item_lens) > 1, ( + f"Multi-control requires at least 1 control + 1 noisy item; got vision_item_split_lens={item_lens}." + ) + assert len(weights) == len(item_lens) - 1, ( + f"control_weights length ({len(weights)}) must equal number of control items ({len(item_lens) - 1})." + ) + ctrl_ranges: list[tuple[int, int]] = [] + cursor = 0 + for lens in item_lens[:-1]: # all but last = control streams + ctrl_ranges.append((cursor, cursor + lens)) + cursor += lens + noisy_range = (cursor, cursor + item_lens[-1]) + n_gen = int(vision_sequence_indexes.shape[0]) if vision_sequence_indexes is not None else 0 + assert noisy_range[1] == n_gen, ( + f"vision_item_split_lens sums to {noisy_range[1]} gen tokens but packed tensor has " + f"{n_gen}; packing inconsistency detected." + ) + attention_meta.control_stream_token_ranges = ctrl_ranges + attention_meta.noisy_token_range = noisy_range + attention_meta.control_weights = weights + input_pack, packed_position_ids = get_context_parallel_sharded_sequence( attn_implementation=self.config.joint_attn_implementation, input_pack=input_pack, diff --git a/cosmos_framework/model/vfm/omni_mot_model.py b/cosmos_framework/model/vfm/omni_mot_model.py index ea2f1b82..d136a801 100644 --- a/cosmos_framework/model/vfm/omni_mot_model.py +++ b/cosmos_framework/model/vfm/omni_mot_model.py @@ -2915,6 +2915,7 @@ def get_data_and_condition(self, data_batch: dict[str, torch.Tensor], iteration: }, step=iteration, ) + control_weights: list[list[float]] | None = data_batch.get("control_weights", None) return GenerationDataClean( batch_size=batch_size, is_image_batch=is_image_batch, @@ -2931,6 +2932,7 @@ def get_data_and_condition(self, data_batch: dict[str, torch.Tensor], iteration: action_domain_id=action_domain_id, num_vision_items_per_sample=num_vision_items_per_sample, raw_action_dim=raw_action_dim, + control_weights=control_weights, ) def _normalize_video_databatch_inplace( diff --git a/cosmos_framework/model/vfm/utils/data_and_condition.py b/cosmos_framework/model/vfm/utils/data_and_condition.py index 44aaa85b..8b738b8b 100644 --- a/cosmos_framework/model/vfm/utils/data_and_condition.py +++ b/cosmos_framework/model/vfm/utils/data_and_condition.py @@ -46,6 +46,11 @@ class GenerationDataClean: action_domain_id: list[torch.Tensor] | None = None # per-sample domain IDs, None when no action samples raw_action_dim: list[torch.Tensor] | None = None # raw action dimension, used adding masks to loss calculation + # Multi-control transfer: per-sample list of per-control weights. + # Shape: [num_samples], each element is a list of floats (one per control stream). + # None for non-transfer or single-control samples. + control_weights: list[list[float]] | None = None + @dataclass(slots=True) class GenerationDataNoised: