Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,16 @@ jobs:
pip install -e ".[dev]"

- name: Run tests
run: pytest tests/ -v --tb=short -m "not slow"
run: |
pytest tests/ -v --tb=short -m "not slow" \
--cov=webnn_torch_export \
--cov-report=term-missing \
--cov-report=xml \
--cov-fail-under=75

- name: Upload coverage report
if: always()
uses: actions/upload-artifact@v4
with:
name: coverage-xml
path: coverage.xml
120 changes: 120 additions & 0 deletions ADDING_OPS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Adding ATen Ops

The exporter uses `torch.export.export` which produces an ATen IR graph. Every
`call_function` node has a target like `aten.relu.default`. The dispatch table in
`webnn_op_mappings.py` maps `str(node.target)` → a method name on
`WebNNGraphGenerator`. If the target isn't in the table the node is emitted as a
comment and the `.webnn` file will fail to parse.

## Quick reference: adding an op

**1. Add an entry to `ATEN_OP_TABLE` in `webnn_op_mappings.py`:**

```python
"aten.some_op.overload": "_convert_some_op",
```

**2. Implement `_convert_some_op` in `webnn_generator.py`:**

```python
def _convert_some_op(self, node: fx.Node, output: str, inputs: List[str]) -> str:
x = inputs[0]
# node.args holds positional args (first is always the input tensor node)
# node.kwargs holds keyword args
return f"[{output}] = webnnOp({x});"
```

The method receives:
- `output` — pre-allocated operand name for this node's result
- `inputs` — operand names for every `fx.Node` in `node.args`, in order
- `node.args` / `node.kwargs` — raw args (scalars, lists, bools) from the ATen schema

Return `None` instead of a string to suppress the node entry entirely (use this for
ops that bake a constant into `inline_constants` and override `node_to_operand`, like
`_convert_arange`).

**3. To find the exact target string for an unfamiliar op:**

```python
ep = torch.export.export(model, (x,))
for n in ep.graph.nodes:
if n.op == "call_function":
print(str(n.target), n.args, n.kwargs)
```

---

## Example patterns

### No-op
```python
# webnn_op_mappings.py
"aten.contiguous.default": "_convert_identity",
```

### Direct WebNN equivalent
```python
# webnn_op_mappings.py
"aten.bmm.default": "_convert_matmul",
```
`_convert_matmul` already exists; nothing else needed.

### Parameterised activation (e.g. relu6)
```python
# webnn_op_mappings.py
"aten.relu6.default": "_convert_relu6",

# webnn_generator.py
def _convert_relu6(self, node: fx.Node, output: str, inputs: List[str]) -> str:
return f"[{output}] = clamp({inputs[0]}, minValue=0.0, maxValue=6.0);"
```

### Scalar-first subtract (`1 - x`)
```python
# webnn_op_mappings.py
"aten.rsub.Scalar": "_convert_rsub_scalar",

# webnn_generator.py
def _convert_rsub_scalar(self, node: fx.Node, output: str, inputs: List[str]) -> str:
# rsub(x, scalar) = scalar - x
scalar = node.args[1]
c = self._create_inline_constant(float(scalar))
return f"[{output}] = sub({c}, {inputs[0]});"
```

### Op that produces a constant (no graph node needed)
```python
# webnn_op_mappings.py
"aten.full.default": "_convert_full",

# webnn_generator.py — return None to skip nodes {} entry
def _convert_full(self, node: fx.Node, output: str, inputs: List[str]) -> Optional[str]:
size = list(node.args[0])
fill = node.args[1]
dtype = node.kwargs.get("dtype", torch.float32) or torch.float32
values = torch.full(size, fill, dtype=dtype)
name = f"const_full_{self.operand_counter}"; self.operand_counter += 1
self.inline_constants[name] = values
self.operand_shapes[name] = list(values.shape)
self.node_to_operand[node.name] = name # redirect downstream refs
return None
```

### Multi-step decomposition
```python
def _convert_select_int(self, node: fx.Node, output: str, inputs: List[str]) -> str:
# aten.select.int(input, dim, index) → slice dim then squeeze
x = inputs[0]
in_shape = self._get_node_shape(node.args[0])
dim = int(node.args[1]) % len(in_shape)
idx = int(node.args[2])
starts = [0] * len(in_shape); starts[dim] = idx
sizes = list(in_shape); sizes[dim] = 1
out_shape = in_shape[:dim] + in_shape[dim+1:]
sliced = f"operand_{self.operand_counter}"; self.operand_counter += 1
s_str = ", ".join(map(str, starts))
sz_str = ", ".join(map(str, sizes))
os_str = ", ".join(map(str, out_shape))
return (f"[{sliced}] = slice({x}, starts=[{s_str}], sizes=[{sz_str}]);\n"
f"\t[{output}] = reshape({sliced}, newShape=[{os_str}]);")
```
241 changes: 2 additions & 239 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ This is an early-stage experimental implementation for research and exploration.
### For Development

```bash
# Clone the repository
git clone https://github.com/yourusername/webnn_torch_export.git
# Clone your forked repository
git clone https://github.com/<yourusername>/webnn_torch_export.git
cd webnn_torch_export

# Install in editable mode with dev dependencies
Expand Down Expand Up @@ -88,243 +88,6 @@ new_model = nn.Sequential(
load_weights_from_safetensors(new_model, "model_weights.safetensors")
```

### Run Basic Example

```bash
# Using the installed command
webnn-export

# Or run directly
python -m webnn_torch_export.exporter
```

## Key Components

### CustomExporter

The `CustomExporter` class is a Dynamo backend that:
1. Receives FX graphs from PyTorch's compilation process
2. Converts them to a custom format (JSON)
3. Provides debug output to understand graph structure
4. Maintains execution compatibility

**Key methods:**
- `export_graph()`: Main callback that receives FX graphs
- `_convert_fx_to_custom_format()`: Converts FX graph to JSON
- `save_to_file()`: Exports graphs to JSON files

### Test Infrastructure

**Single Operator Tests** (`tests/test_single_ops.py`):
- `test_conv2d_export()`: Tests Conv2d export
- `test_matmul_export()`: Tests matmul export
- `test_linear_export()`: Tests Linear layer export
- `test_conv_with_different_configs()`: Parametrized tests for various Conv2d configurations
- `test_exported_graph_structure()`: Validates exported graph structure

**Integration Tests** (`tests/test_mnist_integration.py`):
- `SimplerMNISTClassifier`: Conv + ReLU + Linear
- `MNISTClassifier`: Full classifier with 2 conv blocks
- `test_simple_mnist_export()`: Exports simple model
- `test_full_mnist_export()`: Exports full model
- `test_mnist_inference_consistency()`: Tests consistency across multiple runs
- `test_mnist_batch_size_invariance()`: Tests with different batch sizes

## How It Works

### 1. Dynamo Backend Registration

```python
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
# Your export logic here
return gm

compiled_model = torch.compile(model, backend=custom_backend)
```

### 2. FX Graph Structure

When Dynamo compiles a model, it produces an FX graph with nodes representing:
- **placeholder**: Input tensors
- **call_function**: Function calls (e.g., `torch.relu`, `torch.matmul`)
- **call_module**: Module invocations (e.g., `conv1`, `fc1`)
- **call_method**: Tensor method calls (e.g., `x.flatten()`)
- **output**: Return values

### 3. Export Flow

```
PyTorch Model → torch.compile() → Dynamo → FX Graph → Custom Backend → Export Format
Your Export Logic
```

## Debug Output

With `debug=True`, the exporter prints:
- Complete FX graph representation
- Generated Python code
- Individual node details:
- Node name and operation type
- Target function/module
- Arguments and keyword arguments
- Tensor metadata (shapes, dtypes)

## Example Output

```
================================================================================
DYNAMO EXPORT CALLBACK TRIGGERED
================================================================================

Graph Module:
graph():
%x : [num_users=1] = placeholder[target=x]
%conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%conv1,), kwargs = {})
return (relu,)

Node: x
Op: placeholder
Target: x
...
```

## Exported JSON Format

```json
{
"nodes": [
{
"name": "x",
"op": "placeholder",
"target": "x",
"args": [],
"kwargs": {}
},
{
"name": "conv1",
"op": "call_module",
"target": "conv1",
"module": "conv1",
"args": ["x"],
"kwargs": {}
}
],
"graph_str": "graph(): ...",
"code": "def forward(self, x): ..."
}
```

## Extending the Exporter

### Adding New Operator Support

When you export a model with unsupported operations, you'll get a **clear error message** showing exactly what's missing:

```
================================================================================
UNSUPPORTED OPERATION
================================================================================
Operation: layer_norm
Node: layer_norm_output
Target: <function layer_norm at 0x...>
Schema: aten::layer_norm(Tensor input, int[] normalized_shape, ...)
Args: ['input_tensor', '[3072]', 'weight', 'bias', '1e-5']
Kwargs: {}
================================================================================

This operation is not yet supported in WebNN export.
To add support, update webnn_op_mappings.py with a mapping for this operation.
```

This makes it easy to **incrementally add support** for operations as you need them.

**Quick Steps:**

1. **Run your export** - get the error showing the unsupported operation
2. **Add mapping** in `webnn_torch_export/webnn_op_mappings.py`:
```python
TARGET_CONTAINS_TO_CONVERTER: Dict[str, ConverterFn] = {
# ... existing mappings ...
"layer_norm": lambda gen, node, output, inputs: gen._convert_layer_norm(node, output, inputs),
}
```
3. **Implement converter** in `webnn_torch_export/webnn_generator.py`:
```python
def _convert_layer_norm(self, node: fx.Node, output: str, inputs: List[str]) -> str:
"""Convert LayerNorm to WebNN"""
input_tensor = inputs[0] if inputs else 'unknown'
# ... conversion logic ...
return f'[{output}] = layerNormalization({input_tensor}, ...);'
```
4. **Test** - run export again, repeat for next unsupported operation

**For detailed guidance, see [ADDING_OPS.md](ADDING_OPS.md)** - a comprehensive guide covering:
- How to map PyTorch operations to WebNN
- Common patterns (activations, normalization, matrix ops)
- Step-by-step walkthrough with examples
- WebNN operation reference
- Debugging tips

### Custom Export Format

Modify `_convert_fx_to_custom_format()` to output your desired format:
```python
def _convert_fx_to_custom_format(self, gm):
# Convert to your format (protobuf, flatbuffer, etc.)
my_format = convert_to_my_format(gm.graph)
return my_format
```

## Development

### Running Tests

```bash
# Run all tests
pytest

# Run with coverage
pytest --cov=webnn_torch_export --cov-report=html

# Run specific markers
pytest -m "not slow"
pytest -m integration
```

### Building the Package

```bash
# Build distribution
python -m build

# Install locally
pip install -e ".[dev]"
```

## Requirements

- PyTorch 2.0+ (for `torch.compile` support)
- Python 3.8+

## Tips for Debugging

1. **Start with `debug=True`** to see full graph output
2. **Use single operator tests** to understand individual operations
3. **Check node metadata** for tensor shapes and types
4. **Verify correctness** by comparing original vs compiled outputs
5. **Examine exported JSON** to understand graph structure

## Next Steps

- Add support for more operators (pooling, normalization, etc.)
- Implement graph optimization passes
- Add serialization to binary formats (protobuf, flatbuffer)
- Handle dynamic shapes
- Support quantized models
- Add execution validation tests

## License

Apache License (2.0) (see LICENSE file)
Loading
Loading