Tricks

We demonstrate several features that may come in handy for some users.

using ComponentArrays
using ForwardDiff
using ImplicitDifferentiation
using Krylov
using LinearAlgebra
using Zygote

ComponentArrays

For when you need derivatives with respect to multiple inputs or outputs.

function forward_components_aux(a::AbstractVector, b::AbstractVector, m::Number)
    d = m * sqrt.(a)
    e = sqrt.(b)
    return d, e
end

function conditions_components_aux(a, b, m, d, e)
    c_d = (d ./ m) .^ 2 .- a
    c_e = (e .^ 2) .- b
    return c_d, c_e
end;

You can use ComponentVector from ComponentArrays.jl as an intermediate storage.

function forward_components(x::ComponentVector)
    d, e = forward_components_aux(x.a, x.b, x.m)
    y = ComponentVector(; d=d, e=e)
    z = nothing
    return y, z
end

function conditions_components(x::ComponentVector, y::ComponentVector, _z)
    c_d, c_e = conditions_components_aux(x.a, x.b, x.m, y.d, y.e)
    c = ComponentVector(; c_d=c_d, c_e=c_e)
    return c
end;

And build your implicit function like so.

implicit_components = ImplicitFunction(forward_components, conditions_components);

Now we're good to go.

a, b, m = [1.0, 2.0], [3.0, 4.0, 5.0], 6.0
x = ComponentVector(; a=a, b=b, m=m)
implicit_components(x)
((d = [6.0, 8.485281374238571], e = [1.7320508075688772, 2.0, 2.23606797749979]), nothing)

And it works with both ForwardDiff.jl and Zygote.jl

ForwardDiff.jacobian(first ∘ implicit_components, x)
5×6 Matrix{Float64}:
 3.0  0.0      0.0       0.0   0.0       1.0
 0.0  2.12132  0.0       0.0   0.0       1.41421
 0.0  0.0      0.288675  0.0   0.0       0.0
 0.0  0.0      0.0       0.25  0.0       0.0
 0.0  0.0      0.0       0.0   0.223607  0.0
Zygote.jacobian(first ∘ implicit_components, x)[1]
5×6 Matrix{Float64}:
 3.0  0.0      0.0       0.0   0.0       1.0
 0.0  2.12132  0.0       0.0   0.0       1.41421
 0.0  0.0      0.288675  0.0   0.0       0.0
 0.0  0.0      0.0       0.25  0.0       0.0
 0.0  0.0      0.0       0.0   0.223607  0.0
function full_pipeline(a, b, m)
    x = ComponentVector(; a=a, b=b, m=m)
    y, _ = implicit_components(x)
    return y.d, y.e
end;

This page was generated using Literate.jl.