From 0720b6e04aaba5b26e7da4f55f7f72641ec23206 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 11 Jun 2025 09:50:59 +0200 Subject: [PATCH 1/5] test!: better tests --- docs/src/api.md | 6 +- src/ImplicitDifferentiation.jl | 3 +- src/implicit_function.jl | 57 ++------ src/settings.jl | 55 +++----- test/examples.jl | 8 +- test/formalities.jl | 4 + test/preparation.jl | 62 +++++++++ test/runtests.jl | 6 + test/systematic.jl | 122 ++++++++++------- test/utils.jl | 235 ++++++++++++++++++--------------- 10 files changed, 313 insertions(+), 245 deletions(-) create mode 100644 test/preparation.jl diff --git a/docs/src/api.md b/docs/src/api.md index 8eb6ba27..ce95c90e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -19,13 +19,9 @@ ImplicitFunction ### Settings ```@docs -IterativeLinearSolver MatrixRepresentation OperatorRepresentation -NoPreparation -ForwardPreparation -ReversePreparation -BothPreparation +IterativeLinearSolver ``` ## Internals diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 0e044d48..c7ea12a3 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -33,9 +33,8 @@ include("preparation.jl") include("implicit_function.jl") include("execution.jl") -export IterativeLinearSolver export MatrixRepresentation, OperatorRepresentation -export NoPreparation, ForwardPreparation, ReversePreparation, BothPreparation +export IterativeLinearSolver export ImplicitFunction end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index faeaa626..531078c6 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -22,7 +22,7 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ conditions; representation=OperatorRepresentation(), linear_solver=IterativeLinearSolver(), - backend=nothing, + backends=nothing, preparation=nothing, input_example=nothing, ) @@ -30,19 +30,16 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂ ## Positional arguments - `solver`: a callable returning `(x, args...) -> (y, z)` where `z` is an arbitrary byproduct of the solve. Both `x` and `y` must be subtypes of `AbstractArray`, while `z` and `args` can be anything. -- `conditions`: a callable returning a vector of optimality conditions `(x, y, z, args...) -> c`, must be compatible with automatic differentiation +- `conditions`: a callable returning a vector of optimality conditions `(x, y, z, args...) -> c`, must be compatible with automatic differentiation. ## Keyword arguments -- `representation`: either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref) -- `linear_solver`: a callable to solve linear systems with two required methods, one for `(A, b)` (single solve) and one for `(A, B)` (batched solve). It defaults to [`IterativeLinearSolver`](@ref) but can also be the built-in `\\`, or a user-provided function. -- `backend::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either - - `nothing`, which means that the external autodiff system will be used - - a single object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) - - a named tuple `(; x, y)` of objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl) +- `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented, either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref). +- `linear_solver`: a callable to solve linear systems with two required methods, one for `(A, b::AbstractVector)` (single solve) and one for `(A, B::AbstractMatrix)` (batched solve). It defaults to [`IterativeLinearSolver`](@ref) but can also be the built-in `\\`, or a user-provided function. +- `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl). - `preparation`: either `nothing` or a mode object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl): `ADTypes.ForwardMode()`, `ADTypes.ReverseMode()` or `ADTypes.ForwardOrReverseMode()`. - `input_example`: either `nothing` or a tuple `(x, args...)` used to prepare differentiation. -- `strict::Val=Val(true)`: whether or not to enforce a strict match in [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) between the preparation and the execution types. Relaxing this to `strict=Val(false)` can prove necessary when working with custom array types like ComponentArrays.jl, which are not always compatible with iterative linear solvers. +- `strict::Val=Val(true)`: whether or not to enforce a strict match in [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) between the preparation and the execution types. """ struct ImplicitFunction{ F, @@ -51,7 +48,6 @@ struct ImplicitFunction{ R<:AbstractRepresentation, B<:Union{ Nothing, # - AbstractADType, NamedTuple{(:x, :y),<:Tuple{AbstractADType,AbstractADType}}, }, P<:Union{Nothing,AbstractMode}, @@ -90,32 +86,15 @@ function ImplicitFunction( prep_B = nothing prep_Bᵀ = nothing else - real_backends = backends isa AbstractADType ? (; x=backends, y=backends) : backends x, args = first(input_example), Base.tail(input_example) y, z = solver(x, args...) c = conditions(x, y, z, args...) if preparation isa Union{ForwardMode,ForwardOrReverseMode} prep_A = prepare_A( - representation, - x, - y, - z, - c, - args...; - conditions, - backend=real_backends.y, - strict, + representation, x, y, z, c, args...; conditions, backend=backends.y, strict ) prep_B = prepare_B( - representation, - x, - y, - z, - c, - args...; - conditions, - backend=real_backends.x, - strict, + representation, x, y, z, c, args...; conditions, backend=backends.x, strict ) else prep_A = nothing @@ -123,26 +102,10 @@ function ImplicitFunction( end if preparation isa Union{ReverseMode,ForwardOrReverseMode} prep_Aᵀ = prepare_Aᵀ( - representation, - x, - y, - z, - c, - args...; - conditions, - backend=real_backends.y, - strict, + representation, x, y, z, c, args...; conditions, backend=backends.y, strict ) prep_Bᵀ = prepare_Bᵀ( - representation, - x, - y, - z, - c, - args...; - conditions, - backend=real_backends.x, - strict, + representation, x, y, z, c, args...; conditions, backend=backends.x, strict ) else prep_Aᵀ = nothing diff --git a/src/settings.jl b/src/settings.jl index dbfb876c..ca791114 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -7,11 +7,12 @@ Callable object that can solve linear systems `Ax = b` and `AX = B` in the same # Constructor + IterativeLinearSolver(; kwargs...) IterativeLinearSolver{package}(; kwargs...) The type parameter `package` can be either: -- `:Krylov` to use the solver `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) +- `:Krylov` to use the solver `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) (the default) - `:IterativeSolvers` to use the solver `gmres` from [IterativeSolvers.jl](https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl) Keyword arguments are passed on to the respective solver. @@ -34,7 +35,15 @@ struct IterativeLinearSolver{package,K} end end -IterativeLinearSolver() = IterativeLinearSolver{:Krylov}() +function Base.show(io::IO, linear_solver::IterativeLinearSolver{package}) where {package} + print(io, "IterativeLinearSolver{$(repr(package))}(; ") + for (k, v) in pairs(linear_solver.kwargs) + print(io, "$k=$v, ") + end + return print(io, ")") +end + +IterativeLinearSolver(; kwargs...) = IterativeLinearSolver{:Krylov}(; kwargs...) function (solver::IterativeLinearSolver{:Krylov})(A, b::AbstractVector) x, stats = Krylov.gmres(A, b; solver.kwargs...) @@ -85,6 +94,7 @@ Specify that the matrix `A` involved in the implicit function theorem should be # Constructors + OperatorRepresentation(; symmetric=false, hermitian=false) OperatorRepresentation{package}(; symmetric=false, hermitian=false) The type parameter `package` can be either: @@ -108,38 +118,15 @@ struct OperatorRepresentation{package,symmetric,hermitian} <: AbstractRepresenta end end -OperatorRepresentation() = OperatorRepresentation{:LinearOperators}() - -## Preparation - -abstract type AbstractPreparation end - -""" - ForwardPreparation - -Specify that the derivatives of the conditions should be prepared for subsequent forward-mode differentiation of the implicit function. -""" -struct ForwardPreparation <: AbstractPreparation end - -""" - ReversePreparation - -Specify that the derivatives of the conditions should be prepared for subsequent reverse-mode differentiation of the implicit function. -""" -struct ReversePreparation <: AbstractPreparation end - -""" - BothPreparation - -Specify that the derivatives of the conditions should be prepared for subsequent forward- or reverse-mode differentiation of the implicit function. -""" -struct BothPreparation <: AbstractPreparation end - -""" - NoPreparation +function Base.show( + io::IO, ::OperatorRepresentation{package,symmetric,hermitian} +) where {package,symmetric,hermitian} + return print( + io, + "OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian)", + ) +end -Specify that the derivatives of the conditions should not be prepared for subsequent differentiation of the implicit function. -""" -struct NoPreparation <: AbstractPreparation end +OperatorRepresentation(; kwargs...) = OperatorRepresentation{:LinearOperators}(; kwargs...) function chainrules_suggested_backend end diff --git a/test/examples.jl b/test/examples.jl index 117b5e78..d62129ba 100644 --- a/test/examples.jl +++ b/test/examples.jl @@ -1,15 +1,15 @@ -@testitem "intro" begin +@testitem "Intro" begin include(joinpath(dirname(@__DIR__), "examples", "0_intro.jl")) end -@testitem "basic" begin +@testitem "Basic" begin include(joinpath(dirname(@__DIR__), "examples", "1_basic.jl")) end -@testitem "advanced" begin +@testitem "Advanced" begin include(joinpath(dirname(@__DIR__), "examples", "2_advanced.jl")) end -@testitem "tricks" begin +@testitem "Tricks" begin include(joinpath(dirname(@__DIR__), "examples", "3_tricks.jl")) end diff --git a/test/formalities.jl b/test/formalities.jl index 09e8b2f2..16770760 100644 --- a/test/formalities.jl +++ b/test/formalities.jl @@ -6,16 +6,19 @@ using TestItems using Zygote: Zygote Aqua.test_all(ImplicitDifferentiation; ambiguities=false, undocumented_names=true) end + @testitem "Formatting" begin using JuliaFormatter @test format(ImplicitDifferentiation; verbose=false, overwrite=false) end + @testitem "Static checking" begin using JET using ForwardDiff: ForwardDiff using Zygote: Zygote JET.test_package(ImplicitDifferentiation; target_defined_modules=true) end + @testitem "Imports" begin using ExplicitImports using ForwardDiff: ForwardDiff @@ -27,6 +30,7 @@ end @test check_all_qualified_accesses_via_owners(ImplicitDifferentiation) === nothing @test check_no_self_qualified_accesses(ImplicitDifferentiation) === nothing end + @testitem "Doctests" begin using Documenter Documenter.DocMeta.setdocmeta!( diff --git a/test/preparation.jl b/test/preparation.jl new file mode 100644 index 00000000..d09f95fa --- /dev/null +++ b/test/preparation.jl @@ -0,0 +1,62 @@ +@testitem "Preparation" begin + using ImplicitDifferentiation + using ADTypes: AutoForwardDiff, ForwardOrReverseMode, ForwardMode, ReverseMode + using ForwardDiff: ForwardDiff + using Zygote: Zygote + using Test + + solver(x) = sqrt.(x), nothing + conditions(x, y, z) = y .^ 2 .- x + x = rand(5) + input_example = (x,) + + @testset "None" begin + implicit_nones = ImplicitFunction(solver, conditions) + @test implicit_none.prep_A === nothing + @test implicit_none.prep_Aᵀ === nothing + @test implicit_none.prep_B === nothing + @test implicit_none.prep_Bᵀ === nothing + end + + @testset "ForwardMode" begin + implicit_forward = ImplicitFunction( + solver, + conditions; + preparation=ForwardMode(), + backends=(; x=AutoForwardDiff(), y=AutoForwardDiff()), + input_example, + ) + @test implicit_forward.prep_A !== nothing + @test implicit_forward.prep_Aᵀ === nothing + @test implicit_forward.prep_B !== nothing + @test implicit_forward.prep_Bᵀ === nothing + end + + @testset "ReverseMode" begin + implicit_reverse = ImplicitFunction( + solver, + conditions; + preparation=ReverseMode(), + backends=(; x=AutoZygote(), y=AutoZygote()), + input_example, + ) + @test implicit_reverse.prep_A === nothing + @test implicit_reverse.prep_Aᵀ !== nothing + @test implicit_reverse.prep_B === nothing + @test implicit_reverse.prep_Bᵀ !== nothing + end + + @testset "Both" begin + implicit_both = ImplicitFunction( + solver, + conditions; + preparation=ForwardOrReverseMode(), + backends=(; x=AutoForwardDiff(), y=AutoZygote()), + input_example, + ) + @test implicit_both.prep_A !== nothing + @test implicit_both.prep_Aᵀ !== nothing + @test implicit_both.prep_B !== nothing + @test implicit_both.prep_Bᵀ !== nothing + end +end diff --git a/test/runtests.jl b/test/runtests.jl index b9e874db..47e07c9d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,9 @@ using TestItemRunner +@testmodule TestUtils begin + include("utils.jl") + export Scenario, test_implicit, add_arg_mult + export default_solver, default_conditions +end + @run_package_tests diff --git a/test/systematic.jl b/test/systematic.jl index 71b5a4a1..add1f344 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -1,54 +1,80 @@ using TestItems -@testitem "Systematic tests" begin - using ADTypes - using ADTypes: ForwardMode, ReverseMode - using ForwardDiff: ForwardDiff - using ImplicitDifferentiation - using Test - using Zygote: Zygote, ZygoteRuleConfig - using FiniteDiff: FiniteDiff - - include("utils.jl") - - ## Parameter combinations - - representation_linear_solver_candidates = [ - (MatrixRepresentation(), \), # - (OperatorRepresentation{:LinearOperators}(), IterativeLinearSolver{:Krylov}()), - (OperatorRepresentation{:LinearMaps}(), IterativeLinearSolver{:IterativeSolvers}()), - ] - backend_candidates = [nothing, AutoForwardDiff(), AutoZygote()] - preparation_candidates = [nothing, ForwardMode(), ReverseMode()] - x_candidates = [float.(1:6), reshape(float.(1:12), 6, 2)] +@testitem "Matrix" setup = [TestUtils] begin + using ADTypes, .TestUtils + for (backends, preparation, x) in Iterators.product( + [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], + [nothing, ADTypes.ForwardOrReverseMode()], + [float.(1:3), reshape(float.(1:6), 3, 2)], + ) + yield() + scen = Scenario(; + solver=default_solver, + conditions=default_conditions, + x=x, + implicit_kwargs=(; + representation=MatrixRepresentation(), + linear_solver=\, + backends, + preparation, + input_example=(x,), + ), + ) + scen2 = add_arg_mult(scen) + @info "$scen" + test_implicit(scen) + test_implicit(scen2) + end +end; - ## Test loop +@testitem "Krylov" setup = [TestUtils] begin + using ADTypes, .TestUtils + for (backends, preparation, x) in Iterators.product( + [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], + [nothing, ADTypes.ForwardOrReverseMode()], + [float.(1:3), reshape(float.(1:6), 3, 2)], + ) + yield() + scen = Scenario(; + solver=default_solver, + conditions=default_conditions, + x=x, + implicit_kwargs=(; + representation=OperatorRepresentation{:LinearOperators}(), + linear_solver=IterativeLinearSolver{:Krylov}(), + backends, + preparation, + input_example=(x,), + ), + ) + @info "$scen" + scen2 = add_arg_mult(scen) + test_implicit(scen) + test_implicit(scen2) + end +end; - @testset for (representation, linear_solver) in representation_linear_solver_candidates - for (backend, preparation, x) in - Iterators.product(backend_candidates, preparation_candidates, x_candidates) - x_type = typeof(x) - @info "Testing" linear_solver backend representation preparation x_type - if (representation isa OperatorRepresentation && linear_solver == \) - continue - end - outer_backends = [AutoForwardDiff(), AutoZygote()] - x = Float64.(1:6) - @testset "$((; linear_solver, backend, preparation, x_type))" begin - test_implicit( - outer_backends, - x; - representation, - backends=isnothing(backend) ? nothing : (; x=backend, y=backend), - preparation, - linear_solver, - strict=if linear_solver isa IterativeLinearSolver{:IterativeSolvers} - Val(false) - else - Val(true) - end, - ) - end - end +@testitem "IterativeSolvers" setup = [TestUtils] begin + using ADTypes, .TestUtils + for (backends, preparation, x) in Iterators.product( + [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], + [nothing, ADTypes.ForwardOrReverseMode()], + [float.(1:3), reshape(float.(1:6), 3, 2)], + ) + yield() + scen = Scenario(; + solver=default_solver, + conditions=default_conditions, + x=x, + implicit_kwargs=(; + representation=OperatorRepresentation{:LinearMap}(), + linear_solver=IterativeLinearSolver{:IterativeSolvers}(), + backends, + preparation, + input_example=(x,), + ), + ) + @info "$scen" + test_implicit(scen) end end; diff --git a/test/utils.jl b/test/utils.jl index 5dd30d9b..e95a1c87 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,16 +1,37 @@ using ADTypes using ChainRulesCore using ChainRulesTestUtils -using DifferentiationInterface: DifferentiationInterface +import DifferentiationInterface as DI using ForwardDiff: ForwardDiff import ImplicitDifferentiation as ID using ImplicitDifferentiation: ImplicitFunction using JET using LinearAlgebra +using Random: rand! using Test using Zygote: Zygote, ZygoteRuleConfig -## +@kwdef struct Scenario{S,C,X,A,K} + solver::S + conditions::C + x::X + args::A = () + implicit_kwargs::K = (;) +end + +function Base.show(io::IO, scen::Scenario) + print( + io, + "Scenario(; solver=$(scen.solver), conditions=$(scen.conditions), x::$(typeof(scen.x))", + ) + if !isempty(scen.args) + print(io, ", args::$(typeof(scen.args))") + end + if !isempty(scen.implicit_kwargs) + print(io, ", implicit_kwargs::$(typeof(scen.implicit_kwargs))") + end + return print(io, ")") +end function identity_break_autodiff(x::AbstractArray{R}) where {R} float(first(x)) # break ForwardDiff @@ -18,134 +39,138 @@ function identity_break_autodiff(x::AbstractArray{R}) where {R} result = try throw(copy(x)) catch y - y + y # presumably break Enzyme end return result end -mysqrt(x::AbstractArray) = identity_break_autodiff(sqrt.(x)) - -## Various signatures - -function make_implicit_sqrt_byproduct(x; kwargs...) - forward(x) = 1 .* vcat(mysqrt(x), -mysqrt(x)), 1 - conditions(x, y, z) = abs2.(y ./ z) .- vcat(x, x) - input_example = (copy(x),) - implicit = ImplicitFunction(forward, conditions; input_example, kwargs...) - return implicit +struct NonDifferentiable{S} + solver::S end -function make_implicit_sqrt_args(x; kwargs...) - forward(x, p) = p .* vcat(mysqrt(x), -mysqrt(x)), nothing - conditions(x, y, z, p) = abs2.(y ./ p) .- vcat(x, x) - input_example = (copy(x), 2) - implicit = ImplicitFunction(forward, conditions; input_example, kwargs...) - return implicit -end +(nd::NonDifferentiable)(x, args...) = nd.solver(identity_break_autodiff(x), args...) -function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T} - imf1 = make_implicit_sqrt_byproduct(x; kwargs...) - imf2 = make_implicit_sqrt_args(x; kwargs...) - - y_true = vcat(mysqrt(x), -mysqrt(x)) - y1, z1 = imf1(x) - y2, z2 = imf2(x, 3) - - @testset "Primal value" begin - @test y1 ≈ y_true - @test y2 ≈ 3y_true - @test z1 == 1 - @test z2 === nothing +function add_arg_mult(scen::Scenario, a=3) + @assert isempty(scen.args) + function solver_with_arg_mult(x, a) + y, z = scen.solver(x) + return y .* a, z end + function conditions_with_arg_mult(x, y, z, a) + return scen.conditions(x, y ./ a, z) + end + implicit_kwargs_with_arg_mult = NamedTuple( + Dict(k => if k == :input_example + (only(v), a) + else + v + end for (k, v) in pairs(scen.implicit_kwargs)) + ) + + return Scenario(; + solver=solver_with_arg_mult, + conditions=conditions_with_arg_mult, + x=scen.x, + args=(a,), + implicit_kwargs=implicit_kwargs_with_arg_mult, + ) end -tag(::AbstractArray{<:ForwardDiff.Dual{T}}) where {T} = T +function test_implicit_call(scen::Scenario) + implicit = ImplicitFunction( + NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... + ) + y, z = implicit(scen.x, scen.args...) + y_true, z_true = scen.solver(scen.x, scen.args...) -function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T} - imf1 = make_implicit_sqrt_byproduct(x; kwargs...) - imf2 = make_implicit_sqrt_args(x; kwargs...) - - y_true = vcat(mysqrt(x), -mysqrt(x)) - dx = similar(x) - dx .= 2 * one(T) - x_and_dx = ForwardDiff.Dual.(x, dx) - - y_and_dy1, z1 = imf1(x_and_dx) - y_and_dy2, z2 = imf2(x_and_dx, 3) - - @testset "Dual numbers" begin - @test ForwardDiff.value.(y_and_dy1) ≈ y_true - @test ForwardDiff.value.(y_and_dy2) ≈ 3y_true - @test ForwardDiff.extract_derivative(tag(y_and_dy1), y_and_dy1) ≈ - 2 .* inv.(2 .* vcat(sqrt.(x), -sqrt.(x))) - @test ForwardDiff.extract_derivative(tag(y_and_dy2), y_and_dy2) ≈ - 3 .* 2 .* inv.(2 .* vcat(sqrt.(x), -sqrt.(x))) - @test z1 == 1 - @test z2 === nothing + @testset "Call" begin + @test y ≈ y_true + @test z == z_true end end -function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} - imf1 = make_implicit_sqrt_byproduct(x; kwargs...) - imf2 = make_implicit_sqrt_args(x; kwargs...) - - y_true = vcat(mysqrt(x), -mysqrt(x)) - dy = zero(y_true) - dy[1:(end ÷ 2)] .= one(eltype(y_true)) - dz = nothing - - (y1, z1), pb1 = rrule(rc, imf1, x) - (y2, z2), pb2 = rrule(rc, imf2, x, 3) - - dimf1, dx1 = pb1((dy, dz)) - dimf2, dx2, dp2 = pb2((dy, dz)) - - @testset "Pullbacks" begin - @test y1 ≈ y_true - @test y2 ≈ 3y_true - @test z1 == 1 - @test z2 === nothing +tag(::AbstractArray{<:ForwardDiff.Dual{T}}) where {T} = T - @test dimf1 isa NoTangent - @test dimf2 isa NoTangent +function test_implicit_duals(scen::Scenario) + implicit = ImplicitFunction( + NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... + ) + + dx = similar(scen.x) + rand!(dx) + x_and_dx = ForwardDiff.Dual.(scen.x, dx) + + y_and_dy, z = implicit(x_and_dx, scen.args...) + T = tag(y_and_dy) + y = ForwardDiff.value.(y_and_dy) + dy = ForwardDiff.extract_derivative.(T, y_and_dy) + + y_true, z_true = scen.solver(scen.x, scen.args...) + dy_true = DI.pushforward( + first ∘ scen.solver, + AutoForwardDiff(), + scen.x, + (dx,), + map(DI.Constant, scen.args)..., + )[1] - @test dx2 ≈ 3 .* dx1 - @test dp2 isa ChainRulesCore.NotImplemented + @testset "Duals" begin + @test y ≈ y_true + @test dy ≈ dy_true + @test z == z_true end end -## High-level tests per backend +function test_implicit_rrule(scen::Scenario) + implicit = ImplicitFunction( + NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... + ) + y_true, z_true = scen.solver(scen.x, scen.args...) -function test_implicit_backend( - outer_backend::ADTypes.AbstractADType, x::AbstractArray{T}; kwargs... -) where {T} - imf1 = make_implicit_sqrt_byproduct(x; kwargs...) - imf2 = make_implicit_sqrt_args(x; kwargs...) + dy = similar(y_true) + rand!(dy) + dz = NoTangent() + (y, z), pb = rrule(ZygoteRuleConfig(), implicit, scen.x, scen.args...) + dimpl, dx = pb((dy, dz)) - J1 = DifferentiationInterface.jacobian(first ∘ imf1, outer_backend, x) - J2 = DifferentiationInterface.jacobian(_x -> (first ∘ imf2)(_x, 3), outer_backend, x) + dx_true = DI.pullback( + first ∘ scen.solver, AutoZygote(), scen.x, (dy,), map(DI.Constant, scen.args)... + )[1] - J_true = ForwardDiff.jacobian(_x -> vcat(sqrt.(_x), -sqrt.(_x)), x) - - @testset "Exact Jacobian" begin - @test J1 ≈ J_true - @test J2 ≈ 3 .* J_true + @testset "ChainRule" begin + @test y ≈ y_true + @test z == z_true + @test dimpl isa NoTangent + @test dx ≈ dx_true end - return nothing end -function test_implicit(outer_backends, x; kwargs...) - @testset "Call" begin - test_implicit_call(x; kwargs...) +function test_implicit_jacobian(scen::Scenario, outer_backend::AbstractADType) + implicit = ImplicitFunction( + NonDifferentiable(scen.solver), scen.conditions; scen.implicit_kwargs... + ) + jac = DI.jacobian( + first ∘ implicit, outer_backend, scen.x, map(DI.Constant, scen.args)... + ) + jac_true = DI.jacobian( + first ∘ scen.solver, outer_backend, scen.x, map(DI.Constant, scen.args)... + ) + + @testset "Jacobian - $outer_backend" begin + @test jac ≈ jac_true end - @testset "Duals" begin - test_implicit_duals(x; kwargs...) - end - @testset "ChainRule" begin - test_implicit_rrule(ZygoteRuleConfig(), x; kwargs...) - end - @testset "Jacobian - $outer_backend" for outer_backend in outer_backends - test_implicit_backend(outer_backend, x; kwargs...) +end + +function test_implicit(scen::Scenario, outer_backends=[AutoForwardDiff(), AutoZygote()]) + @testset "$scen" begin + test_implicit_call(scen) + test_implicit_duals(scen) + test_implicit_rrule(scen) + for outer_backend in outer_backends + test_implicit_jacobian(scen, outer_backend) + end end - return nothing end + +default_solver(x) = vcat(sqrt.(x .+ 2), -sqrt.(x)), 2 +default_conditions(x, y, z) = abs2.(y) .- vcat(x .+ z, x) From dac0fd059e1aa9448e581ee241e8c4e793d1e6b8 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 11 Jun 2025 10:09:20 +0200 Subject: [PATCH 2/5] Fix --- test/systematic.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/systematic.jl b/test/systematic.jl index add1f344..568d7d50 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -21,7 +21,6 @@ using TestItems ), ) scen2 = add_arg_mult(scen) - @info "$scen" test_implicit(scen) test_implicit(scen2) end @@ -47,7 +46,6 @@ end; input_example=(x,), ), ) - @info "$scen" scen2 = add_arg_mult(scen) test_implicit(scen) test_implicit(scen2) @@ -67,14 +65,13 @@ end; conditions=default_conditions, x=x, implicit_kwargs=(; - representation=OperatorRepresentation{:LinearMap}(), + representation=OperatorRepresentation{:LinearMaps}(), linear_solver=IterativeLinearSolver{:IterativeSolvers}(), backends, preparation, input_example=(x,), ), ) - @info "$scen" test_implicit(scen) end end; From dbf08f5006c4d29e0615ec798f178801be802d39 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 11 Jun 2025 10:30:15 +0200 Subject: [PATCH 3/5] Fix --- test/preparation.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/preparation.jl b/test/preparation.jl index d09f95fa..296a38c5 100644 --- a/test/preparation.jl +++ b/test/preparation.jl @@ -1,6 +1,7 @@ @testitem "Preparation" begin using ImplicitDifferentiation - using ADTypes: AutoForwardDiff, ForwardOrReverseMode, ForwardMode, ReverseMode + using ADTypes + using ADTypes: ForwardOrReverseMode, ForwardMode, ReverseMode using ForwardDiff: ForwardDiff using Zygote: Zygote using Test From 020f6e73092064f844a70dfdb5eecd429b0226b4 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 11 Jun 2025 11:12:06 +0200 Subject: [PATCH 4/5] Fix --- test/preparation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/preparation.jl b/test/preparation.jl index 296a38c5..b4b2e2c9 100644 --- a/test/preparation.jl +++ b/test/preparation.jl @@ -12,7 +12,7 @@ input_example = (x,) @testset "None" begin - implicit_nones = ImplicitFunction(solver, conditions) + implicit_none = ImplicitFunction(solver, conditions) @test implicit_none.prep_A === nothing @test implicit_none.prep_Aᵀ === nothing @test implicit_none.prep_B === nothing From f34e63459a6f1f4921221337f844d430e6369b1e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 11 Jun 2025 11:55:27 +0200 Subject: [PATCH 5/5] Coverage --- test/printing.jl | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 test/printing.jl diff --git a/test/printing.jl b/test/printing.jl new file mode 100644 index 00000000..4e03d04e --- /dev/null +++ b/test/printing.jl @@ -0,0 +1,6 @@ +using TestItems + +@testitem "Settings" begin + @test startswith(string(OperatorRepresentation()), "Operator") + @test startswith(string(IterativeLinearSolver(; atol=1e-5)), "Iterative") +end