Skip to content

Switch the exporter to aten IR#1

Merged
mtavenrath merged 9 commits into
rustnn:mainfrom
gedoensmax:maximilianm/aten_ir
Apr 28, 2026
Merged

Switch the exporter to aten IR#1
mtavenrath merged 9 commits into
rustnn:mainfrom
gedoensmax:maximilianm/aten_ir

Conversation

@gedoensmax

Copy link
Copy Markdown
Contributor

The FX graph parsing was a bit too much. So i lowered to aten which seems to be a good compromise.
I also added some updates to the flux sample.

@gedoensmax

Copy link
Copy Markdown
Contributor Author

@mtavenrath I think the unit tests failures are not a failure in the exporter.

For Conv i am getting this error:

tests/conftest.py:151: in validate_webnn_execution
    if not torch.allclose(exp, actual, rtol=rtol, atol=atol):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   RuntimeError: The size of tensor a (26) must match the size of tensor b (24) at non-singleton dimension 3

This indicates that there is incorrect padding parsing or forwarding to ONNX in the graph deserialization of below graph:

webnn_graph "model" v1 {
  inputs { x: f32[1, 1, 28, 28]; }
  consts {
	weight_1: f32[32, 1, 5, 5] @weights("conv.weight");
	weight_2: f32[32] @weights("conv.bias");
  }
  nodes {
	[operand_4] = reshape(weight_2, newShape=[1, 32, 1, 1]);
	[operand_5] = conv2d(x, weight_1, filterLayout="oihw", groups=1, inputLayout="nchw", pads=[1, 1, 1, 1]);
	[operand_3] = add(operand_5, operand_4);
  }
  outputs { operand_3; }
}

@gedoensmax

Copy link
Copy Markdown
Contributor Author

The concat operators seems to have issues as well:

webnn_graph "model" v1 {
  inputs { tensors_0: f32[2, 3, 4, 5]; tensors_1: f32[2, 3, 4, 5]; }
  nodes {
	[operand_1] = concat([tensors_0, tensors_1], axis=0);
  }
  outputs { operand_1; }
}

Runtime error:

E   RuntimeError: ONNX execution failed: onnx runtime failed: load model failed: This is an invalid model. In Node, ("concat_0", Concat, "", -1) : () -> ("operand_1": tensor(float),) , Error Node(concat_0) with schema(::Concat:13) has input size 0 not in range [min=1, max=2147483647].

@mtavenrath

Copy link
Copy Markdown
Contributor

There is a PR which fixes several webnn-graph related issues here rustnn/webnn-graph#10. Can you give it a try?

@gedoensmax

Copy link
Copy Markdown
Contributor Author

This did not resolve the issues from what i can tell. I retriggered the CI since pywebnn points to webnn-graph main anyways.

@mtavenrath

Copy link
Copy Markdown
Contributor

This CL uses a batched GEMM in the .webnn export which is not a supported operation of WebNN. I'm going to bring this up to the WebNN WG given that matmul supports batching.

@tarekziade

Copy link
Copy Markdown

for now maybe we can skip this one with a note

@mtavenrath

Copy link
Copy Markdown
Contributor

@gedoensmax already replaced the GEMM with the matmul, mul and add operations. This PR is able to successfully convert all three phases of flux 2 klein 8b to a .webnn file.

@mtavenrath mtavenrath merged commit df1a17f into rustnn:main Apr 28, 2026
0 of 2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants