FAQ
Supported autodiff backends
To differentiate through an ImplicitFunction
, the following backends are supported.
Backend | Forward mode | Reverse mode |
---|---|---|
ForwardDiff.jl | yes | - |
ChainRules.jl-compatible | no | yes |
Enzyme.jl | yes | soon |
By default, the conditions are differentiated using the same "outer" backend that is trying to differentiate the ImplicitFunction
. However, this can be switched to any other "inner" backend compatible with DifferentiationInterface.jl (i.e. a subtype of ADTypes.AbstractADType
).
Input and output types
Vectors
Functions that eat or spit out arbitrary vectors are supported, as long as the forward mapping and conditions return vectors of the same size.
If you deal with small vectors (say, less than 100 elements), consider using StaticArrays.jl for increased performance.
Arrays
Functions that eat or spit out matrices and higher-order tensors are not supported. You can use vec
and reshape
for the conversion to and from vectors.
Scalars
Functions that eat or spit out a single number are not supported. The forward mapping and conditions need vectors: instead of returning val
you should return [val]
(a 1-element Vector
). Or better yet, wrap it in a static vector: SVector(val)
.
Sparse arrays
Sparse arrays are not supported out of the box and might yield incorrect values!
If your use case involves sparse arrays, it is best to differentiate with respect to the dense vector of values and only construct the sparse array inside of the forward
and conditions
functions.
Number of inputs and outputs
Most of the documentation is written for the simple case where the forward mapping is x -> y
, i.e. one input and one output. What can you do to handle multiple inputs or outputs? Well, it depends whether you want their derivatives or not.
Derivatives needed | Derivatives not needed | |
---|---|---|
Multiple inputs | Make x a ComponentVector | Supply args and kwargs to forward |
Multiple outputs | Make y and c two ComponentVector s | Let forward return a byproduct z |
We now detail each of these options.
Multiple inputs or outputs | Derivatives needed
Say your forward mapping takes multiple inputs and returns multiple outputs, such that you want derivatives for all of them.
The trick is to leverage ComponentArrays.jl to wrap all the inputs inside a single a ComponentVector
, and do the same for all the outputs. See the examples for a demonstration.
You may run into issues trying to differentiate through the ComponentVector
constructor. For instance, Zygote.jl will throw ERROR: Mutating arrays is not supported
. Check out this issue for a dirty workaround involving custom chain rules for the constructor.
Multiple inputs | Derivatives not needed
If your forward mapping (or conditions) takes multiple inputs but you don't care about derivatives, then you can add further positional and keyword arguments beyond x
. It is important to make sure that the forward mapping and conditions accept the same set of arguments, even if each of these functions only uses a subset of them.
forward(x, arg1, arg2; kwarg1, kwarg2) = y
conditions(x, y, arg1, arg2; kwarg1, kwarg2) = c
All of the positional and keyword arguments apart from x
will get zero tangents during differentiation of the implicit function.
Multiple outputs | Derivatives not needed
The last and most tricky situation is when your forward mapping returns multiple outputs, but you only care about some of their derivatives. Then, you need to group the objects you don't want to differentiate into a "byproduct" z
, returned alongside the actual output y
. This way, derivatives of z
will not be computed: the byproduct is considered constant during differentiation.
The signatures of your functions will need to be be slightly different from the previous cases:
forward(x, arg1, arg2; kwarg1, kwarg2) = (y, z)
conditions(x, y, z, arg1, arg2; kwarg1, kwarg2) = c
See the examples for a demonstration.
This is mainly useful when the solution procedure creates objects such as Jacobians, which we want to reuse when computing or differentiating the conditions. In that case, you may want to write the conditions differentiation rules yourself. A more advanced application is given by DifferentiableFrankWolfe.jl.
Modeling tips
Writing conditions
We recommend that the conditions themselves do not involve calls to autodiff, even when they describe a gradient. Otherwise, you will need to make sure that nested autodiff works well in your case (i.e. that the "outer" backend can differentiate through the "inner" backend). For instance, if you're differentiating your implicit function (and your conditions) in reverse mode with Zygote.jl, you may want to use ForwardDiff.jl mode to compute gradients inside the conditions.
Dealing with constraints
To express constrained optimization problems as implicit functions, you might need differentiable projections or proximal operators to write the optimality conditions. See Efficient and modular implicit differentiation for precise formulations.
In case these operators are too complicated to code them yourself, here are a few places you can look:
An alternative is differentiating through the KKT conditions, which is exactly what DiffOpt.jl does for JuMP models.
Memoization
In some cases, performance might be increased by using memoization to prevent redundant calls to forward
. For instance, this is relevant when calculating large Jacobians with forward differentiation, where the computation happens in chunks. Packages such as Memoize.jl and Memoization.jl are useful for defining a memoized version of forward
:
using Memoize
@memoize Dict forward(x, args...; kwargs...) = y