diff --git a/src/Expressions.jl b/src/Expressions.jl index 194ba3f4..17a769d8 100644 --- a/src/Expressions.jl +++ b/src/Expressions.jl @@ -328,7 +328,10 @@ function _nud(p::Parser, t::Token, ::Type{T}) where T <: Number return Node{T}(; val = t.value) elseif t.id == OPERATOR if t.op == BINARY_MINUS - val = _parse_statement(p, 100) + # Unary minus. Right-binding-power 25 sits between `*`/`/` (20) + # and `^` (30) — matches standard math precedence so `-t^2` parses + # as `-(t^2)` rather than `(-t)^2`. + val = _parse_statement(p, 25) return Node{T}(; op = UNARY_MINUS, l = val) else error("Unexpected operator in _nud. Found operator $(t.op)") @@ -386,6 +389,20 @@ struct ScalarExpressionFunction{T <: Number} <: AbstractExpressionFunction{T, No expr = Expression(ast; operators, var_names) new{T}(expr, length(var_names)) end + + """ + $(TYPEDSIGNATURES) + + Build a `ScalarExpressionFunction` directly from a prebuilt + `DynamicExpressions.Expression` — used by [`differentiate`](@ref) to wrap + the result of a tree rewrite without round-tripping through the parser. + """ + function ScalarExpressionFunction{T}( + expr::Expression{T, Node{T, DEFAULT_MAX_DEGREE}, ntuple_type}, + num_vars::Int + ) where T <: Number + new{T}(expr, num_vars) + end end Base.eltype(::ScalarExpressionFunction{T}) where T <: Number = T @@ -453,4 +470,178 @@ function (f::VectorExpressionFunction)(X::SVector{ND, T}, t::T) where {ND, T <: return map(func -> func(X, t), f.exprs) end +######################################################## +# Symbolic differentiation on Node{T, D} trees. +# +# The grammar is finite (10 unary + 5 binary operators), so the chain rule +# can be applied by tree rewriting in ~80 lines of pure Julia. This avoids +# pulling in ForwardDiff/Zygote/Symbolics, so the result survives +# `juliac --trim`. All helpers operate on values; no closures are formed. +######################################################## + +@inline _is_const(n::Node) = n.degree == 0 && n.constant +@inline _is_zero(n::Node) = _is_const(n) && iszero(n.val) +@inline _is_one(n::Node) = _is_const(n) && isone(n.val) + +# Smart constructors that fold trivial constants so derivative trees stay +# proportional in size to the input. Each returns a fresh `Node{T, D}`. +function _add(a::Node{T, D}, b::Node{T, D}) where {T, D} + _is_zero(a) && return b + _is_zero(b) && return a + return Node{T, D}(; op = BINARY_PLUS, l = a, r = b) +end + +function _sub(a::Node{T, D}, b::Node{T, D}) where {T, D} + _is_zero(b) && return a + _is_zero(a) && return _neg(b) + return Node{T, D}(; op = BINARY_MINUS, l = a, r = b) +end + +function _mul(a::Node{T, D}, b::Node{T, D}) where {T, D} + (_is_zero(a) || _is_zero(b)) && return Node{T, D}(; val = zero(T)) + _is_one(a) && return b + _is_one(b) && return a + return Node{T, D}(; op = BINARY_MULTIPLY, l = a, r = b) +end + +function _div(a::Node{T, D}, b::Node{T, D}) where {T, D} + _is_zero(a) && return Node{T, D}(; val = zero(T)) + _is_one(b) && return a + return Node{T, D}(; op = BINARY_DIVIDE, l = a, r = b) +end + +function _neg(a::Node{T, D}) where {T, D} + _is_zero(a) && return a + return Node{T, D}(; op = UNARY_MINUS, l = a) +end + +function _pow_int(a::Node{T, D}, k::Int) where {T, D} + # Build a^k for a small positive integer constant k (>= 1). + k == 1 && return a + return Node{T, D}(; op = BINARY_POWER, l = a, r = Node{T, D}(; val = T(k))) +end + +""" +$(TYPEDSIGNATURES) + +Pure tree-rewrite differentiator. Recurses over `node` returning a new +`Node{T, D}` representing `∂ node / ∂ x_{var_idx}`, where `var_idx` is the +1-based feature index of the variable to differentiate with respect to. +""" +function _differentiate(node::Node{T, D}, var_idx::Int) where {T, D} + if node.degree == 0 + if node.constant + return Node{T, D}(; val = zero(T)) + else + return Node{T, D}(; + val = node.feature == var_idx ? one(T) : zero(T) + ) + end + elseif node.degree == 1 + u = node.l + du = _differentiate(u, var_idx) + op = node.op + if op == UNARY_MINUS + return _neg(du) + elseif op == FUNC_COS + sin_u = Node{T, D}(; op = FUNC_SIN, l = u) + return _mul(_neg(sin_u), du) + elseif op == FUNC_COSH + sinh_u = Node{T, D}(; op = FUNC_SINH, l = u) + return _mul(sinh_u, du) + elseif op == FUNC_EXP + return _mul(node, du) + elseif op == FUNC_LOG + return _div(du, u) + elseif op == FUNC_SIN + cos_u = Node{T, D}(; op = FUNC_COS, l = u) + return _mul(cos_u, du) + elseif op == FUNC_SINH + cosh_u = Node{T, D}(; op = FUNC_COSH, l = u) + return _mul(cosh_u, du) + elseif op == FUNC_SQRT + two_sqrt = _mul(Node{T, D}(; val = T(2)), node) + return _div(du, two_sqrt) + elseif op == FUNC_TAN + cos_u = Node{T, D}(; op = FUNC_COS, l = u) + return _div(du, _pow_int(cos_u, 2)) + elseif op == FUNC_TANH + cosh_u = Node{T, D}(; op = FUNC_COSH, l = u) + return _div(du, _pow_int(cosh_u, 2)) + end + error("differentiate: unhandled unary op $op") + elseif node.degree == 2 + u, v = node.l, node.r + op = node.op + if op == BINARY_PLUS + return _add(_differentiate(u, var_idx), _differentiate(v, var_idx)) + elseif op == BINARY_MINUS + return _sub(_differentiate(u, var_idx), _differentiate(v, var_idx)) + elseif op == BINARY_MULTIPLY + du, dv = _differentiate(u, var_idx), _differentiate(v, var_idx) + return _add(_mul(du, v), _mul(u, dv)) + elseif op == BINARY_DIVIDE + du, dv = _differentiate(u, var_idx), _differentiate(v, var_idx) + num = _sub(_mul(du, v), _mul(u, dv)) + return _div(num, _pow_int(v, 2)) + elseif op == BINARY_POWER + # Common cases: constant exponent or constant base get clean + # derivatives. General case uses u^v * (v' log u + v u' / u). + if _is_const(v) + # d/dx(u^c) = c · u^(c-1) · u' + du = _differentiate(u, var_idx) + c = v + c_minus_1 = Node{T, D}(; val = c.val - one(T)) + u_pow = Node{T, D}(; op = BINARY_POWER, l = u, r = c_minus_1) + return _mul(_mul(c, u_pow), du) + elseif _is_const(u) + # d/dx(c^v) = c^v · log(c) · v' + dv = _differentiate(v, var_idx) + log_c = Node{T, D}(; val = log(u.val)) + return _mul(_mul(node, log_c), dv) + else + # general: d/dx(u^v) = u^v · (v' log u + v u'/u) + du = _differentiate(u, var_idx) + dv = _differentiate(v, var_idx) + log_u = Node{T, D}(; op = FUNC_LOG, l = u) + term1 = _mul(dv, log_u) + term2 = _div(_mul(v, du), u) + return _mul(node, _add(term1, term2)) + end + end + error("differentiate: unhandled binary op $op") + end + error("differentiate: unhandled degree $(node.degree)") +end + +""" +$(TYPEDSIGNATURES) + +Return the symbolic derivative of `f` with respect to variable `var_name`. + +The returned function is a `ScalarExpressionFunction` over the same variable +list as `f` — only the underlying expression tree changes. Differentiation +is implemented as a recursive tree rewrite over FEC's closed grammar (10 +unary + 5 binary operators), so there is no dependency on ForwardDiff, +Zygote, or Symbolics; the result survives `juliac --trim`. + +Supported operators: cos, cosh, exp, log, sin, sinh, sqrt, tan, tanh, unary +minus, +, -, *, /, ^. + +```julia +f = ScalarExpressionFunction{Float64}("a * exp(-(t - tc)^2 / (2 * τ^2))", + ["a", "tc", "τ", "t"]) +f_dot = differentiate(f, "t") +f_dot_dot = differentiate(f_dot, "t") +``` +""" +function differentiate(f::ScalarExpressionFunction{T}, var_name::String) where T + var_names = f.expr.metadata.var_names + idx = findfirst(==(var_name), var_names) + @assert idx !== nothing "variable \"$var_name\" not in $(var_names)" + deriv_tree = _differentiate(f.expr.tree, idx) + deriv_expr = Expression(deriv_tree; operators, var_names) + return ScalarExpressionFunction{T}(deriv_expr, f.num_vars) +end + end # module diff --git a/src/bcs/DirichletBCs.jl b/src/bcs/DirichletBCs.jl index 7de25076..662682ff 100644 --- a/src/bcs/DirichletBCs.jl +++ b/src/bcs/DirichletBCs.jl @@ -263,9 +263,11 @@ struct DirichletBCs{ ) end - # juliac safe, for now will only work for static - # need to change bc input to have bindings - # for user provided first and/or second derivatives + # juliac-safe path: derives `func_dot` and `func_dot_dot` symbolically via + # [`differentiate`](@ref) on the user's expression tree, so dynamic + # Dirichlet BCs work under `juliac --trim` without requiring ForwardDiff, + # Zygote, Symbolics, or user-supplied derivatives. F is expected to be + # `Expressions.ScalarExpressionFunction{T}`. function DirichletBCs{F}(mesh::AbstractMesh, dof, bcs_input) where {F <: Function} bc_funcs = DirichletBCFunction{F, F, F}[] if length(bcs_input) == 0 @@ -275,11 +277,11 @@ struct DirichletBCs{ return new{typeof(bc_funcs), IV, RV}(bc_cache, bc_funcs) end - # TODO change me, will fail if F is not an ExpressionFunction - # and not a 2d func time-dependent - zero_func = F("0.0", ["x", "y", "t"]) for bc in bcs_input - push!(bc_funcs, DirichletBCFunction{F, F, F}(bc.func, zero_func, zero_func)) + func_dot = Expressions.differentiate(bc.func, "t") + func_dot_dot = Expressions.differentiate(func_dot, "t") + push!(bc_funcs, + DirichletBCFunction{F, F, F}(bc.func, func_dot, func_dot_dot)) end bc_caches = DirichletBCContainer.((mesh,), (dof,), bcs_input) diff --git a/test/TestBCs.jl b/test/TestBCs.jl index b95e73c7..6d255f9f 100644 --- a/test/TestBCs.jl +++ b/test/TestBCs.jl @@ -104,6 +104,67 @@ end @test all(A[dof.dirichlet_dofs] .≈ 4.0) end +@testitem "BCs - juliac-safe dynamic DirichletBCs (symbolic derivatives)" setup=[BCHelper] begin + # Mirrors `test_dirichlet_update_bc_values!` but goes through the + # `DirichletBCs{F}` juliac-safe constructor, which computes the BC's first + # and second time derivatives symbolically via Expressions.differentiate. + # No ForwardDiff, no user-supplied derivatives, no Symbolics. + import FiniteElementContainers: update_bc_values! + import FiniteElementContainers.Expressions: ScalarExpressionFunction + + u = VectorFunction(fspace, "displ") + dof = DofManager(u) + + # g(t) = 2 t^2 → g'(t) = 4 t → g''(t) = 4 + F = ScalarExpressionFunction{Float64} + bc_func = F("2 * t^2", ["x", "y", "t"]) + bc_in = DirichletBC("displ_x", bc_func; sideset_name = "sset_1") + bcs = DirichletBCs{F}(mesh, dof, DirichletBC[bc_in]) + + X = mesh.nodal_coords + t = 3.0 + update_bc_values!(bcs, X, t) + @test all(bcs.bc_cache.vals .≈ 18.0) + @test all(bcs.bc_cache.vals_dot .≈ 12.0) + @test all(bcs.bc_cache.vals_dot_dot .≈ 4.0) + + U = create_field(dof); V = create_field(dof); A = create_field(dof) + update_field_dirichlet_bcs!(U, V, A, bcs) + @test all(U[dof.dirichlet_dofs] .≈ 18.0) + @test all(V[dof.dirichlet_dofs] .≈ 12.0) + @test all(A[dof.dirichlet_dofs] .≈ 4.0) +end + +@testitem "BCs - juliac-safe DirichletBCs Gaussian pulse" setup=[BCHelper] begin + # Exercises a realistic Gaussian-pulse BC of the form used in the + # Norma-ported clamped-bar test: g(t) = a · exp(-(t-tc)^2 / (2 τ^2)). + # All three of g, g', g'' are produced by symbolic differentiation alone. + import FiniteElementContainers: update_bc_values! + import FiniteElementContainers.Expressions: ScalarExpressionFunction + + u = VectorFunction(fspace, "displ") + dof = DofManager(u) + + a, tc, τ = 1.0e-3, 2.5e-4, 5.0e-5 + F = ScalarExpressionFunction{Float64} + bc_func = F("1.0e-3 * exp(-(t - 2.5e-4)^2 / (2 * (5.0e-5)^2))", + ["x", "y", "t"]) + bc_in = DirichletBC("displ_x", bc_func; sideset_name = "sset_1") + bcs = DirichletBCs{F}(mesh, dof, DirichletBC[bc_in]) + + X = mesh.nodal_coords + for t in (1.5e-4, 2.5e-4, 3.5e-4) + update_bc_values!(bcs, X, t) + η = t - tc + g_ = a * exp(-η^2 / (2 * τ^2)) + gp_ = -(η / τ^2) * g_ + gpp_= (η^2 / τ^4 - 1 / τ^2) * g_ + @test all(bcs.bc_cache.vals .≈ g_ ) + @test all(bcs.bc_cache.vals_dot .≈ gp_ ) + @test all(bcs.bc_cache.vals_dot_dot .≈ gpp_) + end +end + @testitem "BCs - test_neumann_bc_input" setup=[BCHelper] begin bc = NeumannBC("my_var", dummy_func_1, "my_sset") @test bc.var_name == "my_var" diff --git a/test/TestExpressions.jl b/test/TestExpressions.jl index f167ebae..b5b280be 100644 --- a/test/TestExpressions.jl +++ b/test/TestExpressions.jl @@ -184,3 +184,146 @@ end @test val[2] ≈ -100.0 @test val[3] ≈ 15.0 * exp(5.0) end + +@testitem "ScalarExpressionFunction - parser precedence: unary minus vs ^" begin + import FiniteElementContainers.Expressions: ScalarExpressionFunction + # Standard math: `-t^2` parses as `-(t^2)`, not `(-t)^2`. + f = ScalarExpressionFunction{Float64}("-t^2", ["t"]) + @test f([3.0]) ≈ -9.0 + @test f([-3.0]) ≈ -9.0 + + # Numeric base: -2^2 = -(2^2) = -4 + g = ScalarExpressionFunction{Float64}("-2^2", String[]) + @test g(Float64[]) ≈ -4.0 + + # Right-side unary minus inside power: 2^-2 = 0.25 + h = ScalarExpressionFunction{Float64}("2^-2", String[]) + @test h(Float64[]) ≈ 0.25 + + # Multiplicative remains unaffected (commutativity hides any change) + p = ScalarExpressionFunction{Float64}("-x*y", ["x", "y"]) + @test p([2.0, 3.0]) ≈ -6.0 +end + +@testitem "differentiate - constants and variables" begin + import FiniteElementContainers.Expressions: ScalarExpressionFunction, differentiate + # d/dx of a constant is zero + f = ScalarExpressionFunction{Float64}("3.14", ["x"]) + fp = differentiate(f, "x") + @test fp([1.7]) ≈ 0.0 + @test fp([-9.3]) ≈ 0.0 + + # d/dx of x is 1, of y is 0 + g = ScalarExpressionFunction{Float64}("x", ["x", "y"]) + gp = differentiate(g, "x") + @test gp([1.0, 2.0]) ≈ 1.0 + gq = differentiate(g, "y") + @test gq([1.0, 2.0]) ≈ 0.0 +end + +@testitem "differentiate - each unary operator" begin + import FiniteElementContainers.Expressions: ScalarExpressionFunction, differentiate + # (func name, analytical derivative at x) + cases = [ + ("cos(x)", x -> -sin(x)), + ("cosh(x)", x -> sinh(x)), + ("exp(x)", x -> exp(x)), + ("log(x)", x -> 1.0 / x), + ("sin(x)", x -> cos(x)), + ("sinh(x)", x -> cosh(x)), + ("sqrt(x)", x -> 1.0 / (2 * sqrt(x))), + ("tan(x)", x -> 1.0 / cos(x)^2), + ("tanh(x)", x -> 1.0 / cosh(x)^2), + ] + for (expr, dexpr) in cases + f = ScalarExpressionFunction{Float64}(expr, ["x"]) + fp = differentiate(f, "x") + for x in (0.3, 0.7, 1.4, 2.6) + @test fp([x]) ≈ dexpr(x) rtol=1e-12 + end + end + + # Unary minus: d/dx(-x) = -1 + f = ScalarExpressionFunction{Float64}("-x", ["x"]) + fp = differentiate(f, "x") + @test fp([5.0]) ≈ -1.0 +end + +@testitem "differentiate - binary operators" begin + import FiniteElementContainers.Expressions: ScalarExpressionFunction, differentiate + + # +, -, *, / + f = ScalarExpressionFunction{Float64}("x + y", ["x", "y"]) + @test differentiate(f, "x")([2.0, 3.0]) ≈ 1.0 + @test differentiate(f, "y")([2.0, 3.0]) ≈ 1.0 + + f = ScalarExpressionFunction{Float64}("x - y", ["x", "y"]) + @test differentiate(f, "x")([2.0, 3.0]) ≈ 1.0 + @test differentiate(f, "y")([2.0, 3.0]) ≈ -1.0 + + f = ScalarExpressionFunction{Float64}("x * y", ["x", "y"]) + @test differentiate(f, "x")([2.0, 3.0]) ≈ 3.0 + @test differentiate(f, "y")([2.0, 3.0]) ≈ 2.0 + + f = ScalarExpressionFunction{Float64}("x / y", ["x", "y"]) + @test differentiate(f, "x")([2.0, 3.0]) ≈ 1/3 + @test differentiate(f, "y")([2.0, 3.0]) ≈ -2/9 + + # Power: constant exponent — d/dx(x^c) = c x^{c-1} + f = ScalarExpressionFunction{Float64}("x^3", ["x"]) + fp = differentiate(f, "x") + @test fp([2.0]) ≈ 12.0 + + # Power: constant base — d/dx(c^x) = c^x log(c) + f = ScalarExpressionFunction{Float64}("2^x", ["x"]) + fp = differentiate(f, "x") + @test fp([3.0]) ≈ 8.0 * log(2.0) + + # Power: general — d/dx(x^x) = x^x (log(x) + 1) + f = ScalarExpressionFunction{Float64}("x^x", ["x"]) + fp = differentiate(f, "x") + @test fp([2.0]) ≈ 2.0^2.0 * (log(2.0) + 1.0) +end + +@testitem "differentiate - Gaussian pulse to 2nd derivative" begin + import FiniteElementContainers.Expressions: ScalarExpressionFunction, differentiate + # g(t) = a * exp(-(t - tc)^2 / (2 τ^2)) + # g' = -((t-tc)/τ^2) * g + # g'' = ((t-tc)^2/τ^4 - 1/τ^2) * g + g = ScalarExpressionFunction{Float64}( + "a * exp(-(t - tc)^2 / (2 * tau^2))", ["a", "tc", "tau", "t"]) + gp = differentiate(g, "t") + gpp = differentiate(gp, "t") + a, tc, τ = 1.0e-3, 2.5e-4, 5.0e-5 + for t in (0.0, 1.0e-4, 2.5e-4, 4.0e-4, 5.0e-4) + η = t - tc + g_ = a * exp(-η^2 / (2 * τ^2)) + gp_ = -(η / τ^2) * g_ + gpp_= (η^2 / τ^4 - 1 / τ^2) * g_ + @test g([a, tc, τ, t]) ≈ g_ rtol=1e-12 + @test gp([a, tc, τ, t]) ≈ gp_ rtol=1e-12 + @test gpp([a, tc, τ, t]) ≈ gpp_ rtol=1e-12 + end +end + +@testitem "differentiate - spatial derivatives (traveling wave IC)" begin + import FiniteElementContainers.Expressions: ScalarExpressionFunction, differentiate + # u₀(z) = a * exp(-z^2 / (2 s^2)) + # u₀'(z) = -(z/s^2) * u₀ + u0 = ScalarExpressionFunction{Float64}( + "a * exp(-z^2 / (2 * s^2))", ["a", "s", "z"]) + duz = differentiate(u0, "z") + a, s = 0.01, 0.02 + for z in (-0.04, -0.01, 0.0, 0.01, 0.04) + u_ = a * exp(-z^2 / (2 * s^2)) + du_ = -(z / s^2) * u_ + @test u0([a, s, z]) ≈ u_ rtol=1e-12 + @test duz([a, s, z]) ≈ du_ rtol=1e-12 + end +end + +@testitem "differentiate - error on unknown variable" begin + import FiniteElementContainers.Expressions: ScalarExpressionFunction, differentiate + f = ScalarExpressionFunction{Float64}("x", ["x"]) + @test_throws AssertionError differentiate(f, "y") +end