diff --git a/src/aggregation.jl b/src/aggregation.jl index d42f57a..5dfbce6 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -392,15 +392,25 @@ $(SIGNATURES) Helper function for transforming tuples. Used internally, to help type inference. Use via `transfom_tuple`. -""" -_transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ::Tuple{}) = - (), logjac_zero(flag, _ensure_float(eltype(x))), index -function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ts) - tfirst = first(ts) - yfirst, ℓfirst, index′ = transform_with(flag, tfirst, x, index) - yrest, ℓrest, index′′ = _transform_tuple(flag, x, index′, Base.tail(ts)) - (yfirst, yrest...), ℓfirst + ℓrest, index′′ +Implemented as a `@generated` straight-line unroll over the static tuple length. +Equivalent to the natural `Base.tail` recursion, but emits non-recursive code +so that `Enzyme.autodiff` does not hit `AssertionError("conv == 37")` on +tuples of length ≥ 33 (EnzymeAD/Enzyme.jl#3104). +""" +@generated function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, + ts::Tuple{Vararg{AbstractTransform,N}}) where {N} + N == 0 && return :(((), logjac_zero(flag, _ensure_float(eltype(x))), index)) + ys = [Symbol(:y_, i) for i in 1:N] + ℓs = [Symbol(:ℓ_, i) for i in 1:N] + calls = [:(($(ys[i]), $(ℓs[i]), idx) = transform_with(flag, ts[$i], x, idx)) + for i in 1:N] + ℓ_sum = foldl((a, b) -> :($a + $b), ℓs) + return quote + idx = index + $(calls...) + (($(ys...),), $ℓ_sum, idx) + end end """