Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -392,15 +392,25 @@ $(SIGNATURES)

Helper function for transforming tuples. Used internally, to help type inference. Use via
`transfom_tuple`.
"""
_transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ::Tuple{}) =
(), logjac_zero(flag, _ensure_float(eltype(x))), index

function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index, ts)
tfirst = first(ts)
yfirst, ℓfirst, index′ = transform_with(flag, tfirst, x, index)
yrest, ℓrest, index′′ = _transform_tuple(flag, x, index′, Base.tail(ts))
(yfirst, yrest...), ℓfirst + ℓrest, index′′
Implemented as a `@generated` straight-line unroll over the static tuple length.
Equivalent to the natural `Base.tail` recursion, but emits non-recursive code
so that `Enzyme.autodiff` does not hit `AssertionError("conv == 37")` on
tuples of length ≥ 33 (EnzymeAD/Enzyme.jl#3104).
Comment on lines +396 to +399

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very internal for being part of a docstrings? It also might change again in case of upstream compiler or Enzyme changes.

Suggested change
Implemented as a `@generated` straight-line unroll over the static tuple length.
Equivalent to the natural `Base.tail` recursion, but emits non-recursive code
so that `Enzyme.autodiff` does not hit `AssertionError("conv == 37")` on
tuples of length 33 (EnzymeAD/Enzyme.jl#3104).

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an internal helper function anyway and not part of the API. (I like to document my internal functions too, I know this is not common to do so). As far as I am concerned this is fine.

"""
@generated function _transform_tuple(flag::LogJacFlag, x::AbstractVector, index,
ts::Tuple{Vararg{AbstractTransform,N}}) where {N}
N == 0 && return :(((), logjac_zero(flag, _ensure_float(eltype(x))), index))
ys = [Symbol(:y_, i) for i in 1:N]
ℓs = [Symbol(:ℓ_, i) for i in 1:N]
calls = [:(($(ys[i]), $(ℓs[i]), idx) = transform_with(flag, ts[$i], x, idx))
for i in 1:N]
ℓ_sum = foldl((a, b) -> :($a + $b), ℓs)
return quote
idx = index

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a separate idx variable needed? Couldn't we just operate with index?

$(calls...)
(($(ys...),), $ℓ_sum, idx)
end
end

"""
Expand Down
Loading