Skip to content

Symbolic differentiation for juliac-safe dynamic Dirichlet BCs#309

Merged
cmhamel merged 1 commit into
mainfrom
feature/symbolic-differentiator
Jun 8, 2026
Merged

Symbolic differentiation for juliac-safe dynamic Dirichlet BCs#309
cmhamel merged 1 commit into
mainfrom
feature/symbolic-differentiator

Conversation

@lxmota

@lxmota lxmota commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Summary

The DirichletBCs{F} juliac-safe constructor previously hard-coded func_dot and func_dot_dot to a literal "0.0" expression — silently zeroing prescribed velocity and acceleration for any time-dependent BC under juliac --trim. The default Julia path uses ForwardDiff (doesn't survive juliac), and Symbolics.jl is too heavy to pull in just for this.

This PR adds a hand-rolled recursive tree-rewrite differentiator over FEC's own closed expression grammar (10 unary + 5 binary operators). No third-party AD, no closures — pure data, juliac-clean.

  • _differentiate(::Node{T, D}, var_idx) — recursive tree rewrite producing the symbolic ∂/∂x_{var_idx}. Trivial constant-folding (0+x=x, 0·x=0, 1·x=x) keeps derivative trees proportional in size.
  • differentiate(::ScalarExpressionFunction, var_name) — public API returning another ScalarExpressionFunction.
  • Second inner constructor on ScalarExpressionFunction accepting a prebuilt Expression so differentiate can wrap its tree without re-parsing.
  • Parser fix: lowered unary-minus right-binding-power from 100 to 25. Today -t^2 parses as (-t)^2; standard math precedence is -(t^2). The Gaussian-pulse form exp(-t^2/(2τ^2)) does not evaluate correctly otherwise.
  • DirichletBCs{F} juliac path now calls Expressions.differentiate(bc.func, "t") twice at construction. The Julia closure path (ForwardDiff) is unchanged.
  • Dropped the broken 2D-only ["x", "y", "t"] var_names hardcoding in the juliac path. The user's ScalarExpressionFunction carries its own var_names through expr.metadata, so the constructor no longer needs to invent one — 3D problems now work.

The same differentiate is the building block downstream Carina needs for traveling-wave IC mode (v₀ = ±c·∂u₀/∂s, a₀ = c²·∂²u₀/∂s²).

@codecov

codecov Bot commented Jun 8, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 96.33028% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 71.59%. Comparing base (4eb7192) to head (74781f7).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
src/Expressions.jl 96.22% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #309      +/-   ##
==========================================
+ Coverage   65.77%   71.59%   +5.82%     
==========================================
  Files          56       56              
  Lines        4894     5112     +218     
==========================================
+ Hits         3219     3660     +441     
+ Misses       1675     1452     -223     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

The `DirichletBCs{F}` juliac-safe constructor previously hard-coded
`func_dot` and `func_dot_dot` to a literal `"0.0"` expression and
documented "for now will only work for static" — silently zeroing the
prescribed velocity and acceleration for any time-dependent BC under
`juliac --trim`. The default Julia path uses ForwardDiff, which doesn't
survive juliac. Pulling in Symbolics.jl just to round-trip a derivative
also doesn't.

This change adds a hand-rolled recursive tree-rewrite differentiator over
FEC's own closed expression grammar (10 unary + 5 binary operators),
plus minor parser fixes:

- `src/Expressions.jl`:
  - `_differentiate(::Node{T, D}, var_idx)` — pure tree rewrite producing
    the symbolic ∂/∂x_{var_idx}. Trivial constant-folding (`0+x=x`,
    `0·x=0`, `1·x=x`) keeps derivative trees proportional in size.
  - `differentiate(::ScalarExpressionFunction, var_name)` — public API
    returning another ScalarExpressionFunction.
  - Second inner constructor on ScalarExpressionFunction accepting a
    prebuilt Expression (used to wrap the result of differentiate).
  - Parser fix: lower unary-minus right-binding-power from 100 to 25.
    Today `-t^2` parses as `(-t)^2`; standard math precedence is
    `-(t^2)`. The Gaussian-pulse form `exp(-t^2/(2τ^2))` doesn't work
    correctly otherwise.

- `src/bcs/DirichletBCs.jl`: replace the hard-coded `zero_func` with
  `Expressions.differentiate(bc.func, "t")` for the first derivative and
  one more call on that for the second. Drops the broken 2D-only
  `["x", "y", "t"]` var_names hardcoding (the user's expression already
  carries its own var_names through metadata, so the constructor no
  longer needs to invent one). The Julia closure path (ForwardDiff)
  stays unchanged.

- `test/TestExpressions.jl`: 5 new test items exercising the parser
  precedence fix, each unary/binary op's derivative, Gaussian-pulse first
  and second time derivatives, spatial derivatives for traveling-wave
  ICs, and error on unknown variable.

- `test/TestBCs.jl`: 2 new test items exercising
  `DirichletBCs{ScalarExpressionFunction}` with `2 t²` and a Gaussian
  pulse, validating that vals_dot and vals_dot_dot match the analytical
  derivatives at multiple time samples.

Full suite: 18183/18183 (+98 new assertions).
@lxmota lxmota force-pushed the feature/symbolic-differentiator branch from 81686ae to 74781f7 Compare June 8, 2026 01:48
@lxmota lxmota requested a review from cmhamel June 8, 2026 01:49

@cmhamel cmhamel left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @lxmota . Glad to see someone else is getting some use out of the juliac --trim stuff. The errors can be painful, but the payoffs are huge.

I haven't been brave enough to try any of the GPU stack.

I know that MPI.jl has a few things that need to be fixed to get a parallel application that's trimmable.

@cmhamel cmhamel merged commit 0ddce93 into main Jun 8, 2026
10 of 13 checks passed
@cmhamel cmhamel deleted the feature/symbolic-differentiator branch June 8, 2026 02:19
lxmota added a commit that referenced this pull request Jun 8, 2026
Replace the recursive-Node storage of ScalarExpressionFunction with a
fixed-size NTuple of FlatNode records, making the function itself isbits.
The isbits property dissolves the GPU↔juliac conflict that the prior
juliac-safe path created: the function passes through KernelAbstractions
kernels as a value argument (works on CPU, CUDA, and ROCm backends), and
also survives `juliac --trim` since the storage is pure data with no
closures, no `@eval`, and no RuntimeGeneratedFunctions.

This supersedes the symbolic-differentiator approach in PR #309 (recursive
Node storage), keeping the differentiation logic and the parser
precedence fix but moving the persistent representation to the flat form.
The symbolic differentiator operates on the recursive Node{T, D} for
clean expression rewriting (constant folding composes naturally); the
boundary with the flat storage is `_unflatten` / `_flatten`, called once
per `differentiate` invocation at TOML-parse time.

Changes:
- src/Expressions.jl:
  - Add `FlatNode{T}` (1-based indexed, isbits) and FEC_EXPR_MAX_NODES=256
    cap (Carina's biggest second-derivative tree is ~100 nodes; 256 gives
    2-3× headroom at ~6 KB per function).
  - Add `_flatten(::Node{T,D})` (preorder DFS) and `_unflatten` inverse.
  - Add `_apply_unary_op` / `_apply_binary_op` op-code dispatch and
    `_eval_node` recursive walker over the flat tuple.
  - Rewrite ScalarExpressionFunction to store the flat NTuple; existing
    call signatures `(var::T)`, `(vars::AbstractVector{T})`,
    `(X::SVector{ND, T})`, `(X::SVector{ND, T}, t::T)` all preserved so
    downstream consumers (DirichletBCs, InitialConditions, AppTools) need
    no changes.
  - Add `_differentiate(::Node, var_idx)` recursive symbolic rewriter
    covering the 15 supported ops with trivial constant folding.
  - Add public `differentiate(f, var_idx::Integer)` and convenience
    `differentiate(f, var_names, var_name)` overload.
  - Parser fix: lower unary-minus right-binding-power from 100 to 25 so
    `-t^2` parses as `-(t^2)` per standard math precedence.

- src/bcs/DirichletBCs.jl: juliac-safe `DirichletBCs{F}` constructor now
  uses `Expressions.differentiate(bc.func, Int(bc.func.num_vars))` for
  func_dot and one more call for func_dot_dot.  Convention: time is the
  last variable in the user's expression.  Drops the broken 2D-only
  `["x", "y", "t"]` hardcoding.

- test/TestExpressions.jl: 8 new test items covering isbits property,
  parser precedence, each unary/binary op's derivative, Gaussian-pulse
  first and second time derivatives, traveling-wave spatial derivative,
  and the var-name overload + error path.

- test/TestBCs.jl: 2 new test items walking
  `DirichletBCs{ScalarExpressionFunction}` end to end with `2t²` and a
  Gaussian pulse, validating vals/vals_dot/vals_dot_dot.

Full suite: 18187/18187 (+102 new assertions over main).
Carina suite: 158/158 against this FEC branch.

# Conflicts:
#	src/Expressions.jl
#	src/bcs/DirichletBCs.jl
#	test/TestBCs.jl
#	test/TestExpressions.jl
lxmota added a commit that referenced this pull request Jun 8, 2026
Supersedes #309.  Rewrites ScalarExpressionFunction's storage from the
recursive Node{T, DEFAULT_MAX_DEGREE} tree introduced by #309 to a
fixed-size NTuple of FlatNode records, making the function itself isbits.
The isbits property dissolves the GPU↔juliac conflict that #309's
recursive-Node storage would have created downstream in Carina: the
function now passes through KernelAbstractions kernels as a value
argument (so BC evaluation runs in-place on the GPU with no host↔device
sync per time step), AND survives `juliac --trim` since the storage is
pure data — no closures, no `@eval`, no `RuntimeGeneratedFunctions`.

The symbolic differentiation logic from #309 carries over.  It still
operates on a recursive Node{T, D} form (cleanest for constant folding
through tree rewriting); the boundary with the persistent flat storage
is `_unflatten` / `_flatten`, called once per `differentiate` invocation
at TOML-parse time.

Diff is +320 / -138: the deletions are #309's recursive-Node
infrastructure (the old struct definition, four call methods, and the
`differentiate(f, var_name::String)` public API).

Changes:
- src/Expressions.jl:
  - Remove #309's recursive ScalarExpressionFunction (storage was
    `Expression{T, Node{T, DEFAULT_MAX_DEGREE}, ntuple_type}`) and its
    string-keyed differentiate.
  - Add `FlatNode{T}` (1-based indexed, isbits) and FEC_EXPR_MAX_NODES=256
    cap (Carina's biggest second-derivative tree is ~100 nodes; 256
    gives 2-3× headroom at ~6 KB per function).
  - Add `_flatten(::Node{T,D})` (preorder DFS) and `_unflatten` inverse.
  - Add `_apply_unary_op` / `_apply_binary_op` op-code dispatch and
    `_eval_node` recursive walker over the flat tuple.
  - Replacement ScalarExpressionFunction stores the flat NTuple; all
    existing call signatures preserved (`(var::T)`,
    `(vars::AbstractVector{T})`, `(X::SVector{ND, T})`,
    `(X::SVector{ND, T}, t::T)`) so downstream consumers
    (DirichletBCs, InitialConditions, AppTools) need no changes.
  - Replacement `differentiate(f, var_idx::Integer)` with a convenience
    overload `differentiate(f, var_names, var_name)`.  The integer form
    is needed because var_names is `Vector{String}` — not isbits — and
    can't be stored on the function.

- src/bcs/DirichletBCs.jl: juliac-safe `DirichletBCs{F}` constructor
  switches from #309's `differentiate(bc.func, "t")` to
  `Expressions.differentiate(bc.func, Int(bc.func.num_vars))`,
  following the convention that the last variable in the user's
  expression is time.

- test/TestExpressions.jl: rewrite #309's differentiate tests onto the
  integer-index API; add isbits property check; keep the parser
  precedence test from #309.

- test/TestBCs.jl: replace #309's juliac-safe DirichletBCs tests with
  flat-form equivalents.

Full suite: 18187/18187 (no net assertion count change vs #309 — the
test items are equivalent, only the API form differs).
Carina suite: 158/158 against this FEC branch via local path dep.

Migration note for downstream: the only breaking change is
`differentiate(f, var_name::String)` → `differentiate(f, var_idx::Integer)`,
with a string-name convenience overload requiring an explicit var_names
list.  Internal users (FEC's DirichletBCs) already updated.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
lxmota added a commit that referenced this pull request Jun 8, 2026
Supersedes #309.  Rewrites ScalarExpressionFunction's storage from the
recursive Node{T, DEFAULT_MAX_DEGREE} tree introduced by #309 to a
fixed-size NTuple of FlatNode records, making the function itself isbits.
The isbits property dissolves the GPU↔juliac conflict that #309's
recursive-Node storage would have created downstream in Carina: the
function now passes through KernelAbstractions kernels as a value
argument (so BC evaluation runs in-place on the GPU with no host↔device
sync per time step), AND survives `juliac --trim` since the storage is
pure data — no closures, no `@eval`, no `RuntimeGeneratedFunctions`.

The symbolic differentiation logic from #309 carries over.  It still
operates on a recursive Node{T, D} form (cleanest for constant folding
through tree rewriting); the boundary with the persistent flat storage
is `_unflatten` / `_flatten`, called once per `differentiate` invocation
at TOML-parse time.

Diff is +320 / -138: the deletions are #309's recursive-Node
infrastructure (the old struct definition, four call methods, and the
`differentiate(f, var_name::String)` public API).

Changes:
- src/Expressions.jl:
  - Remove #309's recursive ScalarExpressionFunction (storage was
    `Expression{T, Node{T, DEFAULT_MAX_DEGREE}, ntuple_type}`) and its
    string-keyed differentiate.
  - Add `FlatNode{T}` (1-based indexed, isbits) and FEC_EXPR_MAX_NODES=256
    cap (Carina's biggest second-derivative tree is ~100 nodes; 256
    gives 2-3× headroom at ~6 KB per function).
  - Add `_flatten(::Node{T,D})` (preorder DFS) and `_unflatten` inverse.
  - Add `_apply_unary_op` / `_apply_binary_op` op-code dispatch and
    `_eval_node` recursive walker over the flat tuple.
  - Replacement ScalarExpressionFunction stores the flat NTuple; all
    existing call signatures preserved (`(var::T)`,
    `(vars::AbstractVector{T})`, `(X::SVector{ND, T})`,
    `(X::SVector{ND, T}, t::T)`) so downstream consumers
    (DirichletBCs, InitialConditions, AppTools) need no changes.
  - Replacement `differentiate(f, var_idx::Integer)` with a convenience
    overload `differentiate(f, var_names, var_name)`.  The integer form
    is needed because var_names is `Vector{String}` — not isbits — and
    can't be stored on the function.

- src/bcs/DirichletBCs.jl: juliac-safe `DirichletBCs{F}` constructor
  switches from #309's `differentiate(bc.func, "t")` to
  `Expressions.differentiate(bc.func, Int(bc.func.num_vars))`,
  following the convention that the last variable in the user's
  expression is time.

- test/TestExpressions.jl: rewrite #309's differentiate tests onto the
  integer-index API; add isbits property check; keep the parser
  precedence test from #309.

- test/TestBCs.jl: replace #309's juliac-safe DirichletBCs tests with
  flat-form equivalents.

Full suite: 18187/18187 (no net assertion count change vs #309 — the
test items are equivalent, only the API form differs).
Carina suite: 158/158 against this FEC branch via local path dep.

Migration note for downstream: the only breaking change is
`differentiate(f, var_name::String)` → `differentiate(f, var_idx::Integer)`,
with a string-name convenience overload requiring an explicit var_names
list.  Internal users (FEC's DirichletBCs) already updated.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants