diff --git a/Project.toml b/Project.toml index 63743d10..0b0d4c35 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek"] -version = "0.8.0" +version = "0.8.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/faq.md b/docs/src/faq.md index 3e7e4edf..bec4b2bf 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -46,9 +46,6 @@ Say your forward mapping takes multiple inputs and returns multiple outputs, suc The trick is to leverage [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap all the inputs inside a single a `ComponentVector`, and do the same for all the outputs. See the examples for a demonstration. -!!! warning - The default linear operator representation does not support ComponentArrays.jl: you need to select `representation=OperatorRepresentation{:LinearMaps}()` in the [`ImplicitFunction`](@ref) constructor for it to work. - !!! warning You may run into issues trying to differentiate through the `ComponentVector` constructor. For instance, Zygote.jl will throw `ERROR: Mutating arrays is not supported`. diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index dd43ddd2..a1cf4e89 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -42,13 +42,9 @@ function conditions_components(x::ComponentVector, y::ComponentVector, _z) return c end; -# And build your implicit function like so, switching the operator representation to avoid errors with ComponentArrays. +# And build your implicit function like so: -implicit_components = ImplicitFunction( - forward_components, - conditions_components; - representation=OperatorRepresentation{:LinearMaps}(), -); +implicit_components = ImplicitFunction(forward_components, conditions_components); # Now we're good to go. diff --git a/src/execution.jl b/src/execution.jl index 1190bd69..3d5d00e3 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -61,7 +61,7 @@ function build_A_aux( end function build_A_aux( - ::OperatorRepresentation{package,symmetric,hermitian}, + ::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type}, implicit, x, y, @@ -69,7 +69,7 @@ function build_A_aux( c, args...; suggested_backend, -) where {package,symmetric,hermitian} +) where {package,symmetric,hermitian,posdef,keep_input_type} T = Base.promote_eltype(x, y, c) (; conditions, backends, prep_A) = implicit actual_backend = isnothing(backends) ? suggested_backend : backends.y @@ -89,7 +89,13 @@ function build_A_aux( prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, dy_vec, contexts) if package == :LinearOperators return LinearOperator( - T, length(c), length(y), symmetric, hermitian, prod!; S=typeof(dy_vec) + T, + length(c), + length(y), + symmetric, + hermitian, + prod!; + S=keep_input_type ? typeof(dy_vec) : Vector{T}, ) elseif package == :LinearMaps return FunctionMap{T}( @@ -99,6 +105,7 @@ function build_A_aux( ismutating=true, issymmetric=symmetric, ishermitian=hermitian, + isposdef=posdef, ) end end @@ -136,7 +143,7 @@ function build_Aᵀ_aux( end function build_Aᵀ_aux( - ::OperatorRepresentation{package,symmetric,hermitian}, + ::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type}, implicit, x, y, @@ -144,7 +151,7 @@ function build_Aᵀ_aux( c, args...; suggested_backend, -) where {package,symmetric,hermitian} +) where {package,symmetric,hermitian,posdef,keep_input_type} T = Base.promote_eltype(x, y, c) (; conditions, backends, prep_Aᵀ) = implicit actual_backend = isnothing(backends) ? suggested_backend : backends.y @@ -164,7 +171,13 @@ function build_Aᵀ_aux( prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, dc_vec, contexts) if package == :LinearOperators return LinearOperator( - T, length(y), length(c), symmetric, hermitian, prod!; S=typeof(dc_vec) + T, + length(y), + length(c), + symmetric, + hermitian, + prod!; + S=keep_input_type ? typeof(dc_vec) : Vector{T}, ) elseif package == :LinearMaps return FunctionMap{T}( @@ -174,6 +187,7 @@ function build_Aᵀ_aux( ismutating=true, issymmetric=symmetric, ishermitian=hermitian, + isposdef=posdef, ) end end diff --git a/src/settings.jl b/src/settings.jl index ca791114..5c4ba95a 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -12,7 +12,7 @@ Callable object that can solve linear systems `Ax = b` and `AX = B` in the same The type parameter `package` can be either: -- `:Krylov` to use the solver `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) (the default) +- `:Krylov` to use the solver `gmres` or `block_gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) (the default) - `:IterativeSolvers` to use the solver `gmres` from [IterativeSolvers.jl](https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl) Keyword arguments are passed on to the respective solver. @@ -94,36 +94,46 @@ Specify that the matrix `A` involved in the implicit function theorem should be # Constructors - OperatorRepresentation(; symmetric=false, hermitian=false) - OperatorRepresentation{package}(; symmetric=false, hermitian=false) + OperatorRepresentation(; + symmetric=false, hermitian=false, posdef=false, keep_input_type=false + ) + OperatorRepresentation{package}(; + symmetric=false, hermitian=false, posdef=false, keep_input_type=false + ) The type parameter `package` can be either: - `:LinearOperators` to use a wrapper from [LinearOperators.jl](https://github.com/JuliaSmoothOptimizers/LinearOperators.jl) (the default) - `:LinearMaps` to use a wrapper from [LinearMaps.jl](https://github.com/JuliaLinearAlgebra/LinearMaps.jl) -The keyword arguments `symmetric` and `hermitian` give additional properties of the Jacobian of the `conditions` with respect to the solution `y`, in case you can prove them. +The keyword arguments `symmetric`, `hermitian` and `posdef` give additional properties of the Jacobian of the `conditions` with respect to the solution `y`, which are useful to the solver in case you can prove them. + +The keyword argument `keep_input_type` dictates whether to force the linear operator to work with the provided input type, or fall back on a default. # See also - [`ImplicitFunction`](@ref) - [`MatrixRepresentation`](@ref) """ -struct OperatorRepresentation{package,symmetric,hermitian} <: AbstractRepresentation +struct OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type} <: + AbstractRepresentation function OperatorRepresentation{package}(; - symmetric::Bool=false, hermitian::Bool=false + symmetric::Bool=false, + hermitian::Bool=false, + posdef::Bool=false, + keep_input_type::Bool=false, ) where {package} @assert package in [:LinearOperators, :LinearMaps] - return new{package,symmetric,hermitian}() + return new{package,symmetric,hermitian,posdef,keep_input_type}() end end function Base.show( - io::IO, ::OperatorRepresentation{package,symmetric,hermitian} -) where {package,symmetric,hermitian} + io::IO, ::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type} +) where {package,symmetric,hermitian,posdef,keep_input_type} return print( io, - "OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian)", + "OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian, posdef=$posdef, keep_input_type=$keep_input_type)", ) end