diff --git a/src/execution.jl b/src/execution.jl index 39393830..f14324a2 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -46,7 +46,7 @@ end function build_A_aux( ::MatrixRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend ) - (; conditions, backends) = implicit + (; conditions, linear_solver, backends) = implicit (; prep_A) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) @@ -56,7 +56,11 @@ function build_A_aux( else A = jacobian(f, prep_A, actual_backend, y, contexts...) end - return factorize(A) + if linear_solver isa DirectLinearSolver + return factorize(A) + else + return A + end end function build_A_aux( @@ -100,7 +104,7 @@ end function build_Aᵀ_aux( ::MatrixRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend ) - (; conditions, backends) = implicit + (; conditions, linear_solver, backends) = implicit (; prep_Aᵀ) = prep actual_backend = isnothing(backends) ? suggested_backend : backends.y contexts = (Constant(x), Constant(z), map(Constant, args)...) @@ -110,7 +114,11 @@ function build_Aᵀ_aux( else Aᵀ = transpose(jacobian(f, prep_Aᵀ, actual_backend, y, contexts...)) end - return factorize(Aᵀ) + if linear_solver isa DirectLinearSolver + return factorize(Aᵀ) + else + return Aᵀ + end end function build_Aᵀ_aux( diff --git a/test/preparation.jl b/test/preparation.jl index d898b7f9..65a667c1 100644 --- a/test/preparation.jl +++ b/test/preparation.jl @@ -5,9 +5,12 @@ using ADTypes using ADTypes: ForwardOrReverseMode, ForwardMode, ReverseMode using ForwardDiff: ForwardDiff + using LinearAlgebra: Factorization, TransposeFactorization using Zygote: Zygote using Test + const GenericMatrix = Union{AbstractMatrix,Factorization,TransposeFactorization} + solver(x) = sqrt.(x), nothing conditions(x, y, z) = y .^ 2 .- x @@ -41,8 +44,8 @@ @test prep.prep_Aᵀ === nothing @test prep.prep_B !== nothing @test prep.prep_Bᵀ === nothing - @test build_A(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix - @test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix + @test build_A(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix + @test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix @test build_B(implicit, prep, x, y, z, c; suggested_backend) isa JVP @test build_Bᵀ(implicit, prep, x, y, z, c; suggested_backend) isa VJP end @@ -53,8 +56,8 @@ @test prep.prep_Aᵀ !== nothing @test prep.prep_B === nothing @test prep.prep_Bᵀ !== nothing - @test build_A(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix - @test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa AbstractMatrix + @test build_A(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix + @test build_Aᵀ(implicit, prep, x, y, z, c; suggested_backend) isa GenericMatrix @test build_B(implicit, prep, x, y, z, c; suggested_backend) isa JVP @test build_Bᵀ(implicit, prep, x, y, z, c; suggested_backend) isa VJP end diff --git a/test/systematic.jl b/test/systematic.jl index 799b67cf..3d3f6576 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -1,19 +1,19 @@ using TestItems -@testitem "Direct" setup = [TestUtils] begin +@testitem "Matrix" setup = [TestUtils] begin using ADTypes, .TestUtils - for (backends, x) in - Iterators.product([nothing, (; x=AutoForwardDiff(), y=AutoZygote())], [float.(1:3)]) + representation = MatrixRepresentation() + for (linear_solver, backends, x) in Iterators.product( + [DirectLinearSolver(), IterativeLinearSolver()], + [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], + [float.(1:3)], + ) yield() scen = Scenario(; solver=default_solver, conditions=default_conditions, x=x, - implicit_kwargs=(; - representation=MatrixRepresentation(), - linear_solver=DirectLinearSolver(), - backends, - ), + implicit_kwargs=(; representation, linear_solver, backends), ) scen2 = add_arg_mult(scen) test_implicit(scen) @@ -21,15 +21,16 @@ using TestItems end end; -@testitem "Iterative" setup = [TestUtils] begin +@testitem "Operator" setup = [TestUtils] begin using ADTypes, .TestUtils - for (backends, linear_solver, x) in Iterators.product( - [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], + representation = OperatorRepresentation() + for (linear_solver, backends, x) in Iterators.product( [ IterativeLinearSolver(), IterativeLinearSolver(; rtol=1e-8), IterativeLinearSolver(; issymmetric=true, isposdef=true), ], + [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], [float.(1:3), reshape(float.(1:6), 3, 2)], ) yield() @@ -37,9 +38,7 @@ end; solver=default_solver, conditions=default_conditions, x=x, - implicit_kwargs=(; - representation=OperatorRepresentation(), linear_solver, backends - ), + implicit_kwargs=(; representation, linear_solver, backends), ) scen2 = add_arg_mult(scen) test_implicit(scen; type_stability=VERSION >= v"1.11")