diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 65cdba6..a41d3b2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -66,6 +66,15 @@ jobs: restore-keys: | ${{ runner.os }}-target- + # Cache the ResNet-50 ONNX model separately so it survives Cargo.lock churn + # and is not evicted as collateral damage when the target/ cache is invalidated. + # Static key — bump the suffix only if the upstream model URL changes. + - name: Cache ResNet-50 ONNX model + uses: actions/cache@v4 + with: + path: target/test-data + key: resnet50-onnx-opset16-v1 + - name: Check formatting run: make fmt-check diff --git a/Cargo.lock b/Cargo.lock index dc64b9c..8f61f74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anstream" version = "0.6.21" @@ -47,7 +53,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -58,7 +64,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -94,6 +100,16 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" +[[package]] +name = "cc" +version = "1.2.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +dependencies = [ + "find-msvc-tools", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -181,6 +197,17 @@ dependencies = [ "crypto-common", ] +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "either" version = "1.15.0" @@ -200,7 +227,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -209,12 +236,33 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + [[package]] name = "fixedbitset" version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -225,6 +273,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -253,6 +312,13 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", + "serde", + "serde_core", +] [[package]] name = "heck" @@ -260,6 +326,109 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "icu_collections" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" +dependencies = [ + "displaydoc", + "potential_utf", + "utf8_iter", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" + +[[package]] +name = "icu_properties" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" + +[[package]] +name = "icu_provider" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb68373c0d6620ef8105e855e7745e18b0d00d3bdb07fb532e434244cdb9a714" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + [[package]] name = "indexmap" version = "2.13.0" @@ -303,6 +472,12 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +[[package]] +name = "litemap" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" + [[package]] name = "log" version = "0.4.29" @@ -333,6 +508,12 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + [[package]] name = "pest" version = "2.8.4" @@ -386,6 +567,15 @@ dependencies = [ "indexmap", ] +[[package]] +name = "potential_utf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" +dependencies = [ + "zerovec", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -460,9 +650,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.42" +version = "1.0.45" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" dependencies = [ "proc-macro2", ] @@ -502,6 +692,20 @@ version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rustix" version = "1.1.3" @@ -512,7 +716,53 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "safetensors" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5" +dependencies = [ + "hashbrown", + "serde", + "serde_json", ] [[package]] @@ -569,12 +819,36 @@ dependencies = [ "digest", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + [[package]] name = "strsim" version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.111" @@ -586,6 +860,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tempfile" version = "3.24.0" @@ -593,10 +878,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" dependencies = [ "fastrand", - "getrandom", + "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -619,6 +904,16 @@ dependencies = [ "syn", ] +[[package]] +name = "tinystr" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" +dependencies = [ + "displaydoc", + "zerovec", +] + [[package]] name = "typenum" version = "1.19.0" @@ -637,6 +932,45 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "url", + "webpki-roots 0.26.11", +] + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + [[package]] name = "utf8parse" version = "0.2.2" @@ -649,6 +983,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + [[package]] name = "wasip2" version = "1.0.1+wasi-0.2.4" @@ -665,13 +1005,16 @@ dependencies = [ "anyhow", "base64", "clap", + "half", "pest", "pest_derive", "prost", + "safetensors", "serde", "serde_json", "tempfile", "thiserror", + "ureq", "webnn-onnx-utils", ] @@ -690,12 +1033,39 @@ dependencies = [ "thiserror", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.7", +] + +[[package]] +name = "webpki-roots" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f5ee44c96cf55f1b349600768e3ece3a8f26010c05265ab73f945bb1a2eb9d" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.61.2" @@ -705,12 +1075,105 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "wit-bindgen" version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" +[[package]] +name = "writeable" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" + +[[package]] +name = "yoke" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "zerocopy" version = "0.8.32" @@ -731,6 +1194,66 @@ dependencies = [ "syn", ] +[[package]] +name = "zerofrom" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ec05a11813ea801ff6d75110ad09cd0824ddba17dfe17128ea0d5f68e6c5272" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zmij" version = "0.1.7" diff --git a/Cargo.toml b/Cargo.toml index e2738bf..90c6498 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,3 +29,4 @@ webnn-onnx-utils = { git = "https://github.com/rustnn/webnn-onnx-utils", branch [dev-dependencies] tempfile = "3.8" +ureq = { version = "2.10", default-features = false, features = ["tls"] } diff --git a/src/onnx/convert.rs b/src/onnx/convert.rs index 96469c1..f6e659b 100644 --- a/src/onnx/convert.rs +++ b/src/onnx/convert.rs @@ -44,9 +44,15 @@ pub enum OnnxError { } /// Sanitize ONNX identifiers for WebNN DSL compatibility -/// Replaces problematic characters that would confuse the parser +/// Replaces problematic characters that would confuse the parser, and prefixes +/// digit-leading names (e.g. anonymous ONNX outputs like "495") with `_` so they +/// remain parseable in the .webnn text format. pub fn sanitize_identifier(name: &str) -> String { - identifiers::sanitize_for_webnn(name) + let base = identifiers::sanitize_for_webnn(name); + match base.chars().next() { + Some(c) if c.is_ascii_digit() => format!("_{}", base), + _ => base, + } } /// Convert ONNX data type code to WebNN DataType using shared utilities @@ -431,6 +437,253 @@ fn infer_shape( Some(target) } + // Pooling: maxPool / averagePool / global variants. Only handles fully-static inputs. + "MaxPool" | "AveragePool" => { + let ins = node.input.as_slice(); + if ins.is_empty() { + return None; + } + let x_shape = value_shapes.get(ins[0].as_str())?.clone(); + if x_shape.len() < 3 { + return None; + } + let spatial_rank = x_shape.len() - 2; + + let mut auto_pad = String::from("NOTSET"); + let mut strides: Vec = vec![1; spatial_rank]; + let mut dilations: Vec = vec![1; spatial_rank]; + let mut pads: Vec = vec![0; 2 * spatial_rank]; + let mut kernel_shape: Vec = Vec::new(); + let mut ceil_mode = false; + for attr in node.attribute.as_slice() { + match attr.name.as_str() { + "auto_pad" => { + if let Ok(s) = String::from_utf8(attr.s.clone()) { + if !s.is_empty() { + auto_pad = s; + } + } + } + "kernel_shape" if !attr.ints.is_empty() => kernel_shape = attr.ints.clone(), + "strides" if !attr.ints.is_empty() => strides = attr.ints.clone(), + "dilations" if !attr.ints.is_empty() => dilations = attr.ints.clone(), + "pads" if !attr.ints.is_empty() => pads = attr.ints.clone(), + "ceil_mode" => ceil_mode = attr.i != 0, + _ => {} + } + } + if kernel_shape.len() != spatial_rank + || strides.len() != spatial_rank + || dilations.len() != spatial_rank + || pads.len() != 2 * spatial_rank + { + return None; + } + + let mut out_spatial = Vec::with_capacity(spatial_rank); + for i in 0..spatial_rank { + let in_dim = x_shape[2 + i]; + let k = kernel_shape[i]; + let s = strides[i]; + let d = dilations[i]; + let dilated_k = d * (k - 1) + 1; + let out_dim = match auto_pad.as_str() { + "SAME_UPPER" | "SAME_LOWER" => (in_dim + s - 1) / s, + "VALID" => (in_dim - dilated_k) / s + 1, + _ => { + let pad_begin = pads[i]; + let pad_end = pads[i + spatial_rank]; + let numerator = in_dim + pad_begin + pad_end - dilated_k; + if ceil_mode { + (numerator + s - 1) / s + 1 + } else { + numerator / s + 1 + } + } + }; + if out_dim < 0 { + return None; + } + out_spatial.push(out_dim); + } + + let mut out = vec![x_shape[0], x_shape[1]]; + out.extend(out_spatial); + Some(out) + } + + "GlobalMaxPool" | "GlobalAveragePool" => { + let ins = node.input.as_slice(); + if ins.is_empty() { + return None; + } + let x_shape = value_shapes.get(ins[0].as_str())?.clone(); + if x_shape.len() < 3 { + return None; + } + let mut out = vec![x_shape[0], x_shape[1]]; + out.extend(std::iter::repeat_n(1i64, x_shape.len() - 2)); + Some(out) + } + + "Flatten" => { + let ins = node.input.as_slice(); + if ins.is_empty() { + return None; + } + let x_shape = value_shapes.get(ins[0].as_str())?.clone(); + let axis = node + .attribute + .as_slice() + .iter() + .find(|a| a.name.as_str() == "axis") + .map(|a| a.i) + .unwrap_or(1); + let rank = x_shape.len() as i64; + let norm = if axis < 0 { axis + rank } else { axis }; + if norm < 0 || norm > rank { + return None; + } + let norm = norm as usize; + let outer: i64 = if norm == 0 { + 1 + } else { + x_shape[..norm].iter().product() + }; + let inner: i64 = if norm == x_shape.len() { + 1 + } else { + x_shape[norm..].iter().product() + }; + Some(vec![outer, inner]) + } + + // Convolution / transposed convolution: derive output spatial dims. + // Only handles fully-static inputs. Higher-rank cases fall through to None. + "Conv" | "ConvTranspose" => { + let ins = node.input.as_slice(); + if ins.len() < 2 { + return None; + } + let x_shape = value_shapes.get(ins[0].as_str())?.clone(); + let w_shape = value_shapes.get(ins[1].as_str()).cloned().or_else(|| { + initializers + .get(ins[1].as_str()) + .map(|t| t.dims.as_slice().to_vec()) + })?; + if x_shape.len() < 3 || w_shape.len() < 3 { + return None; + } + let spatial_rank = x_shape.len() - 2; + if w_shape.len() != x_shape.len() { + return None; + } + + // Read attributes. + let mut auto_pad = String::from("NOTSET"); + let mut strides: Vec = vec![1; spatial_rank]; + let mut dilations: Vec = vec![1; spatial_rank]; + let mut pads: Vec = vec![0; 2 * spatial_rank]; + let mut kernel_shape: Vec = w_shape[2..].to_vec(); + let mut group: i64 = 1; + let mut output_padding: Vec = vec![0; spatial_rank]; + let mut output_shape_attr: Vec = Vec::new(); + for attr in node.attribute.as_slice() { + match attr.name.as_str() { + "auto_pad" => { + if let Ok(s) = String::from_utf8(attr.s.clone()) { + if !s.is_empty() { + auto_pad = s; + } + } + } + "strides" if !attr.ints.is_empty() => strides = attr.ints.clone(), + "dilations" if !attr.ints.is_empty() => dilations = attr.ints.clone(), + "pads" if !attr.ints.is_empty() => pads = attr.ints.clone(), + "kernel_shape" if !attr.ints.is_empty() => kernel_shape = attr.ints.clone(), + "group" if attr.i > 0 => group = attr.i, + "output_padding" if !attr.ints.is_empty() => output_padding = attr.ints.clone(), + "output_shape" if !attr.ints.is_empty() => { + output_shape_attr = attr.ints.clone() + } + _ => {} + } + } + if strides.len() != spatial_rank + || dilations.len() != spatial_rank + || kernel_shape.len() != spatial_rank + || pads.len() != 2 * spatial_rank + || output_padding.len() != spatial_rank + { + return None; + } + let _ = group; // not needed for shape inference + + let transpose = op == "ConvTranspose"; + // Output channel count. + let m = if transpose { + // Filter layout for ConvTranspose: (C_in, M/group, kSpatial...). + // M = w_shape[1] * group, but with default group=1 we just use w_shape[1]. + w_shape[1] * group + } else { + w_shape[0] + }; + + // If output_shape attr is provided (ConvTranspose), it directly tells us H/W. + if transpose && !output_shape_attr.is_empty() { + let sizes = if output_shape_attr.len() == spatial_rank { + output_shape_attr.clone() + } else if output_shape_attr.len() == x_shape.len() { + output_shape_attr[2..].to_vec() + } else { + return None; + }; + let mut out = vec![x_shape[0], m]; + out.extend(sizes); + return Some(out); + } + + let mut out_spatial = Vec::with_capacity(spatial_rank); + for i in 0..spatial_rank { + let in_dim = x_shape[2 + i]; + let k = kernel_shape[i]; + let s = strides[i]; + let d = dilations[i]; + let dilated_k = d * (k - 1) + 1; + + let out_dim = match auto_pad.as_str() { + "SAME_UPPER" | "SAME_LOWER" if !transpose => { + // Standard "SAME": out = ceil(in / stride) + (in_dim + s - 1) / s + } + "SAME_UPPER" | "SAME_LOWER" if transpose => { + // For transpose: out = in * stride + in_dim * s + } + "VALID" if !transpose => (in_dim - dilated_k) / s + 1, + "VALID" if transpose => (in_dim - 1) * s + dilated_k, + _ => { + // explicit pads (NOTSET) — pads layout: [b1, b2, ..., bk, e1, e2, ..., ek] + let pad_begin = pads[i]; + let pad_end = pads[i + spatial_rank]; + if transpose { + (in_dim - 1) * s - pad_begin - pad_end + dilated_k + output_padding[i] + } else { + (in_dim + pad_begin + pad_end - dilated_k) / s + 1 + } + } + }; + if out_dim < 0 { + return None; + } + out_spatial.push(out_dim); + } + + let mut out = vec![x_shape[0], m]; + out.extend(out_spatial); + Some(out) + } + "Slice" => { let ins = node.input.as_slice(); if ins.is_empty() { diff --git a/src/onnx/ops/activation.rs b/src/onnx/ops/activation.rs index adb1687..d661a14 100644 --- a/src/onnx/ops/activation.rs +++ b/src/onnx/ops/activation.rs @@ -24,6 +24,7 @@ impl OpHandler for ActivationHandler { | "Erf" | "Cos" | "Sin" + | "Identity" ) } @@ -53,6 +54,7 @@ impl OpHandler for ActivationHandler { "Erf" => "erf", "Cos" => "cos", "Sin" => "sin", + "Identity" => "identity", _ => { return Err(OnnxError::UnsupportedOp { op: op_type.to_string(), diff --git a/src/onnx/ops/conv.rs b/src/onnx/ops/conv.rs new file mode 100644 index 0000000..8061364 --- /dev/null +++ b/src/onnx/ops/conv.rs @@ -0,0 +1,954 @@ +// Convolution operators: Conv, ConvTranspose +// +// Maps ONNX Conv/ConvTranspose to WebNN conv2d/convTranspose2d (NCHW layout). +// +// ONNX layout assumptions (the spec defaults): +// * input X : (N, C_in, ...spatial) +// * filter W: Conv -> (M, C_in / group, kH, kW, ...) +// ConvTranspose -> (C_in, M / group, kH, kW, ...) +// * bias B : (M,) (optional) +// +// WebNN defaults match ONNX: +// * inputLayout = "nchw" +// * filterLayout = "oihw" for conv2d +// * filterLayout = "iohw" for convTranspose2d +// +// Spatial dimensionality: +// * 2D (4-D input) -> emitted as conv2d / convTranspose2d directly +// * 1D (3-D input) -> emulated as reshape -> conv2d -> reshape +// * Anything else (1D w/o shape info, 3D, etc.) -> UnsupportedOp error. + +use crate::ast::Node; +use crate::onnx::convert::{sanitize_identifier, OnnxError}; +use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler}; +use crate::protos::onnx::NodeProto; +use serde_json::{json, Map, Value}; + +pub struct ConvHandler; + +impl OpHandler for ConvHandler { + fn supports(&self, op_type: &str) -> bool { + matches!(op_type, "Conv" | "ConvTranspose") + } + + fn convert( + &self, + node: &NodeProto, + context: &ConversionContext, + ) -> Result { + let op_type = node.op_type.as_str(); + let node_name = if !node.name.is_empty() { + node.name.as_str().to_string() + } else { + "unnamed".to_string() + }; + + match op_type { + "Conv" => self.convert_conv(node, &node_name, context, false), + "ConvTranspose" => self.convert_conv(node, &node_name, context, true), + _ => Err(OnnxError::UnsupportedOp { + op: op_type.to_string(), + node: node_name, + }), + } + } +} + +#[derive(Debug, Clone)] +struct ConvAttrs { + auto_pad: String, + dilations: Option>, + group: i64, + kernel_shape: Option>, + pads: Option>, + strides: Option>, + output_padding: Option>, + output_shape: Option>, +} + +fn parse_conv_attrs(node: &NodeProto) -> ConvAttrs { + let mut attrs = ConvAttrs { + auto_pad: "NOTSET".to_string(), + dilations: None, + group: 1, + kernel_shape: None, + pads: None, + strides: None, + output_padding: None, + output_shape: None, + }; + + for attr in node.attribute.as_slice() { + match attr.name.as_str() { + "auto_pad" => { + if let Ok(s) = String::from_utf8(attr.s.clone()) { + if !s.is_empty() { + attrs.auto_pad = s; + } + } + } + "dilations" if !attr.ints.is_empty() => { + attrs.dilations = Some(attr.ints.clone()); + } + "group" if attr.i > 0 => { + attrs.group = attr.i; + } + "kernel_shape" if !attr.ints.is_empty() => { + attrs.kernel_shape = Some(attr.ints.clone()); + } + "pads" if !attr.ints.is_empty() => { + attrs.pads = Some(attr.ints.clone()); + } + "strides" if !attr.ints.is_empty() => { + attrs.strides = Some(attr.ints.clone()); + } + "output_padding" if !attr.ints.is_empty() => { + attrs.output_padding = Some(attr.ints.clone()); + } + "output_shape" if !attr.ints.is_empty() => { + attrs.output_shape = Some(attr.ints.clone()); + } + _ => {} + } + } + + attrs +} + +/// Look up a tensor's shape from value_shapes / initializers. +fn lookup_shape(name: &str, context: &ConversionContext) -> Option> { + if let Some(s) = context.value_shapes.get(name) { + return Some(s.clone()); + } + let sanitized = sanitize_identifier(name); + if let Some(s) = context.value_shapes.get(&sanitized) { + return Some(s.clone()); + } + if let Some(init) = context.initializers.get(name) { + return Some(init.dims.as_slice().to_vec()); + } + None +} + +/// Map ONNX auto_pad string to WebNN autoPad option string. +fn map_auto_pad(auto_pad: &str) -> &'static str { + match auto_pad { + "SAME_UPPER" => "same-upper", + "SAME_LOWER" => "same-lower", + // VALID and NOTSET both map to explicit padding; for VALID we'll also zero pads out. + _ => "explicit", + } +} + +/// Convert ONNX pads layout [b1, b2, ..., bk, e1, e2, ..., ek] +/// to WebNN padding layout [b1, e1, b2, e2, ..., bk, ek]. +fn onnx_pads_to_webnn(pads: &[i64], spatial_rank: usize) -> Vec { + if pads.len() != 2 * spatial_rank { + return pads.to_vec(); + } + let mut out = Vec::with_capacity(2 * spatial_rank); + for i in 0..spatial_rank { + out.push(pads[i]); + out.push(pads[i + spatial_rank]); + } + out +} + +impl ConvHandler { + fn convert_conv( + &self, + node: &NodeProto, + node_name: &str, + context: &ConversionContext, + transpose: bool, + ) -> Result { + let op_label = if transpose { "ConvTranspose" } else { "Conv" }; + let inputs = node.input.as_slice(); + if inputs.len() < 2 || inputs.len() > 3 { + return Err(OnnxError::InvalidShape(format!( + "{} expects 2 or 3 inputs (X, W[, B]), got {}", + op_label, + inputs.len() + ))); + } + + let input_raw = inputs[0].to_string(); + let filter_raw = inputs[1].to_string(); + let bias_raw = inputs.get(2).map(|s| s.to_string()); + + let input_id = context.resolve_input(&input_raw); + let filter_id = context.resolve_input(&filter_raw); + let bias_id = bias_raw.as_ref().map(|n| context.resolve_input(n)); + + let output_name = if node.output.as_slice().is_empty() { + format!("{}_output", node_name) + } else { + sanitize_identifier(&node.output.as_slice()[0].to_string()) + }; + + let attrs = parse_conv_attrs(node); + + // Determine spatial rank. Prefer the explicit kernel_shape attribute, then the + // filter's declared shape, then the input's spatial rank. + let filter_shape = lookup_shape(&filter_raw, context); + let input_shape = lookup_shape(&input_raw, context); + let spatial_rank = if let Some(ks) = attrs.kernel_shape.as_ref() { + ks.len() + } else if let Some(fs) = filter_shape.as_ref() { + if fs.len() >= 2 { + fs.len() - 2 + } else { + return Err(OnnxError::InvalidShape(format!( + "{}: filter '{}' has rank {} (need >= 2)", + op_label, + filter_raw, + fs.len() + ))); + } + } else if let Some(is) = input_shape.as_ref() { + if is.len() >= 2 { + is.len() - 2 + } else { + return Err(OnnxError::InvalidShape(format!( + "{}: cannot determine spatial rank from input '{}' of rank {}", + op_label, + input_raw, + is.len() + ))); + } + } else { + return Err(OnnxError::InvalidShape(format!( + "{}: cannot determine spatial rank — provide kernel_shape attribute or filter/input shape info", + op_label, + ))); + }; + + match spatial_rank { + 2 => self.emit_conv_2d( + node_name, + &output_name, + &input_id, + &filter_id, + bias_id.as_deref(), + &attrs, + transpose, + node, + ), + 1 => self.emit_conv_1d_via_2d( + node_name, + &output_name, + &input_id, + &filter_id, + bias_id.as_deref(), + &attrs, + transpose, + node, + input_shape.as_deref(), + filter_shape.as_deref(), + ), + _ => Err(OnnxError::UnsupportedOp { + op: format!("{}{}D", op_label, spatial_rank), + node: node_name.to_string(), + }), + } + } + + /// Build the options map shared by conv2d / convTranspose2d. + fn build_conv2d_options( + &self, + attrs: &ConvAttrs, + transpose: bool, + ) -> Result, OnnxError> { + let mut options = Map::new(); + + let strides = attrs.strides.clone().unwrap_or_else(|| vec![1, 1]); + let dilations = attrs.dilations.clone().unwrap_or_else(|| vec![1, 1]); + let pads = attrs.pads.clone().unwrap_or_else(|| vec![0, 0, 0, 0]); + + if strides.len() != 2 { + return Err(OnnxError::InvalidShape(format!( + "conv2d: strides must have length 2, got {:?}", + strides + ))); + } + if dilations.len() != 2 { + return Err(OnnxError::InvalidShape(format!( + "conv2d: dilations must have length 2, got {:?}", + dilations + ))); + } + + options.insert("strides".to_string(), json!(strides)); + options.insert("dilations".to_string(), json!(dilations)); + + let mapped_auto_pad = map_auto_pad(&attrs.auto_pad); + // Always emit padding for explicit case (or VALID, which uses zero pads). + if mapped_auto_pad == "explicit" { + // VALID means no padding regardless of `pads` attribute. + let effective_pads = if attrs.auto_pad == "VALID" { + vec![0, 0, 0, 0] + } else { + onnx_pads_to_webnn(&pads, 2) + }; + if effective_pads.len() != 4 { + return Err(OnnxError::InvalidShape(format!( + "conv2d: pads must yield 4 values for 2D, got {:?}", + effective_pads + ))); + } + options.insert("padding".to_string(), json!(effective_pads)); + } else { + options.insert("autoPad".to_string(), json!(mapped_auto_pad)); + } + + if attrs.group != 1 { + options.insert("groups".to_string(), json!(attrs.group)); + } + + if transpose { + if let Some(op) = attrs.output_padding.as_ref() { + if op.len() == 2 { + options.insert("outputPadding".to_string(), json!(op)); + } + } + if let Some(os) = attrs.output_shape.as_ref() { + // ONNX output_shape is the full N×C×H×W; WebNN outputSizes is the spatial part + // [H, W]. Accept either form for robustness. + let sizes: Vec = if os.len() == 2 { + os.clone() + } else if os.len() >= 2 { + os[os.len() - 2..].to_vec() + } else { + Vec::new() + }; + if sizes.len() == 2 { + options.insert("outputSizes".to_string(), json!(sizes)); + } + } + } + + Ok(options) + } + + #[allow(clippy::too_many_arguments)] + fn emit_conv_2d( + &self, + _node_name: &str, + output_name: &str, + input_id: &str, + filter_id: &str, + bias_id: Option<&str>, + attrs: &ConvAttrs, + transpose: bool, + node: &NodeProto, + ) -> Result { + let webnn_op = if transpose { + "convTranspose2d" + } else { + "conv2d" + }; + let options = self.build_conv2d_options(attrs, transpose)?; + + let mut inputs_vec = vec![input_id.to_string(), filter_id.to_string()]; + if let Some(b) = bias_id { + inputs_vec.push(b.to_string()); + } + + let mut result = ConversionResult::new(vec![Node { + id: output_name.to_string(), + op: webnn_op.to_string(), + inputs: inputs_vec, + options, + outputs: None, + }]); + + if let Some(onnx_out) = node.output.as_slice().first() { + result + .output_mappings + .insert(onnx_out.to_string(), output_name.to_string()); + } + Ok(result) + } + + /// Emulate a 1D convolution by reshaping the input/filter to 4-D (W=1), + /// running conv2d, then reshaping the output back to 3-D. + #[allow(clippy::too_many_arguments)] + fn emit_conv_1d_via_2d( + &self, + node_name: &str, + output_name: &str, + input_id: &str, + filter_id: &str, + bias_id: Option<&str>, + attrs: &ConvAttrs, + transpose: bool, + node: &NodeProto, + input_shape: Option<&[i64]>, + filter_shape: Option<&[i64]>, + ) -> Result { + let input_shape = input_shape.ok_or_else(|| { + OnnxError::InvalidShape(format!( + "1D Conv emulation requires known shape for input of node {}", + node_name + )) + })?; + let filter_shape = filter_shape.ok_or_else(|| { + OnnxError::InvalidShape(format!( + "1D Conv emulation requires known shape for filter of node {}", + node_name + )) + })?; + if input_shape.len() != 3 || filter_shape.len() != 3 { + return Err(OnnxError::InvalidShape(format!( + "1D Conv emulation expects rank-3 input/filter, got input {:?} filter {:?}", + input_shape, filter_shape + ))); + } + + // Extend 1D attrs to 2D by appending a trailing dim of "1" (a no-op extra dim). + let mut attrs_2d = attrs.clone(); + attrs_2d.strides = Some(extend_with_one(attrs.strides.as_deref(), 1, 2)); + attrs_2d.dilations = Some(extend_with_one(attrs.dilations.as_deref(), 1, 2)); + attrs_2d.pads = Some(extend_pads_to_2d(attrs.pads.as_deref())); + attrs_2d.kernel_shape = attrs + .kernel_shape + .as_ref() + .map(|ks| extend_with_one(Some(ks.as_slice()), 1, 2)); + if transpose { + attrs_2d.output_padding = Some(extend_with_one(attrs.output_padding.as_deref(), 0, 2)); + // output_shape only makes sense for the full spatial range; we drop it for the + // 1D-via-2D rewrite to avoid mis-encoding [H] vs [H, W]=1. + attrs_2d.output_shape = None; + } + + let options = self.build_conv2d_options(&attrs_2d, transpose)?; + + let reshape_in_id = sanitize_identifier(&format!("{}_x4d", node_name)); + let reshape_w_id = sanitize_identifier(&format!("{}_w4d", node_name)); + let conv_id = sanitize_identifier(&format!("{}_conv2d", node_name)); + + let in_4d_shape: Vec = vec![input_shape[0], input_shape[1], input_shape[2], 1]; + let w_4d_shape: Vec = vec![filter_shape[0], filter_shape[1], filter_shape[2], 1]; + + let mut reshape_in_opts = Map::new(); + reshape_in_opts.insert("newShape".to_string(), json!(in_4d_shape)); + let mut reshape_w_opts = Map::new(); + reshape_w_opts.insert("newShape".to_string(), json!(w_4d_shape)); + + let mut nodes = vec![ + Node { + id: reshape_in_id.clone(), + op: "reshape".to_string(), + inputs: vec![input_id.to_string()], + options: reshape_in_opts, + outputs: None, + }, + Node { + id: reshape_w_id.clone(), + op: "reshape".to_string(), + inputs: vec![filter_id.to_string()], + options: reshape_w_opts, + outputs: None, + }, + ]; + + let webnn_op = if transpose { + "convTranspose2d" + } else { + "conv2d" + }; + let mut conv_inputs = vec![reshape_in_id.clone(), reshape_w_id.clone()]; + if let Some(b) = bias_id { + conv_inputs.push(b.to_string()); + } + nodes.push(Node { + id: conv_id.clone(), + op: webnn_op.to_string(), + inputs: conv_inputs, + options, + outputs: None, + }); + + // Reshape back to 3D. We can compute the spatial output dim with the standard formula + // but we conservatively rely on shape inference downstream by using -1. + let out_shape: Vec = + vec![json!(input_shape[0]), json!(filter_shape[0]), json!(-1i64)]; + let mut final_reshape_opts = Map::new(); + final_reshape_opts.insert("newShape".to_string(), json!(out_shape)); + nodes.push(Node { + id: output_name.to_string(), + op: "reshape".to_string(), + inputs: vec![conv_id.clone()], + options: final_reshape_opts, + outputs: None, + }); + + let mut result = ConversionResult::new(nodes); + if let Some(onnx_out) = node.output.as_slice().first() { + result + .output_mappings + .insert(onnx_out.to_string(), output_name.to_string()); + } + Ok(result) + } +} + +fn extend_with_one(src: Option<&[i64]>, fill: i64, target_len: usize) -> Vec { + let mut out = src.map(|v| v.to_vec()).unwrap_or_default(); + while out.len() < target_len { + out.push(fill); + } + out +} + +/// Extend an ONNX-style pads list (1D = [begin, end]) to 2D ([begin_h, begin_w, end_h, end_w]) +/// by appending zero padding on the trailing dimension. +fn extend_pads_to_2d(pads: Option<&[i64]>) -> Vec { + match pads { + Some(p) if p.len() == 2 => vec![p[0], 0, p[1], 0], + Some(p) if p.len() == 4 => p.to_vec(), + _ => vec![0, 0, 0, 0], + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protos::onnx::{AttributeProto, NodeProto}; + use std::collections::HashMap; + + fn make_node( + op_type: &str, + inputs: Vec<&str>, + outputs: Vec<&str>, + attrs: Vec, + ) -> NodeProto { + NodeProto { + op_type: op_type.to_string(), + name: format!("test_{}", op_type.to_lowercase()), + input: inputs.iter().map(|s| s.to_string()).collect(), + output: outputs.iter().map(|s| s.to_string()).collect(), + attribute: attrs, + ..Default::default() + } + } + + fn int_attr(name: &str, value: i64) -> AttributeProto { + AttributeProto { + name: name.to_string(), + i: value, + ..Default::default() + } + } + + fn ints_attr(name: &str, values: Vec) -> AttributeProto { + AttributeProto { + name: name.to_string(), + ints: values, + ..Default::default() + } + } + + fn string_attr(name: &str, value: &str) -> AttributeProto { + AttributeProto { + name: name.to_string(), + s: value.as_bytes().to_vec(), + ..Default::default() + } + } + + fn make_context<'a>( + initializers: &'a HashMap, + value_shapes: &'a HashMap>, + const_values: &'a HashMap>, + value_ids: &'a HashMap, + value_types: &'a HashMap, + ) -> ConversionContext<'a> { + ConversionContext { + initializers, + value_shapes, + value_shape_dims: crate::onnx::ops::empty_value_shape_dims(), + const_values, + value_ids, + value_types, + } + } + + #[test] + fn supports_conv_ops() { + let h = ConvHandler; + assert!(h.supports("Conv")); + assert!(h.supports("ConvTranspose")); + assert!(!h.supports("MatMul")); + assert!(!h.supports("Pool")); + } + + #[test] + fn conv2d_basic_defaults() { + let h = ConvHandler; + let node = make_node( + "Conv", + vec!["x", "w"], + vec!["y"], + vec![ints_attr("kernel_shape", vec![3, 3])], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 3, 224, 224]); + value_shapes.insert("w".to_string(), vec![64, 3, 3, 3]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + assert_eq!(result.nodes.len(), 1); + let n = &result.nodes[0]; + assert_eq!(n.op, "conv2d"); + assert_eq!(n.id, "y"); + assert_eq!(n.inputs, vec!["x", "w"]); + assert_eq!(n.options.get("strides"), Some(&json!([1, 1]))); + assert_eq!(n.options.get("dilations"), Some(&json!([1, 1]))); + assert_eq!(n.options.get("padding"), Some(&json!([0, 0, 0, 0]))); + // group=1 should not be emitted. + assert!(n.options.get("groups").is_none()); + // No autoPad option for explicit (default). + assert!(n.options.get("autoPad").is_none()); + } + + #[test] + fn conv2d_with_strides_pads_dilations_groups() { + let h = ConvHandler; + let node = make_node( + "Conv", + vec!["x", "w", "b"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3, 3]), + ints_attr("strides", vec![2, 2]), + ints_attr("pads", vec![1, 1, 1, 1]), + ints_attr("dilations", vec![1, 1]), + int_attr("group", 4), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 4, 112, 112]); + value_shapes.insert("w".to_string(), vec![8, 1, 3, 3]); + value_shapes.insert("b".to_string(), vec![8]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + assert_eq!(result.nodes.len(), 1); + let n = &result.nodes[0]; + assert_eq!(n.op, "conv2d"); + assert_eq!(n.inputs, vec!["x", "w", "b"]); + assert_eq!(n.options.get("strides"), Some(&json!([2, 2]))); + assert_eq!(n.options.get("dilations"), Some(&json!([1, 1]))); + // ONNX pads [b1, b2, e1, e2] -> WebNN padding [b1, e1, b2, e2] + assert_eq!(n.options.get("padding"), Some(&json!([1, 1, 1, 1]))); + assert_eq!(n.options.get("groups"), Some(&json!(4))); + } + + #[test] + fn conv2d_pads_layout_reordered() { + // Asymmetric pads to verify the ONNX->WebNN reordering. + let h = ConvHandler; + let node = make_node( + "Conv", + vec!["x", "w"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3, 3]), + // ONNX layout: [top, left, bottom, right] + ints_attr("pads", vec![1, 2, 3, 4]), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 3, 32, 32]); + value_shapes.insert("w".to_string(), vec![8, 3, 3, 3]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + let n = &result.nodes[0]; + // WebNN layout: [top, bottom, left, right] = [1, 3, 2, 4] + assert_eq!(n.options.get("padding"), Some(&json!([1, 3, 2, 4]))); + } + + #[test] + fn conv2d_auto_pad_same_upper() { + let h = ConvHandler; + let node = make_node( + "Conv", + vec!["x", "w"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3, 3]), + string_attr("auto_pad", "SAME_UPPER"), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 3, 32, 32]); + value_shapes.insert("w".to_string(), vec![8, 3, 3, 3]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + let n = &result.nodes[0]; + assert_eq!(n.options.get("autoPad"), Some(&json!("same-upper"))); + assert!(n.options.get("padding").is_none()); + } + + #[test] + fn conv2d_auto_pad_valid_zeroes_pads() { + let h = ConvHandler; + let node = make_node( + "Conv", + vec!["x", "w"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3, 3]), + string_attr("auto_pad", "VALID"), + // even if pads attribute is set, VALID forces zero padding + ints_attr("pads", vec![1, 1, 1, 1]), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 3, 32, 32]); + value_shapes.insert("w".to_string(), vec![8, 3, 3, 3]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + let n = &result.nodes[0]; + assert_eq!(n.options.get("padding"), Some(&json!([0, 0, 0, 0]))); + assert!(n.options.get("autoPad").is_none()); + } + + #[test] + fn conv_transpose_basic() { + let h = ConvHandler; + let node = make_node( + "ConvTranspose", + vec!["x", "w"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3, 3]), + ints_attr("strides", vec![2, 2]), + ints_attr("output_padding", vec![1, 1]), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 16, 32, 32]); + value_shapes.insert("w".to_string(), vec![16, 8, 3, 3]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + assert_eq!(result.nodes.len(), 1); + let n = &result.nodes[0]; + assert_eq!(n.op, "convTranspose2d"); + assert_eq!(n.options.get("strides"), Some(&json!([2, 2]))); + assert_eq!(n.options.get("outputPadding"), Some(&json!([1, 1]))); + } + + #[test] + fn conv_transpose_output_shape_full_form() { + let h = ConvHandler; + // ONNX output_shape is typically N×C×H×W; we should pick H×W for outputSizes. + let node = make_node( + "ConvTranspose", + vec!["x", "w"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3, 3]), + ints_attr("output_shape", vec![1, 8, 64, 64]), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 16, 32, 32]); + value_shapes.insert("w".to_string(), vec![16, 8, 3, 3]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + let n = &result.nodes[0]; + assert_eq!(n.options.get("outputSizes"), Some(&json!([64, 64]))); + } + + #[test] + fn conv1d_emulated_via_2d() { + let h = ConvHandler; + let node = make_node( + "Conv", + vec!["x", "w"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3]), + ints_attr("strides", vec![2]), + ints_attr("pads", vec![1, 1]), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 16, 64]); + value_shapes.insert("w".to_string(), vec![8, 16, 3]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + // reshape input -> reshape filter -> conv2d -> reshape output + assert_eq!(result.nodes.len(), 4); + assert_eq!(result.nodes[0].op, "reshape"); + assert_eq!(result.nodes[1].op, "reshape"); + assert_eq!(result.nodes[2].op, "conv2d"); + assert_eq!(result.nodes[3].op, "reshape"); + // Strides/dilations/pads extended to 2D + let conv = &result.nodes[2]; + assert_eq!(conv.options.get("strides"), Some(&json!([2, 1]))); + assert_eq!(conv.options.get("dilations"), Some(&json!([1, 1]))); + // ONNX 1D pads [1, 1] -> 2D [1, 0, 1, 0] -> WebNN [1, 1, 0, 0] + assert_eq!(conv.options.get("padding"), Some(&json!([1, 1, 0, 0]))); + } + + #[test] + fn conv_3d_unsupported() { + let h = ConvHandler; + let node = make_node( + "Conv", + vec!["x", "w"], + vec!["y"], + vec![ints_attr("kernel_shape", vec![3, 3, 3])], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 3, 16, 16, 16]); + value_shapes.insert("w".to_string(), vec![8, 3, 3, 3, 3]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let err = h.convert(&node, &ctx).unwrap_err(); + match err { + OnnxError::UnsupportedOp { op, .. } => { + assert!(op.contains("3D"), "expected 3D in op label, got {}", op); + } + other => panic!("expected UnsupportedOp, got {:?}", other), + } + } + + #[test] + fn conv_resolves_input_aliases() { + // Ensure input IDs go through ConversionContext::resolve_input. + let h = ConvHandler; + let node = make_node( + "Conv", + vec!["onnx_x", "onnx_w"], + vec!["y"], + vec![ints_attr("kernel_shape", vec![3, 3])], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("onnx_x".to_string(), vec![1, 3, 32, 32]); + value_shapes.insert("onnx_w".to_string(), vec![8, 3, 3, 3]); + let const_values = HashMap::new(); + let mut value_ids = HashMap::new(); + value_ids.insert("onnx_x".to_string(), "x_id".to_string()); + value_ids.insert("onnx_w".to_string(), "w_id".to_string()); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + assert_eq!(result.nodes[0].inputs, vec!["x_id", "w_id"]); + } + + #[test] + fn onnx_pads_to_webnn_reorders() { + // ONNX 2D: [top, left, bottom, right] -> WebNN: [top, bottom, left, right] + assert_eq!(onnx_pads_to_webnn(&[1, 2, 3, 4], 2), vec![1, 3, 2, 4]); + // ONNX 1D: [begin, end] -> WebNN: [begin, end] + assert_eq!(onnx_pads_to_webnn(&[5, 6], 1), vec![5, 6]); + } +} diff --git a/src/onnx/ops/mod.rs b/src/onnx/ops/mod.rs index 379b50e..3d0166d 100644 --- a/src/onnx/ops/mod.rs +++ b/src/onnx/ops/mod.rs @@ -9,10 +9,12 @@ use std::sync::OnceLock; pub mod activation; pub mod comparison; pub mod conditional; +pub mod conv; pub mod conversion; pub mod elementwise; pub mod matmul; pub mod normalization; +pub mod pool; pub mod reduction; pub mod reshape; pub mod scatter; @@ -21,10 +23,12 @@ pub mod utility; use activation::ActivationHandler; use comparison::ComparisonHandler; use conditional::ConditionalHandler; +use conv::ConvHandler; use conversion::ConversionHandler; use elementwise::ElementwiseHandler; use matmul::MatMulHandler; use normalization::NormalizationHandler; +use pool::PoolHandler; use reduction::ReductionHandler; use reshape::ReshapeHandler; use scatter::ScatterHandler; @@ -150,6 +154,8 @@ impl OpRegistry { pub fn new() -> Self { let handlers: Vec> = vec![ Box::new(MatMulHandler), + Box::new(ConvHandler), + Box::new(PoolHandler), Box::new(ElementwiseHandler), Box::new(ComparisonHandler), Box::new(ConditionalHandler), diff --git a/src/onnx/ops/pool.rs b/src/onnx/ops/pool.rs new file mode 100644 index 0000000..dd2ac72 --- /dev/null +++ b/src/onnx/ops/pool.rs @@ -0,0 +1,975 @@ +// Pooling operators: MaxPool, AveragePool, GlobalMaxPool, GlobalAveragePool +// +// Maps ONNX pooling ops to WebNN maxPool2d / averagePool2d (NCHW layout). +// +// ONNX MaxPool / AveragePool attributes (spatial-rank-aware): +// * kernel_shape: required, length = spatial_rank +// * strides: default = [1; spatial_rank] +// * dilations: default = [1; spatial_rank] (MaxPool only) +// * pads: default = [0; 2*spatial_rank], layout [b1, b2, ..., e1, e2, ...] +// * auto_pad: NOTSET | SAME_UPPER | SAME_LOWER | VALID +// * ceil_mode: 0 (floor) | 1 (ceil) +// * count_include_pad (AveragePool): 0 (default) | 1 +// * storage_order (MaxPool): 0 (row major, default) | 1 (column major) — not exposed in WebNN +// +// Global pooling variants take no attributes and pool over the entire spatial volume. + +use crate::ast::Node; +use crate::onnx::convert::{sanitize_identifier, OnnxError}; +use crate::onnx::ops::{ConversionContext, ConversionResult, OpHandler}; +use crate::protos::onnx::NodeProto; +use serde_json::{json, Map, Value}; + +pub struct PoolHandler; + +impl OpHandler for PoolHandler { + fn supports(&self, op_type: &str) -> bool { + matches!( + op_type, + "MaxPool" | "AveragePool" | "GlobalMaxPool" | "GlobalAveragePool" + ) + } + + fn convert( + &self, + node: &NodeProto, + context: &ConversionContext, + ) -> Result { + let op_type = node.op_type.as_str(); + let node_name = if !node.name.is_empty() { + node.name.as_str().to_string() + } else { + "unnamed".to_string() + }; + + match op_type { + "MaxPool" => self.convert_pool(node, &node_name, context, PoolKind::Max), + "AveragePool" => self.convert_pool(node, &node_name, context, PoolKind::Average), + "GlobalMaxPool" => self.convert_global_pool(node, &node_name, context, PoolKind::Max), + "GlobalAveragePool" => { + self.convert_global_pool(node, &node_name, context, PoolKind::Average) + } + _ => Err(OnnxError::UnsupportedOp { + op: op_type.to_string(), + node: node_name, + }), + } + } +} + +#[derive(Debug, Clone, Copy)] +enum PoolKind { + Max, + Average, +} + +impl PoolKind { + fn webnn_op(self) -> &'static str { + match self { + PoolKind::Max => "maxPool2d", + PoolKind::Average => "averagePool2d", + } + } +} + +#[derive(Debug, Clone)] +struct PoolAttrs { + kernel_shape: Option>, + strides: Option>, + dilations: Option>, + pads: Option>, + auto_pad: String, + ceil_mode: bool, + count_include_pad: bool, +} + +fn parse_pool_attrs(node: &NodeProto) -> PoolAttrs { + let mut attrs = PoolAttrs { + kernel_shape: None, + strides: None, + dilations: None, + pads: None, + auto_pad: "NOTSET".to_string(), + ceil_mode: false, + count_include_pad: false, + }; + + for attr in node.attribute.as_slice() { + match attr.name.as_str() { + "auto_pad" => { + if let Ok(s) = String::from_utf8(attr.s.clone()) { + if !s.is_empty() { + attrs.auto_pad = s; + } + } + } + "kernel_shape" if !attr.ints.is_empty() => { + attrs.kernel_shape = Some(attr.ints.clone()); + } + "strides" if !attr.ints.is_empty() => { + attrs.strides = Some(attr.ints.clone()); + } + "dilations" if !attr.ints.is_empty() => { + attrs.dilations = Some(attr.ints.clone()); + } + "pads" if !attr.ints.is_empty() => { + attrs.pads = Some(attr.ints.clone()); + } + "ceil_mode" => { + attrs.ceil_mode = attr.i != 0; + } + "count_include_pad" => { + attrs.count_include_pad = attr.i != 0; + } + _ => {} + } + } + + attrs +} + +fn lookup_shape(name: &str, context: &ConversionContext) -> Option> { + if let Some(s) = context.value_shapes.get(name) { + return Some(s.clone()); + } + let sanitized = sanitize_identifier(name); + if let Some(s) = context.value_shapes.get(&sanitized) { + return Some(s.clone()); + } + None +} + +fn map_auto_pad(auto_pad: &str) -> &'static str { + match auto_pad { + "SAME_UPPER" => "same-upper", + "SAME_LOWER" => "same-lower", + _ => "explicit", + } +} + +fn onnx_pads_to_webnn(pads: &[i64], spatial_rank: usize) -> Vec { + if pads.len() != 2 * spatial_rank { + return pads.to_vec(); + } + let mut out = Vec::with_capacity(2 * spatial_rank); + for i in 0..spatial_rank { + out.push(pads[i]); + out.push(pads[i + spatial_rank]); + } + out +} + +impl PoolHandler { + fn convert_pool( + &self, + node: &NodeProto, + node_name: &str, + context: &ConversionContext, + kind: PoolKind, + ) -> Result { + let op_label = match kind { + PoolKind::Max => "MaxPool", + PoolKind::Average => "AveragePool", + }; + + let inputs = node.input.as_slice(); + if inputs.len() != 1 { + return Err(OnnxError::InvalidShape(format!( + "{} expects 1 input, got {}", + op_label, + inputs.len() + ))); + } + // Reject the optional second MaxPool output (indices) — WebNN has no equivalent. + if matches!(kind, PoolKind::Max) && node.output.as_slice().len() > 1 { + return Err(OnnxError::UnsupportedOp { + op: "MaxPool(with indices output)".to_string(), + node: node_name.to_string(), + }); + } + + let input_raw = inputs[0].to_string(); + let input_id = context.resolve_input(&input_raw); + let input_shape = lookup_shape(&input_raw, context); + + let output_name = if node.output.as_slice().is_empty() { + format!("{}_output", node_name) + } else { + sanitize_identifier(&node.output.as_slice()[0].to_string()) + }; + + let attrs = parse_pool_attrs(node); + + let kernel = attrs + .kernel_shape + .clone() + .ok_or_else(|| OnnxError::MissingAttribute { + attr: "kernel_shape".to_string(), + op: op_label.to_string(), + })?; + let spatial_rank = kernel.len(); + + match spatial_rank { + 2 => self.emit_pool_2d( + node, + node_name, + &output_name, + &input_id, + &attrs, + &kernel, + kind, + ), + 1 => self.emit_pool_1d_via_2d( + node, + node_name, + &output_name, + &input_id, + &attrs, + &kernel, + kind, + input_shape.as_deref(), + ), + _ => Err(OnnxError::UnsupportedOp { + op: format!("{}{}D", op_label, spatial_rank), + node: node_name.to_string(), + }), + } + } + + fn build_pool_2d_options( + &self, + attrs: &PoolAttrs, + kernel: &[i64], + kind: PoolKind, + ) -> Result, OnnxError> { + let mut options = Map::new(); + let strides = attrs.strides.clone().unwrap_or_else(|| vec![1, 1]); + let dilations = attrs.dilations.clone().unwrap_or_else(|| vec![1, 1]); + let pads = attrs.pads.clone().unwrap_or_else(|| vec![0, 0, 0, 0]); + if strides.len() != 2 || dilations.len() != 2 || kernel.len() != 2 { + return Err(OnnxError::InvalidShape(format!( + "pool2d: expected length-2 kernel/strides/dilations, got kernel={:?} strides={:?} dilations={:?}", + kernel, strides, dilations + ))); + } + + options.insert("windowDimensions".to_string(), json!(kernel)); + options.insert("strides".to_string(), json!(strides)); + // AveragePool in ONNX has no dilations; only emit dilations when non-default + // to keep generated calls minimal for the average case. + if matches!(kind, PoolKind::Max) || dilations.iter().any(|&d| d != 1) { + options.insert("dilations".to_string(), json!(dilations)); + } + + let mapped_auto_pad = map_auto_pad(&attrs.auto_pad); + if mapped_auto_pad == "explicit" { + let effective_pads = if attrs.auto_pad == "VALID" { + vec![0, 0, 0, 0] + } else { + onnx_pads_to_webnn(&pads, 2) + }; + if effective_pads.len() != 4 { + return Err(OnnxError::InvalidShape(format!( + "pool2d: padding must yield 4 values for 2D, got {:?}", + effective_pads + ))); + } + options.insert("padding".to_string(), json!(effective_pads)); + } else { + options.insert("autoPad".to_string(), json!(mapped_auto_pad)); + } + + if attrs.ceil_mode { + options.insert("roundingType".to_string(), json!("ceil")); + } + + if matches!(kind, PoolKind::Average) && attrs.count_include_pad { + // The WebNN spec does not currently expose a count_include_pad knob; surface + // a clear error rather than silently producing incorrect results. + return Err(OnnxError::UnsupportedOp { + op: "AveragePool(count_include_pad=1)".to_string(), + node: "".to_string(), + }); + } + + Ok(options) + } + + #[allow(clippy::too_many_arguments)] + fn emit_pool_2d( + &self, + node: &NodeProto, + _node_name: &str, + output_name: &str, + input_id: &str, + attrs: &PoolAttrs, + kernel: &[i64], + kind: PoolKind, + ) -> Result { + let options = self.build_pool_2d_options(attrs, kernel, kind)?; + let mut result = ConversionResult::new(vec![Node { + id: output_name.to_string(), + op: kind.webnn_op().to_string(), + inputs: vec![input_id.to_string()], + options, + outputs: None, + }]); + + if let Some(onnx_out) = node.output.as_slice().first() { + result + .output_mappings + .insert(onnx_out.to_string(), output_name.to_string()); + } + Ok(result) + } + + /// Emulate a 1D pool by reshaping the input to 4-D (trailing W=1), pooling, then reshaping + /// the output back to 3-D. + #[allow(clippy::too_many_arguments)] + fn emit_pool_1d_via_2d( + &self, + node: &NodeProto, + node_name: &str, + output_name: &str, + input_id: &str, + attrs: &PoolAttrs, + kernel: &[i64], + kind: PoolKind, + input_shape: Option<&[i64]>, + ) -> Result { + let input_shape = input_shape.ok_or_else(|| { + OnnxError::InvalidShape(format!( + "1D pool emulation requires known shape for input of node {}", + node_name + )) + })?; + if input_shape.len() != 3 { + return Err(OnnxError::InvalidShape(format!( + "1D pool emulation expects rank-3 input, got {:?}", + input_shape + ))); + } + + let mut attrs_2d = attrs.clone(); + attrs_2d.strides = Some(extend_with(attrs.strides.as_deref(), 1, 2)); + attrs_2d.dilations = Some(extend_with(attrs.dilations.as_deref(), 1, 2)); + attrs_2d.pads = Some(extend_pads_to_2d(attrs.pads.as_deref())); + let kernel_2d: Vec = { + let mut k = kernel.to_vec(); + if k.len() == 1 { + k.push(1); + } + k + }; + attrs_2d.kernel_shape = Some(kernel_2d.clone()); + + let options = self.build_pool_2d_options(&attrs_2d, &kernel_2d, kind)?; + + let reshape_in_id = sanitize_identifier(&format!("{}_x4d", node_name)); + let pool_id = sanitize_identifier(&format!("{}_pool2d", node_name)); + + let in_4d: Vec = vec![input_shape[0], input_shape[1], input_shape[2], 1]; + let mut reshape_in_opts = Map::new(); + reshape_in_opts.insert("newShape".to_string(), json!(in_4d)); + + let nodes = vec![ + Node { + id: reshape_in_id.clone(), + op: "reshape".to_string(), + inputs: vec![input_id.to_string()], + options: reshape_in_opts, + outputs: None, + }, + Node { + id: pool_id.clone(), + op: kind.webnn_op().to_string(), + inputs: vec![reshape_in_id], + options, + outputs: None, + }, + Node { + id: output_name.to_string(), + op: "reshape".to_string(), + inputs: vec![pool_id], + options: { + let mut m = Map::new(); + m.insert( + "newShape".to_string(), + json!([input_shape[0], input_shape[1], -1i64]), + ); + m + }, + outputs: None, + }, + ]; + + let mut result = ConversionResult::new(nodes); + if let Some(onnx_out) = node.output.as_slice().first() { + result + .output_mappings + .insert(onnx_out.to_string(), output_name.to_string()); + } + Ok(result) + } + + fn convert_global_pool( + &self, + node: &NodeProto, + node_name: &str, + context: &ConversionContext, + kind: PoolKind, + ) -> Result { + let op_label = match kind { + PoolKind::Max => "GlobalMaxPool", + PoolKind::Average => "GlobalAveragePool", + }; + let inputs = node.input.as_slice(); + if inputs.len() != 1 { + return Err(OnnxError::InvalidShape(format!( + "{} expects 1 input, got {}", + op_label, + inputs.len() + ))); + } + + let input_raw = inputs[0].to_string(); + let input_id = context.resolve_input(&input_raw); + let input_shape = lookup_shape(&input_raw, context).ok_or_else(|| { + OnnxError::InvalidShape(format!( + "{}: input '{}' shape is unknown — required to determine spatial window size", + op_label, input_raw + )) + })?; + if input_shape.len() < 3 { + return Err(OnnxError::InvalidShape(format!( + "{}: input must be at least rank-3 (N, C, spatial...), got {:?}", + op_label, input_shape + ))); + } + + let output_name = if node.output.as_slice().is_empty() { + format!("{}_output", node_name) + } else { + sanitize_identifier(&node.output.as_slice()[0].to_string()) + }; + + let spatial = &input_shape[2..]; + match spatial.len() { + 2 => { + let mut options = Map::new(); + options.insert("windowDimensions".to_string(), json!(spatial.to_vec())); + // Strides default to 1 — fine since the window covers the whole spatial. + let mut result = ConversionResult::new(vec![Node { + id: output_name.clone(), + op: kind.webnn_op().to_string(), + inputs: vec![input_id], + options, + outputs: None, + }]); + if let Some(onnx_out) = node.output.as_slice().first() { + result + .output_mappings + .insert(onnx_out.to_string(), output_name); + } + Ok(result) + } + 1 => { + // Reshape to 4-D (trailing 1), pool with windowDimensions=[L, 1], reshape back. + let reshape_in_id = sanitize_identifier(&format!("{}_x4d", node_name)); + let pool_id = sanitize_identifier(&format!("{}_pool2d", node_name)); + let in_4d: Vec = vec![input_shape[0], input_shape[1], spatial[0], 1]; + let mut reshape_in_opts = Map::new(); + reshape_in_opts.insert("newShape".to_string(), json!(in_4d)); + + let mut pool_opts = Map::new(); + pool_opts.insert("windowDimensions".to_string(), json!([spatial[0], 1])); + + let nodes = vec![ + Node { + id: reshape_in_id.clone(), + op: "reshape".to_string(), + inputs: vec![input_id], + options: reshape_in_opts, + outputs: None, + }, + Node { + id: pool_id.clone(), + op: kind.webnn_op().to_string(), + inputs: vec![reshape_in_id], + options: pool_opts, + outputs: None, + }, + Node { + id: output_name.clone(), + op: "reshape".to_string(), + inputs: vec![pool_id], + options: { + let mut m = Map::new(); + m.insert( + "newShape".to_string(), + json!([input_shape[0], input_shape[1], 1i64]), + ); + m + }, + outputs: None, + }, + ]; + + let mut result = ConversionResult::new(nodes); + if let Some(onnx_out) = node.output.as_slice().first() { + result + .output_mappings + .insert(onnx_out.to_string(), output_name); + } + Ok(result) + } + _ => Err(OnnxError::UnsupportedOp { + op: format!("{}{}D", op_label, spatial.len()), + node: node_name.to_string(), + }), + } + } +} + +fn extend_with(src: Option<&[i64]>, fill: i64, target_len: usize) -> Vec { + let mut out = src.map(|v| v.to_vec()).unwrap_or_default(); + while out.len() < target_len { + out.push(fill); + } + out +} + +fn extend_pads_to_2d(pads: Option<&[i64]>) -> Vec { + match pads { + Some(p) if p.len() == 2 => vec![p[0], 0, p[1], 0], + Some(p) if p.len() == 4 => p.to_vec(), + _ => vec![0, 0, 0, 0], + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protos::onnx::{AttributeProto, NodeProto}; + use std::collections::HashMap; + + fn make_node( + op_type: &str, + inputs: Vec<&str>, + outputs: Vec<&str>, + attrs: Vec, + ) -> NodeProto { + NodeProto { + op_type: op_type.to_string(), + name: format!("test_{}", op_type.to_lowercase()), + input: inputs.iter().map(|s| s.to_string()).collect(), + output: outputs.iter().map(|s| s.to_string()).collect(), + attribute: attrs, + ..Default::default() + } + } + + fn int_attr(name: &str, value: i64) -> AttributeProto { + AttributeProto { + name: name.to_string(), + i: value, + ..Default::default() + } + } + + fn ints_attr(name: &str, values: Vec) -> AttributeProto { + AttributeProto { + name: name.to_string(), + ints: values, + ..Default::default() + } + } + + fn string_attr(name: &str, value: &str) -> AttributeProto { + AttributeProto { + name: name.to_string(), + s: value.as_bytes().to_vec(), + ..Default::default() + } + } + + fn make_context<'a>( + initializers: &'a HashMap, + value_shapes: &'a HashMap>, + const_values: &'a HashMap>, + value_ids: &'a HashMap, + value_types: &'a HashMap, + ) -> ConversionContext<'a> { + ConversionContext { + initializers, + value_shapes, + value_shape_dims: crate::onnx::ops::empty_value_shape_dims(), + const_values, + value_ids, + value_types, + } + } + + #[test] + fn supports_pool_ops() { + let h = PoolHandler; + assert!(h.supports("MaxPool")); + assert!(h.supports("AveragePool")); + assert!(h.supports("GlobalMaxPool")); + assert!(h.supports("GlobalAveragePool")); + assert!(!h.supports("Conv")); + } + + #[test] + fn maxpool2d_basic() { + let h = PoolHandler; + let node = make_node( + "MaxPool", + vec!["x"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3, 3]), + ints_attr("strides", vec![2, 2]), + ints_attr("pads", vec![1, 1, 1, 1]), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 64, 112, 112]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + assert_eq!(result.nodes.len(), 1); + let n = &result.nodes[0]; + assert_eq!(n.op, "maxPool2d"); + assert_eq!(n.inputs, vec!["x"]); + assert_eq!(n.options.get("windowDimensions"), Some(&json!([3, 3]))); + assert_eq!(n.options.get("strides"), Some(&json!([2, 2]))); + assert_eq!(n.options.get("padding"), Some(&json!([1, 1, 1, 1]))); + assert_eq!(n.options.get("dilations"), Some(&json!([1, 1]))); + } + + #[test] + fn maxpool2d_pads_layout_reordered() { + let h = PoolHandler; + let node = make_node( + "MaxPool", + vec!["x"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3, 3]), + ints_attr("pads", vec![1, 2, 3, 4]), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 64, 32, 32]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + // ONNX [top, left, bottom, right] -> WebNN [top, bottom, left, right] = [1, 3, 2, 4] + assert_eq!( + result.nodes[0].options.get("padding"), + Some(&json!([1, 3, 2, 4])) + ); + } + + #[test] + fn maxpool2d_with_ceil_mode() { + let h = PoolHandler; + let node = make_node( + "MaxPool", + vec!["x"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![2, 2]), + int_attr("ceil_mode", 1), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 8, 7, 7]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + assert_eq!( + result.nodes[0].options.get("roundingType"), + Some(&json!("ceil")) + ); + } + + #[test] + fn maxpool2d_auto_pad_same_upper() { + let h = PoolHandler; + let node = make_node( + "MaxPool", + vec!["x"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3, 3]), + string_attr("auto_pad", "SAME_UPPER"), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 8, 32, 32]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + let n = &result.nodes[0]; + assert_eq!(n.options.get("autoPad"), Some(&json!("same-upper"))); + assert!(n.options.get("padding").is_none()); + } + + #[test] + fn averagepool2d_basic() { + let h = PoolHandler; + let node = make_node( + "AveragePool", + vec!["x"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![2, 2]), + ints_attr("strides", vec![2, 2]), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 8, 14, 14]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + let n = &result.nodes[0]; + assert_eq!(n.op, "averagePool2d"); + assert_eq!(n.options.get("windowDimensions"), Some(&json!([2, 2]))); + // AveragePool: dilations not emitted unless non-default + assert!(n.options.get("dilations").is_none()); + } + + #[test] + fn averagepool_count_include_pad_rejected() { + let h = PoolHandler; + let node = make_node( + "AveragePool", + vec!["x"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![2, 2]), + int_attr("count_include_pad", 1), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 8, 14, 14]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let err = h.convert(&node, &ctx).unwrap_err(); + match err { + OnnxError::UnsupportedOp { op, .. } => { + assert!(op.contains("count_include_pad")); + } + other => panic!("expected UnsupportedOp, got {:?}", other), + } + } + + #[test] + fn global_average_pool_2d() { + let h = PoolHandler; + let node = make_node("GlobalAveragePool", vec!["x"], vec!["y"], vec![]); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 2048, 7, 7]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + assert_eq!(result.nodes.len(), 1); + let n = &result.nodes[0]; + assert_eq!(n.op, "averagePool2d"); + assert_eq!(n.options.get("windowDimensions"), Some(&json!([7, 7]))); + } + + #[test] + fn global_max_pool_2d() { + let h = PoolHandler; + let node = make_node("GlobalMaxPool", vec!["x"], vec!["y"], vec![]); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 16, 14, 14]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + let n = &result.nodes[0]; + assert_eq!(n.op, "maxPool2d"); + assert_eq!(n.options.get("windowDimensions"), Some(&json!([14, 14]))); + } + + #[test] + fn maxpool_missing_kernel_shape_errors() { + let h = PoolHandler; + let node = make_node("MaxPool", vec!["x"], vec!["y"], vec![]); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 8, 14, 14]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let err = h.convert(&node, &ctx).unwrap_err(); + match err { + OnnxError::MissingAttribute { attr, .. } => { + assert_eq!(attr, "kernel_shape"); + } + other => panic!("expected MissingAttribute, got {:?}", other), + } + } + + #[test] + fn maxpool_rejects_indices_output() { + let h = PoolHandler; + let node = make_node( + "MaxPool", + vec!["x"], + vec!["y", "indices"], + vec![ints_attr("kernel_shape", vec![2, 2])], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 8, 14, 14]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let err = h.convert(&node, &ctx).unwrap_err(); + match err { + OnnxError::UnsupportedOp { op, .. } => { + assert!(op.contains("indices")); + } + other => panic!("expected UnsupportedOp, got {:?}", other), + } + } + + #[test] + fn maxpool1d_emulated_via_2d() { + let h = PoolHandler; + let node = make_node( + "MaxPool", + vec!["x"], + vec!["y"], + vec![ + ints_attr("kernel_shape", vec![3]), + ints_attr("strides", vec![2]), + ints_attr("pads", vec![1, 1]), + ], + ); + let initializers = HashMap::new(); + let mut value_shapes = HashMap::new(); + value_shapes.insert("x".to_string(), vec![1, 16, 64]); + let const_values = HashMap::new(); + let value_ids = HashMap::new(); + let value_types = HashMap::new(); + let ctx = make_context( + &initializers, + &value_shapes, + &const_values, + &value_ids, + &value_types, + ); + + let result = h.convert(&node, &ctx).unwrap(); + assert_eq!(result.nodes.len(), 3); // reshape -> pool -> reshape + assert_eq!(result.nodes[0].op, "reshape"); + assert_eq!(result.nodes[1].op, "maxPool2d"); + assert_eq!(result.nodes[2].op, "reshape"); + let pool = &result.nodes[1]; + assert_eq!(pool.options.get("windowDimensions"), Some(&json!([3, 1]))); + assert_eq!(pool.options.get("strides"), Some(&json!([2, 1]))); + // ONNX 1D pads [1, 1] -> 2D [1, 0, 1, 0] -> WebNN [1, 1, 0, 0] + assert_eq!(pool.options.get("padding"), Some(&json!([1, 1, 0, 0]))); + } +} diff --git a/src/onnx/ops/reshape.rs b/src/onnx/ops/reshape.rs index 01ebdfd..459a509 100644 --- a/src/onnx/ops/reshape.rs +++ b/src/onnx/ops/reshape.rs @@ -23,6 +23,7 @@ impl OpHandler for ReshapeHandler { | "Squeeze" | "Tile" | "Expand" + | "Flatten" ) } @@ -47,6 +48,7 @@ impl OpHandler for ReshapeHandler { "Squeeze" => self.convert_squeeze(node, &node_name, context), "Tile" => self.convert_tile(node, &node_name, context), "Expand" => self.convert_expand(node, &node_name, context), + "Flatten" => self.convert_flatten(node, &node_name, context), _ => Err(OnnxError::UnsupportedOp { op: op_type.to_string(), node: node_name, @@ -1295,6 +1297,94 @@ impl ReshapeHandler { Ok(result) } + + /// Convert ONNX Flatten to WebNN reshape. + /// + /// ONNX Flatten reshapes the input to a 2-D matrix `(d_0 * ... * d_{axis-1}, + /// d_axis * ... * d_{n-1})` where `axis` defaults to 1 and may be negative. + fn convert_flatten( + &self, + node: &NodeProto, + node_name: &str, + context: &ConversionContext, + ) -> Result { + let inputs = node.input.as_slice(); + if inputs.len() != 1 { + return Err(OnnxError::InvalidShape(format!( + "Flatten expects 1 input, got {}", + inputs.len() + ))); + } + + let mut axis: i64 = 1; + for attr in node.attribute.as_slice() { + if attr.name.as_str() == "axis" { + axis = attr.i; + break; + } + } + + let output_name = if node.output.as_slice().is_empty() { + format!("{}_output", node_name) + } else { + sanitize_identifier(&node.output.as_slice()[0].to_string()) + }; + + let input_raw = inputs[0].to_string(); + let input_id = context.resolve_input(&input_raw); + let input_shape = context + .value_shapes + .get(&input_raw) + .or_else(|| context.value_shapes.get(&sanitize_identifier(&input_raw))) + .cloned() + .ok_or_else(|| { + OnnxError::InvalidShape(format!("Flatten: input '{}' shape is unknown", input_raw)) + })?; + + let rank = input_shape.len() as i64; + let normalized_axis = if axis < 0 { axis + rank } else { axis }; + if normalized_axis < 0 || normalized_axis > rank { + return Err(OnnxError::InvalidShape(format!( + "Flatten axis {} out of range for input rank {}", + axis, rank + ))); + } + let axis_usize = normalized_axis as usize; + + // ONNX semantics: axis == 0 means output [1, prod(shape)]; axis == rank means [prod(shape), 1]. + let outer: i64 = if axis_usize == 0 { + 1 + } else { + input_shape[..axis_usize].iter().product() + }; + let inner: i64 = if axis_usize == input_shape.len() { + 1 + } else { + input_shape[axis_usize..].iter().product() + }; + + let mut options = Map::new(); + options.insert("newShape".to_string(), serde_json::json!([outer, inner])); + + let mut result = ConversionResult::new(vec![Node { + id: output_name.clone(), + op: "reshape".to_string(), + inputs: vec![input_id], + options, + outputs: None, + }]); + + if let Some(out) = node.output.as_slice().first() { + result + .output_mappings + .insert(out.to_string(), output_name.clone()); + if let Some(dtype) = context.value_types.get(&input_raw) { + result.output_types.insert(out.to_string(), dtype.clone()); + } + } + + Ok(result) + } } #[cfg(test)] diff --git a/src/onnx/shape_inference.rs b/src/onnx/shape_inference.rs index 65ca7e5..711cdff 100644 --- a/src/onnx/shape_inference.rs +++ b/src/onnx/shape_inference.rs @@ -653,6 +653,225 @@ fn infer_node_shape(node: &NodeProto, ctx: &InferenceResult) -> Option> Some(out) } } + "MaxPool" | "AveragePool" => { + let ins = node.input.as_slice(); + if ins.is_empty() { + return None; + } + let x_shape = ctx.value_shapes.get(ins[0].as_str())?.clone(); + if x_shape.len() < 3 { + return None; + } + let spatial_rank = x_shape.len() - 2; + + let mut auto_pad = String::from("NOTSET"); + let mut strides: Vec = vec![1; spatial_rank]; + let mut dilations: Vec = vec![1; spatial_rank]; + let mut pads: Vec = vec![0; 2 * spatial_rank]; + let mut kernel_shape: Vec = Vec::new(); + let mut ceil_mode = false; + for attr in node.attribute.as_slice() { + match attr.name.as_str() { + "auto_pad" => { + if let Ok(s) = String::from_utf8(attr.s.clone()) { + if !s.is_empty() { + auto_pad = s; + } + } + } + "kernel_shape" if !attr.ints.is_empty() => kernel_shape = attr.ints.clone(), + "strides" if !attr.ints.is_empty() => strides = attr.ints.clone(), + "dilations" if !attr.ints.is_empty() => dilations = attr.ints.clone(), + "pads" if !attr.ints.is_empty() => pads = attr.ints.clone(), + "ceil_mode" => ceil_mode = attr.i != 0, + _ => {} + } + } + if kernel_shape.len() != spatial_rank + || strides.len() != spatial_rank + || dilations.len() != spatial_rank + || pads.len() != 2 * spatial_rank + { + return None; + } + + let mut out_spatial = Vec::with_capacity(spatial_rank); + for i in 0..spatial_rank { + let in_dim = x_shape[2 + i]; + let k = kernel_shape[i]; + let s = strides[i]; + let d = dilations[i]; + let dilated_k = d * (k - 1) + 1; + let out_dim = match auto_pad.as_str() { + "SAME_UPPER" | "SAME_LOWER" => (in_dim + s - 1) / s, + "VALID" => (in_dim - dilated_k) / s + 1, + _ => { + let pad_begin = pads[i]; + let pad_end = pads[i + spatial_rank]; + let numerator = in_dim + pad_begin + pad_end - dilated_k; + if ceil_mode { + (numerator + s - 1) / s + 1 + } else { + numerator / s + 1 + } + } + }; + if out_dim < 0 { + return None; + } + out_spatial.push(out_dim); + } + + let mut out = vec![x_shape[0], x_shape[1]]; + out.extend(out_spatial); + Some(out) + } + "GlobalMaxPool" | "GlobalAveragePool" => { + let ins = node.input.as_slice(); + if ins.is_empty() { + return None; + } + let x_shape = ctx.value_shapes.get(ins[0].as_str())?.clone(); + if x_shape.len() < 3 { + return None; + } + let mut out = vec![x_shape[0], x_shape[1]]; + out.extend(std::iter::repeat_n(1i64, x_shape.len() - 2)); + Some(out) + } + "Flatten" => { + let ins = node.input.as_slice(); + if ins.is_empty() { + return None; + } + let x_shape = ctx.value_shapes.get(ins[0].as_str())?.clone(); + let axis = node + .attribute + .as_slice() + .iter() + .find(|a| a.name.as_str() == "axis") + .map(|a| a.i) + .unwrap_or(1); + let rank = x_shape.len() as i64; + let norm = if axis < 0 { axis + rank } else { axis }; + if norm < 0 || norm > rank { + return None; + } + let norm = norm as usize; + let outer: i64 = if norm == 0 { + 1 + } else { + x_shape[..norm].iter().product() + }; + let inner: i64 = if norm == x_shape.len() { + 1 + } else { + x_shape[norm..].iter().product() + }; + Some(vec![outer, inner]) + } + "Conv" | "ConvTranspose" => { + let ins = node.input.as_slice(); + if ins.len() < 2 { + return None; + } + let x_shape = ctx.value_shapes.get(ins[0].as_str())?.clone(); + let w_shape = ctx.value_shapes.get(ins[1].as_str())?.clone(); + if x_shape.len() < 3 || w_shape.len() != x_shape.len() { + return None; + } + let spatial_rank = x_shape.len() - 2; + + let mut auto_pad = String::from("NOTSET"); + let mut strides: Vec = vec![1; spatial_rank]; + let mut dilations: Vec = vec![1; spatial_rank]; + let mut pads: Vec = vec![0; 2 * spatial_rank]; + let mut kernel_shape: Vec = w_shape[2..].to_vec(); + let mut group: i64 = 1; + let mut output_padding: Vec = vec![0; spatial_rank]; + let mut output_shape_attr: Vec = Vec::new(); + for attr in node.attribute.as_slice() { + match attr.name.as_str() { + "auto_pad" => { + if let Ok(s) = String::from_utf8(attr.s.clone()) { + if !s.is_empty() { + auto_pad = s; + } + } + } + "strides" if !attr.ints.is_empty() => strides = attr.ints.clone(), + "dilations" if !attr.ints.is_empty() => dilations = attr.ints.clone(), + "pads" if !attr.ints.is_empty() => pads = attr.ints.clone(), + "kernel_shape" if !attr.ints.is_empty() => kernel_shape = attr.ints.clone(), + "group" if attr.i > 0 => group = attr.i, + "output_padding" if !attr.ints.is_empty() => output_padding = attr.ints.clone(), + "output_shape" if !attr.ints.is_empty() => { + output_shape_attr = attr.ints.clone() + } + _ => {} + } + } + if strides.len() != spatial_rank + || dilations.len() != spatial_rank + || kernel_shape.len() != spatial_rank + || pads.len() != 2 * spatial_rank + || output_padding.len() != spatial_rank + { + return None; + } + + let transpose = op == "ConvTranspose"; + let m = if transpose { + w_shape[1] * group + } else { + w_shape[0] + }; + + if transpose && !output_shape_attr.is_empty() { + let sizes = if output_shape_attr.len() == spatial_rank { + output_shape_attr.clone() + } else if output_shape_attr.len() == x_shape.len() { + output_shape_attr[2..].to_vec() + } else { + return None; + }; + let mut out = vec![x_shape[0], m]; + out.extend(sizes); + return Some(out); + } + + let mut out_spatial = Vec::with_capacity(spatial_rank); + for i in 0..spatial_rank { + let in_dim = x_shape[2 + i]; + let k = kernel_shape[i]; + let s = strides[i]; + let d = dilations[i]; + let dilated_k = d * (k - 1) + 1; + let out_dim = match auto_pad.as_str() { + "SAME_UPPER" | "SAME_LOWER" if !transpose => (in_dim + s - 1) / s, + "SAME_UPPER" | "SAME_LOWER" if transpose => in_dim * s, + "VALID" if !transpose => (in_dim - dilated_k) / s + 1, + "VALID" if transpose => (in_dim - 1) * s + dilated_k, + _ => { + let pad_begin = pads[i]; + let pad_end = pads[i + spatial_rank]; + if transpose { + (in_dim - 1) * s - pad_begin - pad_end + dilated_k + output_padding[i] + } else { + (in_dim + pad_begin + pad_end - dilated_k) / s + 1 + } + } + }; + if out_dim < 0 { + return None; + } + out_spatial.push(out_dim); + } + + let mut out = vec![x_shape[0], m]; + out.extend(out_spatial); + Some(out) + } "ReduceMean" | "ReduceSum" | "ReduceMax" | "ReduceMin" => { let input = node.input.as_slice().first()?; let input_shape = ctx.value_shapes.get(input)?; diff --git a/tests/resnet50_conversion.rs b/tests/resnet50_conversion.rs new file mode 100644 index 0000000..bead15d --- /dev/null +++ b/tests/resnet50_conversion.rs @@ -0,0 +1,234 @@ +//! End-to-end conversion test for ResNet-50 (ONNX Model Zoo). +//! +//! The downloaded model is cached at `target/test-data/resnet50_Opset16.onnx`. +//! If a copy already exists at the repository root (the same file used during +//! manual testing) the test uses that and skips the download. + +use std::fs; +use std::path::{Path, PathBuf}; + +use webnn_graph::ast::GraphJson; +use webnn_graph::emit_js::{emit_builder_js, emit_weights_loader_js}; +use webnn_graph::onnx::convert::{convert_onnx, ConvertOptions}; +use webnn_graph::serialize::{serialize_graph_to_wg_text, SerializeOptions}; +use webnn_graph::validate::{validate_graph, validate_weights}; +use webnn_graph::weights::WeightsManifest; + +const MODEL_URL: &str = "https://media.githubusercontent.com/media/onnx/models/refs/heads/main/Computer_Vision/resnet50_Opset16_timm/resnet50_Opset16.onnx"; +const MODEL_FILENAME: &str = "resnet50_Opset16.onnx"; +// Exact size of the LFS-resolved blob. Catches partial downloads, LFS pointer +// files (~130 B), and upstream model swaps. Bump this if MODEL_URL ever points +// to a different file. +const EXPECTED_MODEL_SIZE_BYTES: u64 = 102_146_206; + +fn locate_or_download_model() -> PathBuf { + // Download and cache the model under target/test-data/. + let repo_root = Path::new(env!("CARGO_MANIFEST_DIR")); + let cache_dir = repo_root.join("target").join("test-data"); + fs::create_dir_all(&cache_dir).expect("create test-data cache dir"); + let cached = cache_dir.join(MODEL_FILENAME); + if cached.exists() && file_size(&cached) == EXPECTED_MODEL_SIZE_BYTES { + return cached; + } + + eprintln!("Downloading {} -> {}", MODEL_URL, cached.display()); + let mut builder = ureq::AgentBuilder::new().timeout(std::time::Duration::from_secs(600)); + // Honor the conventional proxy env vars so the test works behind a corporate + // proxy without code changes. ureq does not read these automatically. + if let Some(proxy_url) = proxy_from_env() { + match ureq::Proxy::new(&proxy_url) { + Ok(proxy) => { + eprintln!("Using proxy from environment: {}", proxy_url); + builder = builder.proxy(proxy); + } + Err(e) => eprintln!("Ignoring malformed proxy env var '{}': {}", proxy_url, e), + } + } + let agent = builder.build(); + let response = agent + .get(MODEL_URL) + .call() + .expect("download ResNet-50 model"); + let mut out = fs::File::create(&cached).expect("create cache file"); + std::io::copy(&mut response.into_reader(), &mut out).expect("stream model to disk"); + + let downloaded_size = file_size(&cached); + assert_eq!( + downloaded_size, EXPECTED_MODEL_SIZE_BYTES, + "downloaded model size mismatch (got {} bytes, expected {}). \ + Likely a Git LFS pointer file, a truncated download, or the upstream \ + model was swapped — bump EXPECTED_MODEL_SIZE_BYTES if intentional.", + downloaded_size, EXPECTED_MODEL_SIZE_BYTES, + ); + + cached +} + +fn file_size(path: &Path) -> u64 { + fs::metadata(path).map(|m| m.len()).unwrap_or(0) +} + +/// Read the conventional HTTPS/HTTP proxy environment variables. The HTTPS variant +/// is preferred since the model URL is HTTPS, but we fall back to plain HTTP_PROXY +/// for environments that only set one. +fn proxy_from_env() -> Option { + for var in ["HTTPS_PROXY", "https_proxy", "HTTP_PROXY", "http_proxy"] { + if let Ok(v) = std::env::var(var) { + if !v.is_empty() { + return Some(v); + } + } + } + None +} + +fn count_op(graph: &GraphJson, op: &str) -> usize { + graph.nodes.iter().filter(|n| n.op == op).count() +} + +#[test] +fn resnet50_converts_validates_and_emits_js() { + let model_path = locate_or_download_model(); + let tmp = tempfile::tempdir().expect("create temp dir"); + + let weights_path = tmp.path().join("resnet50.weights"); + let manifest_path = tmp.path().join("resnet50.manifest.json"); + let webnn_path = tmp.path().join("resnet50.webnn"); + + let options = ConvertOptions { + extract_weights: true, + output_path: webnn_path.to_string_lossy().into_owned(), + weights_path: Some(weights_path.to_string_lossy().into_owned()), + manifest_path: Some(manifest_path.to_string_lossy().into_owned()), + free_dim_overrides: Default::default(), + optimize: true, + experimental_dynamic_inputs: false, + }; + + let graph = convert_onnx(&model_path, options).expect("convert ResNet-50"); + + // ResNet-50 stem + 4 stages of bottlenecks: expect a large but predictable op mix. + let conv2d = count_op(&graph, "conv2d"); + let max_pool = count_op(&graph, "maxPool2d"); + let avg_pool = count_op(&graph, "averagePool2d"); + let relu = count_op(&graph, "relu"); + let add = count_op(&graph, "add"); + let matmul = count_op(&graph, "matmul"); + let reshape = count_op(&graph, "reshape"); + + eprintln!( + "ResNet-50 conversion produced: conv2d={conv2d}, maxPool2d={max_pool}, \ + averagePool2d={avg_pool}, relu={relu}, add={add}, matmul={matmul}, reshape={reshape}, \ + total_nodes={total}", + total = graph.nodes.len() + ); + + // Exact counts: this is a fixed, versioned ONNX file so any drift here means + // either the converter changed semantics or the upstream model was swapped. + assert_eq!(conv2d, 53, "expected exactly 53 conv2d nodes"); + assert_eq!(max_pool, 1, "expected exactly 1 maxPool2d node (stem)"); + assert_eq!( + avg_pool, 1, + "expected exactly 1 averagePool2d node (global pool head)" + ); + assert_eq!(relu, 49, "expected exactly 49 relu nodes"); + assert_eq!( + add, 17, + "expected exactly 17 add nodes (16 residual + FC bias)" + ); + assert_eq!(matmul, 1, "expected exactly 1 matmul (final FC layer)"); + assert_eq!(reshape, 1, "expected exactly 1 reshape (from Flatten)"); + assert_eq!(graph.nodes.len(), 124, "expected exactly 124 total nodes"); + + // The graph has a single image input and a single classification output. + assert_eq!(graph.inputs.len(), 1, "expected exactly 1 input"); + assert_eq!(graph.outputs.len(), 1, "expected exactly 1 output"); + + // Stem conv: 7×7 kernel, stride 2, padding 3 on every side. Locating it via + // the windowDimensions option keeps the assertion robust to op ordering. + let stem_conv = graph + .nodes + .iter() + .find(|n| n.op == "conv2d" && n.options.get("strides") == Some(&serde_json::json!([2, 2]))) + .expect("stem conv2d (stride 2) not found"); + assert_eq!( + stem_conv.options.get("padding"), + Some(&serde_json::json!([3, 3, 3, 3])), + "stem conv2d should have padding [3,3,3,3]" + ); + + // MaxPool stem: 3×3 window, stride 2, padding 1 on every side. + let max_pool_node = graph + .nodes + .iter() + .find(|n| n.op == "maxPool2d") + .expect("maxPool2d not found"); + assert_eq!( + max_pool_node.options.get("windowDimensions"), + Some(&serde_json::json!([3, 3])), + ); + assert_eq!( + max_pool_node.options.get("strides"), + Some(&serde_json::json!([2, 2])), + ); + assert_eq!( + max_pool_node.options.get("padding"), + Some(&serde_json::json!([1, 1, 1, 1])), + ); + + // GlobalAveragePool over the final 7×7 feature map. + let avg_pool_node = graph + .nodes + .iter() + .find(|n| n.op == "averagePool2d") + .expect("averagePool2d not found"); + assert_eq!( + avg_pool_node.options.get("windowDimensions"), + Some(&serde_json::json!([7, 7])), + ); + + // Flatten lowered to reshape [1, 2048] (batch=1, FC input dim). + let flatten_reshape = graph + .nodes + .iter() + .find(|n| { + n.op == "reshape" && n.options.get("newShape") == Some(&serde_json::json!([1, 2048])) + }) + .expect("flatten-as-reshape [1, 2048] not found"); + let _ = flatten_reshape; // assertion above is enough; kept binding for clarity. + + // Validate against the extracted manifest. + let manifest_text = fs::read_to_string(&manifest_path).expect("read manifest"); + let manifest: WeightsManifest = serde_json::from_str(&manifest_text).expect("parse manifest"); + validate_graph(&graph).expect("graph passes structural validation"); + validate_weights(&graph, &manifest).expect("manifest matches graph constants"); + + // Round-trip through the .webnn text format (this used to fail for ResNet-50 + // because some intermediate ONNX values are pure digits, e.g. "495", which the + // grammar rejects unless sanitize_identifier prefixes them). + let serialized = serialize_graph_to_wg_text(&graph, SerializeOptions::default()) + .expect("serialize to .webnn text"); + let reparsed = + webnn_graph::parser::parse_wg_text(&serialized).expect("re-parse serialized .webnn"); + assert_eq!(reparsed.nodes.len(), graph.nodes.len()); + assert_eq!(reparsed.inputs.len(), graph.inputs.len()); + assert_eq!(reparsed.outputs.len(), graph.outputs.len()); + + // Emit JS and spot-check that all four new operators reached the output. + let js = format!("{}\n{}", emit_weights_loader_js(), emit_builder_js(&graph)); + for needle in [ + "MLGraphBuilder", + "builder[\"conv2d\"]", + "builder[\"maxPool2d\"]", + "builder[\"averagePool2d\"]", + "builder[\"reshape\"]", + "builder[\"add\"]", + "builder[\"relu\"]", + "builder.build(outputs)", + ] { + assert!( + js.contains(needle), + "emitted JS missing expected token: {needle}", + ); + } +}