Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 192 additions & 1 deletion src/Expressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,10 @@ function _nud(p::Parser, t::Token, ::Type{T}) where T <: Number
return Node{T}(; val = t.value)
elseif t.id == OPERATOR
if t.op == BINARY_MINUS
val = _parse_statement(p, 100)
# Unary minus. Right-binding-power 25 sits between `*`/`/` (20)
# and `^` (30) — matches standard math precedence so `-t^2` parses
# as `-(t^2)` rather than `(-t)^2`.
val = _parse_statement(p, 25)
return Node{T}(; op = UNARY_MINUS, l = val)
else
error("Unexpected operator in _nud. Found operator $(t.op)")
Expand Down Expand Up @@ -386,6 +389,20 @@ struct ScalarExpressionFunction{T <: Number} <: AbstractExpressionFunction{T, No
expr = Expression(ast; operators, var_names)
new{T}(expr, length(var_names))
end

"""
$(TYPEDSIGNATURES)

Build a `ScalarExpressionFunction` directly from a prebuilt
`DynamicExpressions.Expression` — used by [`differentiate`](@ref) to wrap
the result of a tree rewrite without round-tripping through the parser.
"""
function ScalarExpressionFunction{T}(
expr::Expression{T, Node{T, DEFAULT_MAX_DEGREE}, ntuple_type},
num_vars::Int
) where T <: Number
new{T}(expr, num_vars)
end
end

Base.eltype(::ScalarExpressionFunction{T}) where T <: Number = T
Expand Down Expand Up @@ -453,4 +470,178 @@ function (f::VectorExpressionFunction)(X::SVector{ND, T}, t::T) where {ND, T <:
return map(func -> func(X, t), f.exprs)
end

########################################################
# Symbolic differentiation on Node{T, D} trees.
#
# The grammar is finite (10 unary + 5 binary operators), so the chain rule
# can be applied by tree rewriting in ~80 lines of pure Julia. This avoids
# pulling in ForwardDiff/Zygote/Symbolics, so the result survives
# `juliac --trim`. All helpers operate on values; no closures are formed.
########################################################

@inline _is_const(n::Node) = n.degree == 0 && n.constant
@inline _is_zero(n::Node) = _is_const(n) && iszero(n.val)
@inline _is_one(n::Node) = _is_const(n) && isone(n.val)

# Smart constructors that fold trivial constants so derivative trees stay
# proportional in size to the input. Each returns a fresh `Node{T, D}`.
function _add(a::Node{T, D}, b::Node{T, D}) where {T, D}
_is_zero(a) && return b
_is_zero(b) && return a
return Node{T, D}(; op = BINARY_PLUS, l = a, r = b)
end

function _sub(a::Node{T, D}, b::Node{T, D}) where {T, D}
_is_zero(b) && return a
_is_zero(a) && return _neg(b)
return Node{T, D}(; op = BINARY_MINUS, l = a, r = b)
end

function _mul(a::Node{T, D}, b::Node{T, D}) where {T, D}
(_is_zero(a) || _is_zero(b)) && return Node{T, D}(; val = zero(T))
_is_one(a) && return b
_is_one(b) && return a
return Node{T, D}(; op = BINARY_MULTIPLY, l = a, r = b)
end

function _div(a::Node{T, D}, b::Node{T, D}) where {T, D}
_is_zero(a) && return Node{T, D}(; val = zero(T))
_is_one(b) && return a
return Node{T, D}(; op = BINARY_DIVIDE, l = a, r = b)
end

function _neg(a::Node{T, D}) where {T, D}
_is_zero(a) && return a
return Node{T, D}(; op = UNARY_MINUS, l = a)
end

function _pow_int(a::Node{T, D}, k::Int) where {T, D}
# Build a^k for a small positive integer constant k (>= 1).
k == 1 && return a
return Node{T, D}(; op = BINARY_POWER, l = a, r = Node{T, D}(; val = T(k)))
end

"""
$(TYPEDSIGNATURES)

Pure tree-rewrite differentiator. Recurses over `node` returning a new
`Node{T, D}` representing `∂ node / ∂ x_{var_idx}`, where `var_idx` is the
1-based feature index of the variable to differentiate with respect to.
"""
function _differentiate(node::Node{T, D}, var_idx::Int) where {T, D}
if node.degree == 0
if node.constant
return Node{T, D}(; val = zero(T))
else
return Node{T, D}(;
val = node.feature == var_idx ? one(T) : zero(T)
)
end
elseif node.degree == 1
u = node.l
du = _differentiate(u, var_idx)
op = node.op
if op == UNARY_MINUS
return _neg(du)
elseif op == FUNC_COS
sin_u = Node{T, D}(; op = FUNC_SIN, l = u)
return _mul(_neg(sin_u), du)
elseif op == FUNC_COSH
sinh_u = Node{T, D}(; op = FUNC_SINH, l = u)
return _mul(sinh_u, du)
elseif op == FUNC_EXP
return _mul(node, du)
elseif op == FUNC_LOG
return _div(du, u)
elseif op == FUNC_SIN
cos_u = Node{T, D}(; op = FUNC_COS, l = u)
return _mul(cos_u, du)
elseif op == FUNC_SINH
cosh_u = Node{T, D}(; op = FUNC_COSH, l = u)
return _mul(cosh_u, du)
elseif op == FUNC_SQRT
two_sqrt = _mul(Node{T, D}(; val = T(2)), node)
return _div(du, two_sqrt)
elseif op == FUNC_TAN
cos_u = Node{T, D}(; op = FUNC_COS, l = u)
return _div(du, _pow_int(cos_u, 2))
elseif op == FUNC_TANH
cosh_u = Node{T, D}(; op = FUNC_COSH, l = u)
return _div(du, _pow_int(cosh_u, 2))
end
error("differentiate: unhandled unary op $op")
elseif node.degree == 2
u, v = node.l, node.r
op = node.op
if op == BINARY_PLUS
return _add(_differentiate(u, var_idx), _differentiate(v, var_idx))
elseif op == BINARY_MINUS
return _sub(_differentiate(u, var_idx), _differentiate(v, var_idx))
elseif op == BINARY_MULTIPLY
du, dv = _differentiate(u, var_idx), _differentiate(v, var_idx)
return _add(_mul(du, v), _mul(u, dv))
elseif op == BINARY_DIVIDE
du, dv = _differentiate(u, var_idx), _differentiate(v, var_idx)
num = _sub(_mul(du, v), _mul(u, dv))
return _div(num, _pow_int(v, 2))
elseif op == BINARY_POWER
# Common cases: constant exponent or constant base get clean
# derivatives. General case uses u^v * (v' log u + v u' / u).
if _is_const(v)
# d/dx(u^c) = c · u^(c-1) · u'
du = _differentiate(u, var_idx)
c = v
c_minus_1 = Node{T, D}(; val = c.val - one(T))
u_pow = Node{T, D}(; op = BINARY_POWER, l = u, r = c_minus_1)
return _mul(_mul(c, u_pow), du)
elseif _is_const(u)
# d/dx(c^v) = c^v · log(c) · v'
dv = _differentiate(v, var_idx)
log_c = Node{T, D}(; val = log(u.val))
return _mul(_mul(node, log_c), dv)
else
# general: d/dx(u^v) = u^v · (v' log u + v u'/u)
du = _differentiate(u, var_idx)
dv = _differentiate(v, var_idx)
log_u = Node{T, D}(; op = FUNC_LOG, l = u)
term1 = _mul(dv, log_u)
term2 = _div(_mul(v, du), u)
return _mul(node, _add(term1, term2))
end
end
error("differentiate: unhandled binary op $op")
end
error("differentiate: unhandled degree $(node.degree)")
end

"""
$(TYPEDSIGNATURES)

Return the symbolic derivative of `f` with respect to variable `var_name`.

The returned function is a `ScalarExpressionFunction` over the same variable
list as `f` — only the underlying expression tree changes. Differentiation
is implemented as a recursive tree rewrite over FEC's closed grammar (10
unary + 5 binary operators), so there is no dependency on ForwardDiff,
Zygote, or Symbolics; the result survives `juliac --trim`.

Supported operators: cos, cosh, exp, log, sin, sinh, sqrt, tan, tanh, unary
minus, +, -, *, /, ^.

```julia
f = ScalarExpressionFunction{Float64}("a * exp(-(t - tc)^2 / (2 * τ^2))",
["a", "tc", "τ", "t"])
f_dot = differentiate(f, "t")
f_dot_dot = differentiate(f_dot, "t")
```
"""
function differentiate(f::ScalarExpressionFunction{T}, var_name::String) where T
var_names = f.expr.metadata.var_names
idx = findfirst(==(var_name), var_names)
@assert idx !== nothing "variable \"$var_name\" not in $(var_names)"
deriv_tree = _differentiate(f.expr.tree, idx)
deriv_expr = Expression(deriv_tree; operators, var_names)
return ScalarExpressionFunction{T}(deriv_expr, f.num_vars)
end

end # module
16 changes: 9 additions & 7 deletions src/bcs/DirichletBCs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,11 @@ struct DirichletBCs{
)
end

# juliac safe, for now will only work for static
# need to change bc input to have bindings
# for user provided first and/or second derivatives
# juliac-safe path: derives `func_dot` and `func_dot_dot` symbolically via
# [`differentiate`](@ref) on the user's expression tree, so dynamic
# Dirichlet BCs work under `juliac --trim` without requiring ForwardDiff,
# Zygote, Symbolics, or user-supplied derivatives. F is expected to be
# `Expressions.ScalarExpressionFunction{T}`.
function DirichletBCs{F}(mesh::AbstractMesh, dof, bcs_input) where {F <: Function}
bc_funcs = DirichletBCFunction{F, F, F}[]
if length(bcs_input) == 0
Expand All @@ -275,11 +277,11 @@ struct DirichletBCs{
return new{typeof(bc_funcs), IV, RV}(bc_cache, bc_funcs)
end

# TODO change me, will fail if F is not an ExpressionFunction
# and not a 2d func time-dependent
zero_func = F("0.0", ["x", "y", "t"])
for bc in bcs_input
push!(bc_funcs, DirichletBCFunction{F, F, F}(bc.func, zero_func, zero_func))
func_dot = Expressions.differentiate(bc.func, "t")
func_dot_dot = Expressions.differentiate(func_dot, "t")
push!(bc_funcs,
DirichletBCFunction{F, F, F}(bc.func, func_dot, func_dot_dot))
end

bc_caches = DirichletBCContainer.((mesh,), (dof,), bcs_input)
Expand Down
61 changes: 61 additions & 0 deletions test/TestBCs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,67 @@ end
@test all(A[dof.dirichlet_dofs] .≈ 4.0)
end

@testitem "BCs - juliac-safe dynamic DirichletBCs (symbolic derivatives)" setup=[BCHelper] begin
# Mirrors `test_dirichlet_update_bc_values!` but goes through the
# `DirichletBCs{F}` juliac-safe constructor, which computes the BC's first
# and second time derivatives symbolically via Expressions.differentiate.
# No ForwardDiff, no user-supplied derivatives, no Symbolics.
import FiniteElementContainers: update_bc_values!
import FiniteElementContainers.Expressions: ScalarExpressionFunction

u = VectorFunction(fspace, "displ")
dof = DofManager(u)

# g(t) = 2 t^2 → g'(t) = 4 t → g''(t) = 4
F = ScalarExpressionFunction{Float64}
bc_func = F("2 * t^2", ["x", "y", "t"])
bc_in = DirichletBC("displ_x", bc_func; sideset_name = "sset_1")
bcs = DirichletBCs{F}(mesh, dof, DirichletBC[bc_in])

X = mesh.nodal_coords
t = 3.0
update_bc_values!(bcs, X, t)
@test all(bcs.bc_cache.vals .≈ 18.0)
@test all(bcs.bc_cache.vals_dot .≈ 12.0)
@test all(bcs.bc_cache.vals_dot_dot .≈ 4.0)

U = create_field(dof); V = create_field(dof); A = create_field(dof)
update_field_dirichlet_bcs!(U, V, A, bcs)
@test all(U[dof.dirichlet_dofs] .≈ 18.0)
@test all(V[dof.dirichlet_dofs] .≈ 12.0)
@test all(A[dof.dirichlet_dofs] .≈ 4.0)
end

@testitem "BCs - juliac-safe DirichletBCs Gaussian pulse" setup=[BCHelper] begin
# Exercises a realistic Gaussian-pulse BC of the form used in the
# Norma-ported clamped-bar test: g(t) = a · exp(-(t-tc)^2 / (2 τ^2)).
# All three of g, g', g'' are produced by symbolic differentiation alone.
import FiniteElementContainers: update_bc_values!
import FiniteElementContainers.Expressions: ScalarExpressionFunction

u = VectorFunction(fspace, "displ")
dof = DofManager(u)

a, tc, τ = 1.0e-3, 2.5e-4, 5.0e-5
F = ScalarExpressionFunction{Float64}
bc_func = F("1.0e-3 * exp(-(t - 2.5e-4)^2 / (2 * (5.0e-5)^2))",
["x", "y", "t"])
bc_in = DirichletBC("displ_x", bc_func; sideset_name = "sset_1")
bcs = DirichletBCs{F}(mesh, dof, DirichletBC[bc_in])

X = mesh.nodal_coords
for t in (1.5e-4, 2.5e-4, 3.5e-4)
update_bc_values!(bcs, X, t)
η = t - tc
g_ = a * exp(-η^2 / (2 * τ^2))
gp_ = -(η / τ^2) * g_
gpp_= (η^2 / τ^4 - 1 / τ^2) * g_
@test all(bcs.bc_cache.vals .≈ g_ )
@test all(bcs.bc_cache.vals_dot .≈ gp_ )
@test all(bcs.bc_cache.vals_dot_dot .≈ gpp_)
end
end

@testitem "BCs - test_neumann_bc_input" setup=[BCHelper] begin
bc = NeumannBC("my_var", dummy_func_1, "my_sset")
@test bc.var_name == "my_var"
Expand Down
Loading
Loading