diff --git a/src/settings.jl b/src/settings.jl index c5a3dcf1..3637e17d 100644 --- a/src/settings.jl +++ b/src/settings.jl @@ -19,7 +19,10 @@ Specify that linear systems `Ax = b` should be solved with a direct method. struct DirectLinearSolver <: AbstractSolver end function (solver::DirectLinearSolver)( - A::Union{AbstractMatrix,Factorization}, _Aᵀ, b::AbstractVector, x0::AbstractVector + A::Union{AbstractMatrix,Factorization,Number}, + _Aᵀ, + b::AbstractVector, + x0::AbstractVector, ) return A \ b end diff --git a/test/systematic.jl b/test/systematic.jl index cbb70055..cfc8abc9 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -19,6 +19,23 @@ using TestItems test_implicit(scen) test_implicit(scen2) end + + # Test for output vector of length 1 + for (linear_solver, backends) in Iterators.product( + [DirectLinearSolver(), IterativeLinearSolver()], + [nothing, (; x=AutoForwardDiff(), y=AutoZygote())], + ) + yield() + scen = Scenario(; + solver=x -> (sqrt.(x), nothing), + conditions=(x, y, z) -> y .^ 2 .- x, + x=[1.0], + implicit_kwargs=(; representation, linear_solver, backends), + ) + scen2 = add_arg_mult(scen) + test_implicit(scen) + test_implicit(scen2) + end end; @testitem "Operator" setup = [TestUtils] begin