Tutorial

Necessary imports

using DifferentiableFrankWolfe: DiffFW, simplex_projection
using ForwardDiff: ForwardDiff
using FrankWolfe: UnitSimplexOracle
using Test: @test
using Zygote: Zygote

Constructing the wrapper

f(x, θ) = 0.5 * sum(abs2, x - θ)  # minimizing the squared distance...
f_grad1(x, θ) = x - θ
lmo = UnitSimplexOracle(1.0)  # ... to the probability simplex
dfw = DiffFW(f, f_grad1, lmo);  # ... is equivalent to a simplex projection

Calling the wrapper

θ = rand(10)
10-element Vector{Float64}:
 0.300030460072818
 0.6872777012755744
 0.8568514886465858
 0.323903091725261
 0.8786419965139313
 0.0916362858029951
 0.6389548522730417
 0.29089474608332
 0.4779956250484795
 0.798441118556476
frank_wolfe_kwargs = (max_iteration=100, epsilon=1e-4)
y = dfw(θ; frank_wolfe_kwargs)
10-element SparseArrays.SparseVector{Float64, Int64} with 5 stored entries:
  [2 ]  =  0.115253
  [3 ]  =  0.284826
  [5 ]  =  0.306574
  [7 ]  =  0.0669019
  [10]  =  0.226445
y_true = simplex_projection(θ)
@test Vector(y) ≈ Vector(y_true) atol = 1e-3
Test Passed

Differentiating the wrapper

J1 = Zygote.jacobian(_θ -> dfw(_θ; frank_wolfe_kwargs), θ)[1]
J1_true = Zygote.jacobian(simplex_projection, θ)[1]
@test J1 ≈ J1_true atol = 1e-3
Test Passed
J2 = ForwardDiff.jacobian(_θ -> dfw(_θ; frank_wolfe_kwargs), θ)
J2_true = ForwardDiff.jacobian(simplex_projection, θ)
@test J2 ≈ J2_true atol = 1e-3
Test Passed

This page was generated using Literate.jl.