Tutorial
Necessary imports
using DifferentiableFrankWolfe: DiffFW, simplex_projection
using ForwardDiff: ForwardDiff
using FrankWolfe: UnitSimplexOracle
using Test: @test
using Zygote: Zygote
Constructing the wrapper
f(x, θ) = 0.5 * sum(abs2, x - θ) # minimizing the squared distance...
f_grad1(x, θ) = x - θ
lmo = UnitSimplexOracle(1.0) # ... to the probability simplex
dfw = DiffFW(f, f_grad1, lmo); # ... is equivalent to a simplex projection
Calling the wrapper
θ = rand(10)
10-element Vector{Float64}:
0.9325545708397265
0.7662143208981307
0.6655424605378584
0.6423693908827878
0.5597701877353234
0.7028894953006231
0.6828459623254999
0.20720555151910935
0.8613616683664914
0.5259840470394308
frank_wolfe_kwargs = (; max_iteration=100, epsilon=1e-4)
y, stats = dfw(θ, frank_wolfe_kwargs)
y
10-element SparseArrays.SparseVector{Float64, Int64} with 7 stored entries:
[1] = 0.324803
[2] = 0.158503
[3] = 0.0578799
[4] = 0.0346624
[6] = 0.0952681
[7] = 0.0751824
[9] = 0.253701
y_true = simplex_projection(θ)
@test Vector(y) ≈ Vector(y_true) atol = 1e-3
Test Passed
Differentiating the wrapper
J1 = Zygote.jacobian(_θ -> dfw(_θ, frank_wolfe_kwargs)[1], θ)[1]
J1_true = Zygote.jacobian(simplex_projection, θ)[1]
@test J1 ≈ J1_true atol = 1e-3
Test Passed
J2 = ForwardDiff.jacobian(_θ -> dfw(_θ, frank_wolfe_kwargs)[1], θ)
J2_true = ForwardDiff.jacobian(simplex_projection, θ)
@test J2 ≈ J2_true atol = 1e-3
Test Passed
This page was generated using Literate.jl.