Skip to content

Compatibility with ComponentArrays #152

@gdalle

Description

@gdalle

Opening to keep track of the problem in #150

@benjaminfaber can you see what the new bug is for your MWE?

using ImplicitDifferentiation
using Enzyme
using ComponentArrays

function forward_components_aux(a::AbstractVector, b::AbstractVector, m::Number)
    d = m * sqrt.(a)
    e = sqrt.(b)
    return d, e
end

function conditions_components_aux(a, b, m, d, e)
    c_d = (d ./ m) .^ 2 .- a
    c_e = (e .^ 2) .- b
    return c_d, c_e
end;

function forward_components(x::ComponentVector)
    d, e = forward_components_aux(x.a, x.b, x.m)
    y = ComponentVector(; d=d, e=e)
    return y
end

function conditions_components(x::ComponentVector, y::ComponentVector)
    c_d, c_e = conditions_components_aux(x.a, x.b, x.m, y.d, y.e)
    c = ComponentVector(; c_d=c_d, c_e=c_e)
    return c
end;

implicit_components = ImplicitFunction(forward_components, conditions_components);

a, b, m = [1.0, 2.0], [3.0, 4.0, 5.0], 6.0
x = ComponentVector(; a=a, b=b, m=m)
dx_zero = Enzyme.make_zero(x)
dx = Vector{typeof(x)}(undef, length(x))

for i in eachindex(dx)
    fill!(dx_zero, 0.)
    dx_zero[i] = 1.
    dx[i] = copy(dx_zero)
end
dx = tuple(dx...)

Enzyme.autodiff(Enzyme.Forward, implicit_components, Enzyme.BatchDuplicated, Enzyme.BatchDuplicated(x, dx))

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions