Skip to content

Replace recursive-Node ScalarExpressionFunction with isbits flat form#310

Merged
cmhamel merged 2 commits into
mainfrom
feature/flat-scalar-expression-function
Jun 9, 2026
Merged

Replace recursive-Node ScalarExpressionFunction with isbits flat form#310
cmhamel merged 2 commits into
mainfrom
feature/flat-scalar-expression-function

Conversation

@lxmota

@lxmota lxmota commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Summary

Supersedes #309. Three-commit PR:

  1. Rewrites ScalarExpressionFunction's storage from the recursive Node{T, DEFAULT_MAX_DEGREE} form introduced by Symbolic differentiation for juliac-safe dynamic Dirichlet BCs #309 to a fixed-size NTuple of FlatNode records, making the function itself isbits.
  2. Routes ScalarExpressionFunction through symbolic time derivatives in DirichletBCFunction so the untyped DirichletBCs(mesh, dof, bcs) constructor — the one create_parameters calls — produces fully isbits {F, F, F} BC funcs automatically, without callers needing to thread an F type parameter.
  3. Adds a typed Base.size(::AbstractContinuousField, ::Int)::Int overload so the trim verifier can resolve the call site Carina uses to report node counts in the setup log.

The isbits property dissolves the GPU↔juliac conflict that #309's recursive-Node storage would have created downstream in Carina:

  • BC evaluation in FEC.update_bc_values! runs in-place on the GPU with no host↔device sync per time step (the function passes through KernelAbstractions kernels as a value argument).
  • The function also survives juliac --trim — the storage is pure data, no closures, no @eval, no RuntimeGeneratedFunctions.

The symbolic differentiation logic from #309 carries over. It still operates on the recursive Node{T, D} form (cleanest for constant folding through tree rewriting); the boundary with the persistent flat storage is _unflatten / _flatten, called once per differentiate invocation at TOML-parse time.

What this change deletes (from main, ex-#309)

  • struct ScalarExpressionFunction{T} <: AbstractExpressionFunction{T, Node{T, DEFAULT_MAX_DEGREE}, ntuple_type} — the recursive-Node-backed struct.
  • Its four call methods (the new struct preserves the same signatures).
  • differentiate(f, var_name::String) — the string-keyed public API (replaced by the integer form; a string convenience overload remains, but requires an explicit var_names list since the function no longer stores it).

What this change adds

Commit 1 — Replace recursive-Node ScalarExpressionFunction with isbits flat form (7c79988)

  • src/Expressions.jl

    • FlatNode{T} struct (1-based indices, isbits) and FEC_EXPR_MAX_NODES = 256 cap. Carina's biggest second-derivative tree (Gaussian-pulse BC) is ~100 nodes; 256 gives 2-3× headroom at ~6 KB per function.
    • _flatten(::Node{T,D}) (preorder DFS) and _unflatten inverse.
    • _apply_unary_op / _apply_binary_op open-coded op-code dispatch and _eval_node recursive walker over the flat tuple.
    • Replacement ScalarExpressionFunction storing the flat NTuple; all existing call signatures preserved ((var::T), (vars::AbstractVector{T}), (X::SVector{ND, T}), (X::SVector{ND, T}, t::T)) so downstream consumers in FEC and Carina need no changes.
  • src/bcs/DirichletBCs.jl — juliac-safe DirichletBCs{F} constructor switches from Expressions.differentiate(bc.func, "t") to Expressions.differentiate(bc.func, Int(bc.func.num_vars)), following the convention that the last variable in the user's expression is time.

  • test/TestExpressions.jl — rewrites Symbolic differentiation for juliac-safe dynamic Dirichlet BCs #309's differentiate tests onto the integer-index API; adds isbits property check; keeps the parser precedence test from Symbolic differentiation for juliac-safe dynamic Dirichlet BCs #309.

  • test/TestBCs.jl — replaces Symbolic differentiation for juliac-safe dynamic Dirichlet BCs #309's juliac-safe DirichletBCs tests with flat-form equivalents.

Diff: +320 / -138 lines across 4 files (large minus block = the recursive-Node infrastructure being removed).

Commit 2 — Auto-route ScalarExpressionFunction through symbolic time derivatives (727fc7c)

  • src/bcs/DirichletBCs.jl — new DirichletBCFunction(func::ScalarExpressionFunction) specialization that takes time derivatives via Expressions.differentiate (not ForwardDiff), producing an all-isbits DirichletBCFunction{F, F, F} from a single positional arg. Three motivations:
    1. The legacy DirichletBCFunction(func::F) where F <: Function constructor builds closures around ForwardDiff.derivative, which the SEF call signature ((X::SVector{ND, T}, t::T) where T <: Number) cannot consume — ForwardDiff.Dual{Tag, Float64, 1} is not Float64. Without this overload, any caller passing a ScalarExpressionFunction through create_parameters would have crashed at first BC evaluation.
    2. With this overload, the untyped DirichletBCs(mesh, dof, bcs) container constructor — which is what create_parameters already routes through — gives callers the juliac-safe symbolic path automatically. Callers no longer need to opt in via DirichletBCs{F}(mesh, dof, bcs).
    3. The resulting DirichletBCFunction{F, F, F} is isbits and identically typed to what the typed constructor produces, so GPU dispatch is unchanged.

Diff: +16 / -0 lines, all in DirichletBCs.jl.

Commit 3 — Add typed size(::AbstractContinuousField, ::Int)::Int overload (4042615)

  • src/Fields.jl — direct two-arg method with concrete Int(NF) arithmetic in the body and ::Int return type. Surfaced by a juliac --trim dry-run of Carina: the generic Base.size(A, d)::Int fallback runs size(A)[d] and infers the result as ::Any whenever the field's NF type parameter isn't fully bound at the call site. Carina's setup log emits size(field, 2) to report node counts, so the abstract-::Any flowed through several downstream lines as unresolved calls. The two-arg overload keeps inference exact regardless of how the field's type is bound upstream.

Diff: +12 / -1 lines, all in Fields.jl.

Test plan

  • FEC suite: 18187 / 18187 (commit 1 matches Symbolic differentiation for juliac-safe dynamic Dirichlet BCs #309's count — same test items, integer-index API; commits 2 and 3 added no new tests but their behaviour is exercised by the existing juliac-safe DirichletBCs{F} test and any user of size(::AbstractContinuousField, ::Int)).
  • Carina suite against this branch via local path dep: 188 / 188.
  • isbitstype(ScalarExpressionFunction{Float64}) confirmed at REPL.
  • isbitstype(DirichletBCFunction(::ScalarExpressionFunction)) confirmed — wholly isbits, no captured closures.
  • Gaussian-pulse 1st and 2nd time derivatives match analytical to machine precision at multiple sample points.

Migration note

The only breaking API change vs #309 is differentiate(f, var_name::String)differentiate(f, var_idx::Integer). A string-name convenience overload differentiate(f, var_names, var_name) remains for the common parse-time case where the caller has the var-name list in scope. Internal users (FEC's DirichletBCs juliac path) are already updated.

The var_names list is no longer stored on the ScalarExpressionFunction (it's a Vector{String} — not isbits — and would have broken the GPU kernel argument story). Callers that need to look up names later should keep their own copy.

The SEF overload added in commit 2 and the size overload added in commit 3 are purely additive — DirichletBCFunction(::Function) continues to work unchanged for any non-SEF callable, and size(::AbstractContinuousField) (one-arg) is untouched.

🤖 Generated with Claude Code

@codecov

codecov Bot commented Jun 8, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 91.45299% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 68.73%. Comparing base (0ddce93) to head (4042615).

Files with missing lines Patch % Lines
src/bcs/DirichletBCs.jl 37.50% 5 Missing ⚠️
src/Expressions.jl 96.15% 4 Missing ⚠️
src/Fields.jl 80.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #310      +/-   ##
==========================================
- Coverage   71.59%   68.73%   -2.87%     
==========================================
  Files          56       56              
  Lines        5112     5204      +92     
==========================================
- Hits         3660     3577      -83     
- Misses       1452     1627     +175     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@lxmota lxmota force-pushed the feature/flat-scalar-expression-function branch from b6cebaa to 413ad38 Compare June 8, 2026 04:40
@lxmota lxmota requested a review from cmhamel June 8, 2026 04:41
@lxmota lxmota force-pushed the feature/flat-scalar-expression-function branch from 413ad38 to 31b8253 Compare June 8, 2026 04:45
Supersedes #309.  Rewrites ScalarExpressionFunction's storage from the
recursive Node{T, DEFAULT_MAX_DEGREE} tree introduced by #309 to a
fixed-size NTuple of FlatNode records, making the function itself isbits.
The isbits property dissolves the GPU↔juliac conflict that #309's
recursive-Node storage would have created downstream in Carina: the
function now passes through KernelAbstractions kernels as a value
argument (so BC evaluation runs in-place on the GPU with no host↔device
sync per time step), AND survives `juliac --trim` since the storage is
pure data — no closures, no `@eval`, no `RuntimeGeneratedFunctions`.

The symbolic differentiation logic from #309 carries over.  It still
operates on a recursive Node{T, D} form (cleanest for constant folding
through tree rewriting); the boundary with the persistent flat storage
is `_unflatten` / `_flatten`, called once per `differentiate` invocation
at TOML-parse time.

Diff is +320 / -138: the deletions are #309's recursive-Node
infrastructure (the old struct definition, four call methods, and the
`differentiate(f, var_name::String)` public API).

Changes:
- src/Expressions.jl:
  - Remove #309's recursive ScalarExpressionFunction (storage was
    `Expression{T, Node{T, DEFAULT_MAX_DEGREE}, ntuple_type}`) and its
    string-keyed differentiate.
  - Add `FlatNode{T}` (1-based indexed, isbits) and FEC_EXPR_MAX_NODES=256
    cap (Carina's biggest second-derivative tree is ~100 nodes; 256
    gives 2-3× headroom at ~6 KB per function).
  - Add `_flatten(::Node{T,D})` (preorder DFS) and `_unflatten` inverse.
  - Add `_apply_unary_op` / `_apply_binary_op` op-code dispatch and
    `_eval_node` recursive walker over the flat tuple.
  - Replacement ScalarExpressionFunction stores the flat NTuple; all
    existing call signatures preserved (`(var::T)`,
    `(vars::AbstractVector{T})`, `(X::SVector{ND, T})`,
    `(X::SVector{ND, T}, t::T)`) so downstream consumers
    (DirichletBCs, InitialConditions, AppTools) need no changes.
  - Replacement `differentiate(f, var_idx::Integer)` with a convenience
    overload `differentiate(f, var_names, var_name)`.  The integer form
    is needed because var_names is `Vector{String}` — not isbits — and
    can't be stored on the function.

- src/bcs/DirichletBCs.jl: juliac-safe `DirichletBCs{F}` constructor
  switches from #309's `differentiate(bc.func, "t")` to
  `Expressions.differentiate(bc.func, Int(bc.func.num_vars))`,
  following the convention that the last variable in the user's
  expression is time.

- test/TestExpressions.jl: rewrite #309's differentiate tests onto the
  integer-index API; add isbits property check; keep the parser
  precedence test from #309.

- test/TestBCs.jl: replace #309's juliac-safe DirichletBCs tests with
  flat-form equivalents.

Full suite: 18187/18187 (no net assertion count change vs #309 — the
test items are equivalent, only the API form differs).
Carina suite: 158/158 against this FEC branch via local path dep.

Migration note for downstream: the only breaking change is
`differentiate(f, var_name::String)` → `differentiate(f, var_idx::Integer)`,
with a string-name convenience overload requiring an explicit var_names
list.  Internal users (FEC's DirichletBCs) already updated.
@lxmota lxmota force-pushed the feature/flat-scalar-expression-function branch from 31b8253 to 7c79988 Compare June 8, 2026 04:48
@lxmota lxmota changed the title Flat ScalarExpressionFunction (isbits, GPU + juliac safe) Replace recursive-Node ScalarExpressionFunction with isbits flat form Jun 8, 2026
Add a `DirichletBCFunction(func::ScalarExpressionFunction)` specialization
that takes time derivatives symbolically via `Expressions.differentiate`,
producing an all-`isbits` `DirichletBCFunction{F, F, F}` rather than the
ForwardDiff closures the generic `<: Function` constructor builds.  This
lets the untyped `DirichletBCs(mesh, dof, bcs)` container constructor
called from `create_parameters` give callers the juliac-safe symbolic
path "for free" — they no longer need to thread an `F` type parameter or
call `DirichletBCs{F}(...)` explicitly to opt in.

Carina passes `ScalarExpressionFunction` BC funcs through
`create_parameters`, which routed straight to the ForwardDiff branch;
ForwardDiff's `Dual` numbers do not satisfy the `T <: Number` constraint
on `ScalarExpressionFunction{T}`'s call sites, so the closures would have
crashed at first evaluation.  With this specialization the same flow
yields symbolic derivatives and an `isbits` result.

@cmhamel cmhamel left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed a stale dependency but otherwise looks great to me!

@cmhamel cmhamel force-pushed the feature/flat-scalar-expression-function branch from 75f7ca2 to 727fc7c Compare June 9, 2026 07:20
@cmhamel cmhamel merged commit fa2eb78 into main Jun 9, 2026
16 of 24 checks passed
@cmhamel cmhamel deleted the feature/flat-scalar-expression-function branch June 9, 2026 07:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants