From 5986565348348822e7d625dc27473bfaa3b0dbd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= Date: Mon, 16 Mar 2026 22:02:02 -0700 Subject: [PATCH 1/9] switch to aten IR for webnn export --- ADDING_OPS.md | 120 + README.md | 241 +- pyproject.toml | 3 - tests/conftest.py | 14 +- tests/models.py | 7 + tests/test_operations.py | 108 + ..._rearrange.py => test_shape_operations.py} | 2 +- tests/test_single_ops.py | 22 +- tests/test_unary_ops.py | 20 +- tests/test_upsample.py | 41 + webnn_torch_export/__init__.py | 4 - webnn_torch_export/exporter.py | 331 +- webnn_torch_export/webnn_generator.py | 2939 +++++++---------- webnn_torch_export/webnn_op_mappings.py | 244 +- 14 files changed, 1775 insertions(+), 2321 deletions(-) create mode 100644 ADDING_OPS.md rename tests/{test_rearrange.py => test_shape_operations.py} (94%) create mode 100644 tests/test_upsample.py diff --git a/ADDING_OPS.md b/ADDING_OPS.md new file mode 100644 index 0000000..da55ef9 --- /dev/null +++ b/ADDING_OPS.md @@ -0,0 +1,120 @@ +# Adding ATen Ops + +The exporter uses `torch.export.export` which produces an ATen IR graph. Every +`call_function` node has a target like `aten.relu.default`. The dispatch table in +`webnn_op_mappings.py` maps `str(node.target)` → a method name on +`WebNNGraphGenerator`. If the target isn't in the table the node is emitted as a +comment and the `.webnn` file will fail to parse. + +## Quick reference: adding an op + +**1. Add an entry to `ATEN_OP_TABLE` in `webnn_op_mappings.py`:** + +```python +"aten.some_op.overload": "_convert_some_op", +``` + +**2. Implement `_convert_some_op` in `webnn_generator.py`:** + +```python +def _convert_some_op(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] + # node.args holds positional args (first is always the input tensor node) + # node.kwargs holds keyword args + return f"[{output}] = webnnOp({x});" +``` + +The method receives: +- `output` — pre-allocated operand name for this node's result +- `inputs` — operand names for every `fx.Node` in `node.args`, in order +- `node.args` / `node.kwargs` — raw args (scalars, lists, bools) from the ATen schema + +Return `None` instead of a string to suppress the node entry entirely (use this for +ops that bake a constant into `inline_constants` and override `node_to_operand`, like +`_convert_arange`). + +**3. To find the exact target string for an unfamiliar op:** + +```python +ep = torch.export.export(model, (x,)) +for n in ep.graph.nodes: + if n.op == "call_function": + print(str(n.target), n.args, n.kwargs) +``` + +--- + +## Example patterns + +### No-op +```python +# webnn_op_mappings.py +"aten.contiguous.default": "_convert_identity", +``` + +### Direct WebNN equivalent +```python +# webnn_op_mappings.py +"aten.bmm.default": "_convert_matmul", +``` +`_convert_matmul` already exists; nothing else needed. + +### Parameterised activation (e.g. relu6) +```python +# webnn_op_mappings.py +"aten.relu6.default": "_convert_relu6", + +# webnn_generator.py +def _convert_relu6(self, node: fx.Node, output: str, inputs: List[str]) -> str: + return f"[{output}] = clamp({inputs[0]}, minValue=0.0, maxValue=6.0);" +``` + +### Scalar-first subtract (`1 - x`) +```python +# webnn_op_mappings.py +"aten.rsub.Scalar": "_convert_rsub_scalar", + +# webnn_generator.py +def _convert_rsub_scalar(self, node: fx.Node, output: str, inputs: List[str]) -> str: + # rsub(x, scalar) = scalar - x + scalar = node.args[1] + c = self._create_inline_constant(float(scalar)) + return f"[{output}] = sub({c}, {inputs[0]});" +``` + +### Op that produces a constant (no graph node needed) +```python +# webnn_op_mappings.py +"aten.full.default": "_convert_full", + +# webnn_generator.py — return None to skip nodes {} entry +def _convert_full(self, node: fx.Node, output: str, inputs: List[str]) -> Optional[str]: + size = list(node.args[0]) + fill = node.args[1] + dtype = node.kwargs.get("dtype", torch.float32) or torch.float32 + values = torch.full(size, fill, dtype=dtype) + name = f"const_full_{self.operand_counter}"; self.operand_counter += 1 + self.inline_constants[name] = values + self.operand_shapes[name] = list(values.shape) + self.node_to_operand[node.name] = name # redirect downstream refs + return None +``` + +### Multi-step decomposition +```python +def _convert_select_int(self, node: fx.Node, output: str, inputs: List[str]) -> str: + # aten.select.int(input, dim, index) → slice dim then squeeze + x = inputs[0] + in_shape = self._get_node_shape(node.args[0]) + dim = int(node.args[1]) % len(in_shape) + idx = int(node.args[2]) + starts = [0] * len(in_shape); starts[dim] = idx + sizes = list(in_shape); sizes[dim] = 1 + out_shape = in_shape[:dim] + in_shape[dim+1:] + sliced = f"operand_{self.operand_counter}"; self.operand_counter += 1 + s_str = ", ".join(map(str, starts)) + sz_str = ", ".join(map(str, sizes)) + os_str = ", ".join(map(str, out_shape)) + return (f"[{sliced}] = slice({x}, starts=[{s_str}], sizes=[{sz_str}]);\n" + f"\t[{output}] = reshape({sliced}, newShape=[{os_str}]);") +``` \ No newline at end of file diff --git a/README.md b/README.md index 311a339..6f07890 100644 --- a/README.md +++ b/README.md @@ -15,8 +15,8 @@ This is an early-stage experimental implementation for research and exploration. ### For Development ```bash -# Clone the repository -git clone https://github.com/yourusername/webnn_torch_export.git +# Clone your forked repository +git clone https://github.com//webnn_torch_export.git cd webnn_torch_export # Install in editable mode with dev dependencies @@ -88,243 +88,6 @@ new_model = nn.Sequential( load_weights_from_safetensors(new_model, "model_weights.safetensors") ``` -### Run Basic Example - -```bash -# Using the installed command -webnn-export - -# Or run directly -python -m webnn_torch_export.exporter -``` - -## Key Components - -### CustomExporter - -The `CustomExporter` class is a Dynamo backend that: -1. Receives FX graphs from PyTorch's compilation process -2. Converts them to a custom format (JSON) -3. Provides debug output to understand graph structure -4. Maintains execution compatibility - -**Key methods:** -- `export_graph()`: Main callback that receives FX graphs -- `_convert_fx_to_custom_format()`: Converts FX graph to JSON -- `save_to_file()`: Exports graphs to JSON files - -### Test Infrastructure - -**Single Operator Tests** (`tests/test_single_ops.py`): -- `test_conv2d_export()`: Tests Conv2d export -- `test_matmul_export()`: Tests matmul export -- `test_linear_export()`: Tests Linear layer export -- `test_conv_with_different_configs()`: Parametrized tests for various Conv2d configurations -- `test_exported_graph_structure()`: Validates exported graph structure - -**Integration Tests** (`tests/test_mnist_integration.py`): -- `SimplerMNISTClassifier`: Conv + ReLU + Linear -- `MNISTClassifier`: Full classifier with 2 conv blocks -- `test_simple_mnist_export()`: Exports simple model -- `test_full_mnist_export()`: Exports full model -- `test_mnist_inference_consistency()`: Tests consistency across multiple runs -- `test_mnist_batch_size_invariance()`: Tests with different batch sizes - -## How It Works - -### 1. Dynamo Backend Registration - -```python -def custom_backend(gm: torch.fx.GraphModule, example_inputs): - # Your export logic here - return gm - -compiled_model = torch.compile(model, backend=custom_backend) -``` - -### 2. FX Graph Structure - -When Dynamo compiles a model, it produces an FX graph with nodes representing: -- **placeholder**: Input tensors -- **call_function**: Function calls (e.g., `torch.relu`, `torch.matmul`) -- **call_module**: Module invocations (e.g., `conv1`, `fc1`) -- **call_method**: Tensor method calls (e.g., `x.flatten()`) -- **output**: Return values - -### 3. Export Flow - -``` -PyTorch Model → torch.compile() → Dynamo → FX Graph → Custom Backend → Export Format - ↓ - Your Export Logic -``` - -## Debug Output - -With `debug=True`, the exporter prints: -- Complete FX graph representation -- Generated Python code -- Individual node details: - - Node name and operation type - - Target function/module - - Arguments and keyword arguments - - Tensor metadata (shapes, dtypes) - -## Example Output - -``` -================================================================================ -DYNAMO EXPORT CALLBACK TRIGGERED -================================================================================ - -Graph Module: -graph(): - %x : [num_users=1] = placeholder[target=x] - %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {}) - %relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%conv1,), kwargs = {}) - return (relu,) - -Node: x - Op: placeholder - Target: x - ... -``` - -## Exported JSON Format - -```json -{ - "nodes": [ - { - "name": "x", - "op": "placeholder", - "target": "x", - "args": [], - "kwargs": {} - }, - { - "name": "conv1", - "op": "call_module", - "target": "conv1", - "module": "conv1", - "args": ["x"], - "kwargs": {} - } - ], - "graph_str": "graph(): ...", - "code": "def forward(self, x): ..." -} -``` - -## Extending the Exporter - -### Adding New Operator Support - -When you export a model with unsupported operations, you'll get a **clear error message** showing exactly what's missing: - -``` -================================================================================ -UNSUPPORTED OPERATION -================================================================================ -Operation: layer_norm -Node: layer_norm_output -Target: -Schema: aten::layer_norm(Tensor input, int[] normalized_shape, ...) -Args: ['input_tensor', '[3072]', 'weight', 'bias', '1e-5'] -Kwargs: {} -================================================================================ - -This operation is not yet supported in WebNN export. -To add support, update webnn_op_mappings.py with a mapping for this operation. -``` - -This makes it easy to **incrementally add support** for operations as you need them. - -**Quick Steps:** - -1. **Run your export** - get the error showing the unsupported operation -2. **Add mapping** in `webnn_torch_export/webnn_op_mappings.py`: - ```python - TARGET_CONTAINS_TO_CONVERTER: Dict[str, ConverterFn] = { - # ... existing mappings ... - "layer_norm": lambda gen, node, output, inputs: gen._convert_layer_norm(node, output, inputs), - } - ``` -3. **Implement converter** in `webnn_torch_export/webnn_generator.py`: - ```python - def _convert_layer_norm(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert LayerNorm to WebNN""" - input_tensor = inputs[0] if inputs else 'unknown' - # ... conversion logic ... - return f'[{output}] = layerNormalization({input_tensor}, ...);' - ``` -4. **Test** - run export again, repeat for next unsupported operation - -**For detailed guidance, see [ADDING_OPS.md](ADDING_OPS.md)** - a comprehensive guide covering: -- How to map PyTorch operations to WebNN -- Common patterns (activations, normalization, matrix ops) -- Step-by-step walkthrough with examples -- WebNN operation reference -- Debugging tips - -### Custom Export Format - -Modify `_convert_fx_to_custom_format()` to output your desired format: -```python -def _convert_fx_to_custom_format(self, gm): - # Convert to your format (protobuf, flatbuffer, etc.) - my_format = convert_to_my_format(gm.graph) - return my_format -``` - -## Development - -### Running Tests - -```bash -# Run all tests -pytest - -# Run with coverage -pytest --cov=webnn_torch_export --cov-report=html - -# Run specific markers -pytest -m "not slow" -pytest -m integration -``` - -### Building the Package - -```bash -# Build distribution -python -m build - -# Install locally -pip install -e ".[dev]" -``` - -## Requirements - -- PyTorch 2.0+ (for `torch.compile` support) -- Python 3.8+ - -## Tips for Debugging - -1. **Start with `debug=True`** to see full graph output -2. **Use single operator tests** to understand individual operations -3. **Check node metadata** for tensor shapes and types -4. **Verify correctness** by comparing original vs compiled outputs -5. **Examine exported JSON** to understand graph structure - -## Next Steps - -- Add support for more operators (pooling, normalization, etc.) -- Implement graph optimization passes -- Add serialization to binary formats (protobuf, flatbuffer) -- Handle dynamic shapes -- Support quantized models -- Add execution validation tests - ## License Apache License (2.0) (see LICENSE file) diff --git a/pyproject.toml b/pyproject.toml index 111e6f7..0dce9e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,6 @@ dev = [ "pywebnn @ git+https://github.com/gedoensmax/pywebnn.git@maximilianm/safetensor_support", ] -[project.scripts] -webnn-export = "webnn_torch_export.exporter:main" - [tool.pytest.ini_options] testpaths = ["tests"] python_files = ["test_*.py"] diff --git a/tests/conftest.py b/tests/conftest.py index ff70fc9..d5191d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,19 +46,21 @@ def assert_export_matches( with torch.no_grad(): expected = model(*example_input) if isinstance(example_input, tuple) else model(example_input) - compiled_model, exporter = export_model(model, example_input, debug=debug) + exporter, ep = export_model(model, example_input, debug=debug) + # ep.module() runs the exported ATen graph — should match the original model + exported_callable = ep.module() with torch.no_grad(): - actual = compiled_model(*example_input) if isinstance(example_input, tuple) else compiled_model(example_input) + actual = exported_callable(*example_input) if isinstance(example_input, tuple) else exported_callable(example_input) if not torch.allclose(expected, actual, rtol=rtol, atol=atol): max_diff = torch.max(torch.abs(expected - actual)).item() raise AssertionError( - f"Compiled output doesn't match PyTorch\n" + f"Exported output doesn't match PyTorch\n" f" Max diff: {max_diff:.2e} (rtol={rtol}, atol={atol})" ) - return compiled_model, exporter + return exported_callable, exporter def validate_webnn_execution( @@ -115,7 +117,7 @@ def validate_webnn_execution( expected_output = model(example_input) # Export to WebNN with executor - executor, exporter = export_model_with_weights( + result, ep_or_exporter = export_model_with_weights( model, example_input, webnn_path=webnn_path, @@ -123,6 +125,8 @@ def validate_webnn_execution( debug=debug, return_executor=True ) + executor = result + exporter = ep_or_exporter # Verify executor was created assert executor is not None, "WebNNExecutor was not created" diff --git a/tests/models.py b/tests/models.py index bb7f3d8..85a64dd 100644 --- a/tests/models.py +++ b/tests/models.py @@ -77,6 +77,13 @@ def forward(self, q, k, v): return F.scaled_dot_product_attention(q, k, v) +class SingleEinsum(nn.Module): + """Wrapper for testing torch.einsum with the '...n,d->...nd' pattern (used in Flux RoPE).""" + + def forward(self, a, b): + return torch.einsum("...n,d->...nd", a, b) + + # --------------------------------------------------------------------------- # Multiple-input models (used in test_multiple_inputs.py) # --------------------------------------------------------------------------- diff --git a/tests/test_operations.py b/tests/test_operations.py index b21e065..291aa01 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -2,6 +2,7 @@ import pytest import torch +import torch.nn as nn from .models import ( PointwiseActivationsModel, PointwiseArithmeticModel, @@ -92,3 +93,110 @@ def test_normalization_ops(normalization_model): x = torch.randn(2, 10) assert_export_matches(normalization_model, x, rtol=1e-4, atol=1e-4) validate_webnn_execution(normalization_model, x, rtol=1e-4, atol=1e-4) + + +# --------------------------------------------------------------------------- +# Chunk + getitem +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("chunks,dim,shape", [ + (2, -1, (2, 8)), + (3, 1, (1, 6, 4)), + (6, -1, (1, 12)), # Flux pattern: 6 chunks along last dim +], ids=["2chunks_last", "3chunks_dim1", "6chunks_flux"]) +def test_chunk(chunks, dim, shape): + """chunk splits a tensor and getitem indexes into the result.""" + torch._dynamo.reset() + class ChunkSum(nn.Module): + def __init__(self, n, d): + super().__init__() + self.n, self.d = n, d + def forward(self, x): + return sum(torch.chunk(x, self.n, dim=self.d)) + model = ChunkSum(chunks, dim) + x = torch.randn(*shape) + assert_export_matches(model, x, rtol=1e-4, atol=1e-4) + + +# --------------------------------------------------------------------------- +# Unbind + getitem +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dim,shape", [ + (0, (3, 4)), + (1, (2, 4, 5)), +], ids=["dim0", "dim1"]) +def test_unbind(dim, shape): + """unbind returns individual slices; getitem indexes them.""" + torch._dynamo.reset() + class UnbindSum(nn.Module): + def __init__(self, d): + super().__init__() + self.d = d + def forward(self, x): + return sum(torch.unbind(x, dim=self.d)) + model = UnbindSum(dim) + x = torch.randn(*shape) + assert_export_matches(model, x, rtol=1e-4, atol=1e-4) + + +# --------------------------------------------------------------------------- +# select.int +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dim,index,shape", [ + (0, 1, (4, 5)), + (-1, 0, (2, 3, 4)), + (2, 0, (2, 4, 6, 8)), +], ids=["dim0", "dim_neg", "higher_dim"]) +def test_select(dim, index, shape): + """select picks a single element along dim, removing that dimension.""" + torch._dynamo.reset() + class SelectModel(nn.Module): + def __init__(self, d, i): + super().__init__() + self.d, self.i = d, i + def forward(self, x): + return torch.select(x, self.d, self.i) + model = SelectModel(dim, index) + x = torch.randn(*shape) + assert_export_matches(model, x, rtol=1e-4, atol=1e-4) + + +# --------------------------------------------------------------------------- +# stack (list passed as args[0]) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("dim,shape", [ + (0, (4, 3)), + (-1, (2, 3, 3)), +], ids=["dim0", "dim_neg"]) +def test_stack(dim, shape): + """stack assembles tensors along a new dimension.""" + torch._dynamo.reset() + class StackModel(nn.Module): + def __init__(self, d): + super().__init__() + self.d = d + def forward(self, x): + a, b, c = x[..., 0], x[..., 1], x[..., 2] + return torch.stack([a, b, c], dim=self.d) + model = StackModel(dim) + x = torch.randn(*shape) + assert_export_matches(model, x, rtol=1e-4, atol=1e-4) + + +# --------------------------------------------------------------------------- +# type_as +# --------------------------------------------------------------------------- + +def test_type_as_same_dtype(): + """type_as with matching dtype becomes an identity.""" + torch._dynamo.reset() + class TypeAsModel(nn.Module): + def forward(self, x, ref): + return x.type_as(ref) + model = TypeAsModel() + x = torch.randn(3, 4) + ref = torch.randn(1) + assert_export_matches(model, (x, ref), rtol=1e-5, atol=1e-5) diff --git a/tests/test_rearrange.py b/tests/test_shape_operations.py similarity index 94% rename from tests/test_rearrange.py rename to tests/test_shape_operations.py index 7c6f00b..ac304ac 100644 --- a/tests/test_rearrange.py +++ b/tests/test_shape_operations.py @@ -1,4 +1,4 @@ -"""Test rearrange operation export""" +"""Test shape operation export""" import pytest import torch diff --git a/tests/test_single_ops.py b/tests/test_single_ops.py index 5facbed..204fa8c 100644 --- a/tests/test_single_ops.py +++ b/tests/test_single_ops.py @@ -5,7 +5,7 @@ import pytest import torch -from .models import SingleConv, SingleMatmul, SingleLinear, SingleMM, SingleAddMM, SingleScaledDotProduct +from .models import SingleConv, SingleMatmul, SingleLinear, SingleMM, SingleAddMM, SingleScaledDotProduct, SingleEinsum from .conftest import assert_export_matches, validate_webnn_execution @@ -111,3 +111,23 @@ def test_attention_op(attn_config, qkv_shape): assert_export_matches(model, (q, k, v), rtol=1e-3, atol=1e-3) validate_webnn_execution(model, (q, k, v), rtol=1e-3, atol=1e-3) + +# (id, a_shape, b_shape) — pattern '...n,d->...nd' +EINSUM_OPS = [ + ("2d_8x16", (4, 8), (16,)), + ("3d_2x4x8", (2, 4, 8), (16,)), +] + + +@pytest.mark.parametrize( + "a_shape,b_shape", + [(c[1], c[2]) for c in EINSUM_OPS], + ids=[c[0] for c in EINSUM_OPS], +) +def test_einsum_op(a_shape, b_shape): + torch._dynamo.reset() + model = SingleEinsum() + a = torch.randn(*a_shape) + b = torch.randn(*b_shape) + assert_export_matches(model, (a, b), rtol=1e-4, atol=1e-4) + diff --git a/tests/test_unary_ops.py b/tests/test_unary_ops.py index 9e10280..3dbcb79 100644 --- a/tests/test_unary_ops.py +++ b/tests/test_unary_ops.py @@ -1,6 +1,7 @@ """ Parametrized tests for unary element-wise operations. -Covers: exp, abs, sqrt, log, sigmoid, tanh, relu × multiple input shapes. +Covers: exp, abs, sqrt, log, sigmoid, tanh, relu, rsqrt, reciprocal, + pow.Scalar × multiple input shapes. """ import pytest @@ -24,13 +25,16 @@ def forward(self, x): # (id, model_class, input_factory) # input_factory takes shape tuple and returns a tensor safe for the op UNARY_OPS = [ - ("exp", torch.exp, lambda s: torch.randn(*s)), - ("abs", torch.abs, lambda s: torch.randn(*s)), - ("sqrt", torch.sqrt, lambda s: torch.randn(*s).abs() + 1e-3), - ("log", torch.log, lambda s: torch.randn(*s).abs() + 1e-3), - ("sigmoid", torch.sigmoid, lambda s: torch.randn(*s)), - ("tanh", torch.tanh, lambda s: torch.randn(*s)), - ("relu", F.relu, lambda s: torch.randn(*s)), + ("exp", torch.exp, lambda s: torch.randn(*s)), + ("abs", torch.abs, lambda s: torch.randn(*s)), + ("sqrt", torch.sqrt, lambda s: torch.randn(*s).abs() + 1e-3), + ("log", torch.log, lambda s: torch.randn(*s).abs() + 1e-3), + ("sigmoid", torch.sigmoid, lambda s: torch.randn(*s)), + ("tanh", torch.tanh, lambda s: torch.randn(*s)), + ("relu", F.relu, lambda s: torch.randn(*s)), + ("rsqrt", torch.rsqrt, lambda s: torch.rand(*s) + 0.1), + ("reciprocal", torch.reciprocal, lambda s: torch.rand(*s) + 0.1), + ("pow_scalar", lambda x: torch.pow(2.0, x), lambda s: torch.randn(*s) * 0.5), ] SHAPES = [ diff --git a/tests/test_upsample.py b/tests/test_upsample.py new file mode 100644 index 0000000..5ca782c --- /dev/null +++ b/tests/test_upsample.py @@ -0,0 +1,41 @@ +"""Tests for upsample / interpolate operations.""" + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from .conftest import assert_export_matches, validate_webnn_execution + + +# --------------------------------------------------------------------------- +# Parametrize over (id, scale_or_size kwarg, mode) +# --------------------------------------------------------------------------- + +UPSAMPLE_CASES = [ + # nearest — scale_factor + ("nearest_scale2", dict(scale_factor=2.0, mode="nearest"), (1, 4, 8, 8)), + ("nearest_scale3", dict(scale_factor=3.0, mode="nearest"), (1, 2, 4, 4)), + ("nearest_scale_xy", dict(scale_factor=(2.0, 3.0), mode="nearest"), (1, 4, 6, 4)), + # nearest — explicit output size + ("nearest_size", dict(size=(16, 16), mode="nearest"), (1, 4, 8, 8)), + # bilinear — scale_factor + ("bilinear_scale2", dict(scale_factor=2.0, mode="bilinear", align_corners=False), (1, 4, 8, 8)), + # bilinear — explicit output size + ("bilinear_size", dict(size=(20, 20), mode="bilinear", align_corners=False), (1, 4, 10, 10)), +] + + +@pytest.mark.parametrize( + "interp_kwargs,input_shape", + [(c[1], c[2]) for c in UPSAMPLE_CASES], + ids=[c[0] for c in UPSAMPLE_CASES], +) +def test_upsample(interp_kwargs, input_shape): + class UpsampleModel(nn.Module): + def forward(self, x): + return F.interpolate(x, **interp_kwargs) + + model = UpsampleModel() + x = torch.randn(*input_shape) + assert_export_matches(model, x, rtol=1e-5, atol=1e-5) + validate_webnn_execution(model, x, rtol=1e-4, atol=1e-4) diff --git a/webnn_torch_export/__init__.py b/webnn_torch_export/__init__.py index 5fbdf71..602c4c9 100644 --- a/webnn_torch_export/__init__.py +++ b/webnn_torch_export/__init__.py @@ -7,8 +7,6 @@ export_model, export_model_with_weights, load_weights_from_safetensors, - get_custom_backend, - get_exporter, ) # Optional import - WebNNExecutor requires webnn runtime @@ -26,7 +24,5 @@ "export_model", "export_model_with_weights", "load_weights_from_safetensors", - "get_custom_backend", - "get_exporter", "WebNNExecutor", # May be None if webnn not available ] diff --git a/webnn_torch_export/exporter.py b/webnn_torch_export/exporter.py index a00365e..daed1ef 100644 --- a/webnn_torch_export/exporter.py +++ b/webnn_torch_export/exporter.py @@ -1,222 +1,107 @@ """ -Custom PyTorch Exporter using torch.compile and Dynamo -Demonstrates how to build a custom export backend for PyTorch models +PyTorch → WebNN exporter using torch.export (ATen IR). + +Public API: + export_model(model, example_input) → (compiled_model, exporter) + export_model_with_weights(model, example_input, webnn_path, weights_path) """ import torch -from torch._dynamo.backends.common import aot_autograd -from torch.fx.passes.shape_prop import ShapeProp -from typing import Callable, List, Optional, Union, Tuple -import json -import struct +import torch.export +from typing import Union, Tuple from safetensors.torch import save_file, load_file + from .webnn_generator import WebNNGraphGenerator class CustomExporter: """ - Custom exporter that captures the FX graph from Dynamo and converts it - to a custom format. This is a minimal example to understand the flow. + Wraps a torch.export.ExportedProgram and converts it to the WebNN graph format. """ - def __init__(self, debug=True): + def __init__(self, ep: torch.export.ExportedProgram, debug: bool = False): + self.ep = ep self.debug = debug - self.exported_graphs = [] - self.model = None # Store reference to the original model - self.fx_graph_module = None # Store FX GraphModule for WebNN export self.webnn_generator = WebNNGraphGenerator() - - def export_graph(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - """ - This function receives the FX graph from Dynamo and can process it. - - Args: - gm: The torch.fx.GraphModule representing the traced computation - example_inputs: Example inputs used for tracing - """ - if self.debug: - print("\n" + "="*80) - print("DYNAMO EXPORT CALLBACK TRIGGERED") - print("="*80) - print("\nGraph Module:") - print(gm.graph) - print("\nCode:") - print(gm.code) - print("\n" + "="*80) - - # Propagate shapes through the graph to add metadata - # This ensures that all nodes have shape information available - try: - ShapeProp(gm).propagate(*example_inputs) - if self.debug: - print("\n" + "="*80) - print("Shape propagation successful") - print("="*80) - except Exception as e: - if self.debug: - print(f"\nWarning: Shape propagation failed: {e}") - print("Some operations may not have complete shape information") - - # Convert FX graph to custom format - graph_repr = self._convert_fx_to_custom_format(gm) - self.exported_graphs.append(graph_repr) - - # Store FX graph module for WebNN export - self.fx_graph_module = gm - - # Return the original graph module so it can still be executed - # Note: This needs to return the FX graph for Dynamo to work correctly. - # To get a WebNN executor instead, use export_model_with_weights(..., return_executor=True) - return gm - - def _convert_fx_to_custom_format(self, gm: torch.fx.GraphModule) -> dict: - """ - Convert FX graph to a custom format. - This is where you'd implement your actual export logic. - """ + # Compatibility: expose graph nodes in the same dict format as the old Dynamo backend + self.exported_graphs = [self._graph_to_dict(ep.graph_module)] + + if debug: + print("\n" + "=" * 80) + print("EXPORTED PROGRAM") + print("=" * 80) + ep.graph.print_tabular() + + # ------------------------------------------------------------------ + # Compatibility helpers + # ------------------------------------------------------------------ + + @staticmethod + def _graph_to_dict(gm: torch.fx.GraphModule) -> dict: + """Return graph nodes in the same dict format as the old Dynamo-based exporter.""" nodes = [] - for node in gm.graph.nodes: - node_info = { - 'name': node.name, - 'op': node.op, - 'target': str(node.target), - 'args': [str(arg) for arg in node.args], - 'kwargs': {k: str(v) for k, v in node.kwargs.items()}, - } - - # Add type-specific information - if node.op == 'call_function': - node_info['function'] = node.target.__name__ if hasattr(node.target, '__name__') else str(node.target) - elif node.op == 'call_method': - node_info['method'] = node.target - elif node.op == 'call_module': - node_info['module'] = node.target - - nodes.append(node_info) - - if self.debug: - print(f"\nNode: {node.name}") - print(f" Op: {node.op}") - print(f" Target: {node.target}") - print(f" Args: {node.args}") - print(f" Kwargs: {node.kwargs}") - if hasattr(node, 'meta') and 'tensor_meta' in node.meta: - print(f" Tensor Meta: {node.meta['tensor_meta']}") - - return { - 'nodes': nodes, - 'graph_str': str(gm.graph), - 'code': gm.code - } - - def save_to_file(self, filepath: str): - """Save exported graphs to a JSON file""" - with open(filepath, 'w') as f: - json.dump(self.exported_graphs, f, indent=2) - print(f"\nExported graphs saved to {filepath}") + nodes.append({ + "name": node.name, + "op": node.op, + "target": str(node.target), + "args": [str(a) for a in node.args], + "kwargs": {k: str(v) for k, v in node.kwargs.items()}, + }) + return {"nodes": nodes, "graph_str": str(gm.graph)} + + # ------------------------------------------------------------------ + # WebNN / weights serialisation + # ------------------------------------------------------------------ + + def save_to_webnn(self, filepath: str, graph_name: str = "model") -> None: + """Write the WebNN graph text file.""" + webnn_graph = self.webnn_generator.generate(self.ep, graph_name=graph_name) + with open(filepath, "w") as f: + f.write(webnn_graph) - def save_weights(self, model: torch.nn.Module, filepath: str): + def save_weights(self, filepath: str) -> None: """ - Save model weights to a safetensors file + Save model parameters and buffers to a safetensors file. - Args: - model: The PyTorch model whose weights to save - filepath: Path to save the safetensors file + Keys match the state_dict paths referenced in @weights(...) inside the + .webnn file (e.g. "conv1.weight", "bn.running_mean"). """ - state_dict = model.state_dict() - - # Add generated constants (like arange) to the state dict - if hasattr(self, 'webnn_generator') and hasattr(self.webnn_generator, 'generated_constants'): - for name, tensor in self.webnn_generator.generated_constants.items(): - # Use special prefix to distinguish generated constants - key = f'_generated.{name}' - state_dict[key] = tensor - + state_dict = { + **dict(self.ep.named_parameters()), + **dict(self.ep.named_buffers()), + } save_file(state_dict, filepath) - def set_model(self, model: torch.nn.Module): - """Store reference to the original model""" - self.model = model - - def save_to_webnn(self, filepath: str, graph_name: str = "model"): - """ - Save model as WebNN graph format - - Args: - filepath: Path to save the .webnn file - graph_name: Name for the WebNN graph - """ - if self.fx_graph_module is None: - raise ValueError("No FX graph available. Run export_model first.") - if self.model is None: - raise ValueError("No model reference. Run export_model first.") - - webnn_graph = self.webnn_generator.generate( - self.fx_graph_module, - self.model, - graph_name=graph_name - ) - - with open(filepath, 'w') as f: - f.write(webnn_graph) - -# Global exporter instance -_exporter = None - - -def get_custom_backend(debug=True): - """ - Factory function that returns a Dynamo backend using our custom exporter. - This is what you pass to torch.compile(backend=...) - """ - global _exporter - _exporter = CustomExporter(debug=debug) - - def custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - return _exporter.export_graph(gm, example_inputs) - - return custom_backend - - -def get_exporter(): - """Get the global exporter instance to access exported graphs""" - return _exporter +# --------------------------------------------------------------------------- +# High-level helpers +# --------------------------------------------------------------------------- def export_model( model: torch.nn.Module, example_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], - debug=True -): + debug: bool = False, +) -> Tuple["CustomExporter", torch.export.ExportedProgram]: """ - High-level API to export a model using our custom backend. + Export a model to ATen IR using torch.export. Args: - model: The PyTorch model to export - example_input: Example input tensor(s) for tracing. Can be: - - A single torch.Tensor - - A tuple of torch.Tensors for models with multiple inputs - debug: Whether to print debug information + model: The PyTorch model. + example_input: A single tensor or tuple of tensors. + debug: Print the exported graph. Returns: - Compiled model and exporter instance + (exporter, exported_program) """ - backend = get_custom_backend(debug=debug) - compiled_model = torch.compile(model, backend=backend) - - # Store model reference in exporter - if _exporter: - _exporter.set_model(model) + if not isinstance(example_input, tuple): + example_input = (example_input,) - # Run once to trigger export with torch.no_grad(): - if isinstance(example_input, tuple): - compiled_model(*example_input) - else: - compiled_model(example_input) + ep = torch.export.export(model, example_input) - return compiled_model, get_exporter() + exporter = CustomExporter(ep, debug=debug) + return exporter, ep def export_model_with_weights( @@ -224,78 +109,48 @@ def export_model_with_weights( example_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], webnn_path: str, weights_path: str, - debug=True, + debug: bool = False, graph_name: str = "model", - return_executor: bool = False + return_executor: bool = False, ): """ - Export model to WebNN format with weights. + Export model to WebNN format and save weights. Args: - model: The PyTorch model to export - example_input: Example input tensor(s) for tracing. Can be: - - A single torch.Tensor - - A tuple of torch.Tensors for models with multiple inputs - webnn_path: Path to save the WebNN graph file - weights_path: Path to save the weights safetensors file - debug: Whether to print debug information - graph_name: Name for the WebNN graph (default: "model") - return_executor: If True, returns a WebNNExecutor instead of compiled_model. - The executor wraps the WebNN graph and provides a PyTorch-like - interface with automatic tensor conversion. (default: False) + model: The PyTorch model. + example_input: Example input tensor(s). + webnn_path: Path to write the .webnn graph file. + weights_path: Path to write the .safetensors weights file. + debug: Print the exported graph. + graph_name: Name embedded in the .webnn file header. + return_executor: If True, returns a WebNNExecutor. Returns: - If return_executor=False: Tuple of (compiled_model, exporter) - If return_executor=True: Tuple of (webnn_executor, exporter) + If return_executor=False: (exporter, exported_program) + If return_executor=True: (webnn_executor, exported_program) """ - # Export the model graph - compiled_model, exporter = export_model(model, example_input, debug=debug) - - # Save WebNN format + exporter, ep = export_model(model, example_input, debug=debug) exporter.save_to_webnn(webnn_path, graph_name=graph_name) - - # Save weights to safetensors - exporter.save_weights(model, weights_path) + exporter.save_weights(weights_path) if return_executor: - # Import executor here to avoid import errors if webnn not available from .executor import WebNNExecutor - try: - # Create and return WebNN executor executor = WebNNExecutor(webnn_path, weights_path, example_input) - return executor, exporter + return executor, ep except ImportError as e: - print(f"Warning: Could not create WebNNExecutor: {e}") - print("Returning compiled PyTorch model instead.") - return compiled_model, exporter - - return compiled_model, exporter + print(f"Warning: WebNN runtime not available: {e}") + return exporter, ep + return exporter, ep -def load_weights_from_safetensors(model: torch.nn.Module, filepath: str, strict=True): - """ - Load weights from a safetensors file into a model. - - Args: - model: The PyTorch model to load weights into - filepath: Path to the safetensors file - strict: Whether to strictly enforce that the keys in state_dict match - Returns: - The model with loaded weights - """ +def load_weights_from_safetensors( + model: torch.nn.Module, filepath: str, strict: bool = True +) -> torch.nn.Module: + """Load weights from a safetensors file into a model.""" state_dict = load_file(filepath) model.load_state_dict(state_dict, strict=strict) - print(f"\nWeights loaded from {filepath}") - - # Print loading statistics - total_params = sum(p.numel() for p in state_dict.values()) - print(f"Total parameters loaded: {total_params:,}") - - return model - - - -if __name__ == '__main__': - main() + total = sum(p.numel() for p in state_dict.values()) + print(f"Loaded {total:,} parameters from {filepath}") + return model \ No newline at end of file diff --git a/webnn_torch_export/webnn_generator.py b/webnn_torch_export/webnn_generator.py index a8d1fee..32a4cd6 100644 --- a/webnn_torch_export/webnn_generator.py +++ b/webnn_torch_export/webnn_generator.py @@ -1,18 +1,18 @@ """ -WebNN Graph Generator - Converts PyTorch FX graphs to WebNN format +WebNN Graph Generator - Converts PyTorch ExportedProgram (ATen IR) to WebNN format. + +Entry point: WebNNGraphGenerator().generate(ep, graph_name) +where ep is a torch.export.ExportedProgram. """ -import inspect +import math +import sys import torch import torch.fx as fx -from typing import Dict, List, Tuple, Any -import math -import numpy as np - -from jinja2.lexer import ignored_tokens -from sympy import true +import torch.export +from typing import Dict, List, Optional, Tuple -from .webnn_op_mappings import resolve_pytorch_converter +from .webnn_op_mappings import resolve_aten_converter def isnumeric(obj): @@ -23,226 +23,211 @@ def isnumeric(obj): return False -def throw_unsupported(type, node, module_type="", module_class=""): - error_msg = ( +def throw_unsupported(kind, node): + msg = ( f"\n{'=' * 80}\n" - f"UNSUPPORTED {type}\n" + f"UNSUPPORTED {kind}\n" f"{'=' * 80}\n" - ) - if module_type: - error_msg += ( - f"{type} Type: {module_type}\n" - ) - if module_class: - error_msg += ( - f"{type} Class: {module_class}\n" - ) - error_msg += ( - f"Node: {node.name}\n" + f"Node : {node.name}\n" f"Target: {node.target}\n" - f"Args: {[str(arg) for arg in node.args]}\n" + f"Args : {[str(a) for a in node.args]}\n" f"Kwargs: {node.kwargs}\n" f"{'=' * 80}\n" ) - raise NotImplementedError(error_msg) - - -IGNORED_PLACEHOLDER_TOKENS = {'modules', 'buffers', 'parameters', 'self'} + raise NotImplementedError(msg) class WebNNGraphGenerator: - """Generates WebNN graph format from PyTorch FX GraphModule""" + """Generates WebNN graph format from a torch.export.ExportedProgram.""" def __init__(self): self.operand_counter = 1 - self.node_to_operand = {} - self.weight_operands = {} - self.operand_shapes = {} - self.inline_constants = {} # Store inline constants like scalars (embedded in .webnn file) - - def generate(self, gm: fx.GraphModule, model: torch.nn.Module, graph_name: str = "model") -> str: - """ - Generate WebNN graph format from FX graph + self.node_to_operand: Dict[str, str] = {} + # Maps FX placeholder node name → operand name (for parameters/buffers) + self.weight_operands: Dict[str, str] = {} + self.operand_shapes: Dict[str, List[int]] = {} + self.inline_constants: Dict[str, object] = {} + # Maps multi-output node name → list of per-output operand names (chunk/unbind/split) + self.multi_output_operands: Dict[str, List[str]] = {} + + # ------------------------------------------------------------------ + # Public entry point + # ------------------------------------------------------------------ + + def generate( + self, + ep: torch.export.ExportedProgram, + graph_name: str = "model", + ) -> str: + """Generate WebNN graph format from an ExportedProgram. Args: - gm: FX GraphModule from Dynamo - model: Original PyTorch model for weight extraction - graph_name: Name for the WebNN graph + ep: Result of torch.export.export(model, example_inputs). + graph_name: Name embedded in the .webnn file header. Returns: - WebNN graph as string + WebNN graph as a string. """ + # Reset state self.operand_counter = 1 self.node_to_operand = {} self.weight_operands = {} self.operand_shapes = {} - self.inline_constants = {} # Reset inline constants for each graph - - # Extract forward() parameter order so inputs are emitted in the right order - try: - sig = inspect.signature(model.forward) - param_names = [p for p in sig.parameters if p != 'self'] - except (ValueError, TypeError): - param_names = [] - - # Extract sections - inputs_section = self._extract_inputs(gm, param_names) - consts_section, weight_map = self._extract_weights(model) + self.inline_constants = {} + self.multi_output_operands = {} + + gm = ep.graph_module + sig = ep.graph_signature + + # Set of placeholder names that are actual model inputs (not weights) + user_inputs: set = set(sig.user_inputs) + + # Mapping: placeholder_node_name → state_dict key + param_map: Dict[str, str] = { + **sig.inputs_to_parameters, + **sig.inputs_to_buffers, + } + + # Named tensors for weight shape/dtype lookup + named_params: Dict[str, torch.Tensor] = { + **dict(ep.named_parameters()), + **dict(ep.named_buffers()), + } + + # Build sections + inputs_section = self._extract_inputs(gm, user_inputs) + consts_section = self._extract_weights(gm, param_map, named_params) nodes_section = self._convert_nodes(gm) inline_consts_section = self._extract_inline_constants() outputs_section = self._extract_outputs(gm) - # Combine all constants: inline (scalars), generated (arange), and weights - all_consts = '' + all_consts = "" if inline_consts_section: all_consts += inline_consts_section if consts_section: all_consts += consts_section - # Build WebNN graph graph = f'webnn_graph "{graph_name}" v1 {{\n' - graph += f' inputs {{ {inputs_section} }}\n' + graph += f" inputs {{ {inputs_section} }}\n" if all_consts: - graph += f' consts {{\n{all_consts} }}\n' - graph += f' nodes {{\n{nodes_section} }}\n' - graph += f' outputs {{ {outputs_section} }}\n' - graph += '}\n' + graph += f" consts {{\n{all_consts} }}\n" + graph += f" nodes {{\n{nodes_section} }}\n" + graph += f" outputs {{ {outputs_section} }}\n" + graph += "}\n" return graph - def _extract_inputs(self, gm: fx.GraphModule, param_names: List[str] = None) -> str: - """Extract input tensor declarations (only actual model inputs, not weights). + # ------------------------------------------------------------------ + # Input / weight extraction + # ------------------------------------------------------------------ - param_names: ordered list of parameter names from model.forward() signature. - When provided, the inputs section is emitted in that order so the executor - maps positional arguments correctly. - """ - inputs = [] # list of (sort_key, declaration) + def _extract_inputs( + self, gm: fx.GraphModule, user_inputs: set + ) -> str: + """Emit `inputs {}` section — only real model inputs, not parameters.""" + decls = [] for node in gm.graph.nodes: - if node.op == 'placeholder': - # Skip weight/parameter placeholders - only include actual inputs - node_name = str(node.name) - if set(node_name.split("_")).intersection(IGNORED_PLACEHOLDER_TOKENS): - continue - - if hasattr(node, 'meta') and 'tensor_meta' in node.meta: - tensor = node.meta['tensor_meta'] - shape = list(tensor.shape) - dtype = self._get_webnn_dtype(tensor.dtype) - shape_str = ', '.join(map(str, shape)) - name = node.name - if name.startswith('l_'): - name = name[len('l_'):] - name = name.rstrip("_") - self.node_to_operand[node.name] = name - sort_key = param_names.index(name) if (param_names and name in param_names) else len(inputs) - inputs.append((sort_key, f'{name}: {dtype}[{shape_str}]')) - else: - raise NotImplementedError(f"Dynamic inputs are not supported: {node.name} does not have tensor_meta") - - inputs.sort(key=lambda x: x[0]) - decls = [decl for _, decl in inputs] - return '; '.join(decls) + ';' if decls else '' - - def _extract_weights(self, model: torch.nn.Module) -> Tuple[str, Dict[str, str]]: - """Extract weight constants and create mapping""" + if node.op != "placeholder": + continue + if node.name not in user_inputs: + continue + + shape = self._get_node_shape(node) + dtype = self._get_webnn_dtype(self._get_node_dtype(node)) + shape_str = ", ".join(map(str, shape)) + # Use the parameter name from the graph signature (strip trailing '_') + name = node.name.rstrip("_") + self.node_to_operand[node.name] = name + decls.append(f"{name}: {dtype}[{shape_str}]") + + return "; ".join(decls) + ";" if decls else "" + + def _extract_weights( + self, + gm: fx.GraphModule, + param_map: Dict[str, str], + named_params: Dict[str, torch.Tensor], + ) -> str: + """Emit `consts {}` section for model parameters and buffers.""" consts = [] - weight_map = {} + for node in gm.graph.nodes: + if node.op != "placeholder": + continue + if node.name not in param_map: + continue + + state_key = param_map[node.name] + tensor = named_params.get(state_key) + if tensor is None: + continue - state_dict = model.state_dict() - for name, tensor in state_dict.items(): - operand_name = f'operand_{self.operand_counter}' + operand_name = f"weight_{self.operand_counter}" self.operand_counter += 1 shape = list(tensor.shape) dtype = self._get_webnn_dtype(tensor.dtype) - shape_str = ', '.join(map(str, shape)) + shape_str = ", ".join(map(str, shape)) - consts.append(f'\t{operand_name}: {dtype}[{shape_str}] @weights("{name}");') - weight_map[name] = operand_name - self.weight_operands[name] = operand_name + self.weight_operands[node.name] = operand_name self.operand_shapes[operand_name] = shape - return '\n'.join(consts) + '\n' if consts else '', weight_map + consts.append( + f'\t{operand_name}: {dtype}[{shape_str}] @weights("{state_key}");' + ) + + return "\n".join(consts) + "\n" if consts else "" + + # ------------------------------------------------------------------ + # Node conversion + # ------------------------------------------------------------------ def _convert_nodes(self, gm: fx.GraphModule) -> str: - """Convert FX nodes to WebNN operations""" + """Convert all call_function nodes to WebNN operations.""" operations = [] - - for i, node in enumerate(gm.graph.nodes): + for node in gm.graph.nodes: + if node.op != "call_function": + continue try: - if node.op == 'call_function': - op_str = self._map_pytorch_to_webnn_op(node) - if op_str: - operations.append(f'\t{op_str}') - elif node.op == 'call_module': - op_str = self._map_module_to_webnn_op(node, gm) - if op_str: - operations.append(f'\t{op_str}') - elif node.op == 'call_method': - op_str = self._map_method_to_webnn_op(node) - if op_str: - operations.append(f'\t{op_str}') + op_str = self._map_aten_to_webnn_op(node) + if op_str: + operations.append(f"\t{op_str}") except NotImplementedError as e: - input_operands = [self._get_input_operand(arg) for arg in node.args if isinstance(arg, fx.Node)] - inputs_str = ', '.join(input_operands) - operations.append(f'\t// invalid: {node.op} {node.target} inputs=[{inputs_str}] args={list(node.args)}') - - return '\n'.join(operations) + '\n' if operations else '' - - def _map_pytorch_to_webnn_op(self, node: fx.Node) -> str: - """Map PyTorch function to WebNN operation""" - # Get output operand - output_operand = self._get_operand_name(node) - - # Get input operands - input_operands = [self._get_input_operand(arg) for arg in node.args if isinstance(arg, fx.Node)] - - converter = resolve_pytorch_converter(node.target) - if converter: - return converter(self, node, output_operand, input_operands) - - throw_unsupported("Operation", node) - - def _map_method_to_webnn_op(self, node: fx.Node) -> str: - """Map PyTorch function to WebNN operation""" - # Get output operand - output_operand = self._get_operand_name(node) - - # Get input operands - input_operands = [self._get_input_operand(arg) for arg in node.args if isinstance(arg, fx.Node)] - - converter = resolve_pytorch_converter(node.target) - if converter: - return converter(self, node, output_operand, input_operands) - - throw_unsupported("Method", node) + input_operands = [ + self._get_input_operand(a) + for a in node.args + if isinstance(a, fx.Node) + ] + operations.append( + f"\t// unsupported: {node.target} " + f"inputs=[{', '.join(input_operands)}] args={list(node.args)}" + ) + + return "\n".join(operations) + "\n" if operations else "" + + def _map_aten_to_webnn_op(self, node: fx.Node) -> str: + output = self._get_operand_name(node) + inputs = [ + self._get_input_operand(a) + for a in node.args + if isinstance(a, fx.Node) + ] - def _map_module_to_webnn_op(self, node: fx.Node, gm: fx.GraphModule) -> str: - """Map PyTorch module call to WebNN operation""" - module = self._get_module(gm, node.target) - output_operand = self._get_operand_name(node) - input_operands = [self._get_input_operand(arg) for arg in node.args if isinstance(arg, fx.Node)] + method_name = resolve_aten_converter(node.target) + if method_name is None: + throw_unsupported("ATen op", node) - if isinstance(module, torch.nn.Conv2d): - return self._convert_conv2d_module(node, module, output_operand, input_operands) - elif isinstance(module, torch.nn.ReLU): - return f'[{output_operand}] = clamp({input_operands[0]}, minValue=0.0);' - elif isinstance(module, torch.nn.Linear): - return self._convert_linear_module(node, module, output_operand, input_operands) - else: - # Raise error for unsupported modules - module_type = type(module).__name__ - module_class = f"{type(module).__module__}.{type(module).__name__}" - throw_unsupported("Module", module_type, module_class, node) + method = getattr(self, method_name) + return method(node, output, inputs) - def _emit_conv2d(self, input_tensor: str, weight: str, bias_info, stride, padding, dilation, groups, output: str) -> str: - """Emit WebNN conv2d nodes, including bias reshape+add when bias is present. + # ------------------------------------------------------------------ + # Converter methods — each takes (node, output_operand, input_operands) + # ------------------------------------------------------------------ - bias_info: (bias_operand, num_channels) when a bias exists, else None. - stride/padding/dilation accept both lists and tuples. - """ + # --- Convolution --- + def _emit_conv2d( + self, input_tensor, weight, bias_info, stride, padding, dilation, groups, output + ) -> str: def as_pair(v): return list(v) if isinstance(v, (list, tuple)) else [v, v] @@ -252,37 +237,35 @@ def as_pair(v): params = [] if dilation != [1, 1]: - params.append(f'dilations=[{dilation[0]}, {dilation[1]}]') + params.append(f"dilations=[{dilation[0]}, {dilation[1]}]") params.append('filterLayout="oihw"') - params.append(f'groups={groups}') + params.append(f"groups={groups}") params.append('inputLayout="nchw"') if padding != [0, 0]: - params.append(f'pads=[{padding[0]}, {padding[0]}, {padding[1]}, {padding[1]}]') + params.append(f"pads=[{padding[0]}, {padding[0]}, {padding[1]}, {padding[1]}]") if stride != [1, 1]: - params.append(f'strides=[{stride[0]}, {stride[1]}]') + params.append(f"strides=[{stride[0]}, {stride[1]}]") - params_str = ', '.join(params) + params_str = ", ".join(params) if bias_info is not None: bias_operand, c = bias_info - # Reshape bias [C] → [1, C, 1, 1] for NCHW broadcast - reshaped_bias = f'operand_{self.operand_counter}' + reshaped_bias = f"operand_{self.operand_counter}" self.operand_counter += 1 - conv_out = f'operand_{self.operand_counter}' + conv_out = f"operand_{self.operand_counter}" self.operand_counter += 1 return ( - f'[{reshaped_bias}] = reshape({bias_operand}, newShape=[1, {c}, 1, 1]);\n' - f'\t[{conv_out}] = conv2d({input_tensor}, {weight}, {params_str});\n' - f'\t[{output}] = add({conv_out}, {reshaped_bias});' + f"[{reshaped_bias}] = reshape({bias_operand}, newShape=[1, {c}, 1, 1]);\n" + f"\t[{conv_out}] = conv2d({input_tensor}, {weight}, {params_str});\n" + f"\t[{output}] = add({conv_out}, {reshaped_bias});" ) - return f'[{output}] = conv2d({input_tensor}, {weight}, {params_str});' + return f"[{output}] = conv2d({input_tensor}, {weight}, {params_str});" def _convert_conv2d(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert torch.conv2d (call_function) to WebNN.""" - # torch.conv2d(input, weight, bias, stride, padding, dilation, groups) + """aten.conv2d.default(input, weight, bias, stride, padding, dilation, groups)""" args = node.args - input_tensor = inputs[0] if inputs else 'unknown' - weight = self._get_input_operand(args[1]) if len(args) > 1 else 'unknown' + input_tensor = inputs[0] if inputs else "unknown" + weight = self._get_input_operand(args[1]) if len(args) > 1 else "unknown" stride = args[3] if len(args) > 3 else [1, 1] padding = args[4] if len(args) > 4 else [0, 0] @@ -298,1702 +281,1228 @@ def _convert_conv2d(self, node: fx.Node, output: str, inputs: List[str]) -> str: return self._emit_conv2d(input_tensor, weight, bias_info, stride, padding, dilation, groups, output) - def _convert_conv2d_module(self, node: fx.Node, module: torch.nn.Conv2d, output: str, inputs: List[str]) -> str: - """Convert Conv2d (call_module) to WebNN — delegates to _emit_conv2d.""" - input_tensor = inputs[0] if inputs else 'unknown' - weight = self.weight_operands.get(f'{node.target}.weight', 'unknown') + def _convert_convolution(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.convolution.default(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)""" + args = node.args + input_tensor = inputs[0] if inputs else "unknown" + weight = self._get_input_operand(args[1]) if len(args) > 1 else "unknown" + + stride = args[3] if len(args) > 3 else [1, 1] + padding = args[4] if len(args) > 4 else [0, 0] + dilation = args[5] if len(args) > 5 else [1, 1] + groups = args[8] if len(args) > 8 else 1 bias_info = None - if module.bias is not None: - bias_operand = self.weight_operands.get(f'{node.target}.bias') - if bias_operand: - bias_shape = self.operand_shapes.get(bias_operand, []) - bias_info = (bias_operand, bias_shape[0] if bias_shape else module.out_channels) + bias_node = args[2] if len(args) > 2 else None + if isinstance(bias_node, fx.Node): + bias_operand = self._get_input_operand(bias_node) + bias_shape = self.operand_shapes.get(bias_operand, []) + bias_info = (bias_operand, bias_shape[0] if bias_shape else 0) - return self._emit_conv2d(input_tensor, weight, bias_info, module.stride, module.padding, module.dilation, module.groups, output) + return self._emit_conv2d(input_tensor, weight, bias_info, stride, padding, dilation, groups, output) - def _convert_arithmetric(self, node: fx.Node, output: str, inputs: List[str], op: str) -> str: - """Convert arithmetic operations (add, sub, mul, div) to WebNN""" - if len(inputs) == 2: - return f'[{output}] = {op}({inputs[0]}, {inputs[1]});' - if len(node.args) == 2: - if len(inputs) == 1 and isnumeric(node.args[1]): - # Create an inline constant for the numeric value - const_operand = self._create_inline_constant(node.args[1]) - return f'[{output}] = {op}({inputs[0]}, {const_operand});' - elif len(inputs) == 1 and isnumeric(node.args[0]): - # Create an inline constant for the numeric value - const_operand = self._create_inline_constant(node.args[0]) - return f'[{output}] = {op}({inputs[0]}, {const_operand});' - raise NotImplementedError(f'Invalid {op} operation') - - def _convert_math(self, node: fx.Node, output: str, inputs: List[str], op: str) -> str: - """Convert math functions (sqrt, exp, log, cos, sin) to WebNN""" - input_tensor = inputs[0] if inputs else 'unknown' - return f'[{output}] = {op}({input_tensor});' + # --- Linear --- - def _convert_pow(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert power to WebNN pow""" - if len(inputs) >= 2: - return f'[{output}] = pow({inputs[0]}, {inputs[1]});' - raise NotImplementedError('Invalid pow operation') + def _convert_linear(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.linear.default(input, weight, bias)""" + if len(inputs) < 2: + raise NotImplementedError("linear requires at least 2 inputs") + input_tensor, weight = inputs[0], inputs[1] - def _convert_neg(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert negation to WebNN neg""" - input_tensor = inputs[0] if inputs else 'unknown' - return f'[{output}] = neg({input_tensor});' + if len(inputs) >= 3: + bias = inputs[2] + mm_out = f"operand_{self.operand_counter}" + self.operand_counter += 1 + return ( + f"[{mm_out}] = gemm({input_tensor}, {weight}, bTranspose=true);\n" + f"\t[{output}] = add({mm_out}, {bias});" + ) + return f"[{output}] = gemm({input_tensor}, {weight}, bTranspose=true);" + + def _convert_addmm(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.addmm.default(bias, mat1, mat2) = bias + mat1 @ mat2""" + if len(inputs) < 3: + return self._convert_identity(node, output, inputs) + bias, mat1, mat2 = inputs[0], inputs[1], inputs[2] + stmts = [] + + input_node = node.args[1] if len(node.args) > 1 and isinstance(node.args[1], fx.Node) else None + input_shape = self._get_node_shape(input_node) if input_node else [] + if input_shape and len(input_shape) != 2: + batch = int(input_shape[0]) + features = int(math.prod(input_shape[1:])) + tmp = f"operand_{self.operand_counter}" + self.operand_counter += 1 + stmts.append(f"[{tmp}] = reshape({mat1}, newShape=[{batch}, {features}]);") + mat1 = tmp - def _convert_cast(self, node: fx.Node, output: str, inputs: List[str], dtype=None) -> str: - """Convert type casting (tensor.to) to WebNN cast""" - input_tensor = inputs[0] if inputs else 'unknown' + mm_out = f"operand_{self.operand_counter}" + self.operand_counter += 1 + stmts.append(f"[{mm_out}] = gemm({mat1}, {mat2});") + stmts.append(f"[{output}] = add({mm_out}, {bias});") + return "\n\t".join(stmts) - # Get target dtype from args - # to(dtype) or to(device, dtype) or to(tensor, dtype, ...) - target_dtype = None - if dtype is not None: - # this path represents tensor.float() or similar - target_dtype = dtype - else: - for arg in node.args[1:]: # Skip first arg (input tensor) - if isinstance(arg, torch.dtype): - target_dtype = arg - break - # Also check kwargs - if target_dtype is None and 'dtype' in node.kwargs: - target_dtype = node.kwargs['dtype'] - - # If no dtype found, just return identity (might be device-only cast) - # TODO ignore casts for now since there are some issues in ORT execution: - # RuntimeError: ONNX execution failed: onnx runtime failed: load model failed: This is an invalid model. In Node, ("cast_131", Cast, "", -1) : ("operand_389": tensor(float),) -> ("operand_390",) , Error Required attribute 'to' is missing. - return f'[{output}] = identity({input_tensor});' - if target_dtype is None: - return f'[{output}] = identity({input_tensor});' + def _convert_matmul(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.mm / aten.matmul""" + if len(inputs) < 2: + raise NotImplementedError("matmul requires 2 inputs") + a, b = inputs[0], inputs[1] + stmts = [] + + a_node = node.args[0] if isinstance(node.args[0], fx.Node) else None + b_node = node.args[1] if len(node.args) > 1 and isinstance(node.args[1], fx.Node) else None + a_shape = self._get_node_shape(a_node) if a_node else self.operand_shapes.get(a, []) + b_shape = self._get_node_shape(b_node) if b_node else self.operand_shapes.get(b, []) + out_shape = self._get_node_shape(node) + + if a_shape and len(a_shape) == 1: + tmp = f"operand_{self.operand_counter}" + self.operand_counter += 1 + stmts.append(f"[{tmp}] = reshape({a}, newShape=[1, {a_shape[0]}]);") + a = tmp + if b_shape and len(b_shape) == 1: + tmp = f"operand_{self.operand_counter}" + self.operand_counter += 1 + stmts.append(f"[{tmp}] = reshape({b}, newShape=[{b_shape[0]}, 1]);") + b = tmp + + needs_reshape = bool(out_shape) and len(out_shape) != 2 + gemm_out = output if not needs_reshape else f"operand_{self.operand_counter}" + if needs_reshape: + self.operand_counter += 1 + + stmts.append(f"[{gemm_out}] = gemm({a}, {b});") + if needs_reshape: + shape_str = ", ".join(str(d) for d in out_shape) + stmts.append(f"[{output}] = reshape({gemm_out}, newShape=[{shape_str}]);") + return "\n\t".join(stmts) - # Map PyTorch dtype to WebNN dtype - webnn_dtype = self._get_webnn_dtype(target_dtype) + def _convert_t(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.t — transpose a 2-D matrix.""" + input_tensor = inputs[0] if inputs else "unknown" + return f"[{output}] = transpose({input_tensor}, permutation=[1, 0]);" - return f'[{output}] = cast({input_tensor}, type={webnn_dtype});' + # --- Activations --- + + def _convert_relu(self, node: fx.Node, output: str, inputs: List[str]) -> str: + return f"[{output}] = clamp({inputs[0] if inputs else 'unknown'}, minValue=0.0);" def _convert_sigmoid(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert sigmoid to WebNN sigmoid""" - input_tensor = inputs[0] if inputs else 'unknown' - return f'[{output}] = sigmoid({input_tensor});' + return f"[{output}] = sigmoid({inputs[0] if inputs else 'unknown'});" def _convert_tanh(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert tanh to WebNN tanh""" - input_tensor = inputs[0] if inputs else 'unknown' - return f'[{output}] = tanh({input_tensor});' - - def _convert_softmax(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert softmax to WebNN softmax""" - input_tensor = inputs[0] if inputs else 'unknown' - # Get axis from kwargs or args - axis = node.kwargs.get('dim', -1) - if len(node.args) > 1: - axis = node.args[1] - return f'[{output}] = softmax({input_tensor}, axis={axis});' + return f"[{output}] = tanh({inputs[0] if inputs else 'unknown'});" def _convert_silu(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """ - Convert SiLU (Swish) activation to WebNN operations. + x = inputs[0] if inputs else "unknown" + sig = f"operand_{self.operand_counter}" + self.operand_counter += 1 + return f"[{sig}] = sigmoid({x});\n\t[{output}] = mul({x}, {sig});" + + def _convert_gelu(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """GELU via tanh approximation: 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))""" + x = inputs[0] if inputs else "unknown" + c1 = self._create_inline_constant(0.7978845608028654) # sqrt(2/pi) + c2 = self._create_inline_constant(0.044715) + c3 = self._create_inline_constant(3.0) + c_half = self._create_inline_constant(0.5) + c_one = self._create_inline_constant(1.0) + + def tmp(): + name = f"operand_{self.operand_counter}" + self.operand_counter += 1 + return name - SiLU(x) = x * sigmoid(x) + x3, inner, scaled, tanh_in, tanh_out, one_plus, half_x, r = ( + tmp(), tmp(), tmp(), tmp(), tmp(), tmp(), tmp(), output + ) + stmts = [ + f"[{x3}] = pow({x}, {c3});", + f"[{inner}] = mul({x3}, {c2});", + f"[{inner}] = add({x}, {inner});", + f"[{scaled}] = mul({inner}, {c1});", + f"[{tanh_out}] = tanh({scaled});", + f"[{one_plus}] = add({tanh_out}, {c_one});", + f"[{half_x}] = mul({x}, {c_half});", + f"[{r}] = mul({half_x}, {one_plus});", + ] + return "\n\t".join(stmts) - This is a common activation function in modern neural networks, - especially in diffusion models and transformers. - """ - input_tensor = inputs[0] if inputs else 'unknown' + def _convert_hardtanh(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] if inputs else "unknown" + args = node.args + min_val = args[1] if len(args) > 1 else 0.0 + max_val = args[2] if len(args) > 2 else 6.0 + return f"[{output}] = clamp({x}, minValue={min_val}, maxValue={max_val});" - # Step 1: Compute sigmoid(x) - sigmoid_operand = f'operand_{self.operand_counter}' - self.operand_counter += 1 - step1 = f'[{sigmoid_operand}] = sigmoid({input_tensor});' + def _convert_clamp(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] if inputs else "unknown" + args = node.args + params = [] + min_val = args[1] if len(args) > 1 else node.kwargs.get("min") + max_val = args[2] if len(args) > 2 else node.kwargs.get("max") + if min_val is not None: + params.append(f"minValue={min_val}") + if max_val is not None: + params.append(f"maxValue={max_val}") + return f"[{output}] = clamp({x}, {', '.join(params)});" - # Step 2: Multiply x * sigmoid(x) - step2 = f'[{output}] = mul({input_tensor}, {sigmoid_operand});' + # --- Normalization --- - return f'{step1}\n {step2}' + def _convert_batch_norm_aten(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.batch_norm.default(input, weight, bias, running_mean, running_var, + training, momentum, eps, cudnn_enabled)""" + args = node.args - def _convert_batch_norm(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert batch_norm to WebNN. + def get_op(idx): + n = args[idx] if len(args) > idx else None + return self._get_input_operand(n) if isinstance(n, fx.Node) else None - Dynamo always emits the same target for both cases: - batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps, ...) - When track_running_stats=False the running_mean/var args are None (not FX Nodes). + input_tensor = get_op(0) or "unknown" + weight = get_op(1) # gamma / scale + bias_op = get_op(2) # beta / bias + mean_op = get_op(3) # running_mean + var_op = get_op(4) # running_var + eps = args[7] if len(args) > 7 else 1e-5 - With running stats → batchNormalization(input, mean, var, axis=1, scale, bias, eps). - Without running stats → decompose over NCHW axes [0, 2, 3]. - """ + if mean_op and var_op: + params = [f"epsilon={eps}", "axis=1"] + if weight: + params.append(f"scale={weight}") + if bias_op: + params.append(f"bias={bias_op}") + # TODO fix parser to accept mean and variance as named args as well + # return f"[{output}] = batchNormalization({input_tensor}, mean={mean_op}, variance={var_op}, {', '.join(params)});" + return f"[{output}] = batchNormalization({input_tensor}, {mean_op}, {var_op}, {', '.join(params)});" + + # No running stats — layer-norm style decomposition over NCHW [0,2,3] + return self._batch_norm_decompose(input_tensor, weight, bias_op, eps, node, output) + + def _convert_batch_norm_no_training(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten._native_batch_norm_legit_no_training.default(input, weight, bias, running_mean, running_var, momentum, eps)""" args = node.args def get_op(idx): n = args[idx] if len(args) > idx else None return self._get_input_operand(n) if isinstance(n, fx.Node) else None - input_tensor = get_op(0) or 'unknown' - mean_op = get_op(1) # running_mean or None - var_op = get_op(2) # running_var or None - weight = get_op(3) # gamma - bias = get_op(4) # beta - eps = args[7] if len(args) > 7 else 1e-5 + input_tensor = get_op(0) or "unknown" + weight = get_op(1) + bias_op = get_op(2) + mean_op = get_op(3) + var_op = get_op(4) + eps = args[6] if len(args) > 6 else 1e-5 if mean_op and var_op: - # Eval with running stats — use WebNN batchNormalization - params = [f'epsilon={eps}', 'axis=1'] + params = [f"epsilon={eps}", "axis=1"] if weight: - params.append(f'scale={weight}') - if bias: - params.append(f'bias={bias}') - return f'[{output}] = batchNormalization({input_tensor}, {mean_op}, {var_op}, {", ".join(params)});' + params.append(f"scale={weight}") + if bias_op: + params.append(f"bias={bias_op}") + return f"[{output}] = batchNormalization({input_tensor}, {mean_op}, {var_op}, {', '.join(params)});" - # No running stats — decompose for NCHW (normalize over axes [0, 2, 3]) - input_shape = self._get_node_shape(args[0]) if args and isinstance(args[0], fx.Node) else [] + return self._batch_norm_decompose(input_tensor, weight, bias_op, eps, node, output) + + def _batch_norm_decompose(self, input_tensor, weight, bias_op, eps, node, output) -> str: + """Decompose batch norm into mean/var/normalize ops for NCHW [0,2,3].""" + input_shape = self._get_node_shape(node.args[0]) if node.args and isinstance(node.args[0], fx.Node) else [] C = input_shape[1] if len(input_shape) > 1 else 0 eps_c = self._create_inline_constant(float(eps)) - mean_t = f'operand_{self.operand_counter}'; self.operand_counter += 1 - centered = f'operand_{self.operand_counter}'; self.operand_counter += 1 - sq = f'operand_{self.operand_counter}'; self.operand_counter += 1 - var_t = f'operand_{self.operand_counter}'; self.operand_counter += 1 - var_eps = f'operand_{self.operand_counter}'; self.operand_counter += 1 - std_t = f'operand_{self.operand_counter}'; self.operand_counter += 1 - norm_t = f'operand_{self.operand_counter}'; self.operand_counter += 1 + def tmp(): + name = f"operand_{self.operand_counter}" + self.operand_counter += 1 + return name + mean_t, centered, sq, var_t, var_eps, std_t, norm_t = ( + tmp(), tmp(), tmp(), tmp(), tmp(), tmp(), tmp() + ) steps = [ - f'[{mean_t}] = reduceMean({input_tensor}, axes=[0, 2, 3], keepDimensions=true);', - f'[{centered}] = sub({input_tensor}, {mean_t});', - f'[{sq}] = mul({centered}, {centered});', - f'[{var_t}] = reduceMean({sq}, axes=[0, 2, 3], keepDimensions=true);', - f'[{var_eps}] = add({var_t}, {eps_c});', - f'[{std_t}] = sqrt({var_eps});', - f'[{norm_t}] = div({centered}, {std_t});', + f"[{mean_t}] = reduceMean({input_tensor}, axes=[0, 2, 3], keepDimensions=true);", + f"[{centered}] = sub({input_tensor}, {mean_t});", + f"[{sq}] = mul({centered}, {centered});", + f"[{var_t}] = reduceMean({sq}, axes=[0, 2, 3], keepDimensions=true);", + f"[{var_eps}] = add({var_t}, {eps_c});", + f"[{std_t}] = sqrt({var_eps});", + f"[{norm_t}] = div({centered}, {std_t});", ] result = norm_t if weight and C: - w_shaped = f'operand_{self.operand_counter}'; self.operand_counter += 1 - scaled = f'operand_{self.operand_counter}'; self.operand_counter += 1 + w_shaped, scaled = tmp(), tmp() steps += [ - f'[{w_shaped}] = reshape({weight}, newShape=[1, {C}, 1, 1]);', - f'[{scaled}] = mul({result}, {w_shaped});', + f"[{w_shaped}] = reshape({weight}, newShape=[1, {C}, 1, 1]);", + f"[{scaled}] = mul({result}, {w_shaped});", ] result = scaled - if bias and C: - b_shaped = f'operand_{self.operand_counter}'; self.operand_counter += 1 + if bias_op and C: + b_shaped = tmp() steps += [ - f'[{b_shaped}] = reshape({bias}, newShape=[1, {C}, 1, 1]);', - f'[{output}] = add({result}, {b_shaped});', + f"[{b_shaped}] = reshape({bias_op}, newShape=[1, {C}, 1, 1]);", + f"[{output}] = add({result}, {b_shaped});", ] else: - steps.append(f'[{output}] = identity({result});') + steps.append(f"[{output}] = identity({result});") - return '\n\t'.join(steps) + return "\n\t".join(steps) def _convert_layer_norm(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert layer normalization to WebNN layerNormalization""" - input_tensor = inputs[0] if inputs else 'unknown' - - # Extract parameters from args + """aten.layer_norm.default(input, normalized_shape, weight, bias, eps)""" + x = inputs[0] if inputs else "unknown" args = node.args - # layer_norm(input, normalized_shape, weight, bias, eps) weight = self._get_input_operand(args[2]) if len(args) > 2 and isinstance(args[2], fx.Node) else None - bias = self._get_input_operand(args[3]) if len(args) > 3 and isinstance(args[3], fx.Node) else None + bias_op = self._get_input_operand(args[3]) if len(args) > 3 and isinstance(args[3], fx.Node) else None eps = args[4] if len(args) > 4 else 1e-5 - params = [] if weight: - params.append(f'scale={weight}') - if bias: - params.append(f'bias={bias}') - params.append(f'epsilon={eps}') - - params_str = ', '.join(params) - return f'[{output}] = layerNormalization({input_tensor}, {params_str});' + params.append(f"scale={weight}") + if bias_op: + params.append(f"bias={bias_op}") + params.append(f"epsilon={eps}") + return f"[{output}] = layerNormalization({x}, {', '.join(params)});" def _convert_group_norm(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """ - Convert group normalization to WebNN operations. - - Group normalization divides channels into groups and normalizes within each group. - This is commonly used in diffusion models and other architectures. - - Since WebNN doesn't have a native groupNorm, we decompose it into: - 1. Reshape to separate groups - 2. Normalize per group - 3. Reshape back - 4. Apply affine transform - """ - input_tensor = inputs[0] if inputs else 'unknown' - - # Extract parameters from args - # group_norm(input, num_groups, weight, bias, eps) + """aten.group_norm.default(input, num_groups, weight, bias, eps)""" + x = inputs[0] if inputs else "unknown" args = node.args num_groups = args[1] if len(args) > 1 else 32 weight = self._get_input_operand(args[2]) if len(args) > 2 and isinstance(args[2], fx.Node) else None - bias = self._get_input_operand(args[3]) if len(args) > 3 and isinstance(args[3], fx.Node) else None + bias_op = self._get_input_operand(args[3]) if len(args) > 3 and isinstance(args[3], fx.Node) else None eps = args[4] if len(args) > 4 else 1e-5 - - # Get input shape to understand dimensions - input_shape = self._get_node_shape(node.args[0]) if node.args and isinstance(args[0], fx.Node) else [] + input_shape = self._get_node_shape(args[0]) if args and isinstance(args[0], fx.Node) else [] if not input_shape or len(input_shape) < 2: - # Fallback: use layerNormalization as approximation - # This is not exactly group norm but works for many cases params = [] if weight: - params.append(f'scale={weight}') - if bias: - params.append(f'bias={bias}') - params.append(f'epsilon={eps}') - params_str = ', '.join(params) - return f'[{output}] = layerNormalization({input_tensor}, {params_str}); // approximation of group_norm' - - # Full decomposition for group normalization - # Input shape: [N, C, *spatial] - N = input_shape[0] - C = input_shape[1] - spatial_dims = input_shape[2:] - - # Calculate group parameters - channels_per_group = C // num_groups - - # Step 1: Reshape to separate groups - # [N, C, *spatial] -> [N, num_groups, channels_per_group, *spatial] - reshape1 = f'operand_{self.operand_counter}' - self.operand_counter += 1 - reshape1_shape = [N, num_groups, channels_per_group] + spatial_dims - reshape1_str = ', '.join(map(str, reshape1_shape)) - step1 = f'[{reshape1}] = reshape({input_tensor}, newShape=[{reshape1_str}]);' - - # Step 2: Compute mean per group (reduce over channels_per_group and spatial dims) - # This requires reduceMean over specific axes - # Axes to reduce: [2, 3, 4, ...] (channels_per_group and all spatial dimensions) - mean_axes = list(range(2, 2 + 1 + len(spatial_dims))) - mean_operand = f'operand_{self.operand_counter}' - self.operand_counter += 1 - axes_str = ', '.join(map(str, mean_axes)) - step2 = f'[{mean_operand}] = reduceMean({reshape1}, axes=[{axes_str}], keepDimensions=true);' - - # Step 3: Subtract mean - centered = f'operand_{self.operand_counter}' - self.operand_counter += 1 - step3 = f'[{centered}] = sub({reshape1}, {mean_operand});' - - # Step 4: Compute variance (mean of squared differences) - squared = f'operand_{self.operand_counter}' - self.operand_counter += 1 - step4a = f'[{squared}] = mul({centered}, {centered});' - - var_operand = f'operand_{self.operand_counter}' - self.operand_counter += 1 - step4b = f'[{var_operand}] = reduceMean({squared}, axes=[{axes_str}], keepDimensions=true);' - - # Step 5: Compute std = sqrt(var + eps) - var_eps = f'operand_{self.operand_counter}' - self.operand_counter += 1 - - # Create epsilon constant - eps_const = f'const_eps_{self.operand_counter}' - self.operand_counter += 1 - - eps_tensor = torch.tensor(eps, dtype=torch.float32) - self.inline_constants[eps_const] = eps_tensor - - step5a = f'[{var_eps}] = add({var_operand}, {eps_const});' - - std_operand = f'operand_{self.operand_counter}' - self.operand_counter += 1 - step5b = f'[{std_operand}] = sqrt({var_eps});' + params.append(f"scale={weight}") + if bias_op: + params.append(f"bias={bias_op}") + params.append(f"epsilon={eps}") + return f"[{output}] = layerNormalization({x}, {', '.join(params)}); // approx group_norm" + + N, C = input_shape[0], input_shape[1] + spatial = input_shape[2:] + cpg = C // num_groups + eps_c = self._create_inline_constant(float(eps)) - # Step 6: Normalize: centered / std - normalized = f'operand_{self.operand_counter}' - self.operand_counter += 1 - step6 = f'[{normalized}] = div({centered}, {std_operand});' + def tmp(): + name = f"operand_{self.operand_counter}" + self.operand_counter += 1 + return name - # Step 7: Reshape back to original shape - reshaped_back = f'operand_{self.operand_counter}' - self.operand_counter += 1 - orig_shape_str = ', '.join(map(str, input_shape)) - step7 = f'[{reshaped_back}] = reshape({normalized}, newShape=[{orig_shape_str}]);' + r1 = tmp() + r1_shape = [N, num_groups, cpg] + spatial + mean_axes = list(range(2, 2 + 1 + len(spatial))) + axes_str = ", ".join(map(str, mean_axes)) - # Step 8: Apply affine transform if weight/bias provided - result = reshaped_back - steps = [step1, step2, step3, step4a, step4b, step5a, step5b, step6, step7] + mean_t, centered, sq, var_t, var_eps, std_t, norm_t, r_back = ( + tmp(), tmp(), tmp(), tmp(), tmp(), tmp(), tmp(), tmp() + ) + orig_str = ", ".join(map(str, input_shape)) + steps = [ + f"[{r1}] = reshape({x}, newShape=[{', '.join(map(str, r1_shape))}]);", + f"[{mean_t}] = reduceMean({r1}, axes=[{axes_str}], keepDimensions=true);", + f"[{centered}] = sub({r1}, {mean_t});", + f"[{sq}] = mul({centered}, {centered});", + f"[{var_t}] = reduceMean({sq}, axes=[{axes_str}], keepDimensions=true);", + f"[{var_eps}] = add({var_t}, {eps_c});", + f"[{std_t}] = sqrt({var_eps});", + f"[{norm_t}] = div({centered}, {std_t});", + f"[{r_back}] = reshape({norm_t}, newShape=[{orig_str}]);", + ] + result = r_back if weight: - reshaped_weight = f'operand_{self.operand_counter}' - self.operand_counter += 1 - weight_shape = [1] * len(input_shape) - weight_shape[1] = self.operand_shapes[weight][0] - weight_shape_str = ', '.join(map(str, weight_shape)) - step8a = f'[{reshaped_weight}] = reshape({weight}, newShape=[{weight_shape_str}]);' - steps.append(step8a) - scaled = f'operand_{self.operand_counter}' - self.operand_counter += 1 - step8b = f'[{scaled}] = mul({result}, {reshaped_weight});' - steps.append(step8b) + w_shaped, scaled = tmp(), tmp() + w_shape = [1] * len(input_shape) + w_shape[1] = self.operand_shapes.get(weight, [C])[0] + steps += [ + f"[{w_shaped}] = reshape({weight}, newShape=[{', '.join(map(str, w_shape))}]);", + f"[{scaled}] = mul({result}, {w_shaped});", + ] result = scaled - if bias: - reshaped_bias = f'operand_{self.operand_counter}' - self.operand_counter += 1 - bias_shape = [1] * len(input_shape) - bias_shape[1] = self.operand_shapes[bias][0] - bias_shape_str = ', '.join(map(str, bias_shape)) - step9a = f'[{reshaped_bias}] = reshape({bias}, newShape=[{bias_shape_str}]);' - steps.append(step9a) - step9b = f'[{output}] = add({result}, {reshaped_bias});' - steps.append(step9b) + if bias_op: + b_shaped = tmp() + b_shape = [1] * len(input_shape) + b_shape[1] = self.operand_shapes.get(bias_op, [C])[0] + steps += [ + f"[{b_shaped}] = reshape({bias_op}, newShape=[{', '.join(map(str, b_shape))}]);", + f"[{output}] = add({result}, {b_shaped});", + ] else: - # Rename final result to output - steps.append(f'[{output}] = identity({result});') + steps.append(f"[{output}] = identity({result});") - return '\n\t'.join(steps) + return "\n\t".join(steps) - def _convert_getitem(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """ - Convert Python's getitem (indexing/slicing) to WebNN operations. - - Common patterns: - - tensor[:, None] -> unsqueeze - - tensor[0] -> slice - - tensor[:, 1:10] -> slice - - tensor[..., 0] -> slice - """ - input_tensor = inputs[0] if inputs else 'unknown' - - # Get the index/slice from args - if len(node.args) < 2: - return f'[{output}] = identity({input_tensor}); // getitem with no index' + # --- Pooling --- - index = node.args[1] + def _convert_max_pool2d(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.max_pool2d.default / max_pool2d_with_indices.default""" + x = inputs[0] if inputs else "unknown" + args = node.args + kernel = args[1] if len(args) > 1 else node.kwargs.get("kernel_size", [2, 2]) + stride = args[2] if len(args) > 2 else node.kwargs.get("stride", kernel) + padding = args[3] if len(args) > 3 else node.kwargs.get("padding", [0, 0]) + if not isinstance(kernel, (list, tuple)): + kernel = [kernel, kernel] + if not isinstance(stride, (list, tuple)): + stride = [stride, stride] + if not isinstance(padding, (list, tuple)): + padding = [padding, padding, padding, padding] + elif len(padding) == 2: + padding = [padding[0], padding[0], padding[1], padding[1]] - # Get input shape to help with dimension calculations - input_shape = self._get_node_shape(node.args[0]) if node.args and isinstance(node.args[0], fx.Node) else [] + params = [f"windowDimensions=[{kernel[0]}, {kernel[1]}]"] + if stride != [1, 1]: + params.append(f"strides=[{stride[0]}, {stride[1]}]") + if padding != [0, 0, 0, 0]: + params.append(f"padding=[{padding[0]}, {padding[1]}, {padding[2]}, {padding[3]}]") - # Handle single None (add dimension) - if index is None: - # tensor[None] -> unsqueeze at dimension 0 - # WebNN uses reshape to add dimensions - if input_shape: - new_shape = [1] + input_shape - shape_str = ', '.join(map(str, new_shape)) - return f'[{output}] = reshape({input_tensor}, newShape=[{shape_str}]);' - else: - return f'[{output}] = reshape({input_tensor}, newShape=[1, ...]);' - - # Handle tuple of indices/slices - if isinstance(index, tuple): - # Check for patterns like (:, None) which adds a dimension - none_positions = [i for i, idx in enumerate(index) if idx is None] - slice_positions = [i for i, idx in enumerate(index) if isinstance(idx, slice)] - - if none_positions and all(isinstance(idx, (slice, type(None))) for idx in index): - # This is unsqueeze operation - adding dimensions - # Example: tensor[:, None] adds dimension at position 1 - # Example: tensor[:, None, :] adds dimension at position 1 - - if not input_shape: - return f'[{output}] = identity({input_tensor});' - - # Build new shape by inserting 1s at None positions - output_shape = [] - input_dim = 0 - for i, idx in enumerate(index): - if idx is None: - output_shape.append(1) - elif isinstance(idx, slice): - if idx == slice(None, None, None): # Full slice (:) - if input_dim < len(input_shape): - output_shape.append(input_shape[input_dim]) - input_dim += 1 - else: - # Partial slice - need to calculate size - if input_dim < len(input_shape): - dim_size = input_shape[input_dim] - start = idx.start if idx.start is not None else 0 - stop = idx.stop if idx.stop is not None else dim_size - step = idx.step if idx.step is not None else 1 - sliced_size = (stop - start) // step - output_shape.append(sliced_size) - input_dim += 1 - elif isinstance(idx, int): - # Integer indexing removes the dimension - input_dim += 1 - # Don't add to output_shape (dimension is removed) - - # Add remaining dimensions - while input_dim < len(input_shape): - output_shape.append(input_shape[input_dim]) - input_dim += 1 - - shape_str = ', '.join(map(str, output_shape)) - return f'[{output}] = reshape({input_tensor}, newShape=[{shape_str}]);' - - # Handle pure slicing (no None) - elif not none_positions and all(isinstance(idx, (slice, int)) for idx in index): - # This is slice operation - # Example: tensor[0, :, 1:10] - - if not input_shape: - return f'[{output}] = identity({input_tensor});' - - # Check if it's just integer indexing (removes dimensions) - if all(isinstance(idx, int) for idx in index): - # All integer indices - results in a scalar or lower-rank tensor - # This requires gather operation which we may not have yet - raise NotImplementedError("getitem with all integer indices (gather operation needed)") - - # Mixed slice and integer indexing - # Build starts, sizes, and output shape - starts = [] - sizes = [] - output_shape = [] - for dim_idx, idx in enumerate(index): - if dim_idx >= len(input_shape): - break - - dim_size = input_shape[dim_idx] - - if isinstance(idx, slice): - start = idx.start if idx.start is not None else 0 - stop = idx.stop if idx.stop is not None else dim_size - step = idx.step if idx.step is not None else 1 - - if step != 1: - raise NotImplementedError("getitem with step != 1 not supported yet") - - starts.append(start) - size = stop - start - sizes.append(size) - output_shape.append(size) - elif isinstance(idx, int): - # Integer index - take single element - starts.append(idx) - sizes.append(1) - # Don't add to output_shape (dimension is squeezed) - - # Handle remaining dimensions (implicitly [:]) - for dim_idx in range(len(index), len(input_shape)): - starts.append(0) - sizes.append(input_shape[dim_idx]) - output_shape.append(input_shape[dim_idx]) - - # Generate slice operation - starts_str = ', '.join(map(str, starts)) - sizes_str = ', '.join(map(str, sizes)) - - sliced = f'operand_{self.operand_counter}' - self.operand_counter += 1 - slice_op = f'[{sliced}] = slice({input_tensor}, starts=[{starts_str}], sizes=[{sizes_str}]);' - - # If we removed dimensions (integer indexing), reshape to squeeze them - if len(output_shape) < len(sizes): - output_shape_str = ', '.join(map(str, output_shape)) - reshape_op = f'[{output}] = reshape({sliced}, newShape=[{output_shape_str}]);' - return f'{slice_op}\n {reshape_op}' - else: - return f'[{output}] = slice({input_tensor}, starts=[{starts_str}], sizes=[{sizes_str}]);' - - # Handle single integer index (e.g., tensor[0]) - elif isinstance(index, int): - if not input_shape: - return f'[{output}] = identity({input_tensor});' - - # Slice first dimension at index - starts = [index] + [0] * (len(input_shape) - 1) - sizes = [1] + input_shape[1:] - - starts_str = ', '.join(map(str, starts)) - sizes_str = ', '.join(map(str, sizes)) - - # Slice and then reshape to remove the first dimension - sliced = f'operand_{self.operand_counter}' - self.operand_counter += 1 - slice_op = f'[{sliced}] = slice({input_tensor}, starts=[{starts_str}], sizes=[{sizes_str}]);' + # max_pool2d_with_indices returns a tuple; the first element is the pooled tensor. + # The FX node itself represents index 0 (values), index 1 (indices) via getitem. + return f"[{output}] = maxPool2d({x}, {', '.join(params)});" - output_shape = input_shape[1:] - if output_shape: - output_shape_str = ', '.join(map(str, output_shape)) - reshape_op = f'[{output}] = reshape({sliced}, newShape=[{output_shape_str}]);' - return f'{slice_op}\n {reshape_op}' - else: - return slice_op + def _convert_avg_pool2d(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] if inputs else "unknown" + args = node.args + kernel = args[1] if len(args) > 1 else [2, 2] + stride = args[2] if len(args) > 2 else kernel + padding = args[3] if len(args) > 3 else [0, 0] + if not isinstance(kernel, (list, tuple)): + kernel = [kernel, kernel] + if not isinstance(stride, (list, tuple)): + stride = [stride, stride] + if not isinstance(padding, (list, tuple)): + padding = [padding, padding, padding, padding] + elif len(padding) == 2: + padding = [padding[0], padding[0], padding[1], padding[1]] - # Handle single slice (e.g., tensor[1:10]) - elif isinstance(index, slice): - if not input_shape: - return f'[{output}] = identity({input_tensor});' + params = [f"windowDimensions=[{kernel[0]}, {kernel[1]}]"] + if stride != [1, 1]: + params.append(f"strides=[{stride[0]}, {stride[1]}]") + if padding != [0, 0, 0, 0]: + params.append(f"padding=[{padding[0]}, {padding[1]}, {padding[2]}, {padding[3]}]") + return f"[{output}] = averagePool2d({x}, {', '.join(params)});" - start = index.start if index.start is not None else 0 - stop = index.stop if index.stop is not None else input_shape[0] - step = index.step if index.step is not None else 1 + def _convert_global_avg_pool(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] if inputs else "unknown" + return f"[{output}] = reduceMean({x}, axes=[2, 3], keepDimensions=true);" - if step != 1: - raise NotImplementedError("getitem with step != 1 not supported yet") + def _convert_reduce_mean(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] if inputs else "unknown" + args = node.args + axes = None + keep = True + if "dim" in node.kwargs: + axes = node.kwargs["dim"] + elif len(args) > 1: + axes = args[1] + if "keepdim" in node.kwargs: + keep = node.kwargs["keepdim"] + elif len(args) > 2: + keep = args[2] + if axes is not None: + if not isinstance(axes, (list, tuple)): + axes = [axes] + axes_str = ", ".join(map(str, axes)) + return f"[{output}] = reduceMean({x}, axes=[{axes_str}], keepDimensions={'true' if keep else 'false'});" + return f"[{output}] = reduceMean({x});" - # Build full starts and sizes for all dimensions - starts = [start] + [0] * (len(input_shape) - 1) - sizes = [stop - start] + input_shape[1:] + # --- Arithmetic --- - starts_str = ', '.join(map(str, starts)) - sizes_str = ', '.join(map(str, sizes)) + def _make_arithmetic(self, op: str, node: fx.Node, output: str, inputs: List[str]) -> str: + if len(inputs) == 2: + return f"[{output}] = {op}({inputs[0]}, {inputs[1]});" + if len(inputs) == 1: + # scalar second operand + for arg in node.args: + if isnumeric(arg) and not isinstance(arg, fx.Node): + const = self._create_inline_constant(arg) + return f"[{output}] = {op}({inputs[0]}, {const});" + raise NotImplementedError(f"Invalid {op} operation: inputs={inputs} args={node.args}") - return f'[{output}] = slice({input_tensor}, starts=[{starts_str}], sizes=[{sizes_str}]);' + def _convert_add(self, node, output, inputs): + return self._make_arithmetic("add", node, output, inputs) - # Unknown index type - raise NotImplementedError(f"getitem with index type {type(index).__name__} not supported yet") + def _convert_sub(self, node, output, inputs): + return self._make_arithmetic("sub", node, output, inputs) - def _convert_rearrange(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """ - Convert einops rearrange to WebNN reshape/transpose operations. - - Algorithm: - 1. Parse both sides of the '->' pattern into groups of elementary axes. - 2. Resolve each elementary axis size from kwargs, then lhs→input_shape, - then rhs→output_shape. - 3. Build the permutation that maps lhs elementary-axis order to rhs order. - 4. Emit up to three ops (skipping no-ops): - reshape (expand merged input axes) - transpose (reorder axes) - reshape (collapse merged output axes) - """ - input_tensor = inputs[0] if inputs else 'unknown' + def _convert_mul(self, node, output, inputs): + return self._make_arithmetic("mul", node, output, inputs) - if not isinstance(node.args[1], str): - raise NotImplementedError("Rearrange only supports string patterns") + def _convert_div(self, node, output, inputs): + return self._make_arithmetic("div", node, output, inputs) - pattern = node.args[1] - # kwargs carry axis sizes like pi=2, pj=2 - kwargs = {k: int(v) for k, v in node.kwargs.items() if isinstance(v, int)} + def _convert_neg(self, node: fx.Node, output: str, inputs: List[str]) -> str: + return f"[{output}] = neg({inputs[0] if inputs else 'unknown'});" - input_shape = self._get_node_shape(node.args[0]) - output_shape = self._get_node_shape(node) + def _convert_pow(self, node: fx.Node, output: str, inputs: List[str]) -> str: + if len(inputs) >= 2: + return f"[{output}] = pow({inputs[0]}, {inputs[1]});" + # Scalar exponent in args + x = inputs[0] if inputs else "unknown" + exp_val = node.args[1] if len(node.args) > 1 else None + if exp_val is not None and not isinstance(exp_val, fx.Node): + const = self._create_inline_constant(float(exp_val)) + return f"[{output}] = pow({x}, {const});" + raise NotImplementedError("pow: cannot determine exponent") + + def _convert_pow_scalar(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.pow.Scalar(scalar_base, tensor_exponent) — scalar ** tensor. + Decomposed as exp(log(base) * exponent). + """ + base = node.args[0] if node.args else 1.0 + x = inputs[0] if inputs else "unknown" + log_base = math.log(float(base)) + log_c = self._create_inline_constant(log_base) + scaled = f"operand_{self.operand_counter}" + self.operand_counter += 1 + return f"[{scaled}] = mul({x}, {log_c});\n\t[{output}] = exp({scaled});" - if not input_shape or not output_shape: - raise NotImplementedError( - f"Rearrange requires static shapes, got input={input_shape}, output={output_shape}" - ) + # --- Elementwise math --- - lhs_str, rhs_str = pattern.split('->') - - def parse_side(s: str): - """Return a list of groups; each group is a list of axis name strings.""" - groups = [] - s = s.strip() - i = 0 - while i < len(s): - if s[i] == '(': - j = s.index(')', i) - inner = s[i + 1:j].strip().split() - groups.append(inner) - i = j + 1 - elif s[i] == ' ': - i += 1 - else: - j = i - while j < len(s) and s[j] not in (' ', '(', ')'): - j += 1 - token = s[i:j] - if token: - groups.append([token]) - i = j - return groups - - lhs_groups = parse_side(lhs_str) - rhs_groups = parse_side(rhs_str) - - lhs_axes = [ax for group in lhs_groups for ax in group] - rhs_axes = [ax for group in rhs_groups for ax in group] - - # Resolve axis sizes ----------------------------------------------- - # Any axis name that is a plain integer string is a literal size. - def is_literal(ax: str) -> bool: - return ax.lstrip('-').isdigit() - - axis_sizes = {ax: int(ax) for ax in lhs_axes + rhs_axes if is_literal(ax)} - axis_sizes.update(kwargs) - - for group, dim_size in zip(lhs_groups, input_shape): - if len(group) == 1: - axis_sizes[group[0]] = dim_size - else: - known = math.prod(axis_sizes[ax] for ax in group if ax in axis_sizes) - unknown = [ax for ax in group if ax not in axis_sizes] - if len(unknown) == 1: - axis_sizes[unknown[0]] = dim_size // known - - for group, dim_size in zip(rhs_groups, output_shape): - if len(group) == 1: - axis_sizes.setdefault(group[0], dim_size) - else: - known = math.prod(axis_sizes.get(ax, 1) for ax in group if ax in axis_sizes) - unknown = [ax for ax in group if ax not in axis_sizes] - if len(unknown) == 1: - axis_sizes[unknown[0]] = dim_size // known - - # Expanded (flat) shapes ------------------------------------------- - expanded_input_shape = [axis_sizes[ax] for ax in lhs_axes] - - # Permutation is built only over rhs axes that come from lhs. - # Axes in rhs but not in lhs (e.g. literal '1' insertions) are - # excluded here; the final reshape-to-output_shape inserts them. - lhs_axis_idx = {ax: i for i, ax in enumerate(lhs_axes)} - rhs_real_axes = [ax for ax in rhs_axes if ax in lhs_axis_idx and not is_literal(ax)] - permutation = [lhs_axis_idx[ax] for ax in rhs_real_axes] - - needs_expand = expanded_input_shape != list(input_shape) - needs_transpose = permutation != list(range(len(permutation))) - # Shape after expand + transpose — may differ from output_shape when - # literal '1' dimensions are inserted or merged groups are collapsed. - post_transpose_shape = [axis_sizes[ax] for ax in rhs_real_axes] - needs_collapse = post_transpose_shape != list(output_shape) - - if not needs_expand and not needs_transpose and not needs_collapse: - return f'[{output}] = identity({input_tensor});' - - # Build op sequence ------------------------------------------------ - ops = [] - if needs_expand: - ops.append(('reshape', expanded_input_shape)) - if needs_transpose: - ops.append(('transpose', permutation)) - if needs_collapse: - ops.append(('reshape', output_shape)) + def _convert_math_sqrt(self, node, output, inputs): + return f"[{output}] = sqrt({inputs[0] if inputs else 'unknown'});" - steps = [] - current = input_tensor - for idx, (op_type, param) in enumerate(ops): - is_last = (idx == len(ops) - 1) - out_name = output if is_last else f'operand_{self.operand_counter}' - if not is_last: - self.operand_counter += 1 + def _convert_math_exp(self, node, output, inputs): + return f"[{output}] = exp({inputs[0] if inputs else 'unknown'});" - if op_type == 'reshape': - shape_str = ', '.join(map(str, param)) - steps.append(f'[{out_name}] = reshape({current}, newShape=[{shape_str}]);') - else: # transpose - perm_str = ', '.join(map(str, param)) - steps.append(f'[{out_name}] = transpose({current}, permutation=[{perm_str}]);') + def _convert_math_abs(self, node, output, inputs): + return f"[{output}] = abs({inputs[0] if inputs else 'unknown'});" - current = out_name + def _convert_math_log(self, node, output, inputs): + return f"[{output}] = log({inputs[0] if inputs else 'unknown'});" - return '\n '.join(steps) + def _convert_math_cos(self, node, output, inputs): + return f"[{output}] = cos({inputs[0] if inputs else 'unknown'});" - def _convert_arange(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """ - Convert arange to a pre-computed constant. - Since WebNN doesn't have arange, we compute it at export time and add as a constant. - """ - # Extract parameters from kwargs - kwargs = node.kwargs - start = kwargs.get('start', 0) - end = kwargs.get('end', None) - step = kwargs.get('step', 1) - dtype = kwargs.get('dtype', torch.float32) - - # Handle positional args if no kwargs - if end is None and node.args: - if len(node.args) == 1: - end = node.args[0] - elif len(node.args) >= 2: - start = node.args[0] - end = node.args[1] - if len(node.args) >= 3: - step = node.args[2] - - if end is None: - raise ValueError(f"arange requires 'end' parameter: {node}") - - # Generate arange values - values = torch.arange(start, end, step, dtype=dtype) + def _convert_math_sin(self, node, output, inputs): + return f"[{output}] = sin({inputs[0] if inputs else 'unknown'});" - # Store as a generated constant - const_name = f'const_arange_{self.operand_counter}' + def _convert_rsqrt(self, node, output, inputs): + """aten.rsqrt.default(x) = 1 / sqrt(x)""" + x = inputs[0] if inputs else "unknown" + sqrt_op = f"operand_{self.operand_counter}" self.operand_counter += 1 + one_c = self._create_inline_constant(1.0) + return f"[{sqrt_op}] = sqrt({x});\n\t[{output}] = div({one_c}, {sqrt_op});" - self.inline_constants[const_name] = values - self.operand_shapes[const_name] = list(values.shape) - - # Map this node to the constant operand - self.node_to_operand[node.name] = const_name - - def _convert_einsum(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """ - Convert einsum (Einstein summation) to WebNN operations. + def _convert_reciprocal(self, node, output, inputs): + """aten.reciprocal.default(x) = 1 / x""" + x = inputs[0] if inputs else "unknown" + one_c = self._create_inline_constant(1.0) + return f"[{output}] = div({one_c}, {x});" - Einsum is a powerful operation that can express many tensor operations - through Einstein notation. Common patterns: - - Matrix multiplication: 'ij,jk->ik' - - Batch matrix multiply: 'bij,bjk->bik' - - Outer product: 'i,j->ij' - - Broadcasting: '...n,d->...nd' + # --- Shape manipulation --- - This implementation handles common patterns. Complex patterns may need - additional decomposition. - """ - # Get einsum pattern from args - # einsum(pattern, *tensors) - args = node.args - if not args: - raise NotImplementedError('Invalid einsum: no arguments') - - pattern = args[0] if isinstance(args[0], str) else None - if not pattern: - raise NotImplementedError('Invalid einsum: pattern not found') - - # Get input shapes - input_shapes = [] - for i, inp in enumerate(inputs): - if i < len(args) - 1: # Skip the pattern - node_arg = args[i + 1] - if isinstance(node_arg, fx.Node): - shape = self._get_node_shape(node_arg) - input_shapes.append(shape) - - # Parse the einsum pattern - # Pattern format: 'input1,input2,...->output' - if '->' not in pattern: - return f'// Unsupported einsum pattern (no ->): {pattern}' - - lhs, rhs = pattern.split('->') - input_patterns = [p.strip() for p in lhs.split(',')] - - # Handle common patterns - # Pattern: '...n,d->...nd' (outer product with broadcasting) - if len(input_patterns) == 2 and pattern.endswith('nd') and input_patterns[0].endswith('n') and input_patterns[1] == 'd': - # This is: [..., n] x [d] -> [..., n, d] - # Can be implemented as unsqueeze + broadcast - input1 = inputs[0] if len(inputs) > 0 else 'unknown' - input2 = inputs[1] if len(inputs) > 1 else 'unknown' - - if not input_shapes or len(input_shapes) < 2: - return f'// einsum {pattern}: shape information needed' - - shape1 = input_shapes[0] - shape2 = input_shapes[1] - - # Step 1: Unsqueeze input2 to add a dimension - # [d] -> [1, d] - unsqueezed2 = f'operand_{self.operand_counter}' - self.operand_counter += 1 - unsqueeze_shape = [1] + shape2 - unsqueeze_str = ', '.join(map(str, unsqueeze_shape)) - step1 = f'[{unsqueezed2}] = reshape({input2}, newShape=[{unsqueeze_str}]);' + def _convert_reshape(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] if inputs else "unknown" - # Step 2: Unsqueeze input1 to add dimension for broadcasting - # [..., n] -> [..., n, 1] - unsqueezed1 = f'operand_{self.operand_counter}' - self.operand_counter += 1 - unsqueeze1_shape = shape1 + [1] - unsqueeze1_str = ', '.join(map(str, unsqueeze1_shape)) - step2 = f'[{unsqueezed1}] = reshape({input1}, newShape=[{unsqueeze1_str}]);' - - # Step 3: Multiply (will broadcast automatically) - # [..., n, 1] * [1, d] -> [..., n, d] - step3 = f'[{output}] = mul({unsqueezed1}, {unsqueezed2});' - - return f'{step1}\n {step2}\n {step3}' - - # Pattern: 'ij,jk->ik' (matrix multiplication) - elif len(input_patterns) == 2 and len(input_patterns[0]) == 2 and len(input_patterns[1]) == 2 and len(rhs) == 2: - # Standard matrix multiplication - input1 = inputs[0] if len(inputs) > 0 else 'unknown' - input2 = inputs[1] if len(inputs) > 1 else 'unknown' - return f'[{output}] = matmul({input1}, {input2});' - - # Pattern: 'bij,bjk->bik' (batch matrix multiplication) - elif len(input_patterns) == 2 and len(input_patterns[0]) == 3 and len(input_patterns[1]) == 3 and len(rhs) == 3: - # Batch matrix multiplication - input1 = inputs[0] if len(inputs) > 0 else 'unknown' - input2 = inputs[1] if len(inputs) > 1 else 'unknown' - return f'[{output}] = matmul({input1}, {input2});' - - # Pattern: 'i,j->ij' (outer product) - elif len(input_patterns) == 2 and len(input_patterns[0]) == 1 and len(input_patterns[1]) == 1 and len(rhs) == 2: - input1 = inputs[0] if len(inputs) > 0 else 'unknown' - input2 = inputs[1] if len(inputs) > 1 else 'unknown' - - # Outer product: reshape to [n, 1] and [1, m], then multiply - if not input_shapes or len(input_shapes) < 2: - return f'// einsum {pattern}: shape information needed' - - shape1 = input_shapes[0] - shape2 = input_shapes[1] - - # Reshape input1 to [n, 1] - reshaped1 = f'operand_{self.operand_counter}' - self.operand_counter += 1 - reshape1_shape = shape1 + [1] - reshape1_str = ', '.join(map(str, reshape1_shape)) - step1 = f'[{reshaped1}] = reshape({input1}, newShape=[{reshape1_str}]);' + # Prefer static output shape from FX metadata + meta_shape = self._get_node_shape(node) + if meta_shape: + return f"[{output}] = reshape({x}, newShape=[{', '.join(map(str, meta_shape))}]);" - # Reshape input2 to [1, m] - reshaped2 = f'operand_{self.operand_counter}' - self.operand_counter += 1 - reshape2_shape = [1] + shape2 - reshape2_str = ', '.join(map(str, reshape2_shape)) - step2 = f'[{reshaped2}] = reshape({input2}, newShape=[{reshape2_str}]);' + # aten.flatten.using_ints(input, start_dim, end_dim) + if len(node.args) >= 2 and isinstance(node.args[1], int): + start_dim = int(node.args[1]) + end_dim = int(node.args[2]) if len(node.args) > 2 and isinstance(node.args[2], int) else -1 + if node.args and isinstance(node.args[0], fx.Node): + in_shape = self._get_node_shape(node.args[0]) + if in_shape: + rank = len(in_shape) + if end_dim < 0: + end_dim += rank + if 0 <= start_dim <= end_dim < rank: + flat_dim = math.prod(in_shape[start_dim:end_dim + 1]) + new_shape = in_shape[:start_dim] + [int(flat_dim)] + in_shape[end_dim + 1:] + return f"[{output}] = reshape({x}, newShape=[{', '.join(map(str, new_shape))}]);" - # Multiply - step3 = f'[{output}] = mul({reshaped1}, {reshaped2});' + # aten.reshape / aten.view: second arg is the shape list + if len(node.args) > 1: + new_shape = node.args[1] + if isinstance(new_shape, (list, tuple)): + return f"[{output}] = reshape({x}, newShape=[{', '.join(map(str, new_shape))}]);" - return f'{step1}\n {step2}\n {step3}' + raise NotImplementedError(f"Cannot determine reshape target shape for {node}") - # Pattern: 'ii->i' (diagonal extraction) - elif len(input_patterns) == 1 and len(input_patterns[0]) == 2 and len(rhs) == 1: - # Diagonal extraction - not directly supported, would need gather - raise NotImplementedError(f"einsum diagonal extraction pattern ({pattern}) not supported yet") + def _convert_permute(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.permute.default(input, dims_list) — dims is a single list arg.""" + x = inputs[0] if inputs else "unknown" + dims = node.args[1] if len(node.args) > 1 else None + if isinstance(dims, (list, tuple)): + perm_str = ", ".join(map(str, dims)) + return f"[{output}] = transpose({x}, permutation=[{perm_str}]);" + raise NotImplementedError(f"permute: cannot determine dims from {node.args}") - # Pattern: 'ij->ji' (transpose) - elif len(input_patterns) == 1 and len(input_patterns[0]) == 2 and len(rhs) == 2: - # Check if it's a transpose - if input_patterns[0][0] == rhs[1] and input_patterns[0][1] == rhs[0]: - input1 = inputs[0] if len(inputs) > 0 else 'unknown' - return f'[{output}] = transpose({input1}, permutation=[1, 0]);' + def _convert_transpose(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.transpose.int(input, dim0, dim1) — swaps two dims.""" + x = inputs[0] if inputs else "unknown" + if len(node.args) >= 3: + dim0, dim1 = int(node.args[1]), int(node.args[2]) + in_shape = self._get_node_shape(node.args[0]) if isinstance(node.args[0], fx.Node) else [] + rank = len(in_shape) + if rank == 0: + raise NotImplementedError("transpose: unknown rank") + perm = list(range(rank)) + d0 = dim0 % rank + d1 = dim1 % rank + perm[d0], perm[d1] = perm[d1], perm[d0] + return f"[{output}] = transpose({x}, permutation=[{', '.join(map(str, perm))}]);" + return f"[{output}] = transpose({x});" + + def _convert_unsqueeze(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] if inputs else "unknown" + in_shape = self._get_node_shape(node.args[0]) if isinstance(node.args[0], fx.Node) else [] + dim = int(node.args[1]) if len(node.args) > 1 else 0 + rank = len(in_shape) + if dim < 0: + dim = rank + 1 + dim + new_shape = list(in_shape[:dim]) + [1] + list(in_shape[dim:]) + return f"[{output}] = reshape({x}, newShape=[{', '.join(map(str, new_shape))}]);" + + def _convert_squeeze(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] if inputs else "unknown" + out_shape = self._get_node_shape(node) + if out_shape: + return f"[{output}] = reshape({x}, newShape=[{', '.join(map(str, out_shape))}]);" + return f"[{output}] = identity({x});" - # Unknown pattern - raise NotImplementedError(f"einsum pattern not yet supported: {pattern}") - - def _convert_scaled_dot_product_attention(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """ - Convert scaled dot product attention to WebNN operations. - - Implements: Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V - - This is the core attention mechanism used in transformers and - was introduced as an optimized primitive in PyTorch 2.0. - - Args: - Q: Query tensor of shape [batch, num_heads, seq_len_q, head_dim] - K: Key tensor of shape [batch, num_heads, seq_len_k, head_dim] - V: Value tensor of shape [batch, num_heads, seq_len_v, head_dim] - - Returns: - Output tensor of shape [batch, num_heads, seq_len_q, head_dim] - """ - if len(inputs) < 3: - raise NotImplementedError('Invalid scaled_dot_product_attention: need Q, K, V inputs') - - Q = inputs[0] - K = inputs[1] - V = inputs[2] - - # Get shapes to calculate scaling factor - q_shape = self._get_node_shape(node.args[0]) if len(node.args) > 0 and isinstance(node.args[0], fx.Node) else [] - - # head_dim is the last dimension of Q - head_dim = q_shape[-1] if q_shape else 64 # default to 64 if unknown - - # Calculate scaling factor: 1 / sqrt(head_dim) - import math - scale_factor = 1.0 / math.sqrt(head_dim) - - steps = [] - - # Step 1: Transpose K to get K^T - # K shape: [batch, num_heads, seq_len_k, head_dim] - # K^T shape: [batch, num_heads, head_dim, seq_len_k] - # We transpose the last two dimensions: [... -2, -1] -> [... -1, -2] - k_transposed = f'operand_{self.operand_counter}' - self.operand_counter += 1 - - if len(node.args) > 1 and isinstance(node.args[1], fx.Node): - k_shape = self._get_node_shape(node.args[1]) - if k_shape and len(k_shape) >= 2: - # Build permutation for transpose - # For 4D: [0, 1, 2, 3] -> [0, 1, 3, 2] - perm = list(range(len(k_shape))) - perm[-2], perm[-1] = perm[-1], perm[-2] - perm_str = ', '.join(map(str, perm)) - steps.append(f'[{k_transposed}] = transpose({K}, permutation=[{perm_str}]);') - else: - steps.append(f'[{k_transposed}] = transpose({K}, permutation=[0, 1, 3, 2]);') - else: - steps.append(f'[{k_transposed}] = transpose({K}, permutation=[0, 1, 3, 2]);') - - # Step 2: Compute Q @ K^T - qk = f'operand_{self.operand_counter}' - self.operand_counter += 1 - steps.append(f'[{qk}] = matmul({Q}, {k_transposed});') - - # Step 3: Scale by 1/sqrt(head_dim) - # Create scale constant - scale_const = f'const_scale_{self.operand_counter}' - self.operand_counter += 1 - import torch - scale_tensor = torch.tensor(scale_factor, dtype=torch.float32) - self.inline_constants[scale_const] = scale_tensor.item() - - qk_scaled = f'operand_{self.operand_counter}' - self.operand_counter += 1 - steps.append(f'[{qk_scaled}] = mul({qk}, {scale_const});') - - # Step 4: Apply softmax along the last dimension - # softmax is computed over the key dimension (last dim) - attention_weights = f'operand_{self.operand_counter}' - self.operand_counter += 1 - steps.append(f'[{attention_weights}] = softmax({qk_scaled}, axis=-1);') - - # Step 5: Multiply attention weights with V - # attention_weights @ V - steps.append(f'[{output}] = matmul({attention_weights}, {V});') - - return '\n '.join(steps) - - def _convert_interpolate(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """ - Convert interpolate (upsampling/downsampling) to WebNN operations. - - PyTorch's interpolate supports various modes: - - 'nearest': Nearest neighbor interpolation - - 'linear', 'bilinear', 'trilinear': Linear interpolation - - 'bicubic': Bicubic interpolation - - WebNN has resample2d for this operation. - """ - input_tensor = inputs[0] if inputs else 'unknown' - - # Get parameters - kwargs = node.kwargs - scale_factor = kwargs.get('scale_factor', None) - size = kwargs.get('size', None) - mode = kwargs.get('mode', 'nearest') - align_corners = kwargs.get('align_corners', None) - - # Get input shape - input_shape = self._get_node_shape(node.args[0]) if len(node.args) > 0 and isinstance(node.args[0], fx.Node) else [] - - if scale_factor is not None: - # Use scale factor - if isinstance(scale_factor, (int, float)): - # Same scale for all spatial dimensions - scales = [scale_factor, scale_factor] - elif isinstance(scale_factor, (list, tuple)): - scales = list(scale_factor) - else: - scales = [2.0, 2.0] # default - - scales_str = ', '.join(map(str, scales)) - - # WebNN mode mapping - webnn_mode = 'nearest-neighbor' if mode == 'nearest' else 'linear' - - return f'[{output}] = resample2d({input_tensor}, mode="{webnn_mode}", scales=[{scales_str}]);' - - elif size is not None: - # Use target size - if isinstance(size, (list, tuple)): - target_size = list(size) - else: - target_size = [size, size] - - # Calculate scales from target size - if input_shape and len(input_shape) >= 4: - # Input shape: [N, C, H, W] - current_h, current_w = input_shape[-2:] - target_h, target_w = target_size - scale_h = target_h / current_h - scale_w = target_w / current_w - scales_str = f'{scale_h}, {scale_w}' - else: - # Without shape info, use sizes directly - size_str = ', '.join(map(str, target_size)) - return f'[{output}] = resample2d({input_tensor}, sizes=[{size_str}]);' - - webnn_mode = 'nearest-neighbor' if mode == 'nearest' else 'linear' - return f'[{output}] = resample2d({input_tensor}, mode="{webnn_mode}", scales=[{scales_str}]);' + def _convert_concat(self, node: fx.Node, output: str, inputs: List[str]) -> str: + axis = 0 + if "dim" in node.kwargs: + axis = node.kwargs["dim"] + elif len(node.args) > 1 and not isinstance(node.args[1], fx.Node): + axis = node.args[1] + if inputs: + return f"[{output}] = concat([{', '.join(inputs)}], axis={axis});" + if len(node.args) >= 1 and isinstance(node.args[0], (list, tuple)): + ops = ", ".join(self._get_input_operand(n) for n in node.args[0] if isinstance(n, fx.Node)) + return f"[{output}] = concat([{ops}], axis={axis});" + raise NotImplementedError("concat: no inputs") + def _convert_stack(self, node: fx.Node, output: str, inputs: List[str]) -> str: + # dim may be in kwargs or in node.args[1] + if "dim" in node.kwargs: + dim = node.kwargs["dim"] + elif len(node.args) > 1 and isinstance(node.args[1], int): + dim = node.args[1] else: - return f'// interpolate: need either scale_factor or size' - - def _convert_relu(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert ReLU to WebNN clamp""" - input_tensor = inputs[0] if inputs else 'unknown' - return f'[{output}] = clamp({input_tensor}, minValue=0.0);' - - def _convert_clamp(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert clamp to WebNN clamp""" - input_tensor = inputs[0] if inputs else 'unknown' - args = node.args - min_val = args[1] if len(args) > 1 else None - max_val = args[2] if len(args) > 2 else None - - params = [] - if min_val is not None: - params.append(f'minValue={min_val}') - if max_val is not None: - params.append(f'maxValue={max_val}') - - params_str = ', '.join(params) - return f'[{output}] = clamp({input_tensor}, {params_str});' - - def _convert_hardtanh(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert hardtanh to WebNN clamp (typically ReLU6)""" - input_tensor = inputs[0] if inputs else 'unknown' - args = node.args - # hardtanh(input, min_val, max_val) - min_val = args[1] if len(args) > 1 else 0.0 - max_val = args[2] if len(args) > 2 else 6.0 - - return f'[{output}] = clamp({input_tensor}, maxValue={max_val}, minValue={min_val});' - - def _convert_addmm(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert addmm to WebNN gemm + add. - - torch.addmm(bias, mat1, mat2) = bias + mat1 @ mat2 - mat1: (M, K) — input - mat2: (K, N) — weight, already mm-ready (no transpose) - bias: (N,) or broadcastable - """ - if len(inputs) < 3: - return self._convert_identity(node, output, inputs) - - bias = inputs[0] - mat1 = inputs[1] - mat2 = inputs[2] - - stmts = [] - - # Flatten mat1 to rank-2 if needed (e.g. coming from a conv feature map) - input_node = node.args[1] if len(node.args) > 1 and isinstance(node.args[1], fx.Node) else None - input_shape = self._get_node_shape(input_node) if input_node is not None else [] - if input_shape and len(input_shape) != 2: - batch = int(input_shape[0]) - features = int(math.prod(input_shape[1:])) - tmp = f'operand_{self.operand_counter}' + dim = 0 + + # inputs may be empty when the tensor list is packed as node.args[0] + tensors = node.args[0] if isinstance(node.args[0], (list, tuple)) else [] + tensor_inputs = [self._get_input_operand(n) for n in tensors if isinstance(n, fx.Node)] + if not tensor_inputs: + tensor_inputs = inputs + if not tensor_inputs: + raise NotImplementedError("stack: no inputs") + + first = tensors[0] if tensors and isinstance(tensors[0], fx.Node) else None + in_shape = self._get_node_shape(first) if first else [] + steps = [] + unsqueezed = [] + for inp in tensor_inputs: + us = f"operand_{self.operand_counter}" self.operand_counter += 1 - stmts.append(f'[{tmp}] = reshape({mat1}, newShape=[{batch}, {features}]);') - mat1 = tmp - - # gemm(mat1, mat2) = mat1 @ mat2 — mat2 is already (K, N), no bTranspose - # Add bias separately to avoid WebNN runtime broadcast issues with 1-D c. - mm_out = f'operand_{self.operand_counter}' - self.operand_counter += 1 - stmts.append(f'[{mm_out}] = gemm({mat1}, {mat2});') - stmts.append(f'[{output}] = add({mm_out}, {bias});') - - return '\n '.join(stmts) - - def _convert_matmul(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert matmul to WebNN gemm, reshaping rank-1 inputs to rank-2.""" - if len(inputs) < 2: - raise ValueError("Gemm requires 2 inputs") - - a, b = inputs[0], inputs[1] - stmts = [] - - # Resolve shapes for both inputs - a_node = node.args[0] if len(node.args) > 0 and isinstance(node.args[0], fx.Node) else None - b_node = node.args[1] if len(node.args) > 1 and isinstance(node.args[1], fx.Node) else None - a_shape = self._get_node_shape(a_node) if a_node is not None else self.operand_shapes.get(a, []) - b_shape = self._get_node_shape(b_node) if b_node is not None else self.operand_shapes.get(b, []) - out_shape = self._get_node_shape(node) + unsqueezed.append(us) + if in_shape: + pos = dim if dim >= 0 else len(in_shape) + 1 + dim + new_shape = list(in_shape[:pos]) + [1] + list(in_shape[pos:]) + steps.append(f"[{us}] = reshape({inp}, newShape=[{', '.join(map(str, new_shape))}]);") + else: + steps.append(f"[{us}] = reshape({inp}, newShape=[1]);") + concat_dim = dim if dim >= 0 else len(in_shape) + 1 + dim + steps.append(f"[{output}] = concat([{', '.join(unsqueezed)}], axis={concat_dim});") + return "\n\t".join(steps) - # Reshape rank-1 inputs to rank-2 (gemm requires 2D operands) - if a_shape and len(a_shape) == 1: - tmp = f'operand_{self.operand_counter}' + def _convert_split(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] if inputs else "unknown" + sections = node.args[1] if len(node.args) > 1 else None + dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + if isinstance(sections, (list, tuple)): + # Multi-output split: pre-allocate one operand per section + out_ops = [] + for _ in sections: + op = f"operand_{self.operand_counter}" + self.operand_counter += 1 + out_ops.append(op) + self.multi_output_operands[node.name] = out_ops + return f"[{', '.join(out_ops)}] = split({x}, splits=[{', '.join(map(str, sections))}], axis={dim});" + # Even-split: need shape to compute sizes + in_shape = self._get_node_shape(node.args[0]) if node.args and isinstance(node.args[0], fx.Node) else [] + if in_shape and sections is not None: + dim_n = int(dim) % len(in_shape) + dim_size = in_shape[dim_n] + n = int(sections) + base = dim_size // n + rem = dim_size % n + sizes = [base + (1 if i < rem else 0) for i in range(n)] + out_ops = [] + for _ in sizes: + op = f"operand_{self.operand_counter}" + self.operand_counter += 1 + out_ops.append(op) + self.multi_output_operands[node.name] = out_ops + return f"[{', '.join(out_ops)}] = split({x}, splits=[{', '.join(map(str, sizes))}], axis={dim});" + return f"[{output}] = split({x}, splits={sections}, axis={dim});" + + def _convert_chunk(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.chunk.default(tensor, chunks, dim) — decompose into slice ops per chunk.""" + x = inputs[0] if inputs else "unknown" + n_chunks = int(node.args[1]) if len(node.args) > 1 else 1 + dim = int(node.args[2]) if len(node.args) > 2 else 0 + in_shape = self._get_node_shape(node.args[0]) if node.args and isinstance(node.args[0], fx.Node) else [] + if not in_shape: + raise NotImplementedError(f"chunk: unknown input shape for {node.name}") + rank = len(in_shape) + dim = dim % rank + dim_size = in_shape[dim] + # Compute actual chunk sizes (last chunk may be smaller) + base = (dim_size + n_chunks - 1) // n_chunks + sizes = [] + remaining = dim_size + for _ in range(n_chunks): + s = min(base, remaining) + if s <= 0: + break + sizes.append(s) + remaining -= s + actual_n = len(sizes) + steps = [] + out_ops = [] + offset = 0 + for s in sizes: + op = f"operand_{self.operand_counter}" self.operand_counter += 1 - stmts.append(f'[{tmp}] = reshape({a}, newShape=[1, {a_shape[0]}]);') - a = tmp - - if b_shape and len(b_shape) == 1: - tmp = f'operand_{self.operand_counter}' + out_ops.append(op) + starts = [0] * rank + slice_sizes = list(in_shape) + starts[dim] = offset + slice_sizes[dim] = s + steps.append( + f"[{op}] = slice({x}, starts=[{', '.join(map(str, starts))}], " + f"sizes=[{', '.join(map(str, slice_sizes))}]);" + ) + offset += s + self.multi_output_operands[node.name] = out_ops + # Register the first output as this node's primary operand (satisfies downstream identity) + self.node_to_operand[node.name] = out_ops[0] + return "\n\t".join(steps) + + def _convert_unbind(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.unbind.int(tensor, dim) — split into individual tensors along dim.""" + x = inputs[0] if inputs else "unknown" + dim = int(node.args[1]) if len(node.args) > 1 else 0 + in_shape = self._get_node_shape(node.args[0]) if node.args and isinstance(node.args[0], fx.Node) else [] + if not in_shape: + raise NotImplementedError(f"unbind: unknown input shape for {node.name}") + rank = len(in_shape) + dim = dim % rank + dim_size = in_shape[dim] + out_shape = list(in_shape[:dim]) + list(in_shape[dim + 1:]) + steps = [] + out_ops = [] + for i in range(dim_size): + slice_op = f"operand_{self.operand_counter}" self.operand_counter += 1 - stmts.append(f'[{tmp}] = reshape({b}, newShape=[{b_shape[0]}, 1]);') - b = tmp - - # If the expected output is not rank-2, route gemm through an intermediate - # and reshape down (e.g. mat*vec produces a 1D result, not (1, n)) - needs_output_reshape = bool(out_shape) and len(out_shape) != 2 - gemm_out = output - if needs_output_reshape: - gemm_out = f'operand_{self.operand_counter}' + squeeze_op = f"operand_{self.operand_counter}" self.operand_counter += 1 + starts = [0] * rank + sizes = list(in_shape) + starts[dim] = i + sizes[dim] = 1 + steps.append( + f"[{slice_op}] = slice({x}, starts=[{', '.join(map(str, starts))}], " + f"sizes=[{', '.join(map(str, sizes))}]);" + ) + steps.append( + f"[{squeeze_op}] = reshape({slice_op}, newShape=[{', '.join(map(str, out_shape))}]);" + ) + out_ops.append(squeeze_op) + self.multi_output_operands[node.name] = out_ops + self.node_to_operand[node.name] = out_ops[0] + return "\n\t".join(steps) + + def _convert_getitem(self, node: fx.Node, output: str, inputs: List[str]) -> Optional[str]: + """operator.getitem — index into multi-output results (chunk/unbind/split).""" + source = node.args[0] if node.args and isinstance(node.args[0], fx.Node) else None + idx = node.args[1] if len(node.args) > 1 else 0 + if source is not None and source.name in self.multi_output_operands: + operands = self.multi_output_operands[source.name] + if isinstance(idx, int) and 0 <= idx < len(operands): + # Alias this node to the pre-computed slice operand (no new op needed) + self.node_to_operand[node.name] = operands[idx] + return None + # Fallback: treat as identity of the source + if inputs: + return f"[{output}] = identity({inputs[0]});" + raise NotImplementedError(f"getitem: cannot resolve index {idx} for {node.name}") + + def _convert_select(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.select.int(tensor, dim, index) — select one slice along dim, removing it.""" + x = inputs[0] if inputs else "unknown" + dim = int(node.args[1]) if len(node.args) > 1 else 0 + index = int(node.args[2]) if len(node.args) > 2 else 0 + in_shape = self._get_node_shape(node.args[0]) if node.args and isinstance(node.args[0], fx.Node) else [] + if not in_shape: + raise NotImplementedError(f"select: unknown input shape for {node.name}") + rank = len(in_shape) + dim = dim % rank + if index < 0: + index = in_shape[dim] + index + starts = [0] * rank + sizes = list(in_shape) + starts[dim] = index + sizes[dim] = 1 + out_shape = list(in_shape[:dim]) + list(in_shape[dim + 1:]) + slice_op = f"operand_{self.operand_counter}" + self.operand_counter += 1 + return ( + f"[{slice_op}] = slice({x}, starts=[{', '.join(map(str, starts))}], " + f"sizes=[{', '.join(map(str, sizes))}]);\n" + f"\t[{output}] = reshape({slice_op}, newShape=[{', '.join(map(str, out_shape))}]);" + ) - stmts.append(f'[{gemm_out}] = gemm({a}, {b});') - - if needs_output_reshape: - shape_str = ', '.join(str(d) for d in out_shape) - stmts.append(f'[{output}] = reshape({gemm_out}, newShape=[{shape_str}]);') - - return '\n '.join(stmts) - - def _convert_linear(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert linear function to WebNN gemm""" - # linear(input, weight, bias) - if len(inputs) >= 2: - input_tensor = inputs[0] - weight = inputs[1] - if weight in self.operand_shapes and len(self.operand_shapes[weight]) >= 2: - if len(inputs) >= 3: - bias = inputs[2] - # Use a separate add for the bias rather than gemm's c parameter, - # because the WebNN runtime does not reliably broadcast a 1-D c. - mm_out = f'operand_{self.operand_counter}' - self.operand_counter += 1 - return ( - f'[{mm_out}] = gemm({input_tensor}, {weight}, bTranspose=true);\n' - f'\t[{output}] = add({mm_out}, {bias});' - ) - else: - return f'[{output}] = gemm({input_tensor}, {weight}, bTranspose=true);' - - # TODO: other cases untested - raise NotImplementedError("This linear case is untested.") - input_node = node.args[0] if node.args and isinstance(node.args[0], fx.Node) else None - input_shape = self._get_node_shape(input_node) if input_node is not None else [] - # Gemm expects rank-2 input. Flatten when needed. - if input_shape and len(input_shape) != 2: - batch = int(input_shape[0]) if input_shape else 1 - features = int(math.prod(input_shape[1:])) if len(input_shape) > 1 else 1 - reshaped = f'operand_{self.operand_counter}' + def _convert_einsum(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.einsum.default(equation, operands) — decompose common patterns.""" + equation = node.args[0] if node.args else "" + tensors_arg = node.args[1] if len(node.args) > 1 and isinstance(node.args[1], (list, tuple)) else [] + operands = [self._get_input_operand(n) for n in tensors_arg if isinstance(n, fx.Node)] + + # Pattern: '...n,d->...nd' (broadcast outer product: [...,n] x [d] -> [...,n,d]) + if equation == "...n,d->...nd" and len(operands) == 2: + a, b = operands[0], operands[1] + a_node = tensors_arg[0] if tensors_arg and isinstance(tensors_arg[0], fx.Node) else None + b_node = tensors_arg[1] if len(tensors_arg) > 1 and isinstance(tensors_arg[1], fx.Node) else None + a_shape = self._get_node_shape(a_node) if a_node else [] + out_shape = self._get_node_shape(node) + if not out_shape: + raise NotImplementedError(f"einsum '{equation}': cannot determine output shape") + rank_out = len(out_shape) + + def tmp(): + name = f"operand_{self.operand_counter}" self.operand_counter += 1 - reshape_stmt = f'[{reshaped}] = reshape({input_tensor}, newShape=[{batch}, {features}]);' - if len(inputs) >= 3: - bias = inputs[2] - gemm_stmt = f'[{output}] = gemm({reshaped}, {weight}, bTranspose=true, c={bias});' - else: - gemm_stmt = f'[{output}] = gemm({reshaped}, {weight}, bTranspose=true);' - return f'{reshape_stmt}\n {gemm_stmt}' - - def _convert_linear_module(self, node: fx.Node, module: torch.nn.Linear, output: str, inputs: List[str]) -> str: - """Convert Linear module to WebNN gemm""" - input_tensor = inputs[0] if inputs else 'unknown' - - # Get weight and bias operands - weight_name = f'{node.target}.weight' - weight_operand = self.weight_operands.get(weight_name, 'unknown') - - if module.bias is not None: - bias_name = f'{node.target}.bias' - bias_operand = self.weight_operands.get(bias_name, 'unknown') - return f'[{output}] = gemm({input_tensor}, {weight_operand}, bTranspose=true, c={bias_operand});' - else: - return f'[{output}] = gemm({input_tensor}, {weight_operand}, bTranspose=true);' - - def _convert_global_avg_pool(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert global average pooling to WebNN reduceMean""" - input_tensor = inputs[0] if inputs else 'unknown' - # Global average pool is reduceMean over spatial dimensions (usually axes 2,3 for NCHW) - return f'[{output}] = reduceMean({input_tensor}, axes=[2, 3], keepDimensions=true);' - - def _convert_avg_pool2d(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert 2D average pooling to WebNN averagePool2d""" - input_tensor = inputs[0] if inputs else 'unknown' - - # Extract parameters from args/kwargs - kernel_size = node.args[1] if len(node.args) > 1 else node.kwargs.get('kernel_size', [2, 2]) - stride = node.args[2] if len(node.args) > 2 else node.kwargs.get('stride', kernel_size) - padding = node.args[3] if len(node.args) > 3 else node.kwargs.get('padding', [0, 0]) - - # Ensure parameters are lists - if not isinstance(kernel_size, (list, tuple)): - kernel_size = [kernel_size, kernel_size] - if not isinstance(stride, (list, tuple)): - stride = [stride, stride] - if not isinstance(padding, (list, tuple)): - padding = [padding, padding, padding, padding] - elif len(padding) == 2: - padding = [padding[0], padding[0], padding[1], padding[1]] - - params = [] - params.append(f'windowDimensions=[{kernel_size[0]}, {kernel_size[1]}]') - if stride != [1, 1]: - params.append(f'strides=[{stride[0]}, {stride[1]}]') - if padding != [0, 0, 0, 0]: - params.append(f'padding=[{padding[0]}, {padding[1]}, {padding[2]}, {padding[3]}]') - - params_str = ', '.join(params) - return f'[{output}] = averagePool2d({input_tensor}, {params_str});' - - def _convert_max_pool2d(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert 2D max pooling to WebNN maxPool2d""" - input_tensor = inputs[0] if inputs else 'unknown' - - # Extract parameters from args/kwargs - kernel_size = node.args[1] if len(node.args) > 1 else node.kwargs.get('kernel_size', [2, 2]) - stride = node.args[2] if len(node.args) > 2 else node.kwargs.get('stride', kernel_size) - padding = node.args[3] if len(node.args) > 3 else node.kwargs.get('padding', [0, 0]) - - # Ensure parameters are lists - if not isinstance(kernel_size, (list, tuple)): - kernel_size = [kernel_size, kernel_size] - if not isinstance(stride, (list, tuple)): - stride = [stride, stride] - if not isinstance(padding, (list, tuple)): - padding = [padding, padding, padding, padding] - elif len(padding) == 2: - padding = [padding[0], padding[0], padding[1], padding[1]] - - params = [] - params.append(f'windowDimensions=[{kernel_size[0]}, {kernel_size[1]}]') - if stride != [1, 1]: - params.append(f'strides=[{stride[0]}, {stride[1]}]') - if padding != [0, 0, 0, 0]: - params.append(f'padding=[{padding[0]}, {padding[1]}, {padding[2]}, {padding[3]}]') + return name + + # Unsqueeze a along last axis: [..., n] → [..., n, 1] + a_us = tmp() + a_us_shape = list(a_shape) + [1] + # Reshape b to [1, ..., 1, d] matching output rank + b_node_shape = self._get_node_shape(b_node) if b_node else [] + d = b_node_shape[0] if b_node_shape else out_shape[-1] + b_rs = tmp() + b_rs_shape = [1] * (rank_out - 1) + [d] + return "\n\t".join([ + f"[{a_us}] = reshape({a}, newShape=[{', '.join(map(str, a_us_shape))}]);", + f"[{b_rs}] = reshape({b}, newShape=[{', '.join(map(str, b_rs_shape))}]);", + f"[{output}] = mul({a_us}, {b_rs});", + ]) + + raise NotImplementedError(f"einsum: unsupported equation '{equation}'") - params_str = ', '.join(params) - return f'[{output}] = maxPool2d({input_tensor}, {params_str});' - - def _convert_reduce_mean(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert mean/reduce to WebNN reduceMean""" - input_tensor = inputs[0] if inputs else 'unknown' - - # Get axes from args or kwargs - axes = None - keep_dims = True - - if 'dim' in node.kwargs: - axes = node.kwargs['dim'] - if not isinstance(axes, (list, tuple)): - axes = [axes] - elif 'axis' in node.kwargs: - axes = node.kwargs['axis'] - if not isinstance(axes, (list, tuple)): - axes = [axes] - elif len(node.args) > 1: - axes = node.args[1] - if not isinstance(axes, (list, tuple)): - axes = [axes] - - if 'keepdim' in node.kwargs: - keep_dims = node.kwargs['keepdim'] - elif 'keepdims' in node.kwargs: - keep_dims = node.kwargs['keepdims'] - - if axes: - axes_str = ', '.join(map(str, axes)) - keep_str = 'true' if keep_dims else 'false' - return f'[{output}] = reduceMean({input_tensor}, axes=[{axes_str}], keepDimensions={keep_str});' - else: - # Reduce over all axes - return f'[{output}] = reduceMean({input_tensor});' - - def _convert_transpose(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert transpose/permute to WebNN transpose""" - input_tensor = inputs[0] if inputs else 'unknown' - - # Get permutation from args or kwargs - perm = None - if 'dims' in node.kwargs: - perm = node.kwargs['dims'] - elif len(node.args) > 1: - # For permute: permute(input, dim0, dim1, ...) - # For transpose: transpose(input, dim0, dim1) - if len(node.args) == 3: # transpose(input, dim0, dim1) - dim0, dim1 = node.args[1], node.args[2] - # Create permutation that swaps dim0 and dim1 - # Need to know rank from metadata - input_shape = self._get_node_shape(node.args[0]) if isinstance(node.args[0], fx.Node) else [] - rank = len(input_shape) - perm = list(range(rank)) - if dim0 < rank and dim1 < rank: - perm[dim0], perm[dim1] = perm[dim1], perm[dim0] - else: # permute(input, dim0, dim1, dim2, ...) - perm = list(node.args[1:]) - - if perm: - perm_str = ', '.join(map(str, perm)) - return f'[{output}] = transpose({input_tensor}, permutation=[{perm_str}]);' - else: - # Default transpose (reverse all dimensions) - return f'[{output}] = transpose({input_tensor});' - - def _convert_concat(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert concatenation to WebNN concat""" - # Get axis/dim - axis = node.kwargs.get('dim', 0) - if 'axis' in node.kwargs: - axis = node.kwargs['axis'] - elif len(node.args) > 1 and not isinstance(node.args[1], fx.Node): - axis = node.args[1] - - # Collect input tensors - # torch.cat takes a list/tuple as first argument - if len(inputs) > 0: - inputs_str = ', '.join(inputs) - return f'[{output}] = concat([{inputs_str}], axis={axis});' - else: - if len(node.args) >= 1 and isinstance(node.args[0], (list, tuple)): - concat_inputs = ",".join([self._get_operand_name(n) for n in node.args[0]]) - return f'[{output}] = concat([{concat_inputs}], axis={axis});' - raise NotImplementedError('Invalid concat: no inputs') - - def _convert_stack(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """ - Convert stack to WebNN operations. - - Stack creates a new dimension and concatenates tensors along it. - For example: stack([a, b, c], dim=0) where a.shape=[2, 3] - results in output.shape=[3, 2, 3] - - Implementation: unsqueeze each input at the stack dimension, then concat + def _convert_slice(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.slice.Tensor(input, dim, start, end, step) + start / end can be None (meaning 0 / full dimension respectively). + end can also be 9223372036854775807 (sys.maxsize) meaning full dimension. """ - # Get dimension to stack along - dim = node.kwargs.get('dim', 0) - if 'axis' in node.kwargs: - dim = node.kwargs['axis'] - - if not inputs: - raise NotImplementedError('Invalid stack: no inputs') - - # Get shape of first input to understand dimensions - if len(node.args) > 0 and isinstance(node.args[0], (list, tuple)): - # Stack takes a list of tensors as first argument - first_tensor = node.args[0][0] if node.args[0] else None - if first_tensor and isinstance(first_tensor, fx.Node): - input_shape = self._get_node_shape(first_tensor) - else: - input_shape = [] - else: - input_shape = [] - - # Unsqueeze each input at the stack dimension - unsqueezed_operands = [] - steps = [] - - for i, inp in enumerate(inputs): - unsqueezed = f'operand_{self.operand_counter}' - self.operand_counter += 1 - unsqueezed_operands.append(unsqueezed) - - if input_shape: - # Calculate new shape with unsqueezed dimension - new_shape = list(input_shape) - # Insert 1 at the stack dimension - # Handle negative indexing - if dim < 0: - insert_pos = len(new_shape) + 1 + dim - else: - insert_pos = dim - new_shape.insert(insert_pos, 1) - - shape_str = ', '.join(map(str, new_shape)) - steps.append(f'[{unsqueezed}] = reshape({inp}, newShape=[{shape_str}]);') - else: - # Without shape info, still try to unsqueeze - # This might not work perfectly without knowing the actual shape - steps.append(f'[{unsqueezed}] = reshape({inp}, newShape=[..., 1, ...]);') - - # Concatenate all unsqueezed tensors along the stack dimension - unsqueezed_str = ', '.join(unsqueezed_operands) - # Handle negative indexing - if dim < 0 and input_shape: - concat_dim = len(input_shape) + 1 + dim - else: - concat_dim = dim - - steps.append(f'[{output}] = concat([{unsqueezed_str}], axis={concat_dim});') - - return '\n '.join(steps) - - def _convert_split(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert split to WebNN split""" - input_tensor = inputs[0] if inputs else 'unknown' - - # Get split sizes or number of splits - split_size_or_sections = node.args[1] if len(node.args) > 1 else None - dim = node.kwargs.get('dim', 0) - if len(node.args) > 2: - dim = node.args[2] - - if isinstance(split_size_or_sections, (list, tuple)): - # Split into specific sizes - splits_str = ', '.join(map(str, split_size_or_sections)) - return f'[{output}] = split({input_tensor}, splits=[{splits_str}], axis={dim});' - else: - # Split into equal sections - return f'[{output}] = split({input_tensor}, splits={split_size_or_sections}, axis={dim});' + x = inputs[0] if inputs else "unknown" + args = node.args + dim = int(args[1]) if len(args) > 1 else 0 + start = args[2] if len(args) > 2 else None + end = args[3] if len(args) > 3 else None + step = int(args[4]) if len(args) > 4 and args[4] is not None else 1 + in_shape = self._get_node_shape(args[0]) if args and isinstance(args[0], fx.Node) else [] - def _convert_slice(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert slice to WebNN slice""" - input_tensor = inputs[0] if inputs else 'unknown' + if not in_shape: + return f"// slice: unknown input shape for {node.name}" - # torch.slice or tensor slicing - # Get parameters from args - dim = node.args[1] if len(node.args) > 1 else 0 - start = node.args[2] if len(node.args) > 2 else 0 - end = node.args[3] if len(node.args) > 3 else None - step = node.args[4] if len(node.args) > 4 else 1 + rank = len(in_shape) + dim = dim % rank + dim_size = in_shape[dim] - # Get input shape to calculate sizes - input_shape = self._get_node_shape(node.args[0]) if node.args and isinstance(node.args[0], fx.Node) else [] + # Resolve None / sentinel values + start = 0 if start is None else int(start) + end = dim_size if (end is None or end >= sys.maxsize) else min(int(end), dim_size) - if end is None and input_shape and dim < len(input_shape): - end = input_shape[dim] + # Handle negative indices + if start < 0: + start = max(0, dim_size + start) + if end < 0: + end = max(0, dim_size + end) - if end is not None: - size = (end - start) // step - # WebNN slice takes starts and sizes - # Build starts and sizes for all dimensions - starts = [0] * len(input_shape) - sizes = list(input_shape) - starts[dim] = start - sizes[dim] = size + size = max(0, (end - start + step - 1) // step) # ceiling div for step > 1 - starts_str = ', '.join(map(str, starts)) - sizes_str = ', '.join(map(str, sizes)) - return f'[{output}] = slice({input_tensor}, starts=[{starts_str}], sizes=[{sizes_str}]);' + # If the slice covers the full dimension it's a no-op + if start == 0 and size == dim_size and step == 1: + return f"[{output}] = identity({x});" - return f'// Slice with unknown dimensions' + starts = [0] * rank + sizes = list(in_shape) + starts[dim] = start + sizes[dim] = size + return f"[{output}] = slice({x}, starts=[{', '.join(map(str, starts))}], sizes=[{', '.join(map(str, sizes))}]);" def _convert_expand(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert expand to WebNN expand (broadcast)""" - input_tensor = inputs[0] if inputs else 'unknown' - - # Get target shape from args - target_shape = node.args[1] if len(node.args) > 1 else None - - if target_shape: - if isinstance(target_shape, (list, tuple)): - shape_str = ', '.join(map(str, target_shape)) - else: - shape_str = str(target_shape) - return f'[{output}] = expand({input_tensor}, newShape=[{shape_str}]);' - - return f'// Expand with unknown shape' + x = inputs[0] if inputs else "unknown" + shape = node.args[1] if len(node.args) > 1 else None + if shape and isinstance(shape, (list, tuple)): + return f"[{output}] = expand({x}, newShape=[{', '.join(map(str, shape))}]);" + return f"// expand: unknown shape for {node.name}" + + def _convert_expand_as(self, node: fx.Node, output: str, inputs: List[str]) -> str: + x = inputs[0] if inputs else "unknown" + target_node = node.args[1] if len(node.args) > 1 and isinstance(node.args[1], fx.Node) else None + shape = self._get_node_shape(target_node) if target_node else [] + if shape: + return f"[{output}] = expand({x}, newShape=[{', '.join(map(str, shape))}]);" + return f"[{output}] = identity({x}); // expand_as: unknown target shape" def _convert_pad(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert pad to WebNN pad""" - input_tensor = inputs[0] if inputs else 'unknown' - - # Get padding from args + x = inputs[0] if inputs else "unknown" padding = node.args[1] if len(node.args) > 1 else [0, 0, 0, 0] - mode = node.kwargs.get('mode', 'constant') - value = node.kwargs.get('value', 0) - - # Convert padding format + mode = node.kwargs.get("mode", "constant") + value = node.kwargs.get("value", 0) if isinstance(padding, (list, tuple)): - # PyTorch padding is usually [left, right, top, bottom] for 2D - # WebNN expects [begin_0, end_0, begin_1, end_1, ...] - padding_str = ', '.join(map(str, padding)) + pad_str = ", ".join(map(str, padding)) + if mode == "constant": + return f'[{output}] = pad({x}, padding=[{pad_str}], mode="constant", value={value});' + return f'[{output}] = pad({x}, padding=[{pad_str}], mode="{mode}");' + return f"// pad: unknown parameters for {node.name}" - if mode == 'constant': - return f'[{output}] = pad({input_tensor}, padding=[{padding_str}], mode="constant", value={value});' - else: - return f'[{output}] = pad({input_tensor}, padding=[{padding_str}], mode="{mode}");' + # --- Upsampling --- - return f'// Pad with unknown parameters' + def _convert_upsample_nearest2d(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.upsample_nearest2d.vec(input, output_size, scale_factors) + aten.upsample_nearest2d.default(input, output_size, scales_h, scales_w) + """ + x = inputs[0] if inputs else "unknown" + args = node.args + target_str = str(node.target) + + if "vec" in target_str: + output_size = args[1] if len(args) > 1 else None # int[] or None + scale_factors = args[2] if len(args) > 2 else None # float[] or None + if scale_factors: + scales_str = ", ".join(str(s) for s in scale_factors) + return f'[{output}] = resample2d({x}, mode="nearest-neighbor", scales=[{scales_str}]);' + if output_size: + sizes_str = ", ".join(str(s) for s in output_size) + return f'[{output}] = resample2d({x}, mode="nearest-neighbor", sizes=[{sizes_str}]);' + else: + # .default(input, output_size, scales_h=None, scales_w=None) + output_size = args[1] if len(args) > 1 else None + scales_h = args[2] if len(args) > 2 else None + scales_w = args[3] if len(args) > 3 else None + if scales_h is not None and scales_w is not None: + return f'[{output}] = resample2d({x}, mode="nearest-neighbor", scales=[{scales_h}, {scales_w}]);' + if output_size: + sizes_str = ", ".join(str(s) for s in output_size) + return f'[{output}] = resample2d({x}, mode="nearest-neighbor", sizes=[{sizes_str}]);' + + raise NotImplementedError(f"upsample_nearest2d: need output_size or scale_factors, got args={args}") + + def _convert_upsample_bilinear2d(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.upsample_bilinear2d.vec / .default → resample2d with mode="linear" """ + x = inputs[0] if inputs else "unknown" + args = node.args + target_str = str(node.target) + + if "vec" in target_str: + output_size = args[1] if len(args) > 1 else None + scale_factors = args[3] if len(args) > 3 else None # arg[2] is align_corners + if scale_factors: + scales_str = ", ".join(str(s) for s in scale_factors) + return f'[{output}] = resample2d({x}, mode="linear", scales=[{scales_str}]);' + if output_size: + sizes_str = ", ".join(str(s) for s in output_size) + return f'[{output}] = resample2d({x}, mode="linear", sizes=[{sizes_str}]);' + else: + output_size = args[1] if len(args) > 1 else None + scales_h = args[3] if len(args) > 3 else None # arg[2] is align_corners + scales_w = args[4] if len(args) > 4 else None + if scales_h is not None and scales_w is not None: + return f'[{output}] = resample2d({x}, mode="linear", scales=[{scales_h}, {scales_w}]);' + if output_size: + sizes_str = ", ".join(str(s) for s in output_size) + return f'[{output}] = resample2d({x}, mode="linear", sizes=[{sizes_str}]);' - def _convert_tile(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert tile/repeat to WebNN tile""" - input_tensor = inputs[0] if inputs else 'unknown' + raise NotImplementedError(f"upsample_bilinear2d: need output_size or scale_factors, got args={args}") - # Get repetitions from args - reps = node.args[1] if len(node.args) > 1 else None + # --- Softmax / attention --- - if reps: - if isinstance(reps, (list, tuple)): - reps_str = ', '.join(map(str, reps)) - else: - reps_str = str(reps) - return f'[{output}] = tile({input_tensor}, repetitions=[{reps_str}]);' + def _convert_softmax(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.softmax.int(input, dim, half_to_float)""" + x = inputs[0] if inputs else "unknown" + axis = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + return f"[{output}] = softmax({x}, axis={axis});" - return f'// Tile with unknown repetitions' + def _convert_softmax_aten(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten._softmax.default(input, dim, half_to_float)""" + return self._convert_softmax(node, output, inputs) - def _convert_reshape(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """Convert reshape/view to WebNN reshape""" - input_tensor = inputs[0] if inputs else 'unknown' + def _convert_scaled_dot_product_attention(self, node: fx.Node, output: str, inputs: List[str]) -> str: + if len(inputs) < 3: + raise NotImplementedError("SDPA requires Q, K, V inputs") + Q, K, V = inputs[0], inputs[1], inputs[2] + q_shape = self._get_node_shape(node.args[0]) if isinstance(node.args[0], fx.Node) else [] + head_dim = q_shape[-1] if q_shape else 64 + scale = 1.0 / math.sqrt(head_dim) + scale_c = self._create_inline_constant(scale) + + def tmp(): + name = f"operand_{self.operand_counter}" + self.operand_counter += 1 + return name - # Prefer static output shape from FX metadata when available. - meta_shape = self._get_node_shape(node) - if meta_shape: - shape_str = ', '.join(map(str, meta_shape)) - return f'[{output}] = reshape({input_tensor}, newShape=[{shape_str}]);' + k_shape = self._get_node_shape(node.args[1]) if len(node.args) > 1 and isinstance(node.args[1], fx.Node) else [] + if k_shape and len(k_shape) >= 2: + perm = list(range(len(k_shape))) + perm[-2], perm[-1] = perm[-1], perm[-2] + else: + perm = [0, 1, 3, 2] + perm_str = ", ".join(map(str, perm)) + + kt, qk, qk_sc, attn_w = tmp(), tmp(), tmp(), tmp() + return "\n\t".join([ + f"[{kt}] = transpose({K}, permutation=[{perm_str}]);", + f"[{qk}] = matmul({Q}, {kt});", + f"[{qk_sc}] = mul({qk}, {scale_c});", + f"[{attn_w}] = softmax({qk_sc}, axis=-1);", + f"[{output}] = matmul({attn_w}, {V});", + ]) + + # --- Compile-time constant generation --- + + def _convert_arange(self, node: fx.Node, output: str, inputs: List[str]) -> Optional[str]: + """aten.arange.* — evaluate at export time, embed as @bytes(...) constant. + + Overload args: + arange.default(end) + arange.start(start, end) + arange.start_step(start, end, step) + dtype / device come in via kwargs. + """ + args = node.args + target_str = str(node.target) + if "start_step" in target_str: + start, end, step = args[0], args[1], args[2] + elif "start" in target_str: + start, end = args[0], args[1] + step = node.kwargs.get("step", 1) + else: + start, end, step = 0, args[0], 1 - # Handle flatten(input, start_dim, end_dim) style arguments. - if len(node.args) >= 2 and isinstance(node.args[1], int): - start_dim = int(node.args[1]) - end_dim = int(node.args[2]) if len(node.args) > 2 and isinstance(node.args[2], int) else -1 - if node.args and isinstance(node.args[0], fx.Node): - in_shape = self._get_node_shape(node.args[0]) - if in_shape: - rank = len(in_shape) - if end_dim < 0: - end_dim += rank - if 0 <= start_dim <= end_dim < rank: - flat_dim = math.prod(in_shape[start_dim:end_dim + 1]) - new_shape = in_shape[:start_dim] + [int(flat_dim)] + in_shape[end_dim + 1:] - shape_str = ', '.join(map(str, new_shape)) - return f'[{output}] = reshape({input_tensor}, newShape=[{shape_str}]);' + dtype = node.kwargs.get("dtype", torch.float32) or torch.float32 + values = torch.arange(start, end, step, dtype=dtype) - # Extract new shape from args - if len(node.args) > 1: - new_shape = node.args[1] - if isinstance(new_shape, (list, tuple)): - shape_str = ', '.join(map(str, new_shape)) - return f'[{output}] = reshape({input_tensor}, newShape=[{shape_str}]);' + const_name = f"const_arange_{self.operand_counter}" + self.operand_counter += 1 + self.inline_constants[const_name] = values + self.operand_shapes[const_name] = list(values.shape) + # Override so downstream ops reference the constant operand + self.node_to_operand[node.name] = const_name + return None # no entry emitted in nodes {} - # Last-resort: preserve rank or flatten using input metadata when available. - if node.args and isinstance(node.args[0], fx.Node): - in_shape = self._get_node_shape(node.args[0]) - if in_shape: - if len(in_shape) > 2: - batch = int(in_shape[0]) - features = int(math.prod(in_shape[1:])) - return f'[{output}] = reshape({input_tensor}, newShape=[{batch}, {features}]);' - shape_str = ', '.join(map(str, in_shape)) - return f'[{output}] = reshape({input_tensor}, newShape=[{shape_str}]);' + def _convert_full(self, node: fx.Node, output: str, inputs: List[str]) -> Optional[str]: + """aten.full.default(size, fill_value, *, dtype) — bake into @bytes.""" + args = node.args + size = list(args[0]) if args else [] + fill_value = args[1] if len(args) > 1 else 0 + dtype = node.kwargs.get("dtype", torch.float32) or torch.float32 + values = torch.full(size, fill_value, dtype=dtype) + const_name = f"const_full_{self.operand_counter}" + self.operand_counter += 1 + self.inline_constants[const_name] = values + self.operand_shapes[const_name] = list(values.shape) + self.node_to_operand[node.name] = const_name + return None + + def _convert_full_zeros(self, node: fx.Node, output: str, inputs: List[str]) -> Optional[str]: + """aten.zeros.default — bake into @bytes.""" + size = list(node.args[0]) if node.args else [] + dtype = node.kwargs.get("dtype", torch.float32) or torch.float32 + values = torch.zeros(size, dtype=dtype) + const_name = f"const_zeros_{self.operand_counter}" + self.operand_counter += 1 + self.inline_constants[const_name] = values + self.operand_shapes[const_name] = list(values.shape) + self.node_to_operand[node.name] = const_name + return None + + def _convert_full_ones(self, node: fx.Node, output: str, inputs: List[str]) -> Optional[str]: + """aten.ones.default — bake into @bytes.""" + size = list(node.args[0]) if node.args else [] + dtype = node.kwargs.get("dtype", torch.float32) or torch.float32 + values = torch.ones(size, dtype=dtype) + const_name = f"const_ones_{self.operand_counter}" + self.operand_counter += 1 + self.inline_constants[const_name] = values + self.operand_shapes[const_name] = list(values.shape) + self.node_to_operand[node.name] = const_name + return None - return f'[{output}] = add({input_tensor}, {input_tensor});' + # --- No-ops / cast --- def _convert_identity(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """ - Emit a shape-preserving identity via reshape to keep the graph executable - when an op is not mapped yet. - """ - if len(inputs) == 1: - return f'[{output}] = identity({inputs[0]});' - raise NotImplementedError('Invalid identity operation') + if inputs: + return f"[{output}] = identity({inputs[0]});" + raise NotImplementedError("identity: no input") + + def _emit_cast(self, node: fx.Node, output: str, inputs: List[str], target_dtype: torch.dtype) -> str: + """Emit cast or identity depending on whether the WebNN dtype actually changes.""" + x = inputs[0] if inputs else "unknown" + input_node = node.args[0] if node.args and isinstance(node.args[0], fx.Node) else None + src_dtype = self._get_node_dtype(input_node) if input_node else torch.float32 + src = self._get_webnn_dtype(src_dtype) + tgt = self._get_webnn_dtype(target_dtype) + if src == tgt: + return f"[{output}] = identity({x});" + return f"[{output}] = cast({x}, type={tgt});" + + def _convert_cast(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten._to_copy.default / aten.to.dtype — dtype in args[1] or kwargs['dtype'].""" + x = inputs[0] if inputs else "unknown" + # aten.to.dtype: args = (input, dtype, ...) + # aten._to_copy: dtype lives in kwargs + target_dtype = None + if len(node.args) > 1 and isinstance(node.args[1], torch.dtype): + target_dtype = node.args[1] + if target_dtype is None: + target_dtype = node.kwargs.get("dtype") + if not isinstance(target_dtype, torch.dtype): + return f"[{output}] = identity({x});" + return self._emit_cast(node, output, inputs, target_dtype) + + def _convert_to_device(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.to.device(input, device, dtype, ...) — ignore device, cast dtype.""" + x = inputs[0] if inputs else "unknown" + # args = (input, device, dtype, non_blocking, copy, memory_format) + target_dtype = node.args[2] if len(node.args) > 2 else None + if not isinstance(target_dtype, torch.dtype): + return f"[{output}] = identity({x});" + return self._emit_cast(node, output, inputs, target_dtype) + + def _convert_type_as(self, node: fx.Node, output: str, inputs: List[str]) -> str: + """aten.type_as.default(input, other) — cast input to dtype of other.""" + x = inputs[0] if inputs else "unknown" + other_node = node.args[1] if len(node.args) > 1 and isinstance(node.args[1], fx.Node) else None + if other_node is None: + return f"[{output}] = identity({x});" + target_dtype = self._get_node_dtype(other_node) + return self._emit_cast(node, output, inputs, target_dtype) + + # ------------------------------------------------------------------ + # Output extraction + # ------------------------------------------------------------------ - def _create_inline_constant(self, value) -> str: - """ - Create an inline constant operand for a scalar value. - These constants are embedded directly in the .webnn file, not stored in safetensors. - """ - # Check if we already have this constant - for const_name, const_value in self.inline_constants.items(): - value_type = type(value) - if isinstance(value, torch.Tensor) and isinstance(const_value, torch.Tensor): - if torch.allclose(value, const_value, rtol=1e-5, atol=1e-8): - return const_name - elif isinstance(value, value_type) and isinstance(const_value, value_type): - return const_name - - # Create a new constant operand - const_name = f'const_scalar_{self.operand_counter}' - self.operand_counter += 1 + def _extract_outputs(self, gm: fx.GraphModule) -> str: + outputs = [] + for node in gm.graph.nodes: + if node.op == "output": + flat = node.args[0] + if isinstance(flat, (list, tuple)): + for arg in flat: + if isinstance(arg, fx.Node): + outputs.append(self._get_input_operand(arg)) + elif isinstance(flat, fx.Node): + outputs.append(self._get_input_operand(flat)) + return "; ".join(outputs) + ";" if outputs else "" - # Store the constant value - self.inline_constants[const_name] = value + # ------------------------------------------------------------------ + # Inline constant helpers + # ------------------------------------------------------------------ - return const_name + def _create_inline_constant(self, value) -> str: + for name, v in self.inline_constants.items(): + if type(v) is type(value): + if isinstance(value, torch.Tensor) and isinstance(v, torch.Tensor): + if torch.allclose(value, v): + return name + elif v == value: + return name + name = f"const_scalar_{self.operand_counter}" + self.operand_counter += 1 + self.inline_constants[name] = value + return name def _extract_inline_constants(self) -> str: - """Extract inline scalar constants that are embedded in the .webnn file""" consts = [] - for name, value in self.inline_constants.items(): - # Determine the type based on the value if isinstance(value, torch.Tensor): dtype = self._get_webnn_dtype(value.dtype) shape = list(value.shape) - shape_str = ', '.join(map(str, shape)) raw = value.cpu().numpy().tobytes() - byte_list = ', '.join(str(b) for b in raw) - consts.append(f'\t{name}: {dtype}[{shape_str}] @bytes([{byte_list}]);') + byte_list = ", ".join(str(b) for b in raw) + consts.append(f"\t{name}: {dtype}[{', '.join(map(str, shape))}] @bytes([{byte_list}]);") + elif isinstance(value, float): + consts.append(f"\t{name}: f32[] @scalar({value});") + elif isinstance(value, int): + consts.append(f"\t{name}: i32[] @scalar({value});") else: - if isinstance(value, float): - dtype = 'f32' - elif isinstance(value, int): - dtype = 'i32' - else: - dtype = 'f32' # default - - consts.append(f'\t{name}: {dtype}[] @scalar({value});') - - return '\n'.join(consts) + '\n' if consts else '' - - def _get_node_shape(self, node: fx.Node) -> List[int]: - """Best-effort extraction of static shape from FX node metadata.""" - if not hasattr(node, 'meta'): - return [] - meta = node.meta - val = meta.get('val') - if val is not None and hasattr(val, 'shape'): - return [int(d) for d in val.shape] - tensor_meta = meta.get('tensor_meta') - if tensor_meta is not None and hasattr(tensor_meta, 'shape'): - return [int(d) for d in tensor_meta.shape] - return [] - - def _extract_outputs(self, gm: fx.GraphModule) -> str: - """Extract output declarations""" - outputs = [] - for node in gm.graph.nodes: - if node.op == 'output': - # Output node contains the return value - if isinstance(node.args[0], (list, tuple)): - for arg in node.args[0]: - if isinstance(arg, fx.Node): - outputs.append(self._get_input_operand(arg)) - elif isinstance(node.args[0], fx.Node): - outputs.append(self._get_input_operand(node.args[0])) + consts.append(f"\t{name}: f32[] @scalar({value});") + return "\n".join(consts) + "\n" if consts else "" - return '; '.join(outputs) + ';' if outputs else '' + # ------------------------------------------------------------------ + # Operand name management + # ------------------------------------------------------------------ def _get_operand_name(self, node: fx.Node) -> str: - """Get or create operand name for a node""" - try: - if node.name not in self.node_to_operand: - self.node_to_operand[node.name] = f'operand_{self.operand_counter}' - self.operand_counter += 1 - except AttributeError as e: - print(f"Error: {e}") + if node.name not in self.node_to_operand: + self.node_to_operand[node.name] = f"operand_{self.operand_counter}" + self.operand_counter += 1 return self.node_to_operand[node.name] def _get_input_operand(self, node) -> str: - """Get operand name for an input node""" if isinstance(node, fx.Node): - # Map FX placeholders for parameters/buffers back to state_dict keys so - # ops reference declared const operands. - if node.op == 'placeholder': - key = self._placeholder_to_state_key(node.name) - if key in self.weight_operands: - return self.weight_operands[key] - # TODO there must be a better way to understand where '_''s are coming from in the first place and correctly parse this - # Fallback: names like `running_mean` contain underscores that - # _placeholder_to_state_key splits into separate dot-segments - # (e.g. "bn.running.mean" instead of "bn.running_mean"). - # Normalise both sides by collapsing dots and underscores and retry. - key_flat = key.replace('.', '_') - for state_key, operand in self.weight_operands.items(): - if state_key.replace('.', '_') == key_flat: - return operand + # Parameter / buffer placeholder → look up in weight_operands by node name + if node.op == "placeholder" and node.name in self.weight_operands: + return self.weight_operands[node.name] if node.name in self.node_to_operand: return self.node_to_operand[node.name] - else: - return self._get_operand_name(node) + return self._get_operand_name(node) return str(node) - def _placeholder_to_state_key(self, name: str) -> str: - """ - Convert Dynamo placeholder names like: - l_self_modules_features_modules_0_parameters_weight_ - to state_dict keys like: - features.0.weight - """ - key = name - if key.startswith('l_'): - key = key[2:] - hierarchy = key.split('_') - if hierarchy[-1] == '': - hierarchy.pop(-1) - - hierarchy = filter(lambda x: x not in IGNORED_PLACEHOLDER_TOKENS, hierarchy) - placeholder_name = '.'.join(hierarchy) - return placeholder_name - - def _get_webnn_dtype(self, dtype: torch.dtype) -> str: - """Convert PyTorch dtype to WebNN dtype""" - dtype_map = { - torch.float32: 'f32', - torch.float16: 'f16', - torch.int32: 'i32', - torch.int64: 'i64', - torch.int8: 'i8', - torch.uint8: 'u8', - } - return dtype_map.get(dtype, 'f32') - - def _get_module(self, gm: fx.GraphModule, target: str): - """Get module from GraphModule by target path""" - atoms = target.split('.') - mod = gm - for atom in atoms: - if not hasattr(mod, atom): - return None - mod = getattr(mod, atom) - return mod + # ------------------------------------------------------------------ + # Shape / dtype helpers + # ------------------------------------------------------------------ + + def _get_node_shape(self, node) -> List[int]: + if node is None or not hasattr(node, "meta"): + return [] + meta = node.meta + val = meta.get("val") + if val is not None and hasattr(val, "shape"): + return [int(d) for d in val.shape] + tm = meta.get("tensor_meta") + if tm is not None and hasattr(tm, "shape"): + return [int(d) for d in tm.shape] + return [] + + def _get_node_dtype(self, node) -> torch.dtype: + if not hasattr(node, "meta"): + return torch.float32 + meta = node.meta + val = meta.get("val") + if val is not None and hasattr(val, "dtype"): + return val.dtype + tm = meta.get("tensor_meta") + if tm is not None and hasattr(tm, "dtype"): + return tm.dtype + return torch.float32 + + def _get_webnn_dtype(self, dtype) -> str: + if not isinstance(dtype, torch.dtype): + return "f32" + return { + torch.float32: "f32", + torch.float16: "f16", + torch.bfloat16: "f32", # cast down to f32 + torch.int32: "i32", + torch.int64: "i64", + torch.int8: "i8", + torch.uint8: "u8", + }.get(dtype, "f32") diff --git a/webnn_torch_export/webnn_op_mappings.py b/webnn_torch_export/webnn_op_mappings.py index 41b2a84..be1f1d8 100644 --- a/webnn_torch_export/webnn_op_mappings.py +++ b/webnn_torch_export/webnn_op_mappings.py @@ -1,125 +1,155 @@ -"""Mapping helpers from PyTorch FX call_function targets to converter methods.""" +"""ATen op → WebNN converter dispatch table. -from typing import TYPE_CHECKING, Callable, Dict, List, Optional +All targets are matched by str(node.target) which for ATen ops gives +the canonical name like "aten.conv2d.default". +""" -import torch.fx as fx +from typing import Dict, Optional -if TYPE_CHECKING: - from .webnn_generator import WebNNGraphGenerator +# Maps str(aten_op) → converter method name on WebNNGraphGenerator. +ATEN_OP_TABLE: Dict[str, str] = { + # Convolution + "aten.conv2d.default": "_convert_conv2d", + "aten.convolution.default": "_convert_convolution", - -ConverterFn = Callable[["WebNNGraphGenerator", fx.Node, str, List[str]], str] - - -EXACT_TARGET_TO_CONVERTER: Dict[str, ConverterFn] = { - "": lambda gen, node, output, inputs: gen._convert_arithmetric(node, output, inputs, "add"), - "": lambda gen, node, output, inputs: gen._convert_arithmetric(node, output, inputs, "mul"), - "": lambda gen, node, output, inputs: gen._convert_arithmetric(node, output, inputs, "div"), - "": lambda gen, node, output, inputs: gen._convert_getitem(node, output, inputs), - "": lambda gen, node, output, inputs: gen._convert_neg(node, output, inputs), -} - - -TARGET_NAME_TO_CONVERTER: Dict[str, ConverterFn] = { - # Convolution and Linear - "conv2d": lambda gen, node, output, inputs: gen._convert_conv2d(node, output, inputs), - "linear": lambda gen, node, output, inputs: gen._convert_linear(node, output, inputs), - "addmm": lambda gen, node, output, inputs: gen._convert_addmm(node, output, inputs), - "matmul": lambda gen, node, output, inputs: gen._convert_matmul(node, output, inputs), - "mm": lambda gen, node, output, inputs: gen._convert_matmul(node, output, inputs), - - # Normalization - "batch_norm": lambda gen, node, output, inputs: gen._convert_batch_norm(node, output, inputs), - "layer_norm": lambda gen, node, output, inputs: gen._convert_layer_norm(node, output, inputs), - "group_norm": lambda gen, node, output, inputs: gen._convert_group_norm(node, output, inputs), + # Linear / matmul + "aten.linear.default": "_convert_linear", + "aten.addmm.default": "_convert_addmm", + "aten.mm.default": "_convert_matmul", + "aten.matmul.default": "_convert_matmul", + "aten.t.default": "_convert_t", # Activations - "relu": lambda gen, node, output, inputs: gen._convert_relu(node, output, inputs), - "sigmoid": lambda gen, node, output, inputs: gen._convert_sigmoid(node, output, inputs), - "tanh": lambda gen, node, output, inputs: gen._convert_tanh(node, output, inputs), - "softmax": lambda gen, node, output, inputs: gen._convert_softmax(node, output, inputs), - "hardtanh": lambda gen, node, output, inputs: gen._convert_hardtanh(node, output, inputs), - "clamp": lambda gen, node, output, inputs: gen._convert_clamp(node, output, inputs), - "silu": lambda gen, node, output, inputs: gen._convert_silu(node, output, inputs), + "aten.relu.default": "_convert_relu", + "aten.relu_.default": "_convert_relu", + "aten.sigmoid.default": "_convert_sigmoid", + "aten.tanh.default": "_convert_tanh", + "aten.silu.default": "_convert_silu", + "aten.silu_.default": "_convert_silu", + "aten.hardtanh.default": "_convert_hardtanh", + "aten.hardtanh_.default": "_convert_hardtanh", + "aten.clamp.default": "_convert_clamp", + "aten.clamp_.default": "_convert_clamp", + "aten.gelu.default": "_convert_gelu", - # Arithmetic - "add": lambda gen, node, output, inputs: gen._convert_arithmetric(node, output, inputs, "add"), - "sub": lambda gen, node, output, inputs: gen._convert_arithmetric(node, output, inputs, "sub"), - "mul": lambda gen, node, output, inputs: gen._convert_arithmetric(node, output, inputs, "mul"), - "div": lambda gen, node, output, inputs: gen._convert_arithmetric(node, output, inputs, "div"), - - # Math functions - "sqrt": lambda gen, node, output, inputs: gen._convert_math(node, output, inputs, "sqrt"), - "exp": lambda gen, node, output, inputs: gen._convert_math(node, output, inputs, "exp"), - "abs": lambda gen, node, output, inputs: gen._convert_math(node, output, inputs, "abs"), - "log": lambda gen, node, output, inputs: gen._convert_math(node, output, inputs, "log"), - "cos": lambda gen, node, output, inputs: gen._convert_math(node, output, inputs, "cos"), - "sin": lambda gen, node, output, inputs: gen._convert_math(node, output, inputs, "sin"), - "pow": lambda gen, node, output, inputs: gen._convert_pow(node, output, inputs), + # Normalization + "aten.batch_norm.default": "_convert_batch_norm_aten", + "aten._native_batch_norm_legit_no_training.default": "_convert_batch_norm_no_training", + "aten.layer_norm.default": "_convert_layer_norm", + "aten.group_norm.default": "_convert_group_norm", # Pooling - "adaptive_avg_pool": lambda gen, node, output, inputs: gen._convert_global_avg_pool(node, output, inputs), - "avg_pool2d": lambda gen, node, output, inputs: gen._convert_avg_pool2d(node, output, inputs), - "max_pool2d": lambda gen, node, output, inputs: gen._convert_max_pool2d(node, output, inputs), - "mean": lambda gen, node, output, inputs: gen._convert_reduce_mean(node, output, inputs), - - # Tensor manipulation - "flatten": lambda gen, node, output, inputs: gen._convert_reshape(node, output, inputs), - "view": lambda gen, node, output, inputs: gen._convert_reshape(node, output, inputs), - "reshape": lambda gen, node, output, inputs: gen._convert_reshape(node, output, inputs), - "transpose": lambda gen, node, output, inputs: gen._convert_transpose(node, output, inputs), - "t": lambda gen, node, output, inputs: gen._convert_transpose(node, output, inputs), - "permute": lambda gen, node, output, inputs: gen._convert_transpose(node, output, inputs), - "concat": lambda gen, node, output, inputs: gen._convert_concat(node, output, inputs), - "cat": lambda gen, node, output, inputs: gen._convert_concat(node, output, inputs), - "stack": lambda gen, node, output, inputs: gen._convert_stack(node, output, inputs), - "split": lambda gen, node, output, inputs: gen._convert_split(node, output, inputs), - "slice": lambda gen, node, output, inputs: gen._convert_slice(node, output, inputs), - "expand": lambda gen, node, output, inputs: gen._convert_expand(node, output, inputs), - "pad": lambda gen, node, output, inputs: gen._convert_pad(node, output, inputs), - "tile": lambda gen, node, output, inputs: gen._convert_tile(node, output, inputs), - "to": lambda gen, node, output, inputs: gen._convert_cast(node, output, inputs), - "float": lambda gen, node, output, inputs: gen._convert_cast(node, output, inputs, "f32"), - "half": lambda gen, node, output, inputs: gen._convert_cast(node, output, inputs, "f16"), - - # Special operations - "rearrange": lambda gen, node, output, inputs: gen._convert_rearrange(node, output, inputs), - "arange": lambda gen, node, output, inputs: gen._convert_arange(node, output, inputs), - # "dropout": lambda gen, node, output, inputs: gen._convert_identity(node, output, inputs), - "einsum": lambda gen, node, output, inputs: gen._convert_einsum(node, output, inputs), - "scaled_dot_product_attention": lambda gen, node, output, inputs: gen._convert_scaled_dot_product_attention(node, output, inputs), - "interpolate": lambda gen, node, output, inputs: gen._convert_interpolate(node, output, inputs), - - # No OP - "identity": lambda gen, node, output, inputs: gen._convert_identity(node, output, inputs), - "contiguous": lambda gen, node, output, inputs: gen._convert_identity(node, output, inputs), -} + "aten.max_pool2d.default": "_convert_max_pool2d", + "aten.max_pool2d_with_indices.default": "_convert_max_pool2d", + "aten.avg_pool2d.default": "_convert_avg_pool2d", + "aten.adaptive_avg_pool2d.default": "_convert_global_avg_pool", + "aten.mean.dim": "_convert_reduce_mean", + "aten.mean.default": "_convert_reduce_mean", -SCHEMA_CONTAINS_TO_CONVERTER: Dict[str, ConverterFn] = { - "convolution": lambda gen, node, output, inputs: gen._convert_conv2d(node, output, inputs), - "relu": lambda gen, node, output, inputs: gen._convert_relu(node, output, inputs), - "add": lambda gen, node, output, inputs: gen._convert_arithmetric(node, output, inputs, "add"), + # Arithmetic + "aten.add.Tensor": "_convert_add", + "aten.add.Scalar": "_convert_add", + "aten.add_.Tensor": "_convert_add", + "aten.sub.Tensor": "_convert_sub", + "aten.sub.Scalar": "_convert_sub", + "aten.mul.Tensor": "_convert_mul", + "aten.mul.Scalar": "_convert_mul", + "aten.div.Tensor": "_convert_div", + "aten.div.Scalar": "_convert_div", + "aten.neg.default": "_convert_neg", + "aten.pow.Tensor_Scalar": "_convert_pow", + "aten.pow.Tensor_Tensor": "_convert_pow", + "aten.pow.Scalar": "_convert_pow_scalar", + + # Elementwise math + "aten.sqrt.default": "_convert_math_sqrt", + "aten.exp.default": "_convert_math_exp", + "aten.abs.default": "_convert_math_abs", + "aten.log.default": "_convert_math_log", + "aten.cos.default": "_convert_math_cos", + "aten.sin.default": "_convert_math_sin", + "aten.rsqrt.default": "_convert_rsqrt", + "aten.reciprocal.default": "_convert_reciprocal", + + # Tensor shape manipulation + "aten.reshape.default": "_convert_reshape", + "aten.view.default": "_convert_reshape", + "aten._unsafe_view.default": "_convert_reshape", + "aten.flatten.using_ints": "_convert_reshape", + "aten.permute.default": "_convert_permute", + "aten.transpose.int": "_convert_transpose", + "aten.unsqueeze.default": "_convert_unsqueeze", + "aten.squeeze.dim": "_convert_squeeze", + "aten.squeeze.default": "_convert_squeeze", + "aten.cat.default": "_convert_concat", + "aten.stack.default": "_convert_stack", + "aten.split.Tensor": "_convert_split", + "aten.split_with_sizes.default": "_convert_split", + "aten.chunk.default": "_convert_chunk", + "aten.unbind.int": "_convert_unbind", + "aten.select.int": "_convert_select", + "aten.einsum.default": "_convert_einsum", + "aten.slice.Tensor": "_convert_slice", + "aten.expand.default": "_convert_expand", + "aten.expand_as.default": "_convert_expand_as", + + # Padding + "aten.constant_pad_nd.default": "_convert_pad", + "aten.pad.default": "_convert_pad", + + # Upsampling + "aten.upsample_nearest2d.vec": "_convert_upsample_nearest2d", + "aten.upsample_nearest2d.default": "_convert_upsample_nearest2d", + "aten.upsample_bilinear2d.vec": "_convert_upsample_bilinear2d", + "aten.upsample_bilinear2d.default":"_convert_upsample_bilinear2d", + + # Softmax / attention + "aten.softmax.int": "_convert_softmax", + "aten._softmax.default": "_convert_softmax_aten", + "aten.scaled_dot_product_attention.default": "_convert_scaled_dot_product_attention", + + # Compile-time constant generation + "aten.arange.default": "_convert_arange", + "aten.arange.start": "_convert_arange", + "aten.arange.start_step": "_convert_arange", + "aten.full.default": "_convert_full", + "aten.zeros.default": "_convert_full_zeros", + "aten.ones.default": "_convert_full_ones", + + # No-ops / memory ops + "aten.clone.default": "_convert_identity", + "aten.contiguous.default": "_convert_identity", + "aten._assert_tensor_metadata.default": "_convert_identity", + "aten.contiguous.memory_format": "_convert_identity", + "aten.dropout.default": "_convert_identity", + "aten.alias.default": "_convert_identity", + + # Type cast + "aten._to_copy.default": "_convert_cast", + "aten.to.dtype": "_convert_cast", + "aten.to.device": "_convert_to_device", + "aten.to.dtype_layout": "_convert_cast", + "aten.type_as.default": "_convert_type_as", + + # Tuple/list indexing (from chunk, unbind, split) + "": "_convert_getitem", } -def resolve_pytorch_converter(target) -> Optional[ConverterFn]: - """Resolve converter callable for a PyTorch FX call_function target.""" - target_str = str(target) +def resolve_aten_converter(target) -> Optional[str]: + """Return converter method name for an ATen (or other) op target, or None.""" + return ATEN_OP_TABLE.get(str(target)) - if target_str in EXACT_TARGET_TO_CONVERTER: - return EXACT_TARGET_TO_CONVERTER[target_str] - if target_str in TARGET_NAME_TO_CONVERTER: - return TARGET_NAME_TO_CONVERTER[target_str] +if __name__ == "__main__": + from torch._decomp import _core_aten_decompositions_post_autograd - target_name = getattr(target, "__name__", None) - if target_name in TARGET_NAME_TO_CONVERTER: - return TARGET_NAME_TO_CONVERTER[target_name] + aten_ops = _core_aten_decompositions_post_autograd() + aten_ops_as_str = set(".".join((op.name().split("::"))) for op in aten_ops.keys()) + supported_aten_str = set(ATEN_OP_TABLE.keys()) - schema = getattr(target, "_schema", None) - if schema is not None: - schema_str = str(schema) - if schema_str in SCHEMA_CONTAINS_TO_CONVERTER.items(): - return SCHEMA_CONTAINS_TO_CONVERTER[schema_str] + missing_aten = aten_ops_as_str.difference(supported_aten_str) + # filter any op that contains "backward" + missing_aten_non_bwd = [op for op in missing_aten if "backward" not in op] - return None + print(missing_aten_non_bwd) \ No newline at end of file From 7f76c2dd48047a46a7740e3e100fb3bfa97060b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= Date: Mon, 16 Mar 2026 22:03:15 -0700 Subject: [PATCH 2/9] flux sample update --- examples/flux_klein/_common.py | 17 ++ examples/flux_klein/export_decoder.py | 105 +++++++ examples/flux_klein/export_flow.py | 121 ++++++++ examples/flux_klein/export_text_encoder.py | 38 +++ examples/flux_klein/main.py | 329 +-------------------- 5 files changed, 293 insertions(+), 317 deletions(-) create mode 100644 examples/flux_klein/_common.py create mode 100644 examples/flux_klein/export_decoder.py create mode 100644 examples/flux_klein/export_flow.py create mode 100644 examples/flux_klein/export_text_encoder.py diff --git a/examples/flux_klein/_common.py b/examples/flux_klein/_common.py new file mode 100644 index 0000000..c8e6e47 --- /dev/null +++ b/examples/flux_klein/_common.py @@ -0,0 +1,17 @@ +"""Shared setup for Flux Klein export scripts.""" + +import os +import torch + +OUTPUT_DIR = os.path.dirname(__file__) + + +def hf_login(): + hf_token = os.environ.get("HF_TOKEN", "") + if hf_token: + try: + from huggingface_hub import login + login(token=hf_token, add_to_git_credential=False) + print("Logged in to HuggingFace") + except Exception as e: + print(f"Note: Could not login to HuggingFace: {e}") \ No newline at end of file diff --git a/examples/flux_klein/export_decoder.py b/examples/flux_klein/export_decoder.py new file mode 100644 index 0000000..3dac2fc --- /dev/null +++ b/examples/flux_klein/export_decoder.py @@ -0,0 +1,105 @@ +""" +Export the Flux Klein AutoEncoder decoder to WebNN. + +Usage: + python export_decoder.py +""" + +import os +import torch +import numpy as np +from webnn_torch_export import export_model_with_weights +from flux2.util import load_ae +from _common import hf_login, OUTPUT_DIR + + +class DecoderOnly(torch.nn.Module): + """Wraps the AutoEncoder so only the decode path is exported.""" + + def __init__(self, ae): + super().__init__() + self.ae = ae + + def forward(self, z): + return self.ae.decode(z) + + +def main(): + hf_login() + torch.manual_seed(42) + device = "cpu" + + print("=" * 60) + print("Flux Klein — AutoEncoder Decoder Export") + print("=" * 60) + + print("\nLoading autoencoder...") + try: + autoencoder = load_ae("flux.2-klein-4b", device=device) + autoencoder.eval() + except SystemExit: + print("Could not load weights from HuggingFace — using placeholder model.") + from flux2.autoencoder import AutoEncoder, AutoEncoderParams + with torch.device(device): + autoencoder = AutoEncoder(AutoEncoderParams()) + autoencoder.eval() + + ae_params = sum(p.numel() for p in autoencoder.parameters()) + print(f"\nAutoEncoder Statistics:") + print(f" Total parameters : {ae_params:,}") + print(f" Model size (fp32): ~{ae_params * 4 / 1024**2:.1f} MB") + + sample_latent = torch.randn(1, 128, 32, 32) + print(f"\nInput latent shape: {list(sample_latent.shape)}") + + with torch.no_grad(): + decoded_image = autoencoder.decode(sample_latent) + print(f"Decoded image shape: {list(decoded_image.shape)}") + + decoder_model = DecoderOnly(autoencoder) + + webnn_path = os.path.join(OUTPUT_DIR, "flux_klein_decoder.webnn") + weights_path = os.path.join(OUTPUT_DIR, "flux_klein_decoder_weights.safetensors") + + print("\nExporting decoder to WebNN...") + try: + export_model_with_weights( + model=decoder_model, + example_input=sample_latent, + webnn_path=webnn_path, + weights_path=weights_path, + graph_name="flux_klein_decoder", + debug=False, + ) + print(f"Export successful!") + print(f" Graph : {webnn_path} ({os.path.getsize(webnn_path) / 1024:.1f} KB)") + print(f" Weights: {weights_path} ({os.path.getsize(weights_path) / 1024:.1f} KB)") + + # Validate against WebNN runtime if available + try: + import webnn + print("\nValidating against WebNN runtime...") + context = webnn.ML().create_context(device_type="auto") + webnn_graph = webnn.MLGraph.load(webnn_path, weights_path=weights_path) + input_name = webnn_graph.get_input_names()[0] + output_name = webnn_graph.get_output_names()[0] + webnn_output = context.compute( + webnn_graph, + {input_name: sample_latent.detach().cpu().numpy().astype(np.float32)}, + )[output_name] + torch_output = decoded_image.detach().cpu().numpy() + mae = float(np.mean(np.abs(webnn_output - torch_output))) + print(f"WebNN vs Torch MAE: {mae:.6f}") + except ImportError: + print("webnn package not available — skipping runtime validation.") + + except NotImplementedError as e: + print(f"Unsupported operation:\n{e}") + except Exception as e: + import traceback + print(f"Export error: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/flux_klein/export_flow.py b/examples/flux_klein/export_flow.py new file mode 100644 index 0000000..d2055d0 --- /dev/null +++ b/examples/flux_klein/export_flow.py @@ -0,0 +1,121 @@ +""" +Export the Flux Klein 4B flow model to WebNN. + +Usage: + python export_flow.py +""" + +import os +import torch +import numpy as np +from webnn_torch_export import export_model_with_weights +from flux2.util import load_flow_model +from _common import hf_login, OUTPUT_DIR + + +def main(): + hf_login() + torch.manual_seed(42) + device = "cpu" + + print("=" * 60) + print("Flux Klein 4B — Flow Model Export") + print("=" * 60) + + print("\nLoading flow model...") + try: + flow_model = load_flow_model("flux.2-klein-4b", debug_mode=True, device=device) + flow_model.eval() + except SystemExit: + print("Could not load weights from HuggingFace — using placeholder model.") + from flux2.model import Flux2, Klein4BParams + with torch.device(device): + flow_model = Flux2(Klein4BParams()).to(torch.bfloat16) + flow_model.eval() + + total_params = sum(p.numel() for p in flow_model.parameters()) + print(f"\nFlow Model Statistics:") + print(f" Total parameters : {total_params:,}") + print(f" Model size (fp32): ~{total_params * 4 / 1024**2:.1f} MB") + + # Sample inputs + batch_size = 1 + img_seq_len = 64 + txt_seq_len = 32 + inf_dtype = torch.bfloat16 + + sample_x = torch.randn(batch_size, img_seq_len, 128, dtype=inf_dtype) + sample_x_ids = torch.zeros(batch_size, img_seq_len, 4, dtype=torch.long) + sample_timesteps = torch.tensor([0.5], dtype=inf_dtype) + sample_ctx = torch.randn(batch_size, txt_seq_len, 7680, dtype=inf_dtype) + sample_ctx_ids = torch.zeros(batch_size, txt_seq_len, 4, dtype=torch.long) + sample_guidance = torch.tensor([1.0], dtype=inf_dtype) + + print("\nInput shapes:") + for name, t in [("x", sample_x), ("x_ids", sample_x_ids), + ("timesteps", sample_timesteps), ("ctx", sample_ctx), + ("ctx_ids", sample_ctx_ids), ("guidance", sample_guidance)]: + print(f" {name}: {list(t.shape)}") + + print("\nRunning forward pass...") + with torch.no_grad(): + output = flow_model( + x=sample_x, x_ids=sample_x_ids, timesteps=sample_timesteps, + ctx=sample_ctx, ctx_ids=sample_ctx_ids, guidance=sample_guidance + ) + print(f"Output shape: {output.shape}") + + webnn_path = os.path.join(OUTPUT_DIR, "flux_klein_flow.webnn") + weights_path = os.path.join(OUTPUT_DIR, "flux_klein_flow_weights.safetensors") + + print("\nExporting to WebNN...") + try: + export_model_with_weights( + model=flow_model, + example_input=(sample_x, sample_x_ids, sample_timesteps, + sample_ctx, sample_ctx_ids, sample_guidance), + webnn_path=webnn_path, + weights_path=weights_path, + graph_name="flux_klein_4b_flow", + debug=False, + ) + print(f"Export successful!") + print(f" Graph : {webnn_path} ({os.path.getsize(webnn_path) / 1024:.1f} KB)") + print(f" Weights: {weights_path} ({os.path.getsize(weights_path) / 1024:.1f} KB)") + + # Validate against WebNN runtime if available + try: + import webnn + print("\nValidating against WebNN runtime...") + context = webnn.ML().create_context(device_type="auto") + webnn_graph = webnn.MLGraph.load(webnn_path, weights_path=weights_path) + + # Build input dict — convert bfloat16 to float32 for WebNN + inputs = { + "x": sample_x.float().detach().cpu().numpy(), + "x_ids": sample_x_ids.detach().cpu().numpy(), + "timesteps": sample_timesteps.float().detach().cpu().numpy(), + "ctx": sample_ctx.float().detach().cpu().numpy(), + "ctx_ids": sample_ctx_ids.detach().cpu().numpy(), + "guidance": sample_guidance.float().detach().cpu().numpy(), + } + + webnn_output = context.compute(webnn_graph, inputs) + output_name = webnn_graph.get_output_names()[0] + torch_output = output.float().detach().cpu().numpy() + + mae = float(np.mean(np.abs(webnn_output[output_name] - torch_output))) + print(f"WebNN vs Torch MAE: {mae:.6f}") + except ImportError: + print("webnn package not available — skipping runtime validation.") + + except NotImplementedError as e: + print(f"Unsupported operation:\n{e}") + except Exception as e: + import traceback + print(f"Export error: {e}") + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/examples/flux_klein/export_text_encoder.py b/examples/flux_klein/export_text_encoder.py new file mode 100644 index 0000000..3563627 --- /dev/null +++ b/examples/flux_klein/export_text_encoder.py @@ -0,0 +1,38 @@ +""" +Export the Flux Klein text encoder (Qwen3-4B) to WebNN. + +Usage: + python export_text_encoder.py + +Note: The text encoder is very large (~16 GB). In production workflows it is +common to pre-compute and cache the text embeddings rather than running the +encoder at inference time, so exporting it is optional. +""" + +import os +import torch +from webnn_torch_export import export_model_with_weights +from _common import hf_login, OUTPUT_DIR + + +def main(): + hf_login() + torch.manual_seed(42) + device = "cpu" + + print("=" * 60) + print("Flux Klein — Text Encoder (Qwen3-4B) Export") + print("=" * 60) + print() + print("Text encoder export is not yet implemented.") + print("Qwen3-4B is a very large model; most pipelines pre-compute") + print("text embeddings and cache them instead of exporting the encoder.") + print() + print("To add support:") + print(" 1. Load the tokenizer and model from HuggingFace") + print(" 2. Wrap the encode step in a thin nn.Module") + print(" 3. Call export_model_with_weights() with a sample token tensor") + + +if __name__ == "__main__": + main() diff --git a/examples/flux_klein/main.py b/examples/flux_klein/main.py index 8e369d6..6fd3ef9 100644 --- a/examples/flux_klein/main.py +++ b/examples/flux_klein/main.py @@ -1,323 +1,18 @@ """ -Flux Klein Export Demo +Flux Klein WebNN Export — Entry Point -Demonstrates exporting a more complex multi-component diffusion model: -- Text Encoder (Qwen3-4B) -- Flow Model (Flux Klein 4B) -- AutoEncoder (VAE) +Runs all export scripts in sequence. You can also run each script individually: -This example shows how to handle models with multiple sub-components, -each of which can be exported separately. + python export_text_encoder.py # text encoder (Qwen3-4B) + python export_flow.py # flow / denoising model + python export_decoder.py # VAE decoder """ -import torch -import os -import numpy as np -from webnn_torch_export import export_model_with_weights -from flux2.util import load_flow_model, load_ae -from flux2.sampling import get_schedule, batched_prc_txt, batched_prc_img, denoise -import webnn +import export_text_encoder +import export_flow +import export_decoder -hf_token = os.environ.get("HF_TOKEN") - -# Login to HuggingFace if token is provided -if hf_token: - try: - from huggingface_hub import login - login(token=hf_token, add_to_git_credential=False) - print(f"Logged in to HuggingFace with token") - except Exception as e: - print(f"Note: Could not login to HuggingFace: {e}") - -def main(): - torch.manual_seed(42) - device = "cpu" # Using CPU for export compatibility - - print("=" * 60) - print("FLUX KLEIN 4B - WebNN Export Demo") - print("=" * 60) - print("\nThis example demonstrates exporting a complex diffusion model") - print("with multiple components: text encoder, flow model, and autoencoder.\n") - - # Model selection - model_name = "flux.2-klein-4b" - print(f"Loading model: {model_name}") - print("Note: First run will download ~8GB of model weights from HuggingFace") - - # ========================================================================= - # PART 1: Load Text Encoder - # ========================================================================= - print("\n" + "=" * 60) - print("PART 1: Text Encoder (Qwen3-4B)") - print("=" * 60) - - print("\nLoading text encoder...") - # text_encoder = load_text_encoder(model_name, device=device) - # text_encoder.eval() - - # Sample text input - sample_prompt = "a cat wearing sunglasses" - print(f"\nSample prompt: '{sample_prompt}'") - - # Encode text (simplified - real usage would involve tokenization) - # For demo purposes, we'll create a dummy encoded text tensor - # Real text encoding happens inside text_encoder with proper tokenization - print("\nSkipping text encoder export (uses Qwen3 - very large model)") - print("In production, you'd export this separately or use a pre-encoded cache") - output_dir = os.path.dirname(__file__) - - # ========================================================================= - # PART 2: Load Flow Model (Main Diffusion Model) - # ========================================================================= - export_flow_model = True - export_ae_model = False - if export_flow_model: - print("\n" + "=" * 60) - print("PART 2: Flow Model (Flux Klein 4B)") - print("=" * 60) - - print("\nLoading flow model...") - try: - flow_model = load_flow_model(model_name, debug_mode=True, device=device) - flow_model.eval() - except SystemExit: - print("\n" + "!" * 60) - print("ERROR: Could not load Flux Klein model") - print("!" * 60) - print("\nThis example requires access to the HuggingFace model repository.") - print("Please ensure:") - print(" 1. You have a HuggingFace account") - print(" 2. You've accepted the model license at:") - print(" https://huggingface.co/black-forest-labs/FLUX.2-klein-4B") - print(" 3. You're logged in: huggingface-cli login") - print("\nAlternatively, set environment variables to use local weights:") - print(" export KLEIN_4B_MODEL_PATH=/path/to/flux-2-klein-4b.safetensors") - print(" export AE_MODEL_PATH=/path/to/ae.safetensors") - print("\nFor now, this demo will continue with a placeholder model...") - print("!" * 60) - - # Create a minimal placeholder model for demonstration - from flux2.model import Flux2, Klein4BParams - with torch.device(device): - flow_model = Flux2(Klein4BParams()).to(torch.bfloat16) - flow_model.eval() - print("Using minimal placeholder model (not the real Flux Klein weights)") - - # Count parameters - total_params = sum(p.numel() for p in flow_model.parameters()) - trainable_params = sum(p.numel() for p in flow_model.parameters() if p.requires_grad) - print(f"\nFlow Model Statistics:") - print(f" - Total parameters: {total_params:,}") - print(f" - Trainable parameters: {trainable_params:,}") - print(f" - Model size (fp32): ~{total_params * 4 / 1024**2:.1f} MB") - - # Create sample inputs for the flow model - # The flow model takes: - # - x: latent image tensor [batch, seq_len, channels] - # - x_ids: position IDs for latents - # - timesteps: diffusion timestep - # - ctx: text embeddings [batch, text_seq_len, text_dim] - # - ctx_ids: position IDs for text - # - guidance: guidance scale - - batch_size = 1 - img_seq_len = 64 # Small for demo (real: ~1024 for 512x512) - txt_seq_len = 32 # Small for demo (real: ~256) - - # Note: these dimensions match Klein4BParams - latent_channels = 128 - text_embed_dim = 7680 # context_in_dim for Klein 4B - inf_dtype = torch.bfloat16 - sample_x = torch.randn(batch_size, img_seq_len, latent_channels, dtype=inf_dtype) - sample_x_ids = torch.zeros(batch_size, img_seq_len, 4, dtype=torch.long) # [t, h, w, l] - sample_timesteps = torch.tensor([0.5], dtype=inf_dtype) # Mid diffusion timestep - sample_ctx = torch.randn(batch_size, txt_seq_len, text_embed_dim, dtype=inf_dtype) - sample_ctx_ids = torch.zeros(batch_size, txt_seq_len, 4, dtype=torch.long) - sample_guidance = torch.tensor([1.0], dtype=inf_dtype) # Klein 4B uses guidance=1.0 - - print("\nInput shapes:") - print(f" - Latent image (x): {list(sample_x.shape)}") - print(f" - Image position IDs (x_ids): {list(sample_x_ids.shape)}") - print(f" - Timesteps: {list(sample_timesteps.shape)}") - print(f" - Text embeddings (ctx): {list(sample_ctx.shape)}") - print(f" - Text position IDs (ctx_ids): {list(sample_ctx_ids.shape)}") - print(f" - Guidance: {list(sample_guidance.shape)}") - - # Run inference - print("\nRunning flow model inference...") - with torch.no_grad(): - output = flow_model( - x=sample_x, - x_ids=sample_x_ids, - timesteps=sample_timesteps, - ctx=sample_ctx, - ctx_ids=sample_ctx_ids, - guidance=sample_guidance - ) - print(f"Output shape: {output.shape}") - - # Export flow model - flow_weights_path = os.path.join(output_dir, "flux_klein_flow_weights.safetensors") - flow_webnn_path = os.path.join(output_dir, "flux_klein_flow.webnn") - - print("\n" + "-" * 60) - print("Exporting Flow Model to WebNN...") - print("-" * 60) - - try: - compiled_model, exporter = export_model_with_weights( - model=flow_model, - example_input=(sample_x, sample_x_ids, sample_timesteps, - sample_ctx, sample_ctx_ids, sample_guidance), - webnn_path=flow_webnn_path, - weights_path=flow_weights_path, - graph_name="flux_klein_4b_flow", - debug=False - ) - print("Flow model export successful!") - except NotImplementedError as e: - print("Flow model export hit an unsupported operation:") - print(str(e)) - print("\nThis is expected for complex models like Flux Klein.") - print("You can incrementally add support for each operation in webnn_op_mappings.py") - except Exception as e: - print(f"Flow model export encountered an error: {e}") - import traceback - traceback.print_exc() - - # ========================================================================= - # PART 3: Load AutoEncoder - # ========================================================================= - if export_ae_model: - print("\n" + "=" * 60) - print("PART 3: AutoEncoder (VAE)") - print("=" * 60) - - print("\nLoading autoencoder...") - try: - autoencoder = load_ae(model_name, device=device) - autoencoder.eval() - except SystemExit: - print("Could not load autoencoder weights from HuggingFace.") - print("Using placeholder model without pretrained weights...") - from flux2.autoencoder import AutoEncoder, AutoEncoderParams - with torch.device(device): - autoencoder = AutoEncoder(AutoEncoderParams()) - autoencoder.eval() - - ae_params = sum(p.numel() for p in autoencoder.parameters()) - print(f"\nAutoEncoder Statistics:") - print(f" - Total parameters: {ae_params:,}") - print(f" - Model size (fp32): ~{ae_params * 4 / 1024**2:.1f} MB") - - # The autoencoder has two main functions: - # 1. encode: RGB image -> latent - # 2. decode: latent -> RGB image - - # Test decoder (latent -> image) - print("\nTesting decoder...") - sample_latent = torch.randn(1, 128, 32, 32) # [batch, channels, h, w] - print(f"Input latent shape: {list(sample_latent.shape)}") - - with torch.no_grad(): - decoded_image = autoencoder.decode(sample_latent) - print(f"Decoded image shape: {list(decoded_image.shape)}") - - # Export decoder - decoder_weights_path = os.path.join(output_dir, "flux_klein_decoder_weights.safetensors") - decoder_webnn_path = os.path.join(output_dir, "flux_klein_decoder.webnn") - - print("\n" + "-" * 60) - print("Exporting Decoder to WebNN...") - print("-" * 60) - - try: - # Create a wrapper for decoder only - class DecoderOnly(torch.nn.Module): - def __init__(self, ae): - super().__init__() - self.ae = ae - - def forward(self, z): - return self.ae.decode(z) - - decoder_model = DecoderOnly(autoencoder) - - compiled_decoder, decoder_exporter = export_model_with_weights( - model=decoder_model, - example_input=sample_latent, - webnn_path=decoder_webnn_path, - weights_path=decoder_weights_path, - graph_name="flux_klein_decoder", - debug=False - ) - print("Decoder export successful!") - - # Test WebNN decoder - print("\nTesting WebNN decoder...") - context = webnn.ML().create_context(device_type="auto") - webnn_graph = webnn.MLGraph.load( - decoder_webnn_path, - weights_path=decoder_weights_path - ) - input_name = webnn_graph.get_input_names()[0] - output_name = webnn_graph.get_output_names()[0] - webnn_output = context.compute( - webnn_graph, - {input_name: sample_latent.detach().cpu().numpy().astype(np.float32)}, - )[output_name] - torch_output = decoded_image.detach().cpu().numpy() - - mae = float(np.mean(np.abs(webnn_output - torch_output))) - print(f"WebNN vs Torch MAE: {mae:.6f}") - - except NotImplementedError as e: - print("Decoder export hit an unsupported operation:") - print(str(e)) - print("\nTo add support, implement the converter in webnn_generator.py") - except Exception as e: - print(f"Decoder export encountered an error: {e}") - import traceback - traceback.print_exc() - - # ========================================================================= - # Summary - # ========================================================================= - print("\n" + "=" * 60) - print("EXPORT SUMMARY") - print("=" * 60) - - print("\nFlux Klein 4B is a complex multi-component model:") - print(" 1. Text Encoder (Qwen3-4B): Encodes text prompts") - print(" 2. Flow Model: Main diffusion model for denoising") - print(" 3. AutoEncoder: Converts between pixel and latent space") - - print("\nExported components:") - if export_flow_model: - if os.path.exists(flow_webnn_path): - flow_size = os.path.getsize(flow_webnn_path) / 1024 - weights_size = os.path.getsize(flow_weights_path) / 1024 - print(f" - Flow Model WebNN: {flow_webnn_path}") - print(f" Size: {flow_size:.1f} KB (graph) + {weights_size:.1f} KB (weights)") - if export_ae_model: - if os.path.exists(decoder_webnn_path): - decoder_size = os.path.getsize(decoder_webnn_path) / 1024 - decoder_weights_size = os.path.getsize(decoder_weights_path) / 1024 - print(f" - Decoder WebNN: {decoder_webnn_path}") - print(f" Size: {decoder_size:.1f} KB (graph) + {decoder_weights_size:.1f} KB (weights)") - - print("\nKey Takeaways:") - print(" - Complex models can be exported component-by-component") - print(" - Each component has its own graph and weights") - print(" - Text encoder can be cached/pre-computed for efficiency") - print(" - Flow model runs iteratively during generation (4-50 steps)") - print(" - Decoder runs once at the end to convert latents to images") - - print("\nNote: Some operations in Flux Klein may not be fully supported") - print("by WebNN yet. This example demonstrates the export process and") - print("structure for complex multi-component models.") - print("\n" + "=" * 60) - - -if __name__ == '__main__': - main() +if __name__ == "__main__": + # export_text_encoder.main() + export_flow.main() + export_decoder.main() From 2337c11d496c072f0eab7a44c3a2d26135e8952c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= Date: Mon, 16 Mar 2026 22:21:58 -0700 Subject: [PATCH 3/9] bump pywebnn version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0dce9e2..4e6f8da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ dev = [ "pytest>=7.0.0", "pytest-cov>=4.0.0", - "pywebnn @ git+https://github.com/gedoensmax/pywebnn.git@maximilianm/safetensor_support", + "pywebnn @ git+https://github.com/rustnn/pywebnn.git@safetensor_support", ] [tool.pytest.ini_options] From cf01aa1fb6eef6c7146e8fbb8234b9f2f81ed250 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= Date: Mon, 16 Mar 2026 22:33:19 -0700 Subject: [PATCH 4/9] set softmax axis as -1 is not supported --- webnn_torch_export/webnn_generator.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/webnn_torch_export/webnn_generator.py b/webnn_torch_export/webnn_generator.py index 32a4cd6..528aca6 100644 --- a/webnn_torch_export/webnn_generator.py +++ b/webnn_torch_export/webnn_generator.py @@ -1238,6 +1238,8 @@ def _convert_softmax(self, node: fx.Node, output: str, inputs: List[str]) -> str """aten.softmax.int(input, dim, half_to_float)""" x = inputs[0] if inputs else "unknown" axis = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) + if axis == -1: + axis = len(self._get_node_shape(node.args[0])) - 1 return f"[{output}] = softmax({x}, axis={axis});" def _convert_softmax_aten(self, node: fx.Node, output: str, inputs: List[str]) -> str: @@ -1267,11 +1269,12 @@ def tmp(): perm_str = ", ".join(map(str, perm)) kt, qk, qk_sc, attn_w = tmp(), tmp(), tmp(), tmp() + axis = len(q_shape) - 1 return "\n\t".join([ f"[{kt}] = transpose({K}, permutation=[{perm_str}]);", f"[{qk}] = matmul({Q}, {kt});", f"[{qk_sc}] = mul({qk}, {scale_c});", - f"[{attn_w}] = softmax({qk_sc}, axis=-1);", + f"[{attn_w}] = softmax({qk_sc}, axis={axis});", f"[{output}] = matmul({attn_w}, {V});", ]) From 56b187eae5a4bf57075d6dc78747e94724bce63e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= Date: Mon, 16 Mar 2026 22:49:58 -0700 Subject: [PATCH 5/9] only positive concat axis --- tests/models.py | 15 +++++++++++++++ tests/test_operations.py | 23 ++++++++++++++++++++++- webnn_torch_export/webnn_generator.py | 6 ++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/models.py b/tests/models.py index 85a64dd..dd43da7 100644 --- a/tests/models.py +++ b/tests/models.py @@ -167,6 +167,21 @@ def forward(self, x): return self.norm(x) +# --------------------------------------------------------------------------- +# Concat models (used in test_concat.py) +# --------------------------------------------------------------------------- + +class ConcatModel(nn.Module): + """Concatenate any number of tensors along the given axis.""" + + def __init__(self, axis: int): + super().__init__() + self.axis = axis + + def forward(self, *tensors): + return torch.cat(tensors, dim=self.axis) + + # --------------------------------------------------------------------------- # Batch norm models (used in test_batch_norm.py) # --------------------------------------------------------------------------- diff --git a/tests/test_operations.py b/tests/test_operations.py index 291aa01..83b5e55 100644 --- a/tests/test_operations.py +++ b/tests/test_operations.py @@ -4,12 +4,13 @@ import torch import torch.nn as nn from .models import ( + ConcatModel, + NormalizationOpsModel, PointwiseActivationsModel, PointwiseArithmeticModel, PointwiseMathModel, ReductionOpsModel, ShapeOpsModel, - NormalizationOpsModel, ) from .conftest import assert_export_matches, validate_webnn_execution @@ -186,6 +187,26 @@ def forward(self, x): assert_export_matches(model, x, rtol=1e-4, atol=1e-4) +# --------------------------------------------------------------------------- +# concat (aten.cat) +# --------------------------------------------------------------------------- + +# 4-D shape (2, 3, 4, 5) — supports axes 0, 1, 3 and their negative equivalents +@pytest.mark.parametrize("n_inputs,axis", [ + (2, 0), (2, 1), (2, 3), (2, -1), (2, -3), + (5, 0), (5, 1), (5, 3), (5, -1), (5, -3), +], ids=[f"{n}in_ax{a}" for n, a in [ + (2, 0), (2, 1), (2, 3), (2, -1), (2, -3), + (5, 0), (5, 1), (5, 3), (5, -1), (5, -3), +]]) +def test_concat(n_inputs, axis): + torch._dynamo.reset() + model = ConcatModel(axis) + inputs = tuple(torch.randn(2, 3, 4, 5) for _ in range(n_inputs)) + assert_export_matches(model, inputs) + validate_webnn_execution(model, inputs) + + # --------------------------------------------------------------------------- # type_as # --------------------------------------------------------------------------- diff --git a/webnn_torch_export/webnn_generator.py b/webnn_torch_export/webnn_generator.py index 528aca6..ce980af 100644 --- a/webnn_torch_export/webnn_generator.py +++ b/webnn_torch_export/webnn_generator.py @@ -876,6 +876,12 @@ def _convert_concat(self, node: fx.Node, output: str, inputs: List[str]) -> str: axis = node.kwargs["dim"] elif len(node.args) > 1 and not isinstance(node.args[1], fx.Node): axis = node.args[1] + if axis < 0: + tensor_list = node.args[0] + first_tensor = tensor_list[0] if isinstance(tensor_list, (list, tuple)) else tensor_list + ndim = len(self._get_node_shape(first_tensor)) + if ndim > 0: + axis = ndim + axis if inputs: return f"[{output}] = concat([{', '.join(inputs)}], axis={axis});" if len(node.args) >= 1 and isinstance(node.args[0], (list, tuple)): From f2db057e9849fa486e9961de9874cc105df09625 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= Date: Fri, 20 Mar 2026 09:23:58 -0700 Subject: [PATCH 6/9] rebase safetensors to latest --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4e6f8da..385fe98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ dependencies = [ dev = [ "pytest>=7.0.0", "pytest-cov>=4.0.0", - "pywebnn @ git+https://github.com/rustnn/pywebnn.git@safetensor_support", + "pywebnn @ git+https://github.com/gedoensmax/pywebnn.git@safetensor_support", ] [tool.pytest.ini_options] From cd4b2c1213b30ffbd638a2af614aec7c4659f425 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= Date: Thu, 9 Apr 2026 14:42:31 +0200 Subject: [PATCH 7/9] safetensor in main --- pyproject.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 385fe98..cf14721 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ readme = "README.md" requires-python = ">=3.8" license = {text = "Apache License (2.0)"} authors = [ - {name = "Your Name", email = "your.email@example.com"} + {name = "Maximilian Müller", email = "maximilianm@nvidia.com"} ] classifiers = [ "Development Status :: 3 - Alpha", @@ -22,6 +22,8 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", ] dependencies = [ "torch>=2.0.0", @@ -35,7 +37,7 @@ dependencies = [ dev = [ "pytest>=7.0.0", "pytest-cov>=4.0.0", - "pywebnn @ git+https://github.com/gedoensmax/pywebnn.git@safetensor_support", + "pywebnn @ git+https://github.com/rustnn/pywebnn.git", ] [tool.pytest.ini_options] From a603242d52b67234de72e7d01b97eca1656967ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= Date: Mon, 13 Apr 2026 15:23:11 +0200 Subject: [PATCH 8/9] adopt fixes in webnn graph --- webnn_torch_export/webnn_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webnn_torch_export/webnn_generator.py b/webnn_torch_export/webnn_generator.py index ce980af..f7d24c0 100644 --- a/webnn_torch_export/webnn_generator.py +++ b/webnn_torch_export/webnn_generator.py @@ -242,7 +242,7 @@ def as_pair(v): params.append(f"groups={groups}") params.append('inputLayout="nchw"') if padding != [0, 0]: - params.append(f"pads=[{padding[0]}, {padding[0]}, {padding[1]}, {padding[1]}]") + params.append(f"padding=[{padding[0]}, {padding[0]}, {padding[1]}, {padding[1]}]") if stride != [1, 1]: params.append(f"strides=[{stride[0]}, {stride[1]}]") From 8d7a0eee16f128f130f1cd1be53f81a234dc140d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= Date: Mon, 13 Apr 2026 16:20:19 +0200 Subject: [PATCH 9/9] expand code coverage and add to tests --- .github/workflows/test.yml | 14 +- tests/conftest.py | 32 +++ tests/models.py | 262 +++++++++++++++++++++++- tests/test_shape_operations.py | 52 ++++- tests/test_single_ops.py | 52 ++++- webnn_torch_export/webnn_generator.py | 34 +-- webnn_torch_export/webnn_op_mappings.py | 1 - 7 files changed, 406 insertions(+), 41 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8716057..84d66b4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,4 +36,16 @@ jobs: pip install -e ".[dev]" - name: Run tests - run: pytest tests/ -v --tb=short -m "not slow" + run: | + pytest tests/ -v --tb=short -m "not slow" \ + --cov=webnn_torch_export \ + --cov-report=term-missing \ + --cov-report=xml \ + --cov-fail-under=75 + + - name: Upload coverage report + if: always() + uses: actions/upload-artifact@v4 + with: + name: coverage-xml + path: coverage.xml diff --git a/tests/conftest.py b/tests/conftest.py index d5191d9..c4b2820 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,6 +63,38 @@ def assert_export_matches( return exported_callable, exporter +def assert_generates_webnn( + model: torch.nn.Module, + example_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + debug: bool = False, +) -> str: + """ + Export a model and generate its WebNN graph text, returning the text. + + Does NOT require the WebNN runtime — only exercises the generator. + Useful for asserting that the generator produces valid output for a given op. + """ + from webnn_torch_export import export_model + import tempfile + import os + + if not isinstance(example_input, tuple): + example_input = (example_input,) + + model.eval() + exporter, _ = export_model(model, example_input, debug=debug) + + with tempfile.NamedTemporaryFile(mode="w", suffix=".webnn", delete=False) as f: + path = f.name + try: + exporter.save_to_webnn(path) + with open(path) as f: + return f.read() + finally: + if os.path.exists(path): + os.unlink(path) + + def validate_webnn_execution( model: torch.nn.Module, example_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], diff --git a/tests/models.py b/tests/models.py index dd43da7..67d1045 100644 --- a/tests/models.py +++ b/tests/models.py @@ -16,9 +16,13 @@ class SingleConv(nn.Module): """Wrapper for testing single Conv2d operation""" - def __init__(self, in_channels=3, out_channels=16, kernel_size=3, padding=1): + def __init__(self, in_channels=3, out_channels=16, kernel_size=3, padding=1, + stride=1, dilation=1, bias=True): super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding) + self.conv = nn.Conv2d( + in_channels, out_channels, kernel_size, + padding=padding, stride=stride, dilation=dilation, bias=bias, + ) def forward(self, x): return self.conv(x) @@ -259,6 +263,37 @@ def forward(self, x): return x +class SplitEven(nn.Module): + """Split tensor into equal-sized chunks along a given dim.""" + + def __init__(self, split_size: int, dim: int = 0): + super().__init__() + self.split_size = split_size + self.dim = dim + + def forward(self, x): + chunks = torch.split(x, self.split_size, dim=self.dim) + # Sum all chunks so the model has a single tensor output + out = chunks[0] + for c in chunks[1:]: + out = out + c + return out + + +class SplitWithSizes(nn.Module): + """Split tensor into variable-sized sections along a given dim.""" + + def __init__(self, split_sizes, dim: int = 0): + super().__init__() + self.split_sizes = split_sizes + self.dim = dim + + def forward(self, x): + chunks = torch.split(x, self.split_sizes, dim=self.dim) + # Concatenate back so the model has a single tensor output + return torch.cat(chunks, dim=self.dim) + + class MNISTClassifier(nn.Module): """ Full MNIST classifier: @@ -281,3 +316,226 @@ def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return x + + +# --------------------------------------------------------------------------- +# Activation function wrappers +# --------------------------------------------------------------------------- + +class SingleGELU(nn.Module): + def forward(self, x): + return F.gelu(x) + + +class SingleSiLU(nn.Module): + def forward(self, x): + return F.silu(x) + + +class SingleHardtanh(nn.Module): + def forward(self, x): + return F.hardtanh(x, min_val=-1.0, max_val=1.0) + + +class SingleClamp(nn.Module): + def forward(self, x): + return torch.clamp(x, min=-1.0, max=1.0) + + +# --------------------------------------------------------------------------- +# Elementwise math wrappers +# --------------------------------------------------------------------------- + +class SingleNeg(nn.Module): + def forward(self, x): + return torch.neg(x) + + +class SingleCos(nn.Module): + def forward(self, x): + return torch.cos(x) + + +class SingleSin(nn.Module): + def forward(self, x): + return torch.sin(x) + + +class SinglePowTensor(nn.Module): + """aten.pow.Tensor_Scalar — tensor raised to a scalar exponent.""" + def forward(self, x): + return x ** 2 + + +# --------------------------------------------------------------------------- +# Shape manipulation wrappers +# --------------------------------------------------------------------------- + +class SingleUnsqueeze(nn.Module): + def __init__(self, dim: int = 1): + super().__init__() + self.dim = dim + + def forward(self, x): + return x.unsqueeze(self.dim) + + +class SingleSqueeze(nn.Module): + def __init__(self, dim: int = 1): + super().__init__() + self.dim = dim + + def forward(self, x): + return x.squeeze(self.dim) + + +class SingleCat(nn.Module): + def __init__(self, dim: int = 0): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.cat([x, x], dim=self.dim) + + +class SingleStack(nn.Module): + def __init__(self, dim: int = 0): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.stack([x, x], dim=self.dim) + + +class SingleChunk(nn.Module): + def __init__(self, n: int = 3, dim: int = 0): + super().__init__() + self.n = n + self.dim = dim + + def forward(self, x): + chunks = torch.chunk(x, self.n, dim=self.dim) + out = chunks[0] + for c in chunks[1:]: + out = out + c + return out + + +class SingleUnbind(nn.Module): + def __init__(self, dim: int = 0): + super().__init__() + self.dim = dim + + def forward(self, x): + parts = torch.unbind(x, dim=self.dim) + out = parts[0] + for p in parts[1:]: + out = out + p + return out + + +class SingleSelect(nn.Module): + def __init__(self, dim: int = 0, index: int = 0): + super().__init__() + self.dim = dim + self.index = index + + def forward(self, x): + return x.select(self.dim, self.index) + + +class SingleSlice(nn.Module): + def forward(self, x): + return x[:, 1:3] + + +class SingleExpand(nn.Module): + """Input must be shape (1, D); expands to (N, D).""" + def __init__(self, n: int = 3): + super().__init__() + self.n = n + + def forward(self, x): + return x.expand(self.n, -1) + + +class SinglePad(nn.Module): + def __init__(self, padding=(1, 1)): + super().__init__() + self.padding = padding + + def forward(self, x): + return F.pad(x, self.padding) + + +class SingleTranspose(nn.Module): + def __init__(self, dim0: int = 0, dim1: int = 1): + super().__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + return x.transpose(self.dim0, self.dim1) + + +# --------------------------------------------------------------------------- +# Reduction / pooling wrappers +# --------------------------------------------------------------------------- + +class SingleMeanDim(nn.Module): + def __init__(self, dim=-1, keepdim=False): + super().__init__() + self.dim = dim + self.keepdim = keepdim + + def forward(self, x): + return x.mean(dim=self.dim, keepdim=self.keepdim) + + +class SingleGlobalAvgPool(nn.Module): + def forward(self, x): + return F.adaptive_avg_pool2d(x, 1) + + +class SingleGroupNorm(nn.Module): + def __init__(self, num_groups: int = 2, num_channels: int = 4): + super().__init__() + self.gn = nn.GroupNorm(num_groups, num_channels) + + def forward(self, x): + return self.gn(x) + + +# --------------------------------------------------------------------------- +# Constant generation wrappers +# --------------------------------------------------------------------------- + +class AddWithZeros(nn.Module): + """Uses aten.zeros.default to create a zero tensor and adds it to x.""" + def forward(self, x): + z = torch.zeros(x.shape[-1]) + return x + z + + +class AddWithOnes(nn.Module): + """Uses aten.ones.default to create an all-ones tensor and adds it to x.""" + def forward(self, x): + o = torch.ones(x.shape[-1]) + return x + o + + +class AddWithFull(nn.Module): + """Uses aten.full.default to create a constant tensor and adds it to x.""" + def forward(self, x): + c = torch.full((x.shape[-1],), 0.5) + return x + c + + +# --------------------------------------------------------------------------- +# Type cast wrappers +# --------------------------------------------------------------------------- + +class CastToFloat16AndBack(nn.Module): + """Cast to float16 then back to float32 (exercises _convert_cast).""" + def forward(self, x): + return x.to(torch.float16).to(torch.float32) diff --git a/tests/test_shape_operations.py b/tests/test_shape_operations.py index ac304ac..950c524 100644 --- a/tests/test_shape_operations.py +++ b/tests/test_shape_operations.py @@ -2,8 +2,9 @@ import pytest import torch -from .models import RearrangeModel +from .models import RearrangeModel, SplitEven, SplitWithSizes from .conftest import assert_export_matches, validate_webnn_execution +import torch._dynamo def test_rearrange_export(): @@ -18,3 +19,52 @@ def test_rearrange_export(): assert output.shape == (1, 3, 8, 8) validate_webnn_execution(rearrange_model, rearrange_input, rtol=1e-5, atol=1e-5) + +# --------------------------------------------------------------------------- +# Even split (aten.split.Tensor) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("split_size,dim,shape", [ + (2, 0, (6, 4)), # 3 chunks along dim 0 + (4, 1, (2, 8, 3)), # 2 chunks along dim 1 + (1, -1, (3, 4)), # 4 chunks along last dim (negative index) +]) +def test_split_even_export(split_size, dim, shape): + torch._dynamo.reset() + model = SplitEven(split_size, dim) + x = torch.randn(*shape) + assert_export_matches(model, x, rtol=1e-5, atol=1e-5) + + +# --------------------------------------------------------------------------- +# Uneven / variable-size split (aten.split_with_sizes.default) +# --------------------------------------------------------------------------- + +@pytest.mark.parametrize("split_sizes,dim,shape", [ + ([2, 3, 1], 0, (6, 4)), # uneven along dim 0 + ([3, 5], 1, (2, 8, 3)), # uneven along dim 1 + ([1, 2, 1], -1, (3, 4)), # uneven along last dim +]) +def test_split_with_sizes_export(split_sizes, dim, shape): + torch._dynamo.reset() + model = SplitWithSizes(split_sizes, dim) + x = torch.randn(*shape) + assert_export_matches(model, x, rtol=1e-5, atol=1e-5) + + +# --------------------------------------------------------------------------- +# WebNN execution tests +# --------------------------------------------------------------------------- + +def test_split_even_webnn(): + torch._dynamo.reset() + model = SplitEven(split_size=2, dim=0) + x = torch.randn(6, 4) + validate_webnn_execution(model, x, rtol=1e-4, atol=1e-4) + + +def test_split_with_sizes_webnn(): + torch._dynamo.reset() + model = SplitWithSizes([2, 3, 1], dim=0) + x = torch.randn(6, 4) + validate_webnn_execution(model, x, rtol=1e-4, atol=1e-4) diff --git a/tests/test_single_ops.py b/tests/test_single_ops.py index 204fa8c..de96c21 100644 --- a/tests/test_single_ops.py +++ b/tests/test_single_ops.py @@ -10,22 +10,44 @@ class ConvOpConfig: - def __init__(self, in_cannels, out_channels, kernel_size, stride, padding): - self.in_cannels = in_cannels + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, bias=True): + self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding + self.dilation = dilation + self.bias = bias def name(self): - return f"conv_{self.in_cannels}x{self.out_channels}_{self.kernel_size}x{self.kernel_size}" - - -CONV_OPS = [ # config, input_shape - (ConvOpConfig(16, 32, 3, stride=1, padding=1), (1, 16, 10, 10)), - (ConvOpConfig(1, 32, 5, stride=1, padding=2), (1, 1, 28, 28)), - (ConvOpConfig(3, 64, 3, stride=1, padding=1), (1, 3, 28, 28)), - (ConvOpConfig(16, 32, 1, stride=1, padding=0), (1, 16, 28, 28)), + parts = [f"conv_{self.in_channels}x{self.out_channels}_k{self.kernel_size}"] + if self.stride != 1: + parts.append(f"s{self.stride}") + if self.dilation != 1: + parts.append(f"d{self.dilation}") + if not self.bias: + parts.append("nobias") + return "_".join(parts) + + +CONV_OPS = [ # (config, input_shape) + # baseline + (ConvOpConfig(16, 32, 3, padding=1), (1, 16, 10, 10)), + (ConvOpConfig(1, 32, 5, padding=2), (1, 1, 28, 28)), + (ConvOpConfig(3, 64, 3, padding=1), (1, 3, 28, 28)), + (ConvOpConfig(16, 32, 1, padding=0), (1, 16, 28, 28)), + # no bias + (ConvOpConfig(16, 32, 3, padding=1, bias=False), (1, 16, 10, 10)), + (ConvOpConfig(3, 64, 3, padding=1, bias=False), (1, 3, 28, 28)), + # stride > 1 + (ConvOpConfig(16, 32, 3, stride=2, padding=1), (1, 16, 14, 14)), + (ConvOpConfig(3, 32, 3, stride=2, padding=1), (1, 3, 28, 28)), + (ConvOpConfig(16, 32, 3, stride=2, padding=1, bias=False), (1, 16, 14, 14)), + # dilation > 1 (padding = dilation to preserve spatial size) + (ConvOpConfig(16, 32, 3, dilation=2, padding=2), (1, 16, 14, 14)), + (ConvOpConfig(3, 32, 3, dilation=2, padding=2), (1, 3, 28, 28)), + (ConvOpConfig(16, 32, 3, dilation=2, padding=2, bias=False), (1, 16, 14, 14)), ] class GemmOpConfig: @@ -78,7 +100,15 @@ def name(self): ) def test_conv_op(conv_config, input_shape): torch._dynamo.reset() - model = SingleConv(conv_config.in_cannels, conv_config.out_channels, conv_config.kernel_size, conv_config.stride) + model = SingleConv( + conv_config.in_channels, + conv_config.out_channels, + conv_config.kernel_size, + padding=conv_config.padding, + stride=conv_config.stride, + dilation=conv_config.dilation, + bias=conv_config.bias, + ) x = torch.randn(*input_shape) assert_export_matches(model, x, rtol=1e-3) validate_webnn_execution(model, x) diff --git a/webnn_torch_export/webnn_generator.py b/webnn_torch_export/webnn_generator.py index f7d24c0..e216887 100644 --- a/webnn_torch_export/webnn_generator.py +++ b/webnn_torch_export/webnn_generator.py @@ -280,27 +280,7 @@ def _convert_conv2d(self, node: fx.Node, output: str, inputs: List[str]) -> str: bias_info = (bias_operand, bias_shape[0] if bias_shape else 0) return self._emit_conv2d(input_tensor, weight, bias_info, stride, padding, dilation, groups, output) - - def _convert_convolution(self, node: fx.Node, output: str, inputs: List[str]) -> str: - """aten.convolution.default(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups)""" - args = node.args - input_tensor = inputs[0] if inputs else "unknown" - weight = self._get_input_operand(args[1]) if len(args) > 1 else "unknown" - - stride = args[3] if len(args) > 3 else [1, 1] - padding = args[4] if len(args) > 4 else [0, 0] - dilation = args[5] if len(args) > 5 else [1, 1] - groups = args[8] if len(args) > 8 else 1 - - bias_info = None - bias_node = args[2] if len(args) > 2 else None - if isinstance(bias_node, fx.Node): - bias_operand = self._get_input_operand(bias_node) - bias_shape = self.operand_shapes.get(bias_operand, []) - bias_info = (bias_operand, bias_shape[0] if bias_shape else 0) - - return self._emit_conv2d(input_tensor, weight, bias_info, stride, padding, dilation, groups, output) - + # --- Linear --- def _convert_linear(self, node: fx.Node, output: str, inputs: List[str]) -> str: @@ -928,6 +908,8 @@ def _convert_split(self, node: fx.Node, output: str, inputs: List[str]) -> str: x = inputs[0] if inputs else "unknown" sections = node.args[1] if len(node.args) > 1 else None dim = node.args[2] if len(node.args) > 2 else node.kwargs.get("dim", 0) + if dim == -1: + dim = len(self._get_node_shape(node.args[0])) - 1 if isinstance(sections, (list, tuple)): # Multi-output split: pre-allocate one operand per section out_ops = [] @@ -942,10 +924,12 @@ def _convert_split(self, node: fx.Node, output: str, inputs: List[str]) -> str: if in_shape and sections is not None: dim_n = int(dim) % len(in_shape) dim_size = in_shape[dim_n] - n = int(sections) - base = dim_size // n - rem = dim_size % n - sizes = [base + (1 if i < rem else 0) for i in range(n)] + split_size = int(sections) + n = dim_size // split_size + rem = dim_size % split_size + sizes = [split_size] * n + if rem: + sizes.append(rem) out_ops = [] for _ in sizes: op = f"operand_{self.operand_counter}" diff --git a/webnn_torch_export/webnn_op_mappings.py b/webnn_torch_export/webnn_op_mappings.py index be1f1d8..b3449c2 100644 --- a/webnn_torch_export/webnn_op_mappings.py +++ b/webnn_torch_export/webnn_op_mappings.py @@ -10,7 +10,6 @@ ATEN_OP_TABLE: Dict[str, str] = { # Convolution "aten.conv2d.default": "_convert_conv2d", - "aten.convolution.default": "_convert_convolution", # Linear / matmul "aten.linear.default": "_convert_linear",