diff --git a/src/Jutul.jl b/src/Jutul.jl index 569d59468..a713840e9 100644 --- a/src/Jutul.jl +++ b/src/Jutul.jl @@ -41,7 +41,7 @@ module Jutul import ForwardDiff import DifferentiationInterface: AutoSparse, prepare_jacobian, jacobian, AutoForwardDiff - import SparseConnectivityTracer: TracerLocalSparsityDetector + import SparseConnectivityTracer: TracerLocalSparsityDetector, TracerSparsityDetector, jacobian_sparsity import SparseMatrixColorings: GreedyColoringAlgorithm # Timing diff --git a/src/ad/sparsity.jl b/src/ad/sparsity.jl index 231e391a5..eb14b9f7b 100644 --- a/src/ad/sparsity.jl +++ b/src/ad/sparsity.jl @@ -1,19 +1,3 @@ -# function unpack_tag(v::Type{ForwardDiff.Dual{T, F, N}}, t::Symbol = :entity) where {T, F, N} -# @info "???" -# if v isa Tuple -# if t == :entity -# out = T[2] -# elseif t == :model -# out = T[1] -# else -# out = T -# end -# else -# out = T -# end -# return unpack_tag(out) -# end - function unpack_tag(v::Type{ForwardDiff.Dual{T, F, N}}) where {T, F, N} return unpack_tag(T) end @@ -29,8 +13,9 @@ end unpack_tag(A::AbstractArray, arg...) = unpack_tag(eltype(A), arg...) unpack_tag(::Any, arg...) = nothing -struct SparsityTracingWrapper{T, N, D} <: AbstractArray{T, N} +struct SparsityTracingWrapper{T, N, D, AD<:AbstractVector} <: AbstractArray{T, N} data::Array{D, N} + advec::AD end """ @@ -41,8 +26,9 @@ produces outputs that have the same value as the wrapped type, but contains a SparsityTracing seeded value with seed equal to the column index (if matrix) or linear index (if vector). """ -function SparsityTracingWrapper(x::AbstractArray{T, N}) where {T, N} - return SparsityTracingWrapper{Float64, N, T}(x) +function SparsityTracingWrapper(x::AbstractArray{T, N}, advec) where {T, N} + size(x)[end] == length(advec) || error("Length of advec ($(length(advec)) must match last dimension of x (size(x)=$(size(x)))") + return SparsityTracingWrapper{Float64, N, T, typeof(advec)}(x, advec) end Base.parent(A::SparsityTracingWrapper) = A.data @@ -65,7 +51,7 @@ function Base.getindex(A::SparsityTracingWrapper{T, D, <:Any}, I, J) where {T, D if J isa Colon J = axes(A, 2) end - Ts = Jutul.ST.ADval{T} + Ts = eltype(A.advec) n = length(I) m = length(J) out = Matrix{Ts}(undef, n, m) @@ -78,42 +64,44 @@ function Base.getindex(A::SparsityTracingWrapper{T, D, <:Any}, I, J) where {T, D end function Base.getindex(A::SparsityTracingWrapper, i::Int, j::Int) - return as_tracer(A.data[i, j], j) + return traced_value(A.data[i, j], A, j) + # return value(A.data[i, j])*A.advec[j] end function Base.getindex(A::SparsityTracingWrapper{T, 2, D}, ix::Int) where {T, D} - n, m = size(A) + m = size(A)[2] zero_ix = ix - 1 i = (zero_ix รท m) + 1 j = mod(zero_ix, m) + 1 - return as_tracer(A.data[i, j], j) + return traced_value(A.data[i, j], A, j) + # return value(A.data[i, j])*A.advec[j] end function Base.getindex(A::SparsityTracingWrapper{T, 1, D}, i::Int) where {T, D} - return as_tracer(A.data[i], i) -end - -function as_tracer(x::Real, leaf_init) - return Jutul.ST.ADval{Float64}(value(x), Jutul.ST.DerivLeaf(leaf_init)) + return traced_value(A.data[i], A, i) end -function as_tracer(x, leaf_init) - return x +function traced_value(baseval, A, idx) + # bval = value(baseval) + return value(baseval)*A.advec[idx] + # return sign(bval)*(abs(bval) + idx*1e-3)*A.advec[idx] end -function create_mock_state(state, tag, entities = ad_entities(state); subkeys = nothing) +function create_mock_state(state, tag, X_tracer::AbstractVector; subkeys = nothing) no_provided_keys = isnothing(subkeys) mock_state = JutulStorage() - n = entities[tag].n for k in keys(state) v = state[k] tag_matches = unpack_tag(v) == tag key_matches = no_provided_keys || (k in subkeys) if tag_matches && key_matches # Assign mock value with tracer - new_v = SparsityTracingWrapper(v) + new_v = SparsityTracingWrapper(v, X_tracer) + elseif eltype(v) <: Real + n = size(v)[end] + new_v = SparsityTracingWrapper(v, ones(n)) else - # Assign mock value as doubles + # Probably not a numeric array, just use value wrapper new_v = as_value(v) end mock_state[k] = new_v @@ -143,20 +131,38 @@ function ad_entities(state) return out end -function determine_sparsity(F!, n, state, state0, tag, entities, N = entities[tag].n) - mstate = create_mock_state(state, tag, entities) - mstate0 = create_mock_state(state0, tag, entities) - - out = ST.create_advec(zeros(n)) - J = [Vector{Int64}() for i in 1:N] - for i in 1:N - @inbounds F!(out, mstate, mstate0, i) - # Take the sum over all return values to reduce to scalar. - # This should accumulate the full "entity" pattern if some - # equations have a different stencil. - V = sum(out) - D = ST.deriv(V) - J[i] = D.nzind +function determine_sparsity(F!, n, state, state0, count_of_tag, tag, entities, N = entities[tag].n) + # n: number of equations per entity + # N: number of entities where the equation lies (output size) + # count_of_tag: number of variables with the given tag (input size) + function x_to_evaluated(X::AbstractVector{T}) where T + out = zeros(T, N) + eq_buf = zeros(T, n) + mstate = create_mock_state(state, tag, X) + mstate0 = create_mock_state(state0, tag, X) + + for i in 1:N + @inbounds F!(eq_buf, mstate, mstate0, i) + # Take the sum over all return values to reduce to scalar. + # This should accumulate the full "entity" pattern if some + # equations have a different stencil. + v = zero(T) + for j in 1:n + v += abs(eq_buf[j]) + end + out[i] = v + end + return out + end + + dtct = TracerSparsityDetector() + # dtct = TracerLocalSparsityDetector() + S = jacobian_sparsity(x_to_evaluated, ones(count_of_tag), dtct) + + J = [Vector{Int64}() for _ in 1:N] + rows, cols, = findnz(S) + for (row, col) in zip(rows, cols) + push!(J[row], col) end return J end @@ -175,21 +181,23 @@ function determine_sparsity_simple(F, model, state, state0 = nothing; variant = end end for (k, v) in entities - mstate = create_mock_state(state, k, entities, subkeys = subkeys) - if isnothing(state0) - f_ad = F(mstate) - else - mstate0 = create_mock_state(state0, k, entities, subkeys = subkeys) - f_ad = F(mstate, mstate0) - end - V = sum(f_ad) - if V isa AbstractFloat || V isa Integer - S = zeros(Int64, 0) - else - D = ST.deriv(V) - S = D.nzind + function trace_entity(X) + mstate = create_mock_state(state, k, X, subkeys = subkeys) + if isnothing(state0) + f_ad = F(mstate) + else + mstate0 = create_mock_state(state0, k, X, subkeys = subkeys) + f_ad = F(mstate, mstate0) + end + return sum(f_ad) end + ne = count_entities(model.domain, k) + # dtct = TracerLocalSparsityDetector() + dtct = TracerSparsityDetector() + js = jacobian_sparsity(trace_entity, ones(ne), dtct) + S = findnz(js)[2] sparsity[k] = S end + return sparsity end diff --git a/src/conservation/flux.jl b/src/conservation/flux.jl index e06ec092e..b93d02a25 100644 --- a/src/conservation/flux.jl +++ b/src/conservation/flux.jl @@ -335,6 +335,11 @@ end export upw_flux function upw_flux(v, l, r) + # Take both branches - could involve sparsity + return ifelse(v > 0, l, r) +end + +function upw_flux(v::JutulReal, l::JutulReal, r::JutulReal) if v > 0 # Flow l -> r out = l @@ -344,11 +349,7 @@ function upw_flux(v, l, r) return out end -function upw_flux(v, l::T, r::T) where {T<:ST.ADval} - if v > 0 - out = l + r*0 - else - out = r + l*0 - end - return out +function upw_flux(v, l::JutulReal, r::JutulReal) + flag = v > 0 + return flag*l + (1.0 - flag)*r end diff --git a/src/core_types/core_types.jl b/src/core_types/core_types.jl index 79e9ab5ab..7fe348b3a 100644 --- a/src/core_types/core_types.jl +++ b/src/core_types/core_types.jl @@ -1622,3 +1622,4 @@ function timestepping_is_done(C::EndTimeTerminationCriterion, simulator, states, return now >= C.end_time end +JutulReal = Union{Float64, ForwardDiff.Dual} diff --git a/src/equations.jl b/src/equations.jl index a293c24d5..8d3718666 100644 --- a/src/equations.jl +++ b/src/equations.jl @@ -205,7 +205,8 @@ function create_equation_caches(model, equations_per_entity, number_of_entities, for (e, epack) in entities is_self = e == self_entity self_entity_found = self_entity_found || is_self - @tic "sparsity detection" S = determine_sparsity(F!, equations_per_entity, state, state0, e, entities, number_of_entities) + num_e = count_entities(model.domain, e) + @tic "sparsity detection" S = determine_sparsity(F!, equations_per_entity, state, state0, num_e, e, entities, number_of_entities) if !isnothing(extra_sparsity) # We have some extra sparsity, need to merge that in S_e = extra_sparsity[entity_as_symbol(e)] diff --git a/test/sparsity.jl b/test/sparsity.jl index da379fb93..f5f5e1833 100644 --- a/test/sparsity.jl +++ b/test/sparsity.jl @@ -6,6 +6,9 @@ using Jutul, Test test_mat = zeros(m, n) test_vec = zeros(n) + # Use NaN to see what values get tagged for testing purposes + internal_vec = collect(-1:-1:-n) + for i in 1:n test_vec[i] = i for j in 1:m @@ -13,21 +16,17 @@ using Jutul, Test end end - vec_st = Jutul.SparsityTracingWrapper(test_vec) + vec_st = Jutul.SparsityTracingWrapper(test_vec, internal_vec) for i in 1:n v = vec_st[i] - @test v isa Jutul.ST.ADval - @test v.derivnode.index == i - @test v.val == test_vec[i] + @test v == test_vec[i]*-i end - mat_st = Jutul.SparsityTracingWrapper(test_mat) + mat_st = Jutul.SparsityTracingWrapper(test_mat, internal_vec) for i in 1:n for j in 1:m v = mat_st[j, i] - @test v isa Jutul.ST.ADval - @test v.derivnode.index == i - @test v.val == test_mat[j, i] + @test v == test_mat[j, i]*-i end end @@ -49,7 +48,7 @@ using Jutul, Test end end end - +## @testset "ad_tags" begin v = allocate_array_ad(1, diag_pos = 1, tag = Cells()) @test Jutul.value(v[1]) isa Float64