Tutorial
Necessary imports
using DifferentiableFrankWolfe: DiffFW, simplex_projection
using ForwardDiff: ForwardDiff
using FrankWolfe: ProbabilitySimplexOracle
using ProximalOperators: ProximalOperators
using Test: @test
using Zygote: Zygote
Constructing the wrapper
f(x, θ) = 0.5 * sum(abs2, x - θ) # minimizing the squared distance...
f_grad1(x, θ) = x - θ
lmo = ProbabilitySimplexOracle(1.0) # ... to the probability simplex
dfw = DiffFW(f, f_grad1, lmo); # ... is equivalent to a simplex projection if we're not already in it
Calling the wrapper
x0 = ones(3) ./ 3
θ = [1.0, 1.5, 0.2]
3-element Vector{Float64}:
1.0
1.5
0.2
frank_wolfe_kwargs = (; max_iteration=100, epsilon=1e-4)
y = dfw(θ, x0; frank_wolfe_kwargs...)
3-element Vector{Float64}:
0.25000000000000006
0.75
0.0
true_simplex_projection(x) = ProximalOperators.prox(ProximalOperators.IndSimplex(1.0), x)[1]
true_simplex_projection (generic function with 1 method)
y_true = true_simplex_projection(θ)
@test Vector(y) ≈ Vector(y_true) atol = 1e-3
Test Passed
Differentiating the wrapper
J_true = ForwardDiff.jacobian(true_simplex_projection, θ)
3×3 Matrix{Float64}:
0.5 -0.5 0.0
-0.5 0.5 0.0
0.0 0.0 0.0
J1 = Zygote.jacobian(_θ -> dfw(_θ, x0; frank_wolfe_kwargs...), θ)[1]
@test J1 ≈ J_true atol = 1e-3
Test Passed
J2 = ForwardDiff.jacobian(_θ -> dfw(_θ, x0; frank_wolfe_kwargs...), θ)
@test J2 ≈ J_true atol = 1e-3
Test Passed
This page was generated using Literate.jl.