From 9dc1575433cd15bfb0bbbb7d3bc51337a990b00a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olav=20M=C3=B8yner?= Date: Thu, 22 Jan 2026 19:59:14 +0100 Subject: [PATCH 1/9] Refactor to use SparseConnectivityTracer --- src/Jutul.jl | 2 +- src/ad/sparsity.jl | 72 +++++++++++++++++++++++----------------------- test/sparsity.jl | 17 ++++++----- 3 files changed, 45 insertions(+), 46 deletions(-) diff --git a/src/Jutul.jl b/src/Jutul.jl index 569d59468..39a6e4d7a 100644 --- a/src/Jutul.jl +++ b/src/Jutul.jl @@ -41,7 +41,7 @@ module Jutul import ForwardDiff import DifferentiationInterface: AutoSparse, prepare_jacobian, jacobian, AutoForwardDiff - import SparseConnectivityTracer: TracerLocalSparsityDetector + import SparseConnectivityTracer: TracerLocalSparsityDetector, jacobian_sparsity import SparseMatrixColorings: GreedyColoringAlgorithm # Timing diff --git a/src/ad/sparsity.jl b/src/ad/sparsity.jl index 231e391a5..68d05e787 100644 --- a/src/ad/sparsity.jl +++ b/src/ad/sparsity.jl @@ -1,19 +1,3 @@ -# function unpack_tag(v::Type{ForwardDiff.Dual{T, F, N}}, t::Symbol = :entity) where {T, F, N} -# @info "???" -# if v isa Tuple -# if t == :entity -# out = T[2] -# elseif t == :model -# out = T[1] -# else -# out = T -# end -# else -# out = T -# end -# return unpack_tag(out) -# end - function unpack_tag(v::Type{ForwardDiff.Dual{T, F, N}}) where {T, F, N} return unpack_tag(T) end @@ -29,8 +13,9 @@ end unpack_tag(A::AbstractArray, arg...) = unpack_tag(eltype(A), arg...) unpack_tag(::Any, arg...) = nothing -struct SparsityTracingWrapper{T, N, D} <: AbstractArray{T, N} +struct SparsityTracingWrapper{T, N, D, AD<:AbstractVector} <: AbstractArray{T, N} data::Array{D, N} + advec::AD end """ @@ -41,8 +26,9 @@ produces outputs that have the same value as the wrapped type, but contains a SparsityTracing seeded value with seed equal to the column index (if matrix) or linear index (if vector). """ -function SparsityTracingWrapper(x::AbstractArray{T, N}) where {T, N} - return SparsityTracingWrapper{Float64, N, T}(x) +function SparsityTracingWrapper(x::AbstractArray{T, N}, advec) where {T, N} + size(x)[end] == length(advec) || error("Length of advec must match last dimension of x") + return SparsityTracingWrapper{Float64, N, T, typeof(advec)}(x, advec) end Base.parent(A::SparsityTracingWrapper) = A.data @@ -78,40 +64,31 @@ function Base.getindex(A::SparsityTracingWrapper{T, D, <:Any}, I, J) where {T, D end function Base.getindex(A::SparsityTracingWrapper, i::Int, j::Int) - return as_tracer(A.data[i, j], j) + return A.data[i, j]*A.advec[j] end function Base.getindex(A::SparsityTracingWrapper{T, 2, D}, ix::Int) where {T, D} - n, m = size(A) + m = size(A)[2] zero_ix = ix - 1 i = (zero_ix ÷ m) + 1 j = mod(zero_ix, m) + 1 - return as_tracer(A.data[i, j], j) + return A.data[i, j]*A.advec[j] end function Base.getindex(A::SparsityTracingWrapper{T, 1, D}, i::Int) where {T, D} - return as_tracer(A.data[i], i) -end - -function as_tracer(x::Real, leaf_init) - return Jutul.ST.ADval{Float64}(value(x), Jutul.ST.DerivLeaf(leaf_init)) -end - -function as_tracer(x, leaf_init) - return x + return A.data[i]*A.advec[i] end -function create_mock_state(state, tag, entities = ad_entities(state); subkeys = nothing) +function create_mock_state(state, tag, X_tracer::AbstractVector; subkeys = nothing) no_provided_keys = isnothing(subkeys) mock_state = JutulStorage() - n = entities[tag].n for k in keys(state) v = state[k] tag_matches = unpack_tag(v) == tag key_matches = no_provided_keys || (k in subkeys) if tag_matches && key_matches # Assign mock value with tracer - new_v = SparsityTracingWrapper(v) + new_v = SparsityTracingWrapper(v, X_tracer) else # Assign mock value as doubles new_v = as_value(v) @@ -144,8 +121,29 @@ function ad_entities(state) end function determine_sparsity(F!, n, state, state0, tag, entities, N = entities[tag].n) - mstate = create_mock_state(state, tag, entities) - mstate0 = create_mock_state(state0, tag, entities) + # ne = count_entities(model.domain, k) + @info "??" N tag + function x_to_evaluated(X::AbstractVector{T}) where T + out = zeros(T, N) + eq_buf = zeros(T, n) + mstate = create_mock_state(state, tag, X) + mstate0 = create_mock_state(state0, tag, X) + + for i in 1:N + @inbounds F!(eq_buf, mstate, mstate0, i) + # Take the sum over all return values to reduce to scalar. + # This should accumulate the full "entity" pattern if some + # equations have a different stencil. + out[i] = sum(eq_buf) + end + return out + end + + S = jacobian_sparsity(x_to_evaluated, ones(N), TracerLocalSparsityDetector()) + + @info S + error("Not implemented yet.") + out = ST.create_advec(zeros(n)) J = [Vector{Int64}() for i in 1:N] @@ -175,6 +173,8 @@ function determine_sparsity_simple(F, model, state, state0 = nothing; variant = end end for (k, v) in entities + ne = count_entities(model.domain, k) + @info "??" ne mstate = create_mock_state(state, k, entities, subkeys = subkeys) if isnothing(state0) f_ad = F(mstate) diff --git a/test/sparsity.jl b/test/sparsity.jl index da379fb93..f5f5e1833 100644 --- a/test/sparsity.jl +++ b/test/sparsity.jl @@ -6,6 +6,9 @@ using Jutul, Test test_mat = zeros(m, n) test_vec = zeros(n) + # Use NaN to see what values get tagged for testing purposes + internal_vec = collect(-1:-1:-n) + for i in 1:n test_vec[i] = i for j in 1:m @@ -13,21 +16,17 @@ using Jutul, Test end end - vec_st = Jutul.SparsityTracingWrapper(test_vec) + vec_st = Jutul.SparsityTracingWrapper(test_vec, internal_vec) for i in 1:n v = vec_st[i] - @test v isa Jutul.ST.ADval - @test v.derivnode.index == i - @test v.val == test_vec[i] + @test v == test_vec[i]*-i end - mat_st = Jutul.SparsityTracingWrapper(test_mat) + mat_st = Jutul.SparsityTracingWrapper(test_mat, internal_vec) for i in 1:n for j in 1:m v = mat_st[j, i] - @test v isa Jutul.ST.ADval - @test v.derivnode.index == i - @test v.val == test_mat[j, i] + @test v == test_mat[j, i]*-i end end @@ -49,7 +48,7 @@ using Jutul, Test end end end - +## @testset "ad_tags" begin v = allocate_array_ad(1, diag_pos = 1, tag = Cells()) @test Jutul.value(v[1]) isa Float64 From 7038ab43b3c987cfdc8af8b04a5e06f138db6bb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olav=20M=C3=B8yner?= Date: Thu, 22 Jan 2026 20:04:25 +0100 Subject: [PATCH 2/9] Update sparsity.jl --- src/ad/sparsity.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ad/sparsity.jl b/src/ad/sparsity.jl index 68d05e787..b67ee5d04 100644 --- a/src/ad/sparsity.jl +++ b/src/ad/sparsity.jl @@ -64,7 +64,7 @@ function Base.getindex(A::SparsityTracingWrapper{T, D, <:Any}, I, J) where {T, D end function Base.getindex(A::SparsityTracingWrapper, i::Int, j::Int) - return A.data[i, j]*A.advec[j] + return value(A.data[i, j])*A.advec[j] end function Base.getindex(A::SparsityTracingWrapper{T, 2, D}, ix::Int) where {T, D} @@ -72,11 +72,11 @@ function Base.getindex(A::SparsityTracingWrapper{T, 2, D}, ix::Int) where {T, D} zero_ix = ix - 1 i = (zero_ix ÷ m) + 1 j = mod(zero_ix, m) + 1 - return A.data[i, j]*A.advec[j] + return value(A.data[i, j])*A.advec[j] end function Base.getindex(A::SparsityTracingWrapper{T, 1, D}, i::Int) where {T, D} - return A.data[i]*A.advec[i] + return value(A.data[i])*A.advec[i] end function create_mock_state(state, tag, X_tracer::AbstractVector; subkeys = nothing) From a4d729083a68a0796d5d6980a959bff8007ef496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olav=20M=C3=B8yner?= Date: Thu, 22 Jan 2026 20:15:35 +0100 Subject: [PATCH 3/9] Basic example runs --- src/Jutul.jl | 2 +- src/ad/sparsity.jl | 38 +++++++++++++++++--------------------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/src/Jutul.jl b/src/Jutul.jl index 39a6e4d7a..a713840e9 100644 --- a/src/Jutul.jl +++ b/src/Jutul.jl @@ -41,7 +41,7 @@ module Jutul import ForwardDiff import DifferentiationInterface: AutoSparse, prepare_jacobian, jacobian, AutoForwardDiff - import SparseConnectivityTracer: TracerLocalSparsityDetector, jacobian_sparsity + import SparseConnectivityTracer: TracerLocalSparsityDetector, TracerSparsityDetector, jacobian_sparsity import SparseMatrixColorings: GreedyColoringAlgorithm # Timing diff --git a/src/ad/sparsity.jl b/src/ad/sparsity.jl index b67ee5d04..3ae245d6f 100644 --- a/src/ad/sparsity.jl +++ b/src/ad/sparsity.jl @@ -51,7 +51,7 @@ function Base.getindex(A::SparsityTracingWrapper{T, D, <:Any}, I, J) where {T, D if J isa Colon J = axes(A, 2) end - Ts = Jutul.ST.ADval{T} + Ts = eltype(A.advec) n = length(I) m = length(J) out = Matrix{Ts}(undef, n, m) @@ -64,7 +64,8 @@ function Base.getindex(A::SparsityTracingWrapper{T, D, <:Any}, I, J) where {T, D end function Base.getindex(A::SparsityTracingWrapper, i::Int, j::Int) - return value(A.data[i, j])*A.advec[j] + return traced_value(A.data[i, j], A, j) + # return value(A.data[i, j])*A.advec[j] end function Base.getindex(A::SparsityTracingWrapper{T, 2, D}, ix::Int) where {T, D} @@ -72,11 +73,16 @@ function Base.getindex(A::SparsityTracingWrapper{T, 2, D}, ix::Int) where {T, D} zero_ix = ix - 1 i = (zero_ix ÷ m) + 1 j = mod(zero_ix, m) + 1 - return value(A.data[i, j])*A.advec[j] + return traced_value(A.data[i, j], A, j) + # return value(A.data[i, j])*A.advec[j] end function Base.getindex(A::SparsityTracingWrapper{T, 1, D}, i::Int) where {T, D} - return value(A.data[i])*A.advec[i] + return traced_value(A.data[i], A, i) +end + +function traced_value(baseval, A, idx) + return max(value(baseval), 1e-8)*A.advec[idx] end function create_mock_state(state, tag, X_tracer::AbstractVector; subkeys = nothing) @@ -121,8 +127,6 @@ function ad_entities(state) end function determine_sparsity(F!, n, state, state0, tag, entities, N = entities[tag].n) - # ne = count_entities(model.domain, k) - @info "??" N tag function x_to_evaluated(X::AbstractVector{T}) where T out = zeros(T, N) eq_buf = zeros(T, n) @@ -139,22 +143,14 @@ function determine_sparsity(F!, n, state, state0, tag, entities, N = entities[ta return out end - S = jacobian_sparsity(x_to_evaluated, ones(N), TracerLocalSparsityDetector()) - - @info S - error("Not implemented yet.") - + dtct = TracerSparsityDetector() + dtct = TracerLocalSparsityDetector() + S = jacobian_sparsity(x_to_evaluated, ones(N), dtct) - out = ST.create_advec(zeros(n)) - J = [Vector{Int64}() for i in 1:N] - for i in 1:N - @inbounds F!(out, mstate, mstate0, i) - # Take the sum over all return values to reduce to scalar. - # This should accumulate the full "entity" pattern if some - # equations have a different stencil. - V = sum(out) - D = ST.deriv(V) - J[i] = D.nzind + J = [Vector{Int64}() for _ in 1:N] + rows, cols, = findnz(S) + for (row, col) in zip(rows, cols) + push!(J[row], col) end return J end From aba0fa387fa79bee7bd6c6b9480d85b4d8f9ae38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olav=20M=C3=B8yner?= Date: Thu, 22 Jan 2026 20:30:08 +0100 Subject: [PATCH 4/9] fixes --- src/ad/sparsity.jl | 9 ++++++--- src/equations.jl | 3 ++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/ad/sparsity.jl b/src/ad/sparsity.jl index 3ae245d6f..ac334ac44 100644 --- a/src/ad/sparsity.jl +++ b/src/ad/sparsity.jl @@ -27,7 +27,7 @@ SparsityTracing seeded value with seed equal to the column index (if matrix) or linear index (if vector). """ function SparsityTracingWrapper(x::AbstractArray{T, N}, advec) where {T, N} - size(x)[end] == length(advec) || error("Length of advec must match last dimension of x") + size(x)[end] == length(advec) || error("Length of advec ($(length(advec)) must match last dimension of x (size(x)=$(size(x)))") return SparsityTracingWrapper{Float64, N, T, typeof(advec)}(x, advec) end @@ -126,7 +126,10 @@ function ad_entities(state) return out end -function determine_sparsity(F!, n, state, state0, tag, entities, N = entities[tag].n) +function determine_sparsity(F!, n, state, state0, count_of_tag, tag, entities, N = entities[tag].n) + # n: number of equations per entity + # N: number of entities where the equation lies (output size) + # count_of_tag: number of variables with the given tag (input size) function x_to_evaluated(X::AbstractVector{T}) where T out = zeros(T, N) eq_buf = zeros(T, n) @@ -145,7 +148,7 @@ function determine_sparsity(F!, n, state, state0, tag, entities, N = entities[ta dtct = TracerSparsityDetector() dtct = TracerLocalSparsityDetector() - S = jacobian_sparsity(x_to_evaluated, ones(N), dtct) + S = jacobian_sparsity(x_to_evaluated, ones(count_of_tag), dtct) J = [Vector{Int64}() for _ in 1:N] rows, cols, = findnz(S) diff --git a/src/equations.jl b/src/equations.jl index a293c24d5..8d3718666 100644 --- a/src/equations.jl +++ b/src/equations.jl @@ -205,7 +205,8 @@ function create_equation_caches(model, equations_per_entity, number_of_entities, for (e, epack) in entities is_self = e == self_entity self_entity_found = self_entity_found || is_self - @tic "sparsity detection" S = determine_sparsity(F!, equations_per_entity, state, state0, e, entities, number_of_entities) + num_e = count_entities(model.domain, e) + @tic "sparsity detection" S = determine_sparsity(F!, equations_per_entity, state, state0, num_e, e, entities, number_of_entities) if !isnothing(extra_sparsity) # We have some extra sparsity, need to merge that in S_e = extra_sparsity[entity_as_symbol(e)] From 610667743846be2e86b7be8b13cb36b152b2733f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olav=20M=C3=B8yner?= Date: Thu, 22 Jan 2026 20:46:24 +0100 Subject: [PATCH 5/9] Update sparsity.jl --- src/ad/sparsity.jl | 53 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/src/ad/sparsity.jl b/src/ad/sparsity.jl index ac334ac44..ad41e9741 100644 --- a/src/ad/sparsity.jl +++ b/src/ad/sparsity.jl @@ -147,7 +147,7 @@ function determine_sparsity(F!, n, state, state0, count_of_tag, tag, entities, N end dtct = TracerSparsityDetector() - dtct = TracerLocalSparsityDetector() + # dtct = TracerLocalSparsityDetector() S = jacobian_sparsity(x_to_evaluated, ones(count_of_tag), dtct) J = [Vector{Int64}() for _ in 1:N] @@ -155,6 +155,7 @@ function determine_sparsity(F!, n, state, state0, count_of_tag, tag, entities, N for (row, col) in zip(rows, cols) push!(J[row], col) end + @info "Sparsity" tag J return J end @@ -172,23 +173,43 @@ function determine_sparsity_simple(F, model, state, state0 = nothing; variant = end end for (k, v) in entities - ne = count_entities(model.domain, k) - @info "??" ne - mstate = create_mock_state(state, k, entities, subkeys = subkeys) - if isnothing(state0) - f_ad = F(mstate) - else - mstate0 = create_mock_state(state0, k, entities, subkeys = subkeys) - f_ad = F(mstate, mstate0) - end - V = sum(f_ad) - if V isa AbstractFloat || V isa Integer - S = zeros(Int64, 0) - else - D = ST.deriv(V) - S = D.nzind + function trace_entity(X) + mstate = create_mock_state(state, k, X, subkeys = subkeys) + if isnothing(state0) + f_ad = F(mstate) + else + mstate0 = create_mock_state(state0, k, X, subkeys = subkeys) + f_ad = F(mstate, mstate0) + end + return sum(f_ad) end + ne = count_entities(model.domain, k) + dtct = TracerLocalSparsityDetector() + # dtct = TracerSparsityDetector() + js = jacobian_sparsity(trace_entity, ones(ne), dtct) + S = findnz(js)[2] + @info "???" S k v sparsity[k] = S end + + # for (k, v) in entities + # ne = count_entities(model.domain, k) + # @info "??" ne + # mstate = create_mock_state(state, k, X, subkeys = subkeys) + # if isnothing(state0) + # f_ad = F(mstate) + # else + # mstate0 = create_mock_state(state0, k, X, subkeys = subkeys) + # f_ad = F(mstate, mstate0) + # end + # V = sum(f_ad) + # if V isa AbstractFloat || V isa Integer + # S = zeros(Int64, 0) + # else + # D = ST.deriv(V) + # S = D.nzind + # end + # sparsity[k] = S + # end return sparsity end From 9b7f9fb5e36952c630bdee6b2c80e8ba21f6a30b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olav=20M=C3=B8yner?= Date: Thu, 22 Jan 2026 20:52:15 +0100 Subject: [PATCH 6/9] Update sparsity.jl --- src/ad/sparsity.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/ad/sparsity.jl b/src/ad/sparsity.jl index ad41e9741..18fc3768d 100644 --- a/src/ad/sparsity.jl +++ b/src/ad/sparsity.jl @@ -82,7 +82,8 @@ function Base.getindex(A::SparsityTracingWrapper{T, 1, D}, i::Int) where {T, D} end function traced_value(baseval, A, idx) - return max(value(baseval), 1e-8)*A.advec[idx] + bval = value(baseval) + return sign(bval)*max(abs(bval), 1e-8)*A.advec[idx] end function create_mock_state(state, tag, X_tracer::AbstractVector; subkeys = nothing) @@ -141,7 +142,11 @@ function determine_sparsity(F!, n, state, state0, count_of_tag, tag, entities, N # Take the sum over all return values to reduce to scalar. # This should accumulate the full "entity" pattern if some # equations have a different stencil. - out[i] = sum(eq_buf) + v = zero(T) + for j in 1:n + v += abs(eq_buf[j]) + end + out[i] = v end return out end From 70d0c24a44af7f04931b87f829c330ecc76292bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olav=20M=C3=B8yner?= Date: Thu, 22 Jan 2026 21:13:59 +0100 Subject: [PATCH 7/9] Update sparsity.jl --- src/ad/sparsity.jl | 35 +++++++++-------------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/src/ad/sparsity.jl b/src/ad/sparsity.jl index 18fc3768d..eb14b9f7b 100644 --- a/src/ad/sparsity.jl +++ b/src/ad/sparsity.jl @@ -82,8 +82,9 @@ function Base.getindex(A::SparsityTracingWrapper{T, 1, D}, i::Int) where {T, D} end function traced_value(baseval, A, idx) - bval = value(baseval) - return sign(bval)*max(abs(bval), 1e-8)*A.advec[idx] + # bval = value(baseval) + return value(baseval)*A.advec[idx] + # return sign(bval)*(abs(bval) + idx*1e-3)*A.advec[idx] end function create_mock_state(state, tag, X_tracer::AbstractVector; subkeys = nothing) @@ -96,8 +97,11 @@ function create_mock_state(state, tag, X_tracer::AbstractVector; subkeys = nothi if tag_matches && key_matches # Assign mock value with tracer new_v = SparsityTracingWrapper(v, X_tracer) + elseif eltype(v) <: Real + n = size(v)[end] + new_v = SparsityTracingWrapper(v, ones(n)) else - # Assign mock value as doubles + # Probably not a numeric array, just use value wrapper new_v = as_value(v) end mock_state[k] = new_v @@ -160,7 +164,6 @@ function determine_sparsity(F!, n, state, state0, count_of_tag, tag, entities, N for (row, col) in zip(rows, cols) push!(J[row], col) end - @info "Sparsity" tag J return J end @@ -189,32 +192,12 @@ function determine_sparsity_simple(F, model, state, state0 = nothing; variant = return sum(f_ad) end ne = count_entities(model.domain, k) - dtct = TracerLocalSparsityDetector() - # dtct = TracerSparsityDetector() + # dtct = TracerLocalSparsityDetector() + dtct = TracerSparsityDetector() js = jacobian_sparsity(trace_entity, ones(ne), dtct) S = findnz(js)[2] - @info "???" S k v sparsity[k] = S end - # for (k, v) in entities - # ne = count_entities(model.domain, k) - # @info "??" ne - # mstate = create_mock_state(state, k, X, subkeys = subkeys) - # if isnothing(state0) - # f_ad = F(mstate) - # else - # mstate0 = create_mock_state(state0, k, X, subkeys = subkeys) - # f_ad = F(mstate, mstate0) - # end - # V = sum(f_ad) - # if V isa AbstractFloat || V isa Integer - # S = zeros(Int64, 0) - # else - # D = ST.deriv(V) - # S = D.nzind - # end - # sparsity[k] = S - # end return sparsity end From e9a43425f259435a01e8e6388d0dffcf4a249fee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olav=20M=C3=B8yner?= Date: Thu, 22 Jan 2026 21:53:35 +0100 Subject: [PATCH 8/9] Update core_types.jl --- src/core_types/core_types.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/core_types/core_types.jl b/src/core_types/core_types.jl index 79e9ab5ab..7fe348b3a 100644 --- a/src/core_types/core_types.jl +++ b/src/core_types/core_types.jl @@ -1622,3 +1622,4 @@ function timestepping_is_done(C::EndTimeTerminationCriterion, simulator, states, return now >= C.end_time end +JutulReal = Union{Float64, ForwardDiff.Dual} From f2e69a901905ef6fd8252fa59d372812b857463d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Olav=20M=C3=B8yner?= Date: Sat, 24 Jan 2026 18:42:12 +0100 Subject: [PATCH 9/9] Update flux.jl --- src/conservation/flux.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/conservation/flux.jl b/src/conservation/flux.jl index e06ec092e..b93d02a25 100644 --- a/src/conservation/flux.jl +++ b/src/conservation/flux.jl @@ -335,6 +335,11 @@ end export upw_flux function upw_flux(v, l, r) + # Take both branches - could involve sparsity + return ifelse(v > 0, l, r) +end + +function upw_flux(v::JutulReal, l::JutulReal, r::JutulReal) if v > 0 # Flow l -> r out = l @@ -344,11 +349,7 @@ function upw_flux(v, l, r) return out end -function upw_flux(v, l::T, r::T) where {T<:ST.ADval} - if v > 0 - out = l + r*0 - else - out = r + l*0 - end - return out +function upw_flux(v, l::JutulReal, r::JutulReal) + flag = v > 0 + return flag*l + (1.0 - flag)*r end