Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek"]
version = "0.8.0"
version = "0.8.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
3 changes: 0 additions & 3 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ Say your forward mapping takes multiple inputs and returns multiple outputs, suc
The trick is to leverage [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl) to wrap all the inputs inside a single a `ComponentVector`, and do the same for all the outputs.
See the examples for a demonstration.

!!! warning
The default linear operator representation does not support ComponentArrays.jl: you need to select `representation=OperatorRepresentation{:LinearMaps}()` in the [`ImplicitFunction`](@ref) constructor for it to work.

!!! warning
You may run into issues trying to differentiate through the `ComponentVector` constructor.
For instance, Zygote.jl will throw `ERROR: Mutating arrays is not supported`.
Expand Down
8 changes: 2 additions & 6 deletions examples/3_tricks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,9 @@ function conditions_components(x::ComponentVector, y::ComponentVector, _z)
return c
end;

# And build your implicit function like so, switching the operator representation to avoid errors with ComponentArrays.
# And build your implicit function like so:

implicit_components = ImplicitFunction(
forward_components,
conditions_components;
representation=OperatorRepresentation{:LinearMaps}(),
);
implicit_components = ImplicitFunction(forward_components, conditions_components);

# Now we're good to go.

Expand Down
26 changes: 20 additions & 6 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ function build_A_aux(
end

function build_A_aux(
::OperatorRepresentation{package,symmetric,hermitian},
::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type},
implicit,
x,
y,
z,
c,
args...;
suggested_backend,
) where {package,symmetric,hermitian}
) where {package,symmetric,hermitian,posdef,keep_input_type}
T = Base.promote_eltype(x, y, c)
(; conditions, backends, prep_A) = implicit
actual_backend = isnothing(backends) ? suggested_backend : backends.y
Expand All @@ -89,7 +89,13 @@ function build_A_aux(
prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, dy_vec, contexts)
if package == :LinearOperators
return LinearOperator(
T, length(c), length(y), symmetric, hermitian, prod!; S=typeof(dy_vec)
T,
length(c),
length(y),
symmetric,
hermitian,
prod!;
S=keep_input_type ? typeof(dy_vec) : Vector{T},
)
elseif package == :LinearMaps
return FunctionMap{T}(
Expand All @@ -99,6 +105,7 @@ function build_A_aux(
ismutating=true,
issymmetric=symmetric,
ishermitian=hermitian,
isposdef=posdef,
)
end
end
Expand Down Expand Up @@ -136,15 +143,15 @@ function build_Aᵀ_aux(
end

function build_Aᵀ_aux(
::OperatorRepresentation{package,symmetric,hermitian},
::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type},
implicit,
x,
y,
z,
c,
args...;
suggested_backend,
) where {package,symmetric,hermitian}
) where {package,symmetric,hermitian,posdef,keep_input_type}
T = Base.promote_eltype(x, y, c)
(; conditions, backends, prep_Aᵀ) = implicit
actual_backend = isnothing(backends) ? suggested_backend : backends.y
Expand All @@ -164,7 +171,13 @@ function build_Aᵀ_aux(
prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, dc_vec, contexts)
if package == :LinearOperators
return LinearOperator(
T, length(y), length(c), symmetric, hermitian, prod!; S=typeof(dc_vec)
T,
length(y),
length(c),
symmetric,
hermitian,
prod!;
S=keep_input_type ? typeof(dc_vec) : Vector{T},
)
elseif package == :LinearMaps
return FunctionMap{T}(
Expand All @@ -174,6 +187,7 @@ function build_Aᵀ_aux(
ismutating=true,
issymmetric=symmetric,
ishermitian=hermitian,
isposdef=posdef,
)
end
end
Expand Down
30 changes: 20 additions & 10 deletions src/settings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Callable object that can solve linear systems `Ax = b` and `AX = B` in the same

The type parameter `package` can be either:

- `:Krylov` to use the solver `gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) (the default)
- `:Krylov` to use the solver `gmres` or `block_gmres` from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) (the default)
- `:IterativeSolvers` to use the solver `gmres` from [IterativeSolvers.jl](https://github.com/JuliaLinearAlgebra/IterativeSolvers.jl)

Keyword arguments are passed on to the respective solver.
Expand Down Expand Up @@ -94,36 +94,46 @@ Specify that the matrix `A` involved in the implicit function theorem should be

# Constructors

OperatorRepresentation(; symmetric=false, hermitian=false)
OperatorRepresentation{package}(; symmetric=false, hermitian=false)
OperatorRepresentation(;
symmetric=false, hermitian=false, posdef=false, keep_input_type=false
)
OperatorRepresentation{package}(;
symmetric=false, hermitian=false, posdef=false, keep_input_type=false
)

The type parameter `package` can be either:

- `:LinearOperators` to use a wrapper from [LinearOperators.jl](https://github.com/JuliaSmoothOptimizers/LinearOperators.jl) (the default)
- `:LinearMaps` to use a wrapper from [LinearMaps.jl](https://github.com/JuliaLinearAlgebra/LinearMaps.jl)

The keyword arguments `symmetric` and `hermitian` give additional properties of the Jacobian of the `conditions` with respect to the solution `y`, in case you can prove them.
The keyword arguments `symmetric`, `hermitian` and `posdef` give additional properties of the Jacobian of the `conditions` with respect to the solution `y`, which are useful to the solver in case you can prove them.

The keyword argument `keep_input_type` dictates whether to force the linear operator to work with the provided input type, or fall back on a default.

# See also

- [`ImplicitFunction`](@ref)
- [`MatrixRepresentation`](@ref)
"""
struct OperatorRepresentation{package,symmetric,hermitian} <: AbstractRepresentation
struct OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type} <:
AbstractRepresentation
function OperatorRepresentation{package}(;
symmetric::Bool=false, hermitian::Bool=false
symmetric::Bool=false,
hermitian::Bool=false,
posdef::Bool=false,
keep_input_type::Bool=false,
) where {package}
@assert package in [:LinearOperators, :LinearMaps]
return new{package,symmetric,hermitian}()
return new{package,symmetric,hermitian,posdef,keep_input_type}()
end
end

function Base.show(
io::IO, ::OperatorRepresentation{package,symmetric,hermitian}
) where {package,symmetric,hermitian}
io::IO, ::OperatorRepresentation{package,symmetric,hermitian,posdef,keep_input_type}
) where {package,symmetric,hermitian,posdef,keep_input_type}
return print(
io,
"OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian)",
"OperatorRepresentation{$(repr(package))}(; symmetric=$symmetric, hermitian=$hermitian, posdef=$posdef, keep_input_type=$keep_input_type)",
)
end

Expand Down
Loading