Add weighted einsum implementation#671
Conversation
| norm_expr = handler(NormalizeIntp)(evaluate)(expr) | ||
|
|
||
| @jax.jit | ||
| def jitted_einsum(*args): |
There was a problem hiding this comment.
[perf] einsum() re-normalizes and re-jits on every call.
Each call rebuilds the symbolic expression, re-runs the full NormalizeIntp pipeline, and defines a fresh @jax.jit closure that is invoked once and discarded — so jax's per-wrapper compilation cache never hits. Timing shows a flat ~170 ms/call across repeated calls with identical spec+shapes. Callers that wrap einsum in an outer jax.jit (as the benchmark does) are fine since everything folds into their trace, but bare callers pay full normalization + XLA compile on every invocation. Caching norm_expr/the jitted function keyed on (subscripts, shapes, dtypes) would make repeat calls nearly free.
There was a problem hiding this comment.
Seems like a good point? Was there some reason not to address this?
eb8680
left a comment
There was a problem hiding this comment.
It's a nice start but it seems like a lot of this is tightly coupled to JAX and opt_einsum, to a degree that will make it difficult to use for anything other than implementing einsum.
| return expr, fresh | ||
|
|
||
|
|
||
| class BindDimsBindDims(ObjectInterpretation): |
There was a problem hiding this comment.
This seems like it should be default behavior for bind_dims?
| ) | ||
| contract = used - elsewhere | ||
|
|
||
| # dispatching monoid2.plus on symbolic terms causes an infinite |
There was a problem hiding this comment.
This comment is confusing. I would expect the correct version of this rule to always make progress - it should push monoid.reduce over monoid2.plus, and since there's no rule elsewhere that does the opposite this non-termination issue should never arise. We should not need to depend on re-parenthesizing finitary products.
| return () | ||
|
|
||
| @implements(Monoid.reduce) | ||
| def _(self, monoid: Monoid, body, streams: Streams): |
There was a problem hiding this comment.
This rule seems like it's trying to do too much at once, primarily as a side effect of treating the opt_einsum optimizer like magic. We'd probably have a much easier time debugging and generalizing the behavior here if we first break it into smaller steps.
| return fwd() | ||
|
|
||
| stream_vars = set(streams.keys()) | ||
| if not stream_vars: |
There was a problem hiding this comment.
Shouldn't this be caught by another rule already? Why is this check necessary here?
| if not stream_vars: | ||
| return fwd() | ||
|
|
||
| # grab sizes of reduction dimensions and any dimensions of the factors |
There was a problem hiding this comment.
This is all very specific to dense arrays and opt_einsum - we're giving up a lot of generality/simplicity and not getting much in return given how simple opt_einsum is under the hood.
|
|
||
| @implements(Monoid.reduce) | ||
| def _(self, monoid: Monoid, body, streams: Streams): | ||
| if monoid is not Sum: |
There was a problem hiding this comment.
Should this rule handle Sum.reduce rather than Monoid.reduce?
| with non-concrete bounds): every ``v()`` becomes | ||
| ``unbind_dims(streams[k], fresh_v)`` -- a gather. | ||
|
|
||
| The two passes cannot be fused: the direct-index pass must see a bare |
There was a problem hiding this comment.
We should always be suspicious of cases where fusion seems impossible, because that is basically equivalent to saying that some term cannot be put into our desired normal form by evaluation. It's more likely that something is off about the rule - I'm finding this function and its return type very hard to parse.
| norm_expr = handler(NormalizeIntp)(evaluate)(expr) | ||
|
|
||
| @jax.jit | ||
| def jitted_einsum(*args): |
There was a problem hiding this comment.
Seems like a good point? Was there some reason not to address this?
| return fwd() | ||
|
|
||
| factors = body.args | ||
| if len(factors) < 2 or not all( |
There was a problem hiding this comment.
I'm a little confused about why this rule is handling the N-ary case len(factors) > 2, which seems to be adding a lot of complexity and muddying the division of labor with the main distributive rule above. I would only ever expect this rule be to invoke the kernel jnp.tensordot on two fully concrete arrays - it seems odd to carry around tail as well.
I also suspect a lot of this is reinventing internal logic inside jnp.einsum. It's probably easier to generate a jnp.einsum primitive call rather than a tensordot call and let JAX handle the reduction to tensordot.
| class ReduceOrderContraction(ObjectInterpretation): | ||
| """Reorder a large product before contraction using an ``opt_einsum`` path. | ||
|
|
||
| Matches ``monoid.reduce(monoid2.plus(f1, ..., fn), streams)`` where |
There was a problem hiding this comment.
What are this rule's implicit preconditions on streams? I don't see them listed or checked anywhere but I don't think it can accept arbitrary dependent streams even if they have Array-valued elements.
Adds an einsum implementation at
handlers.jax.einsum, which produces a weighted term, normalizes it, and jits it. Performance is comparable with thejaxeinsum implementation.Adds optimization rules:
ReduceSumProductContractionreplacesSum.reduce(Product.plus(A, B), streams)with a call tojnp.tensordot.ReduceOrderContractionusesopt_einsumto choose a contraction ordering, producing pairwise reductions.ArrayReducenow has a fast path for reduction overarangestreams that produces slices instead of gathers.Changes to existing operations:
bind_dimsnow introduces missing dimensions sobind_dims(unbind_dims(t, x), x, y)reduces to a tensor with two leading dimensions instead of failing to reduce.Benchmarks: