Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Jutul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ module Jutul
import ForwardDiff

import DifferentiationInterface: AutoSparse, prepare_jacobian, jacobian, AutoForwardDiff
import SparseConnectivityTracer: TracerLocalSparsityDetector
import SparseConnectivityTracer: TracerLocalSparsityDetector, TracerSparsityDetector, jacobian_sparsity
import SparseMatrixColorings: GreedyColoringAlgorithm

# Timing
Expand Down
130 changes: 69 additions & 61 deletions src/ad/sparsity.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
# function unpack_tag(v::Type{ForwardDiff.Dual{T, F, N}}, t::Symbol = :entity) where {T, F, N}
# @info "???"
# if v isa Tuple
# if t == :entity
# out = T[2]
# elseif t == :model
# out = T[1]
# else
# out = T
# end
# else
# out = T
# end
# return unpack_tag(out)
# end

function unpack_tag(v::Type{ForwardDiff.Dual{T, F, N}}) where {T, F, N}
return unpack_tag(T)
end
Expand All @@ -29,8 +13,9 @@ end
unpack_tag(A::AbstractArray, arg...) = unpack_tag(eltype(A), arg...)
unpack_tag(::Any, arg...) = nothing

struct SparsityTracingWrapper{T, N, D} <: AbstractArray{T, N}
struct SparsityTracingWrapper{T, N, D, AD<:AbstractVector} <: AbstractArray{T, N}
data::Array{D, N}
advec::AD
end

"""
Expand All @@ -41,8 +26,9 @@ produces outputs that have the same value as the wrapped type, but contains a
SparsityTracing seeded value with seed equal to the column index (if matrix) or
linear index (if vector).
"""
function SparsityTracingWrapper(x::AbstractArray{T, N}) where {T, N}
return SparsityTracingWrapper{Float64, N, T}(x)
function SparsityTracingWrapper(x::AbstractArray{T, N}, advec) where {T, N}
size(x)[end] == length(advec) || error("Length of advec ($(length(advec)) must match last dimension of x (size(x)=$(size(x)))")
return SparsityTracingWrapper{Float64, N, T, typeof(advec)}(x, advec)
end

Base.parent(A::SparsityTracingWrapper) = A.data
Expand All @@ -65,7 +51,7 @@ function Base.getindex(A::SparsityTracingWrapper{T, D, <:Any}, I, J) where {T, D
if J isa Colon
J = axes(A, 2)
end
Ts = Jutul.ST.ADval{T}
Ts = eltype(A.advec)
n = length(I)
m = length(J)
out = Matrix{Ts}(undef, n, m)
Expand All @@ -78,42 +64,44 @@ function Base.getindex(A::SparsityTracingWrapper{T, D, <:Any}, I, J) where {T, D
end

function Base.getindex(A::SparsityTracingWrapper, i::Int, j::Int)
return as_tracer(A.data[i, j], j)
return traced_value(A.data[i, j], A, j)
# return value(A.data[i, j])*A.advec[j]
end

function Base.getindex(A::SparsityTracingWrapper{T, 2, D}, ix::Int) where {T, D}
n, m = size(A)
m = size(A)[2]
zero_ix = ix - 1
i = (zero_ix ÷ m) + 1
j = mod(zero_ix, m) + 1
return as_tracer(A.data[i, j], j)
return traced_value(A.data[i, j], A, j)
# return value(A.data[i, j])*A.advec[j]
end

function Base.getindex(A::SparsityTracingWrapper{T, 1, D}, i::Int) where {T, D}
return as_tracer(A.data[i], i)
end

function as_tracer(x::Real, leaf_init)
return Jutul.ST.ADval{Float64}(value(x), Jutul.ST.DerivLeaf(leaf_init))
return traced_value(A.data[i], A, i)
end

function as_tracer(x, leaf_init)
return x
function traced_value(baseval, A, idx)
# bval = value(baseval)
return value(baseval)*A.advec[idx]
# return sign(bval)*(abs(bval) + idx*1e-3)*A.advec[idx]
end

function create_mock_state(state, tag, entities = ad_entities(state); subkeys = nothing)
function create_mock_state(state, tag, X_tracer::AbstractVector; subkeys = nothing)
no_provided_keys = isnothing(subkeys)
mock_state = JutulStorage()
n = entities[tag].n
for k in keys(state)
v = state[k]
tag_matches = unpack_tag(v) == tag
key_matches = no_provided_keys || (k in subkeys)
if tag_matches && key_matches
# Assign mock value with tracer
new_v = SparsityTracingWrapper(v)
new_v = SparsityTracingWrapper(v, X_tracer)
elseif eltype(v) <: Real
n = size(v)[end]
new_v = SparsityTracingWrapper(v, ones(n))
else
# Assign mock value as doubles
# Probably not a numeric array, just use value wrapper
new_v = as_value(v)
end
mock_state[k] = new_v
Expand Down Expand Up @@ -143,20 +131,38 @@ function ad_entities(state)
return out
end

function determine_sparsity(F!, n, state, state0, tag, entities, N = entities[tag].n)
mstate = create_mock_state(state, tag, entities)
mstate0 = create_mock_state(state0, tag, entities)

out = ST.create_advec(zeros(n))
J = [Vector{Int64}() for i in 1:N]
for i in 1:N
@inbounds F!(out, mstate, mstate0, i)
# Take the sum over all return values to reduce to scalar.
# This should accumulate the full "entity" pattern if some
# equations have a different stencil.
V = sum(out)
D = ST.deriv(V)
J[i] = D.nzind
function determine_sparsity(F!, n, state, state0, count_of_tag, tag, entities, N = entities[tag].n)
# n: number of equations per entity
# N: number of entities where the equation lies (output size)
# count_of_tag: number of variables with the given tag (input size)
function x_to_evaluated(X::AbstractVector{T}) where T
out = zeros(T, N)
eq_buf = zeros(T, n)
mstate = create_mock_state(state, tag, X)
mstate0 = create_mock_state(state0, tag, X)

for i in 1:N
@inbounds F!(eq_buf, mstate, mstate0, i)
# Take the sum over all return values to reduce to scalar.
# This should accumulate the full "entity" pattern if some
# equations have a different stencil.
v = zero(T)
for j in 1:n
v += abs(eq_buf[j])
end
out[i] = v
end
return out
end

dtct = TracerSparsityDetector()
# dtct = TracerLocalSparsityDetector()
S = jacobian_sparsity(x_to_evaluated, ones(count_of_tag), dtct)

J = [Vector{Int64}() for _ in 1:N]
rows, cols, = findnz(S)
for (row, col) in zip(rows, cols)
push!(J[row], col)
end
return J
end
Expand All @@ -175,21 +181,23 @@ function determine_sparsity_simple(F, model, state, state0 = nothing; variant =
end
end
for (k, v) in entities
mstate = create_mock_state(state, k, entities, subkeys = subkeys)
if isnothing(state0)
f_ad = F(mstate)
else
mstate0 = create_mock_state(state0, k, entities, subkeys = subkeys)
f_ad = F(mstate, mstate0)
end
V = sum(f_ad)
if V isa AbstractFloat || V isa Integer
S = zeros(Int64, 0)
else
D = ST.deriv(V)
S = D.nzind
function trace_entity(X)
mstate = create_mock_state(state, k, X, subkeys = subkeys)
if isnothing(state0)
f_ad = F(mstate)
else
mstate0 = create_mock_state(state0, k, X, subkeys = subkeys)
f_ad = F(mstate, mstate0)
end
return sum(f_ad)
end
ne = count_entities(model.domain, k)
# dtct = TracerLocalSparsityDetector()
dtct = TracerSparsityDetector()
js = jacobian_sparsity(trace_entity, ones(ne), dtct)
S = findnz(js)[2]
sparsity[k] = S
end

return sparsity
end
15 changes: 8 additions & 7 deletions src/conservation/flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,11 @@ end

export upw_flux
function upw_flux(v, l, r)
# Take both branches - could involve sparsity
return ifelse(v > 0, l, r)
end

function upw_flux(v::JutulReal, l::JutulReal, r::JutulReal)
if v > 0
# Flow l -> r
out = l
Expand All @@ -344,11 +349,7 @@ function upw_flux(v, l, r)
return out
end

function upw_flux(v, l::T, r::T) where {T<:ST.ADval}
if v > 0
out = l + r*0
else
out = r + l*0
end
return out
function upw_flux(v, l::JutulReal, r::JutulReal)
flag = v > 0
return flag*l + (1.0 - flag)*r
end
1 change: 1 addition & 0 deletions src/core_types/core_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1622,3 +1622,4 @@ function timestepping_is_done(C::EndTimeTerminationCriterion, simulator, states,
return now >= C.end_time
end

JutulReal = Union{Float64, ForwardDiff.Dual}
3 changes: 2 additions & 1 deletion src/equations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@ function create_equation_caches(model, equations_per_entity, number_of_entities,
for (e, epack) in entities
is_self = e == self_entity
self_entity_found = self_entity_found || is_self
@tic "sparsity detection" S = determine_sparsity(F!, equations_per_entity, state, state0, e, entities, number_of_entities)
num_e = count_entities(model.domain, e)
@tic "sparsity detection" S = determine_sparsity(F!, equations_per_entity, state, state0, num_e, e, entities, number_of_entities)
if !isnothing(extra_sparsity)
# We have some extra sparsity, need to merge that in
S_e = extra_sparsity[entity_as_symbol(e)]
Expand Down
17 changes: 8 additions & 9 deletions test/sparsity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,27 @@ using Jutul, Test
test_mat = zeros(m, n)
test_vec = zeros(n)

# Use NaN to see what values get tagged for testing purposes
internal_vec = collect(-1:-1:-n)

for i in 1:n
test_vec[i] = i
for j in 1:m
test_mat[j, i] = i + (j-1)*n
end
end

vec_st = Jutul.SparsityTracingWrapper(test_vec)
vec_st = Jutul.SparsityTracingWrapper(test_vec, internal_vec)
for i in 1:n
v = vec_st[i]
@test v isa Jutul.ST.ADval
@test v.derivnode.index == i
@test v.val == test_vec[i]
@test v == test_vec[i]*-i
end

mat_st = Jutul.SparsityTracingWrapper(test_mat)
mat_st = Jutul.SparsityTracingWrapper(test_mat, internal_vec)
for i in 1:n
for j in 1:m
v = mat_st[j, i]
@test v isa Jutul.ST.ADval
@test v.derivnode.index == i
@test v.val == test_mat[j, i]
@test v == test_mat[j, i]*-i
end
end

Expand All @@ -49,7 +48,7 @@ using Jutul, Test
end
end
end

##
@testset "ad_tags" begin
v = allocate_array_ad(1, diag_pos = 1, tag = Cells())
@test Jutul.value(v[1]) isa Float64
Expand Down
Loading