Tutorial

Necessary imports

using DifferentiableFrankWolfe: DiffFW, simplex_projection
using ForwardDiff: ForwardDiff
using FrankWolfe: ProbabilitySimplexOracle
using ProximalOperators: ProximalOperators
using Test: @test
using Zygote: Zygote

Constructing the wrapper

f(x, θ) = 0.5 * sum(abs2, x - θ)  # minimizing the squared distance...
f_grad1(x, θ) = x - θ
lmo = ProbabilitySimplexOracle(1.0)  # ... to the probability simplex
dfw = DiffFW(f, f_grad1, lmo);  # ... is equivalent to a simplex projection if we're not already in it

Calling the wrapper

x0 = ones(3) ./ 3
θ = [1.0, 1.5, 0.2]
3-element Vector{Float64}:
 1.0
 1.5
 0.2
frank_wolfe_kwargs = (; max_iteration=100, epsilon=1e-4)
y = dfw(θ, x0; frank_wolfe_kwargs...)
3-element Vector{Float64}:
 0.25000000000000006
 0.75
 0.0
true_simplex_projection(x) = ProximalOperators.prox(ProximalOperators.IndSimplex(1.0), x)[1]
true_simplex_projection (generic function with 1 method)
y_true = true_simplex_projection(θ)
@test Vector(y) ≈ Vector(y_true) atol = 1e-3
Test Passed

Differentiating the wrapper

J_true = ForwardDiff.jacobian(true_simplex_projection, θ)
3×3 Matrix{Float64}:
  0.5  -0.5  0.0
 -0.5   0.5  0.0
  0.0   0.0  0.0
J1 = Zygote.jacobian(_θ -> dfw(_θ, x0; frank_wolfe_kwargs...), θ)[1]
@test J1 ≈ J_true atol = 1e-3
Test Passed
J2 = ForwardDiff.jacobian(_θ -> dfw(_θ, x0; frank_wolfe_kwargs...), θ)
@test J2 ≈ J_true atol = 1e-3
Test Passed

This page was generated using Literate.jl.