Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ end
function build_A_aux(
::MatrixRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend
)
(; conditions, backends) = implicit
(; conditions, linear_solver, backends) = implicit
(; prep_A) = prep
actual_backend = isnothing(backends) ? suggested_backend : backends.y
contexts = (Constant(x), Constant(z), map(Constant, args)...)
Expand All @@ -56,7 +56,11 @@ function build_A_aux(
else
A = jacobian(f, prep_A, actual_backend, y, contexts...)
end
return factorize(A)
if linear_solver isa DirectLinearSolver
return factorize(A)
else
return A
end
end

function build_A_aux(
Expand Down Expand Up @@ -100,7 +104,7 @@ end
function build_Aᵀ_aux(
::MatrixRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend
)
(; conditions, backends) = implicit
(; conditions, linear_solver, backends) = implicit
(; prep_Aᵀ) = prep
actual_backend = isnothing(backends) ? suggested_backend : backends.y
contexts = (Constant(x), Constant(z), map(Constant, args)...)
Expand All @@ -110,7 +114,11 @@ function build_Aᵀ_aux(
else
Aᵀ = transpose(jacobian(f, prep_Aᵀ, actual_backend, y, contexts...))
end
return factorize(Aᵀ)
if linear_solver isa DirectLinearSolver
return factorize(Aᵀ)
else
return Aᵀ
end
end

function build_Aᵀ_aux(
Expand Down
11 changes: 7 additions & 4 deletions test/preparation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
using ADTypes
using ADTypes: ForwardOrReverseMode, ForwardMode, ReverseMode
using ForwardDiff: ForwardDiff
using LinearAlgebra: Factorization, TransposeFactorization
using Zygote: Zygote
using Test

const GenericMatrix = Union{AbstractMatrix,Factorization,TransposeFactorization}

solver(x) = sqrt.(x), nothing
conditions(x, y, z) = y .^ 2 .- x

Expand Down Expand Up @@ -41,8 +44,8 @@
@test prep.prep_Aᵀ === nothing
@test prep.prep_B !== nothing
@test prep.prep_Bᵀ === nothing
@test build_A(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix
@test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix
@test build_A(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix
@test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix
@test build_B(implicit, prep, x, y, z, c; suggested_backend) isa JVP
@test build_Bᵀ(implicit, prep, x, y, z, c; suggested_backend) isa VJP
end
Expand All @@ -53,8 +56,8 @@
@test prep.prep_Aᵀ !== nothing
@test prep.prep_B === nothing
@test prep.prep_Bᵀ !== nothing
@test build_A(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix
@test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix
@test build_A(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix
@test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix
@test build_B(implicit, prep, x, y, z, c; suggested_backend) isa JVP
@test build_Bᵀ(implicit, prep, x, y, z, c; suggested_backend) isa VJP
end
Expand Down
27 changes: 13 additions & 14 deletions test/systematic.jl
Original file line number Diff line number Diff line change
@@ -1,45 +1,44 @@
using TestItems

@testitem "Direct" setup = [TestUtils] begin
@testitem "Matrix" setup = [TestUtils] begin
using ADTypes, .TestUtils
for (backends, x) in
Iterators.product([nothing, (; x=AutoForwardDiff(), y=AutoZygote())], [float.(1:3)])
representation = MatrixRepresentation()
for (linear_solver, backends, x) in Iterators.product(
[DirectLinearSolver(), IterativeLinearSolver()],
[nothing, (; x=AutoForwardDiff(), y=AutoZygote())],
[float.(1:3)],
)
yield()
scen = Scenario(;
solver=default_solver,
conditions=default_conditions,
x=x,
implicit_kwargs=(;
representation=MatrixRepresentation(),
linear_solver=DirectLinearSolver(),
backends,
),
implicit_kwargs=(; representation, linear_solver, backends),
)
scen2 = add_arg_mult(scen)
test_implicit(scen)
test_implicit(scen2)
end
end;

@testitem "Iterative" setup = [TestUtils] begin
@testitem "Operator" setup = [TestUtils] begin
using ADTypes, .TestUtils
for (backends, linear_solver, x) in Iterators.product(
[nothing, (; x=AutoForwardDiff(), y=AutoZygote())],
representation = OperatorRepresentation()
for (linear_solver, backends, x) in Iterators.product(
[
IterativeLinearSolver(),
IterativeLinearSolver(; rtol=1e-8),
IterativeLinearSolver(; issymmetric=true, isposdef=true),
],
[nothing, (; x=AutoForwardDiff(), y=AutoZygote())],
[float.(1:3), reshape(float.(1:6), 3, 2)],
)
yield()
scen = Scenario(;
solver=default_solver,
conditions=default_conditions,
x=x,
implicit_kwargs=(;
representation=OperatorRepresentation(), linear_solver, backends
),
implicit_kwargs=(; representation, linear_solver, backends),
)
scen2 = add_arg_mult(scen)
test_implicit(scen; type_stability=VERSION >= v"1.11")
Expand Down
Loading