Skip to content

fix: handle non-tensor sample_kwargs in static_input_surface#1

Open
shjwudp wants to merge 2 commits into
buptzyb:mainfrom
shjwudp:handle_non_tensor_sample_kwargs
Open

fix: handle non-tensor sample_kwargs in static_input_surface#1
shjwudp wants to merge 2 commits into
buptzyb:mainfrom
shjwudp:handle_non_tensor_sample_kwargs

Conversation

@shjwudp

@shjwudp shjwudp commented Jun 26, 2026

Copy link
Copy Markdown

Fix non-tensor sample_kwargs in static_input_surface + positional args in replay

Two commits fixing crashes when frameworks pass mixed tensor/non-tensor
keyword arguments and positional args during replay.

Commit 1: fix: handle non-tensor sample_kwargs in static_input_surface

When sample_kwargs contains non-tensor values (e.g. attention_mask=None,
img_shapes=[[1,64,64]]), tree_flatten passes them into static_input_surface.
All existing .requires_grad and .data_ptr() accesses crash on None.

Fix: 6 guards across 4 locations:

  • _run_warmup_backward: i is not None and i.requires_grad
  • num_required_grad_sample_args: isinstance(arg, torch.Tensor) and arg.requires_grad
  • Backward capture inputs (x2): i is not None and i.requires_grad
  • Graphed.forward copy loop: inputs[i] is not None

Commit 2: fix: handle positional args in functionalized during graph replay

Frameworks may pass captured kwargs as positional args during replay
(e.g. Attention.forward(hidden_states, attention_mask=mask)). The previous
strict kwargs_keys validation would reject this.

Fix: Remove the strict key in user_kwargs validation. Reconstruct the
capture-time arg order by checking both user_kwargs (by name) and user_args
(by position) for each key in kwargs_keys.

Diff

src/te_graph_runtime/graph.py | 28 +++++++++++++++++-----------
1 file changed, 17 insertions(+), 11 deletions(-)

Impact

Frameworks using sample_kwargs with non-tensor values or positional args no longer crash during
warmup, capture, or replay.

shjwudp added 2 commits June 26, 2026 15:41
sample_kwargs may contain non-tensor values (None, lists, bools)
that tree_flatten passes through into the static input surface.
Guard all .requires_grad and .data_ptr() accesses with None/type
checks to avoid AttributeError crashes when frameworks pass
mixed tensor/non-tensor keyword arguments.

- _run_warmup_backward: guard .requires_grad with 'is not None'
- num_required_grad_sample_args: handle non-tensor with isinstance check
- backward capture inputs: guard .requires_grad with 'is not None'
- Graphed.forward copy loop: guard inputs[i] with 'is not None'
Some frameworks pass captured kwargs as positional args during
replay (e.g. Attention.forward hidden_states).  The previous
strict kwargs_keys validation would reject this.  Now we:

1. Remove the strict key-in-user_kwargs validation.
2. Reconstruct the capture-time arg order by checking both
   user_kwargs (by name) and user_args (by position) in
   kwargs_keys order.

Also removes the now-redundant flatten_user_args since all
args are merged into flatten_user_kwargs in the right order.
kwarg_values.append(user_pos_args.pop(0))
# else: key was a default not passed — skip (not a tensor)
flatten_user_kwargs, _ = _tree_flatten(kwarg_values)
func_args = tuple(flatten_user_kwargs) + module_params

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker: this drops normal replay positional inputs. user_args is only consumed as a fallback source for captured kwargs, and func_args is built from flatten_user_kwargs + module_params, so graphed(x) with no captured kwargs passes only module params, and graphed(x, scale=scale) drops x. This needs to preserve the flattened explicit user args and then append the captured kwarg values in the same order used during capture.

def _run_warmup_backward(func_idx, func, outputs, warmup_iter, callable_idx) -> None:
static_input_surface = per_callable_static_input_surfaces[func_idx]
inputs = tuple(i for i in static_input_surface if i.requires_grad)
inputs = tuple(i for i in static_input_surface if i is not None and i.requires_grad)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker: this still crashes for non-None, non-tensor kwargs such as the img_shapes=[[1,64,64]] example in the PR body. i is not None and i.requires_grad will call .requires_grad on a Python list/int/etc. The guard needs to be isinstance(i, torch.Tensor) and i.requires_grad here and in the corresponding backward-capture paths below.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants