diff --git a/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py b/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py index b7cd806f..e20abc26 100644 --- a/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py +++ b/groot/vla/model/dreamzero/action_head/wan_flow_matching_action_tf.py @@ -544,7 +544,7 @@ def encode_prompt(self, input_ids, attention_mask): prompt_emb = self.text_encoder(input_ids, attention_mask) prompt_emb = prompt_emb.clone().to(dtype=torch.bfloat16) for i, v in enumerate(seq_lens): - prompt_emb[:, v:] = 0 + prompt_emb[i, v:] = 0 return prompt_emb def _ensure_vae_on_device(self, ref_tensor):