Skip to content

Add weighted einsum implementation#671

Open
jfeser wants to merge 72 commits into
staging-weightedfrom
jf-weighted-einsum
Open

Add weighted einsum implementation#671
jfeser wants to merge 72 commits into
staging-weightedfrom
jf-weighted-einsum

Conversation

@jfeser
Copy link
Copy Markdown
Contributor

@jfeser jfeser commented Jun 4, 2026

Adds an einsum implementation at handlers.jax.einsum, which produces a weighted term, normalizes it, and jits it. Performance is comparable with the jax einsum implementation.

Adds optimization rules:

  • ReduceSumProductContraction replaces Sum.reduce(Product.plus(A, B), streams) with a call to jnp.tensordot.
  • ReduceOrderContraction uses opt_einsum to choose a contraction ordering, producing pairwise reductions.
  • ArrayReduce now has a fast path for reduction over arange streams that produces slices instead of gathers.

Changes to existing operations:

  • bind_dims now introduces missing dimensions so bind_dims(unbind_dims(t, x), x, y) reduces to a tensor with two leading dimensions instead of failing to reduce.

Benchmarks:

---------------------------------------------------------------------------------------------- benchmark 'spec=,ij,ij->': 2 tests ----------------------------------------------------------------------------------------------
Name (time in us)                                          Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[scalar_scaled_reduce-jax]            98.2500 (1.0)      239.1250 (1.0)      116.1235 (1.0)       9.6665 (1.0)      112.9580 (1.0)       5.5317 (1.0)       784;779        8.6115 (1.0)        7463           1
test_einsum_bench[scalar_scaled_reduce-effectful]     140.4160 (1.43)     270.6670 (1.13)     162.6965 (1.40)     11.6362 (1.20)     158.8330 (1.41)     10.7089 (1.94)      743;336        6.1464 (0.71)       4989           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=,ij,ij-_.svg
------------------------------------------------------------------------------- benchmark 'spec=a,abi,bcij,cdij->ij': 2 tests --------------------------------------------------------------------------------
Name (time in us)                              Min                 Max              Mean            StdDev            Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[mixed_rank-jax]           6.0000 (1.0)      117.6250 (1.06)     7.7225 (1.0)      3.9413 (1.0)      7.2500 (1.0)      0.4590 (1.0)      447;1090      129.4921 (1.0)       22141           1
test_einsum_bench[mixed_rank-effectful]     6.0840 (1.01)     110.4580 (1.0)      8.5189 (1.10)     5.7894 (1.47)     7.4999 (1.03)     0.6669 (1.45)     465;1120      117.3867 (0.91)      13461           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=a,abi,bcij,cdij-_ij.svg
------------------------------------------------------------------------------- benchmark 'spec=ab,bc,cd,de,ef->af': 2 tests -------------------------------------------------------------------------------
Name (time in us)                           Min                 Max               Mean            StdDev            Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[chain_5-effectful]     6.7910 (1.0)      135.4170 (1.08)      9.7131 (1.0)      4.5718 (1.0)      9.7080 (1.0)      2.0411 (5.43)      410;520      102.9536 (1.0)       22202           1
test_einsum_bench[chain_5-jax]           7.0001 (1.03)     125.3340 (1.0)      10.4208 (1.07)     4.7951 (1.05)     9.7920 (1.01)     0.3759 (1.0)      389;3123       95.9622 (0.93)      17094           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ab,bc,cd,de,ef-_af.svg
---------------------------------------------------------------------------------- benchmark 'spec=ab,bc,cd,de->ae': 2 tests -----------------------------------------------------------------------------------
Name (time in us)                            Min                 Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[chain_4-jax]            8.4171 (1.0)      429.7080 (2.99)     20.7793 (1.0)      12.7268 (1.68)     17.5835 (1.0)      14.6249 (1.34)     2317;614       48.1247 (1.0)       23646           1
test_einsum_bench[chain_4-effectful]     10.2919 (1.22)     143.9170 (1.0)      25.8598 (1.24)      7.5661 (1.0)      24.1250 (1.37)     10.8740 (1.0)      2315;192       38.6700 (0.80)      16950           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ab,bc,cd,de-_ae.svg
------------------------------------------------------------------------------------ benchmark 'spec=ab,bc,cd->ad': 2 tests -----------------------------------------------------------------------------------
Name (time in us)                                Min                Max              Mean            StdDev            Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[skewed_chain-jax]           5.4580 (1.0)      95.0421 (1.34)     6.4028 (1.00)     2.6042 (1.21)     6.1250 (1.0)      0.3750 (1.0)      315;1955      156.1825 (1.00)      25026           1
test_einsum_bench[skewed_chain-effectful]     5.5410 (1.02)     71.0420 (1.0)      6.3803 (1.0)      2.1519 (1.0)      6.1670 (1.01)     0.4170 (1.11)      146;611      156.7315 (1.0)       15979           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ab,bc,cd-_ad.svg
---------------------------------------------------------------------------------------- benchmark 'spec=abc,cde->abde': 2 tests -----------------------------------------------------------------------------------------
Name (time in us)                                       Min                 Max               Mean            StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[tensor_contraction-effectful]     12.2080 (1.0)      167.9170 (1.0)      32.2843 (1.15)     8.0095 (1.0)      30.6671 (1.09)      6.4160 (1.0)     1874;1082       30.9748 (0.87)      19247           1
test_einsum_bench[tensor_contraction-jax]           12.3340 (1.01)     182.7499 (1.09)     28.1879 (1.0)      9.8916 (1.23)     28.0830 (1.0)      11.7499 (1.83)     4717;204       35.4762 (1.0)       15790           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=abc,cde-_abde.svg
------------------------------------------------------------------------------------- benchmark 'spec=ai,bi,ci,di->abcd': 2 tests -------------------------------------------------------------------------------------
Name (time in us)                                     Min                 Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[star_contraction-jax]           13.3750 (1.0)      152.1250 (1.0)      30.0178 (1.0)      5.3877 (1.0)      28.0420 (1.0)      5.1670 (1.00)     1504;458       33.3135 (1.0)       15926           1
test_einsum_bench[star_contraction-effectful]     13.6250 (1.02)     167.6669 (1.10)     30.8312 (1.03)     5.4680 (1.01)     29.5000 (1.05)     5.1669 (1.0)      1480;398       32.4347 (0.97)      14252           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ai,bi,ci,di-_abcd.svg
------------------------------------------------------------------------------------ benchmark 'spec=bi,ij,bj->b': 2 tests ------------------------------------------------------------------------------------
Name (time in us)                             Min                 Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[bilinear-jax]           27.4171 (1.0)      172.2500 (1.0)      47.2378 (1.0)      5.6295 (1.0)      47.9170 (1.0)      3.3760 (1.0)      1536;596       21.1695 (1.0)       12960           1
test_einsum_bench[bilinear-effectful]     33.7500 (1.23)     179.4590 (1.04)     55.8158 (1.18)     8.5959 (1.53)     55.2500 (1.15)     7.7090 (2.28)     1966;700       17.9161 (0.85)      11101           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=bi,ij,bj-_b.svg
------------------------------------------------------------------------------------- benchmark 'spec=bii->b': 2 tests ------------------------------------------------------------------------------------
Name (time in ms)                                 Min               Max              Mean            StdDev            Median               IQR            Outliers       OPS            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[batched_trace-effectful]     1.1437 (1.0)      1.5298 (1.0)      1.2884 (1.00)     0.0683 (1.0)      1.2852 (1.00)     0.0648 (1.00)       225;52  776.1822 (1.00)        750           1
test_einsum_bench[batched_trace-jax]           1.1455 (1.00)     2.1825 (1.43)     1.2856 (1.0)      0.0835 (1.22)     1.2815 (1.0)      0.0647 (1.0)         90;21  777.8369 (1.0)         401           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=bii-_b.svg
-------------------------------------------------------------------------------------- benchmark 'spec=bij,bjk,bkl->bil': 2 tests --------------------------------------------------------------------------------------
Name (time in us)                                  Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[batched_chain-jax]           35.6250 (1.0)      500.2080 (1.0)       55.2209 (1.0)      19.8227 (1.0)       54.8330 (1.0)       7.8320 (1.0)       273;404       18.1091 (1.0)       13469           1
test_einsum_bench[batched_chain-effectful]     72.7091 (2.04)     840.2080 (1.68)     116.2667 (2.11)     44.2612 (2.23)     106.8331 (1.95)     35.6039 (4.55)      470;174        8.6009 (0.47)       8111           1
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=bij,bjk,bkl-_bil.svg
----------------------------------------------------------------------------------------- benchmark 'spec=bij,bjk->bik': 2 tests -----------------------------------------------------------------------------------------
Name (time in us)                                   Min                   Max               Mean              StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[batched_matmul-jax]           22.5409 (1.0)        133.1250 (1.0)      37.9738 (1.0)        8.0505 (1.0)      36.3750 (1.0)       9.0000 (1.0)      2817;336       26.3340 (1.0)       15010           1
test_einsum_bench[batched_matmul-effectful]     37.1670 (1.65)     9,025.7920 (67.80)    92.0579 (2.42)     175.6027 (21.81)    79.2915 (2.18)     23.9171 (2.66)       52;608       10.8627 (0.41)       9938           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=bij,bjk-_bik.svg
-------------------------------------------------------------------------------------------- benchmark 'spec=i,j->ij': 2 tests --------------------------------------------------------------------------------------------
Name (time in us)                                   Min                   Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[outer_product-effectful]     143.4590 (1.0)      1,954.0000 (4.71)     194.3566 (1.01)     49.9831 (1.42)     181.5840 (1.01)     27.4265 (1.05)      298;322        5.1452 (0.99)       2733           1
test_einsum_bench[outer_product-jax]           155.2500 (1.08)       414.7920 (1.0)      191.5506 (1.0)      35.2963 (1.0)      179.9580 (1.0)      26.0209 (1.0)       421;389        5.2206 (1.0)        3420           1
---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=i,j-_ij.svg
----------------------------------------------------------------------------------------- benchmark 'spec=ii->': 2 tests ----------------------------------------------------------------------------------------
Name (time in us)                           Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[trace-effectful]     420.2920 (1.0)      784.4160 (1.0)      496.0164 (1.0)      36.9181 (1.0)      489.0830 (1.0)      61.6348 (1.06)        651;3        2.0161 (1.0)        1807           1
test_einsum_bench[trace-jax]           423.2910 (1.01)     820.2920 (1.05)     505.9884 (1.02)     41.8712 (1.13)     499.8749 (1.02)     58.1354 (1.0)        637;21        1.9763 (0.98)       1993           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ii-_.svg
----------------------------------------------------------------------------------------- benchmark 'spec=ii->i': 2 tests ------------------------------------------------------------------------------------------
Name (time in us)                              Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[diagonal-jax]           250.6670 (1.0)      541.7500 (1.0)      359.6303 (1.00)     38.6073 (1.0)      356.8330 (1.00)     59.5420 (1.0)         894;3        2.7806 (1.00)       2678           1
test_einsum_bench[diagonal-effectful]     253.8330 (1.01)     614.8330 (1.13)     358.9530 (1.0)      41.4854 (1.07)     355.1249 (1.0)      65.6875 (1.10)        773;5        2.7859 (1.0)        2224           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ii-_i.svg
--------------------------------------------------------------------------------------------- benchmark 'spec=iij->ij': 2 tests ---------------------------------------------------------------------------------------------
Name (time in us)                                   Min                   Max                Mean              StdDev              Median                 IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[diagonal_keep-effectful]     516.3749 (1.0)      1,767.0831 (1.44)     738.0344 (1.03)     126.5156 (1.24)     727.3749 (1.04)     128.2087 (1.17)       248;26        1.3550 (0.97)       1171           1
test_einsum_bench[diagonal_keep-jax]           523.5000 (1.01)     1,228.4170 (1.0)      714.4403 (1.0)      101.8507 (1.0)      700.3749 (1.0)      109.6150 (1.0)        204;10        1.3997 (1.0)         709           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=iij-_ij.svg
--------------------------------------------------------------------------------------- benchmark 'spec=ij,ij->ij': 2 tests ----------------------------------------------------------------------------------------
Name (time in us)                             Min                    Max               Mean              StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[hadamard-effectful]     37.9171 (1.0)      47,111.5421 (279.04)   57.6196 (1.12)     522.6699 (47.44)    47.5410 (1.0)      2.5410 (1.0)        3;1061       17.3552 (0.89)       8120           1
test_einsum_bench[hadamard-jax]           38.4591 (1.01)        168.8330 (1.0)      51.2457 (1.0)       11.0177 (1.0)      47.5420 (1.00)     2.5830 (1.02)     829;1264       19.5138 (1.0)        9123           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ij,ij-_ij.svg
------------------------------------------------------------------------------------- benchmark 'spec=ij,jk,ik->': 2 tests -------------------------------------------------------------------------------------
Name (time in us)                             Min                 Max               Mean            StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[triangle-jax]            9.4160 (1.0)      111.5830 (1.0)      19.9889 (1.0)      6.7050 (1.49)     18.0420 (1.0)      11.5420 (2.22)      2950;87       50.0279 (1.0)       16998           1
test_einsum_bench[triangle-effectful]     14.5419 (1.54)     138.5830 (1.24)     30.6376 (1.53)     4.4994 (1.0)      29.0840 (1.61)      5.2081 (1.0)      1886;273       32.6396 (0.65)      18619           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ij,jk,ik-_.svg
----------------------------------------------------------------------------------------- benchmark 'spec=ij,jk,ki->': 2 tests ----------------------------------------------------------------------------------------
Name (time in us)                                     Min                 Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[trace_of_product-jax]           20.5830 (1.0)      161.2920 (1.03)     37.7087 (1.0)      5.5924 (1.0)      35.3330 (1.0)      5.6669 (1.06)     1014;273       26.5191 (1.0)       10248           1
test_einsum_bench[trace_of_product-effectful]     22.2080 (1.08)     156.0420 (1.0)      39.0765 (1.04)     5.6737 (1.01)     37.3750 (1.06)     5.3331 (1.0)      2383;340       25.5908 (0.96)      13715           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ij,jk,ki-_.svg
------------------------------------------------------------------------------------ benchmark 'spec=ij,jk->ik': 2 tests ------------------------------------------------------------------------------------
Name (time in us)                           Min                 Max               Mean            StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[matmul-effectful]     13.5420 (1.0)      122.5410 (1.04)     28.9971 (1.0)      4.2935 (1.0)      27.2919 (1.0)      4.4999 (1.0)      1396;331       34.4862 (1.0)       12645           1
test_einsum_bench[matmul-jax]           13.6250 (1.01)     117.9170 (1.0)      29.1672 (1.01)     4.7132 (1.10)     27.3751 (1.00)     4.5419 (1.01)     1787;489       34.2851 (0.99)      17493           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ij,jk-_ik.svg
-------------------------------------------------------------------------------------- benchmark 'spec=ij,kl->ijkl': 2 tests ---------------------------------------------------------------------------------------
Name (time in us)                              Min                 Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[outer_4d-effectful]     143.9580 (1.0)      780.1671 (1.78)     205.5479 (1.03)     47.3344 (1.16)     191.0001 (1.01)     34.1875 (1.17)      572;417        4.8650 (0.97)       3303           1
test_einsum_bench[outer_4d-jax]           144.4580 (1.00)     438.5000 (1.0)      200.1329 (1.0)      40.9241 (1.0)      188.3125 (1.0)      29.2296 (1.0)       376;293        4.9967 (1.0)        2300           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ij,kl-_ijkl.svg
---------------------------------------------------------------------------------------- benchmark 'spec=ij->ji': 2 tests ---------------------------------------------------------------------------------------
Name (time in us)                              Min                 Max               Mean            StdDev             Median               IQR             Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[transpose-jax]           13.7499 (1.0)      129.1250 (1.0)      18.3844 (1.0)      5.5233 (1.0)      16.9580 (1.0)      0.9169 (1.0)      1572;4013       54.3939 (1.0)       24641           1
test_einsum_bench[transpose-effectful]     13.8331 (1.01)     132.0840 (1.02)     19.0492 (1.04)     6.4886 (1.17)     17.0830 (1.01)     1.3750 (1.50)     1403;2424       52.4957 (0.97)      17609           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ij-_ji.svg
------------------------------------------------------------------------------------ benchmark 'spec=ijk,jl,kl->il': 2 tests ------------------------------------------------------------------------------------
Name (time in us)                              Min                 Max               Mean            StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[hyperedge-jax]           19.7079 (1.0)      137.6250 (1.0)      37.3558 (1.0)      6.5955 (1.0)      34.7500 (1.0)       6.0411 (1.0)       954;363       26.7696 (1.0)        9816           1
test_einsum_bench[hyperedge-effectful]     22.9171 (1.16)     157.9170 (1.15)     39.2796 (1.05)     7.8602 (1.19)     37.7085 (1.09)     10.3750 (1.72)     1366;246       25.4585 (0.95)      11284           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ijk,jl,kl-_il.svg
-------------------------------------------------------------------------------------------- benchmark 'spec=ijk->': 2 tests --------------------------------------------------------------------------------------------
Name (time in us)                                 Min                   Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[full_reduce-jax]           221.5840 (1.0)      1,653.5419 (4.43)     250.3371 (1.02)     36.9706 (2.29)     240.2080 (1.0)      14.2083 (1.26)      251;406        3.9946 (0.98)       3701           1
test_einsum_bench[full_reduce-effectful]     222.9169 (1.01)       373.6251 (1.0)      246.3825 (1.0)      16.1692 (1.0)      240.8330 (1.00)     11.2919 (1.0)       458;348        4.0587 (1.0)        3811           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Generated histogram: benchmark_20260604_181000-spec=ijk-_.svg
------------------------------------------------------------------------------------------ benchmark 'spec=ijk->k': 2 tests ------------------------------------------------------------------------------------------
Name (time in us)                                   Min                 Max               Mean             StdDev             Median               IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_einsum_bench[partial_reduce-jax]           63.9589 (1.0)      295.3330 (1.0)      83.7702 (1.0)       9.2439 (1.0)      80.7910 (1.01)     5.1659 (1.0)     1144;1172       11.9374 (1.0)        9820           1
test_einsum_bench[partial_reduce-effectful]     66.1660 (1.03)     402.9160 (1.36)     84.5012 (1.01)     13.2542 (1.43)     80.3340 (1.0)      5.3851 (1.04)     857;1260       11.8341 (0.99)       9589           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Comment thread effectful/handlers/jax/_terms.py
Comment thread effectful/handlers/jax/monoid.py Outdated
Comment thread effectful/handlers/jax/monoid.py Outdated
Comment thread effectful/ops/monoid.py
Comment thread effectful/handlers/jax/monoid.py Outdated
norm_expr = handler(NormalizeIntp)(evaluate)(expr)

@jax.jit
def jitted_einsum(*args):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[perf] einsum() re-normalizes and re-jits on every call.

Each call rebuilds the symbolic expression, re-runs the full NormalizeIntp pipeline, and defines a fresh @jax.jit closure that is invoked once and discarded — so jax's per-wrapper compilation cache never hits. Timing shows a flat ~170 ms/call across repeated calls with identical spec+shapes. Callers that wrap einsum in an outer jax.jit (as the benchmark does) are fine since everything folds into their trace, but bare callers pay full normalization + XLA compile on every invocation. Caching norm_expr/the jitted function keyed on (subscripts, shapes, dtypes) would make repeat calls nearly free.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a good point? Was there some reason not to address this?

Comment thread effectful/handlers/jax/monoid.py Outdated
Comment thread effectful/handlers/jax/monoid.py Outdated
@jfeser jfeser requested a review from eb8680 June 4, 2026 21:56
@jfeser jfeser marked this pull request as ready for review June 4, 2026 21:56
Copy link
Copy Markdown
Contributor

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a nice start but it seems like a lot of this is tightly coupled to JAX and opt_einsum, to a degree that will make it difficult to use for anything other than implementing einsum.

return expr, fresh


class BindDimsBindDims(ObjectInterpretation):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it should be default behavior for bind_dims?

)
contract = used - elsewhere

# dispatching monoid2.plus on symbolic terms causes an infinite
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is confusing. I would expect the correct version of this rule to always make progress - it should push monoid.reduce over monoid2.plus, and since there's no rule elsewhere that does the opposite this non-termination issue should never arise. We should not need to depend on re-parenthesizing finitary products.

return ()

@implements(Monoid.reduce)
def _(self, monoid: Monoid, body, streams: Streams):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rule seems like it's trying to do too much at once, primarily as a side effect of treating the opt_einsum optimizer like magic. We'd probably have a much easier time debugging and generalizing the behavior here if we first break it into smaller steps.

return fwd()

stream_vars = set(streams.keys())
if not stream_vars:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be caught by another rule already? Why is this check necessary here?

if not stream_vars:
return fwd()

# grab sizes of reduction dimensions and any dimensions of the factors
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is all very specific to dense arrays and opt_einsum - we're giving up a lot of generality/simplicity and not getting much in return given how simple opt_einsum is under the hood.


@implements(Monoid.reduce)
def _(self, monoid: Monoid, body, streams: Streams):
if monoid is not Sum:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this rule handle Sum.reduce rather than Monoid.reduce?

with non-concrete bounds): every ``v()`` becomes
``unbind_dims(streams[k], fresh_v)`` -- a gather.

The two passes cannot be fused: the direct-index pass must see a bare
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should always be suspicious of cases where fusion seems impossible, because that is basically equivalent to saying that some term cannot be put into our desired normal form by evaluation. It's more likely that something is off about the rule - I'm finding this function and its return type very hard to parse.

norm_expr = handler(NormalizeIntp)(evaluate)(expr)

@jax.jit
def jitted_einsum(*args):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a good point? Was there some reason not to address this?

return fwd()

factors = body.args
if len(factors) < 2 or not all(
Copy link
Copy Markdown
Contributor

@eb8680 eb8680 Jun 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused about why this rule is handling the N-ary case len(factors) > 2, which seems to be adding a lot of complexity and muddying the division of labor with the main distributive rule above. I would only ever expect this rule be to invoke the kernel jnp.tensordot on two fully concrete arrays - it seems odd to carry around tail as well.

I also suspect a lot of this is reinventing internal logic inside jnp.einsum. It's probably easier to generate a jnp.einsum primitive call rather than a tensordot call and let JAX handle the reduction to tensordot.

class ReduceOrderContraction(ObjectInterpretation):
"""Reorder a large product before contraction using an ``opt_einsum`` path.

Matches ``monoid.reduce(monoid2.plus(f1, ..., fn), streams)`` where
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are this rule's implicit preconditions on streams? I don't see them listed or checked anywhere but I don't think it can accept arbitrary dependent streams even if they have Array-valued elements.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants