Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
311 changes: 247 additions & 64 deletions src/Expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -370,66 +370,251 @@ $(TYPEDEF)
"""
abstract type AbstractExpressionFunction{T, N, D} <: Function end

########################################################
# Flat expression-tree representation
#
# `ScalarExpressionFunction` stores its expression tree as a fixed-size
# `NTuple` of `FlatNode` records — a struct of plain integer fields and a
# value of type `T`. The whole thing is `isbits`, which means:
#
# * It passes through KernelAbstractions kernels as a value argument and
# can be called on the GPU directly — no host↔device round-trip per
# time step for BC evaluation.
# * It survives `juliac --trim`: no closures, no `@eval`, no
# `RuntimeGeneratedFunctions`, just data.
#
# `FEC_EXPR_MAX_NODES` is the maximum tree size. Carina's largest
# inlined expression today is ~25 nodes, but second-derivative trees from
# the symbolic differentiator can grow to ~100 nodes for Gaussian-pulse
# BCs; 256 gives 2-3× headroom at ~6 KB per function — negligible memory
# for typical BC/IC counts. Trees that exceed the cap raise at
# construction time.
########################################################

const FEC_EXPR_MAX_NODES = 256

"""
$(TYPEDEF)
$(TYPEDFIELDS)

One node of a flattened expression tree. Children are referenced by
1-based index into the parent array; index 0 marks "no child". Leaves
carry either a constant `val` or a `feature` (1-based variable index).
"""
struct FlatNode{T <: Number}
degree::UInt8 # 0 = leaf, 1 = unary, 2 = binary
op::UInt8 # FUNC_* or BINARY_* op code; 0 for leaves
constant::Bool # leaf form: true ⇒ use `val`, false ⇒ use `feature`
val::T # leaf value when `constant`
feature::UInt16 # leaf variable index (1-based) when !`constant`
l_idx::UInt16 # index of left child (0 if leaf)
r_idx::UInt16 # index of right child (0 if leaf or unary)
end

@inline FlatNode{T}() where T <: Number =
FlatNode{T}(UInt8(0), UInt8(0), false, zero(T), UInt16(0), UInt16(0), UInt16(0))

# Flatten a recursive `Node{T, D}` (produced by FEC's Pratt parser) into a
# preorder NTuple of `FlatNode{T}` records. The first entry is the root;
# children follow in DFS order.
function _flatten_visit!(buf::Vector{FlatNode{T}}, node::Node{T, D})::UInt16 where {T, D}
if node.degree == 0
if node.constant
push!(buf, FlatNode{T}(UInt8(0), UInt8(0), true,
T(node.val), UInt16(0),
UInt16(0), UInt16(0)))
else
push!(buf, FlatNode{T}(UInt8(0), UInt8(0), false,
zero(T), UInt16(node.feature),
UInt16(0), UInt16(0)))
end
return UInt16(length(buf))
elseif node.degree == 1
push!(buf, FlatNode{T}()) # reserve own slot
my_idx = length(buf)
l = _flatten_visit!(buf, node.l)
buf[my_idx] = FlatNode{T}(UInt8(1), UInt8(node.op), false,
zero(T), UInt16(0), l, UInt16(0))
return UInt16(my_idx)
else # degree == 2
push!(buf, FlatNode{T}()) # reserve own slot
my_idx = length(buf)
l = _flatten_visit!(buf, node.l)
r = _flatten_visit!(buf, node.r)
buf[my_idx] = FlatNode{T}(UInt8(2), UInt8(node.op), false,
zero(T), UInt16(0), l, r)
return UInt16(my_idx)
end
end

function _flatten(root::Node{T, D}) where {T, D}
buf = FlatNode{T}[]
sizehint!(buf, FEC_EXPR_MAX_NODES)
_flatten_visit!(buf, root)
n_active = length(buf)
n_active <= FEC_EXPR_MAX_NODES || error(
"expression too large: $n_active nodes (max $FEC_EXPR_MAX_NODES)"
)
# Pad to fixed length with default nodes so the resulting NTuple type
# has a constant size at the type level.
while length(buf) < FEC_EXPR_MAX_NODES
push!(buf, FlatNode{T}())
end
nodes = NTuple{FEC_EXPR_MAX_NODES, FlatNode{T}}(buf)
return nodes, UInt16(n_active)
end

# Inverse used by `differentiate`: rebuild a recursive `Node{T, D}` from a
# flat NTuple so the existing recursive symbolic differentiator can run on
# it. Called once per `differentiate` invocation, off the GPU.
function _unflatten(nodes::NTuple{N, FlatNode{T}}, idx::Integer = 1) where {N, T}
n = nodes[idx]
if n.degree == 0
if n.constant
return Node{T, DEFAULT_MAX_DEGREE}(; val = n.val)
else
return Node{T, DEFAULT_MAX_DEGREE}(; feature = Int(n.feature))
end
elseif n.degree == 1
l = _unflatten(nodes, n.l_idx)
return Node{T, DEFAULT_MAX_DEGREE}(; op = Int(n.op), l = l)
else
l = _unflatten(nodes, n.l_idx)
r = _unflatten(nodes, n.r_idx)
return Node{T, DEFAULT_MAX_DEGREE}(; op = Int(n.op), l = l, r = r)
end
end

# Op-code dispatch — open-coded if/elseif so the compiler can inline
# branches inside KA kernels without resorting to a function table.
@inline function _apply_unary_op(::Type{T}, op::UInt8, u::T) where T <: Number
if op == UInt8(FUNC_MINUS); return -u
elseif op == UInt8(FUNC_COS); return cos(u)
elseif op == UInt8(FUNC_COSH); return cosh(u)
elseif op == UInt8(FUNC_EXP); return exp(u)
elseif op == UInt8(FUNC_LOG); return log(u)
elseif op == UInt8(FUNC_SIN); return sin(u)
elseif op == UInt8(FUNC_SINH); return sinh(u)
elseif op == UInt8(FUNC_SQRT); return sqrt(u)
elseif op == UInt8(FUNC_TAN); return tan(u)
elseif op == UInt8(FUNC_TANH); return tanh(u)
end
return T(NaN)
end

@inline function _apply_binary_op(::Type{T}, op::UInt8, u::T, v::T) where T <: Number
if op == UInt8(BINARY_PLUS); return u + v
elseif op == UInt8(BINARY_MINUS); return u - v
elseif op == UInt8(BINARY_MULTIPLY); return u * v
elseif op == UInt8(BINARY_DIVIDE); return u / v
elseif op == UInt8(BINARY_POWER); return u ^ v
end
return T(NaN)
end

# Recursive evaluator over the flat NTuple. Depth is bounded by the
# expression tree height (≤ ~10 for the expressions Carina uses today),
# so GPUCompiler handles the recursion without stack pressure.
function _eval_node(nodes::NTuple{N, FlatNode{T}}, idx::UInt16,
vars) where {N, T}
n = nodes[idx]
if n.degree == 0
if n.constant
return n.val
else
return T(vars[n.feature])
end
elseif n.degree == 1
u = _eval_node(nodes, n.l_idx, vars)
return _apply_unary_op(T, n.op, u)
else
u = _eval_node(nodes, n.l_idx, vars)
v = _eval_node(nodes, n.r_idx, vars)
return _apply_binary_op(T, n.op, u, v)
end
end

"""
$(TYPEDEF)
$(TYPEDFIELDS)

Scalar expression function as a flat, `isbits` value — usable as a
KernelAbstractions kernel argument and trim-mode safe under `juliac`.
The trailing variable is conventionally time; FEC's juliac-safe
`DirichletBCs` constructor uses `num_vars` as the time-derivative index.
"""
struct ScalarExpressionFunction{T <: Number} <: AbstractExpressionFunction{T, Node{T, DEFAULT_MAX_DEGREE}, ntuple_type}
expr::Expression{T, Node{T, DEFAULT_MAX_DEGREE}, ntuple_type}
num_vars::Int
struct ScalarExpressionFunction{T <: Number} <: AbstractExpressionFunction{T, FlatNode{T}, ntuple_type}
nodes::NTuple{FEC_EXPR_MAX_NODES, FlatNode{T}}
n_active::UInt16
num_vars::UInt8

"""
$(TYPEDSIGNATURES)

Parse `string` as an expression in the variable namespace `var_names`
and store the resulting tree in flat form. `var_names` is consumed by
the parser to bind identifiers to feature indices; it is not retained
on the resulting function.
"""
function ScalarExpressionFunction{T}(string::String, var_names::Vector{String}) where T <: Number
p = Parser{T}(string, var_names)
# params = _find_parameters(p)
_reset!(p)
ast = _parse_statement(p, 0)
expr = Expression(ast; operators, var_names)
new{T}(expr, length(var_names))
nodes, n_active = _flatten(ast)
new{T}(nodes, n_active, UInt8(length(var_names)))
end

"""
$(TYPEDSIGNATURES)

Build a `ScalarExpressionFunction` directly from a prebuilt
`DynamicExpressions.Expression` — used by [`differentiate`](@ref) to wrap
the result of a tree rewrite without round-tripping through the parser.
Build a `ScalarExpressionFunction` from a prebuilt flat NTuple — used
internally by [`differentiate`](@ref) to wrap the result of a tree
rewrite without round-tripping through the parser.
"""
function ScalarExpressionFunction{T}(
expr::Expression{T, Node{T, DEFAULT_MAX_DEGREE}, ntuple_type},
num_vars::Int
nodes::NTuple{FEC_EXPR_MAX_NODES, FlatNode{T}},
n_active::UInt16,
num_vars::Integer
) where T <: Number
new{T}(expr, num_vars)
new{T}(nodes, n_active, UInt8(num_vars))
end
end

Base.eltype(::ScalarExpressionFunction{T}) where T <: Number = T

function (f::ScalarExpressionFunction)(var::T) where T <: Number
# Single scalar
function (f::ScalarExpressionFunction{T})(var::T) where T <: Number
@assert f.num_vars == 1
return f.expr(SMatrix{1, 1, T, 1}(var))[1]
return _eval_node(f.nodes, UInt16(1), SVector{1, T}(var))
end

# fall back function, not that efficient though
function (f::ScalarExpressionFunction)(vars::AbstractVector{T}) where T <: Number
vars = reshape(vars, length(vars), 1)
return f.expr(vars)[1]
# Vector of variable values (no `t` overload)
function (f::ScalarExpressionFunction{T})(vars::AbstractVector{T}) where T <: Number
@assert length(vars) == Int(f.num_vars) "expected $(Int(f.num_vars)) variables, got $(length(vars))"
return _eval_node(f.nodes, UInt16(1), vars)
end

# for ic type funcs
function (f::ScalarExpressionFunction)(X::SVector{ND, T}) where {ND, T <: Number}
@assert f.num_vars == ND "You need $ND variables for this function"
X = SMatrix{ND, 1, T, ND}(X.data)
return f.expr(X)[1]
# IC-style call: ND spatial coords (no time variable in the expression).
function (f::ScalarExpressionFunction{T})(X::SVector{ND, T}) where {ND, T <: Number}
@assert Int(f.num_vars) == ND "You need $ND variables for this function"
return _eval_node(f.nodes, UInt16(1), X)
end

# for bc type funcs
function (f::ScalarExpressionFunction)(X::SVector{ND, T}, t::T) where {ND, T <: Number}
@assert f.num_vars == ND + 1 "You need $(ND + 1) variables for this function"
vars = SMatrix{ND + 1, 1, T, ND + 1}(X..., t)
return f.expr(vars)[1]
# BC-style call: ND spatial coords + scalar time, packed into a stack-
# allocated SVector so the call survives KA kernels.
function (f::ScalarExpressionFunction{T})(X::SVector{ND, T}, t::T) where {ND, T <: Number}
@assert Int(f.num_vars) == ND + 1 "You need $(ND + 1) variables for this function"
if ND == 1
vars = SVector{2, T}(X[1], t)
elseif ND == 2
vars = SVector{3, T}(X[1], X[2], t)
elseif ND == 3
vars = SVector{4, T}(X[1], X[2], X[3], t)
else
# Generic path (very unlikely; covered for completeness). Allocates.
vars = T[X...; t]
end
return _eval_node(f.nodes, UInt16(1), vars)
end

"""
Expand Down Expand Up @@ -471,20 +656,21 @@ function (f::VectorExpressionFunction)(X::SVector{ND, T}, t::T) where {ND, T <:
end

########################################################
# Symbolic differentiation on Node{T, D} trees.
# Symbolic differentiation on the recursive Node form.
#
# The grammar is finite (10 unary + 5 binary operators), so the chain rule
# can be applied by tree rewriting in ~80 lines of pure Julia. This avoids
# pulling in ForwardDiff/Zygote/Symbolics, so the result survives
# `juliac --trim`. All helpers operate on values; no closures are formed.
# The differentiator is a pure tree-rewrite over FEC's closed grammar
# (10 unary + 5 binary ops). It operates on `Node{T, D}` because the
# constant-folding helpers compose recursively; the boundary with
# `ScalarExpressionFunction` is `_unflatten` / `_flatten`, called once
# per `differentiate` invocation.
########################################################

@inline _is_const(n::Node) = n.degree == 0 && n.constant
@inline _is_zero(n::Node) = _is_const(n) && iszero(n.val)
@inline _is_one(n::Node) = _is_const(n) && isone(n.val)

# Smart constructors that fold trivial constants so derivative trees stay
# proportional in size to the input. Each returns a fresh `Node{T, D}`.
# proportional in size to the input.
function _add(a::Node{T, D}, b::Node{T, D}) where {T, D}
_is_zero(a) && return b
_is_zero(b) && return a
Expand Down Expand Up @@ -516,18 +702,12 @@ function _neg(a::Node{T, D}) where {T, D}
end

function _pow_int(a::Node{T, D}, k::Int) where {T, D}
# Build a^k for a small positive integer constant k (>= 1).
k == 1 && return a
return Node{T, D}(; op = BINARY_POWER, l = a, r = Node{T, D}(; val = T(k)))
end

"""
$(TYPEDSIGNATURES)

Pure tree-rewrite differentiator. Recurses over `node` returning a new
`Node{T, D}` representing `∂ node / ∂ x_{var_idx}`, where `var_idx` is the
1-based feature index of the variable to differentiate with respect to.
"""
# Recursive tree-rewrite differentiator. Returns a new tree representing
# ∂node/∂x_{var_idx}.
function _differentiate(node::Node{T, D}, var_idx::Int) where {T, D}
if node.degree == 0
if node.constant
Expand Down Expand Up @@ -585,8 +765,6 @@ function _differentiate(node::Node{T, D}, var_idx::Int) where {T, D}
num = _sub(_mul(du, v), _mul(u, dv))
return _div(num, _pow_int(v, 2))
elseif op == BINARY_POWER
# Common cases: constant exponent or constant base get clean
# derivatives. General case uses u^v * (v' log u + v u' / u).
if _is_const(v)
# d/dx(u^c) = c · u^(c-1) · u'
du = _differentiate(u, var_idx)
Expand Down Expand Up @@ -617,31 +795,36 @@ end
"""
$(TYPEDSIGNATURES)

Return the symbolic derivative of `f` with respect to variable `var_name`.
Return the symbolic derivative of `f` with respect to the variable whose
1-based feature index is `var_idx`. Differentiation is implemented as a
recursive tree rewrite over FEC's closed grammar (10 unary + 5 binary
operators) — no dependency on ForwardDiff, Zygote, or Symbolics.

The returned function is a `ScalarExpressionFunction` over the same variable
list as `f` — only the underlying expression tree changes. Differentiation
is implemented as a recursive tree rewrite over FEC's closed grammar (10
unary + 5 binary operators), so there is no dependency on ForwardDiff,
Zygote, or Symbolics; the result survives `juliac --trim`.
The result is a fresh `ScalarExpressionFunction` over the same variable
slots; the trailing variable is conventionally time.
"""
function differentiate(f::ScalarExpressionFunction{T}, var_idx::Integer) where T
@assert 1 <= Int(var_idx) <= Int(f.num_vars) "var_idx $(var_idx) out of range 1..$(Int(f.num_vars))"
tree = _unflatten(f.nodes)
deriv_tree = _differentiate(tree, Int(var_idx))
nodes, n_active = _flatten(deriv_tree)
return ScalarExpressionFunction{T}(nodes, n_active, f.num_vars)
end

Supported operators: cos, cosh, exp, log, sin, sinh, sqrt, tan, tanh, unary
minus, +, -, *, /, ^.
"""
$(TYPEDSIGNATURES)

```julia
f = ScalarExpressionFunction{Float64}("a * exp(-(t - tc)^2 / (2 * τ^2))",
["a", "tc", "τ", "t"])
f_dot = differentiate(f, "t")
f_dot_dot = differentiate(f_dot, "t")
```
Convenience overload that resolves `var_name` against an explicit
`var_names` list, then delegates to the integer form. Useful when the
caller still has the var-name list in scope (typical at TOML parse time);
runtime hot paths should call the integer form directly.
"""
function differentiate(f::ScalarExpressionFunction{T}, var_name::String) where T
var_names = f.expr.metadata.var_names
idx = findfirst(==(var_name), var_names)
function differentiate(f::ScalarExpressionFunction{T},
var_names::AbstractVector{<:AbstractString},
var_name::AbstractString) where T
idx = findfirst(==(var_name), var_names)
@assert idx !== nothing "variable \"$var_name\" not in $(var_names)"
deriv_tree = _differentiate(f.expr.tree, idx)
deriv_expr = Expression(deriv_tree; operators, var_names)
return ScalarExpressionFunction{T}(deriv_expr, f.num_vars)
return differentiate(f, idx)
end

end # module
Loading
Loading