Training on Warcraft Shortest Path
This tutorial demonstrates how to train a decision-focused learning policy on the Warcraft shortest path benchmark using the Perturbed Fenchel-Young Loss Imitation algorithm.
Setup
First, let's load the required packages:
using DecisionFocusedLearningAlgorithms
using DecisionFocusedLearningBenchmarks
using Flux
using MLUtils
using Plots
using StatisticsBenchmark Setup
The Warcraft benchmark involves predicting edge costs in a grid graph for shortest path problems. We'll create a benchmark instance and generate training data:
benchmark = WarcraftBenchmark()
dataset = generate_dataset(benchmark, 50)50-element Vector{DecisionFocusedLearningBenchmarks.Utils.DataSample{@NamedTuple{}, @NamedTuple{}, Array{Float32, 4}, BitMatrix, Matrix{Float16}}}:
DataSample(x=Float32[0.133333 0.133333 … 0.164706 0.235294; 0.12549 0.105882 … 0.168627 0.313726; … ; 0.247059 0.247059 … 0.0196078 0.0156863; 0.384314 0.423529 … 0.0156863 0.0196078;;; 0.309804 0.294118 … 0.341176 0.356863; 0.298039 0.282353 … 0.329412 0.254902; … ; 0.305882 0.266667 … 0.227451 0.219608; 0.239216 0.239216 … 0.219608 0.227451;;; 0.0392157 0.027451 … 0.0509804 0.0705882; 0.0313726 0.0235294 … 0.0470588 0.0235294; … ; 0.0431373 0.027451 … 0.458824 0.454902; 0.00392157 0.0 … 0.454902 0.458824;;;;], θ=Float16[-0.8 -0.8 … -1.2 -0.8; -0.8 -0.8 … -1.2 -1.2; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.0196078 0.0196078 … 0.423529 0.411765; 0.0156863 0.0156863 … 0.427451 0.443137; … ; 0.0196078 0.0196078 … 0.439216 0.439216; 0.0196078 0.0196078 … 0.431373 0.435294;;; 0.219608 0.215686 … 0.25098 0.239216; 0.219608 0.207843 … 0.25098 0.258824; … ; 0.227451 0.227451 … 0.258824 0.258824; 0.231373 0.227451 … 0.254902 0.258824;;; 0.454902 0.447059 … 0.00784314 0.00392157; 0.454902 0.443137 … 0.00784314 0.0117647; … ; 0.458824 0.462745 … 0.00784314 0.00784314; 0.462745 0.458824 … 0.00392157 0.00784314;;;;], θ=Float16[-7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -1.2 -1.2; … ; -7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -9.2 -1.2], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.223529 0.0 … 0.160784 0.164706; 0.203922 0.0588235 … 0.156863 0.156863; … ; 0.0235294 0.00392157 … 0.00392157 0.0235294; 0.0352941 0.0117647 … 0.0 0.0;;; 0.407843 0.25098 … 0.32549 0.337255; 0.356863 0.32549 … 0.321569 0.329412; … ; 0.196078 0.215686 … 0.207843 0.235294; 0.215686 0.196078 … 0.231373 0.247059;;; 0.0823529 0.0 … 0.0431373 0.0509804; 0.054902 0.0352941 … 0.0431373 0.0431373; … ; 0.0 0.0 … 0.0 0.00784314; 0.0196078 0.0 … 0.0 0.0;;;;], θ=Float16[-5.3 -5.3 … -1.2 -0.8; -5.3 -5.3 … -1.2 -0.8; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.431373 0.454902 … 0.0156863 0.0196078; 0.411765 0.368627 … 0.0156863 0.0156863; … ; 0.14902 0.152941 … 0.152941 0.164706; 0.156863 0.137255 … 0.164706 0.164706;;; 0.25098 0.258824 … 0.219608 0.227451; 0.270588 0.278431 … 0.219608 0.219608; … ; 0.313726 0.321569 … 0.321569 0.333333; 0.32549 0.301961 … 0.333333 0.333333;;; 0.00392157 0.00784314 … 0.454902 0.462745; 0.0156863 0.0235294 … 0.454902 0.454902; … ; 0.0392157 0.0431373 … 0.0392157 0.0470588; 0.0431373 0.0313726 … 0.0470588 0.0470588;;;;], θ=Float16[-1.2 -7.7 … -7.7 -7.7; -7.7 -1.2 … -7.7 -7.7; … ; -0.8 -0.8 … -0.8 -5.3; -0.8 -1.2 … -0.8 -0.8], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.152941 0.317647 … 0.435294 0.403922; 0.376471 0.305882 … 0.431373 0.411765; … ; 0.2 0.121569 … 0.360784 0.411765; 0.0901961 0.172549 … 0.219608 0.447059;;; 0.152941 0.317647 … 0.258824 0.239216; 0.376471 0.305882 … 0.254902 0.243137; … ; 0.2 0.121569 … 0.223529 0.286275; 0.0901961 0.172549 … 0.164706 0.258824;;; 0.152941 0.317647 … 0.00784314 0.0; 0.376471 0.305882 … 0.00392157 0.0; … ; 0.2 0.121569 … 0.027451 0.109804; 0.0901961 0.172549 … 0.0862745 0.0;;;;], θ=Float16[-9.2 -9.2 … -0.8 -1.2; -9.2 -9.2 … -0.8 -1.2; … ; -9.2 -9.2 … -9.2 -1.2; -9.2 -9.2 … -9.2 -1.2], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 1])
DataSample(x=Float32[0.0235294 0.0196078 … 0.164706 0.164706; 0.0156863 0.0156863 … 0.168627 0.164706; … ; 0.168627 0.164706 … 0.431373 0.427451; 0.156863 0.160784 … 0.431373 0.427451;;; 0.227451 0.223529 … 0.333333 0.329412; 0.215686 0.207843 … 0.337255 0.329412; … ; 0.337255 0.337255 … 0.254902 0.254902; 0.329412 0.317647 … 0.254902 0.254902;;; 0.458824 0.454902 … 0.0470588 0.0470588; 0.45098 0.443137 … 0.0470588 0.0470588; … ; 0.0470588 0.0470588 … 0.00392157 0.00392157; 0.0431373 0.0392157 … 0.00392157 0.00392157;;;;], θ=Float16[-7.7 -7.7 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8; … ; -1.2 -1.2 … -9.2 -9.2; -0.8 -1.2 … -9.2 -1.2], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.407843 0.415686 … 0.380392 0.286275; 0.407843 0.411765 … 0.376471 0.309804; … ; 0.12549 0.113725 … 0.152941 0.160784; 0.129412 0.129412 … 0.156863 0.152941;;; 0.239216 0.243137 … 0.203922 0.278431; 0.239216 0.243137 … 0.278431 0.258824; … ; 0.298039 0.278431 … 0.317647 0.329412; 0.294118 0.298039 … 0.32549 0.333333;;; 0.00392157 0.00784314 … 0.0 0.0352941; 0.00392157 0.00392157 … 0.0235294 0.0235294; … ; 0.0313726 0.0196078 … 0.0392157 0.0470588; 0.027451 0.0313726 … 0.0431373 0.0470588;;;;], θ=Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -0.8 … -1.2 -1.2; … ; -0.8 -1.2 … -5.3 -5.3; -0.8 -0.8 … -0.8 -0.8], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 1])
DataSample(x=Float32[0.0156863 0.0156863 … 0.439216 0.427451; 0.0156863 0.0156863 … 0.431373 0.435294; … ; 0.337255 0.364706 … 0.168627 0.196078; 0.262745 0.368627 … 0.227451 0.341176;;; 0.207843 0.2 … 0.258824 0.25098; 0.203922 0.207843 … 0.254902 0.254902; … ; 0.239216 0.27451 … 0.34902 0.376471; 0.317647 0.243137 … 0.360784 0.247059;;; 0.443137 0.435294 … 0.0117647 0.00784314; 0.439216 0.443137 … 0.00784314 0.00784314; … ; 0.0156863 0.0196078 … 0.054902 0.0784314; 0.0470588 0.0117647 … 0.0705882 0.0196078;;;;], θ=Float16[-7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -1.2; -1.2 -1.2 … -0.8 -0.8], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.0235294 0.0196078 … 0.443137 0.439216; 0.0156863 0.0156863 … 0.435294 0.443137; … ; 0.396078 0.392157 … 0.439216 0.439216; 0.392157 0.392157 … 0.431373 0.435294;;; 0.235294 0.231373 … 0.258824 0.258824; 0.219608 0.215686 … 0.258824 0.258824; … ; 0.235294 0.231373 … 0.258824 0.258824; 0.231373 0.227451 … 0.254902 0.258824;;; 0.466667 0.462745 … 0.00784314 0.00784314; 0.454902 0.454902 … 0.00784314 0.00784314; … ; 0.0 0.0 … 0.00784314 0.00784314; 0.0 0.0 … 0.00392157 0.00784314;;;;], θ=Float16[-7.7 -7.7 … -7.7 -1.2; -7.7 -7.7 … -7.7 -1.2; … ; -1.2 -1.2 … -1.2 -9.2; -1.2 -1.2 … -1.2 -1.2], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.0 0.0235294 … 0.4 0.392157; 0.0196078 0.00392157 … 0.392157 0.376471; … ; 0.137255 0.137255 … 0.156863 0.156863; 0.12549 0.105882 … 0.168627 0.168627;;; 0.188235 0.321569 … 0.235294 0.227451; 0.270588 0.247059 … 0.231373 0.219608; … ; 0.313726 0.309804 … 0.34902 0.333333; 0.298039 0.282353 … 0.345098 0.337255;;; 0.0 0.00784314 … 0.0 0.0; 0.0156863 0.00392157 … 0.0 0.0; … ; 0.0392157 0.0352941 … 0.0588235 0.0470588; 0.0313726 0.0235294 … 0.0588235 0.0509804;;;;], θ=Float16[-5.3 -5.3 … -1.2 -1.2; -5.3 -5.3 … -1.2 -1.2; … ; -5.3 -5.3 … -1.2 -1.2; -0.8 -0.8 … -0.8 -1.2], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
⋮
DataSample(x=Float32[0.0235294 0.0196078 … 0.427451 0.423529; 0.0156863 0.0156863 … 0.435294 0.443137; … ; 0.0156863 0.0156863 … 0.113725 0.12549; 0.0156863 0.0196078 … 0.129412 0.129412;;; 0.235294 0.231373 … 0.254902 0.25098; 0.219608 0.215686 … 0.258824 0.262745; … ; 0.211765 0.211765 … 0.278431 0.298039; 0.215686 0.223529 … 0.298039 0.294118;;; 0.466667 0.462745 … 0.00392157 0.0; 0.454902 0.454902 … 0.00784314 0.0117647; … ; 0.447059 0.447059 … 0.0196078 0.0313726; 0.447059 0.454902 … 0.0313726 0.027451;;;;], θ=Float16[-7.7 -1.2 … -1.2 -1.2; -7.7 -1.2 … -0.8 -1.2; … ; -7.7 -1.2 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.160784 0.160784 … 0.133333 0.12549; 0.160784 0.156863 … 0.105882 0.113725; … ; 0.364706 0.388235 … 0.435294 0.439216; 0.341176 0.415686 … 0.447059 0.435294;;; 0.321569 0.329412 … 0.294118 0.294118; 0.333333 0.313726 … 0.282353 0.286275; … ; 0.270588 0.27451 … 0.258824 0.258824; 0.254902 0.254902 … 0.254902 0.262745;;; 0.0392157 0.0431373 … 0.027451 0.027451; 0.0470588 0.0352941 … 0.0235294 0.027451; … ; 0.0156863 0.0196078 … 0.00784314 0.00784314; 0.00784314 0.00784314 … 0.00392157 0.00784314;;;;], θ=Float16[-0.8 -5.3 … -0.8 -0.8; -5.3 -5.3 … -1.2 -0.8; … ; -0.8 -5.3 … -7.7 -7.7; -0.8 -0.8 … -7.7 -1.2], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.12549 0.133333 … 0.145098 0.156863; 0.113725 0.105882 … 0.137255 0.152941; … ; 0.0431373 0.0 … 0.172549 0.137255; 0.0745098 0.0117647 … 0.145098 0.164706;;; 0.294118 0.294118 … 0.313726 0.329412; 0.286275 0.282353 … 0.305882 0.317647; … ; 0.247059 0.207843 … 0.333333 0.313726; 0.321569 0.219608 … 0.313726 0.360784;;; 0.027451 0.027451 … 0.0392157 0.0431373; 0.027451 0.0235294 … 0.0352941 0.0352941; … ; 0.0117647 0.0 … 0.0470588 0.0352941; 0.027451 0.00392157 … 0.0352941 0.0627451;;;;], θ=Float16[-0.8 -0.8 … -0.8 -0.8; -0.8 -0.8 … -1.2 -0.8; … ; -5.3 -0.8 … -1.2 -1.2; -5.3 -0.8 … -1.2 -0.8], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.439216 0.443137 … 0.431373 0.415686; 0.427451 0.435294 … 0.427451 0.431373; … ; 0.0196078 0.0156863 … 0.113725 0.12549; 0.0156863 0.0156863 … 0.129412 0.129412;;; 0.254902 0.247059 … 0.254902 0.247059; 0.25098 0.258824 … 0.254902 0.254902; … ; 0.211765 0.203922 … 0.278431 0.294118; 0.203922 0.203922 … 0.298039 0.298039;;; 0.00392157 0.0 … 0.00392157 0.0; 0.0 0.00784314 … 0.00392157 0.00392157; … ; 0.443137 0.439216 … 0.0196078 0.0313726; 0.439216 0.439216 … 0.0313726 0.027451;;;;], θ=Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -1.2; … ; -7.7 -7.7 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.12549 0.133333 … 0.160784 0.160784; 0.113725 0.105882 … 0.388235 0.403922; … ; 0.156863 0.164706 … 0.172549 0.286275; 0.14902 0.160784 … 0.266667 0.376471;;; 0.294118 0.294118 … 0.239216 0.188235; 0.286275 0.282353 … 0.247059 0.262745; … ; 0.329412 0.337255 … 0.172549 0.211765; 0.321569 0.321569 … 0.278431 0.211765;;; 0.027451 0.027451 … 0.0941176 0.0509804; 0.027451 0.0235294 … 0.0 0.00392157; … ; 0.0470588 0.0509804 … 0.00392157 0.0117647; 0.0431373 0.0392157 … 0.027451 0.0;;;;], θ=Float16[-0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -0.8 -0.8; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.431373 0.415686 … 0.101961 0.168627; 0.454902 0.337255 … 0.219608 0.32549; … ; 0.156863 0.156863 … 0.388235 0.45098; 0.141176 0.129412 … 0.431373 0.411765;;; 0.25098 0.254902 … 0.321569 0.352941; 0.262745 0.301961 … 0.447059 0.517647; … ; 0.32549 0.32549 … 0.262745 0.25098; 0.309804 0.298039 … 0.262745 0.266667;;; 0.00392157 0.00392157 … 0.0470588 0.0823529; 0.00784314 0.0352941 … 0.0941176 0.141176; … ; 0.0431373 0.0431373 … 0.0156863 0.00392157; 0.0352941 0.027451 … 0.00784314 0.0117647;;;;], θ=Float16[-0.8 -0.8 … -5.3 -5.3; -0.8 -5.3 … -5.3 -5.3; … ; -0.8 -1.2 … -0.8 -1.2; -0.8 -0.8 … -0.8 -1.2], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.137255 0.231373 … 0.129412 0.12549; 0.384314 0.262745 … 0.133333 0.12549; … ; 0.34902 0.235294 … 0.392157 0.34902; 0.172549 0.192157 … 0.411765 0.34902;;; 0.137255 0.231373 … 0.294118 0.290196; 0.384314 0.258824 … 0.298039 0.290196; … ; 0.34902 0.235294 … 0.243137 0.262745; 0.172549 0.192157 … 0.258824 0.254902;;; 0.129412 0.235294 … 0.027451 0.027451; 0.384314 0.258824 … 0.027451 0.0235294; … ; 0.34902 0.235294 … 0.00392157 0.0156863; 0.172549 0.192157 … 0.00784314 0.0117647;;;;], θ=Float16[-9.2 -9.2 … -0.8 -0.8; -9.2 -9.2 … -0.8 -0.8; … ; -9.2 -1.2 … -0.8 -0.8; -9.2 -1.2 … -0.8 -0.8], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.188235 0.313726 … 0.423529 0.403922; 0.254902 0.329412 … 0.411765 0.4; … ; 0.411765 0.396078 … 0.121569 0.117647; 0.435294 0.384314 … 0.109804 0.129412;;; 0.188235 0.313726 … 0.25098 0.235294; 0.254902 0.329412 … 0.243137 0.235294; … ; 0.243137 0.231373 … 0.282353 0.286275; 0.254902 0.223529 … 0.286275 0.290196;;; 0.188235 0.313726 … 0.00392157 0.0; 0.254902 0.329412 … 0.00392157 0.0; … ; 0.00784314 0.0 … 0.0235294 0.0235294; 0.0117647 0.0 … 0.0235294 0.0235294;;;;], θ=Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -0.8; -1.2 -1.2 … -0.8 -0.8], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])
DataSample(x=Float32[0.156863 0.14902 … 0.415686 0.364706; 0.14902 0.145098 … 0.313726 0.34902; … ; 0.168627 0.164706 … 0.188235 0.235294; 0.164706 0.160784 … 0.129412 0.156863;;; 0.329412 0.313726 … 0.270588 0.243137; 0.317647 0.309804 … 0.262745 0.270588; … ; 0.333333 0.337255 … 0.188235 0.235294; 0.329412 0.321569 … 0.129412 0.156863;;; 0.0470588 0.0352941 … 0.0117647 0.0117647; 0.0392157 0.0352941 … 0.0196078 0.0235294; … ; 0.0470588 0.0470588 … 0.188235 0.235294; 0.0431373 0.0392157 … 0.129412 0.156863;;;;], θ=Float16[-0.8 -1.2 … -9.2 -1.2; -0.8 -1.2 … -9.2 -1.2; … ; -0.8 -1.2 … -9.2 -9.2; -0.8 -0.8 … -9.2 -9.2], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])Split the dataset into training, validation, and test sets:
train_data, val_data = dataset[1:45], dataset[46:end](DecisionFocusedLearningBenchmarks.Utils.DataSample{@NamedTuple{}, @NamedTuple{}, Array{Float32, 4}, BitMatrix, Matrix{Float16}}[DataSample(x=Float32[0.133333 0.133333 … 0.164706 0.235294; 0.12549 0.105882 … 0.168627 0.313726; … ; 0.247059 0.247059 … 0.0196078 0.0156863; 0.384314 0.423529 … 0.0156863 0.0196078;;; 0.309804 0.294118 … 0.341176 0.356863; 0.298039 0.282353 … 0.329412 0.254902; … ; 0.305882 0.266667 … 0.227451 0.219608; 0.239216 0.239216 … 0.219608 0.227451;;; 0.0392157 0.027451 … 0.0509804 0.0705882; 0.0313726 0.0235294 … 0.0470588 0.0235294; … ; 0.0431373 0.027451 … 0.458824 0.454902; 0.00392157 0.0 … 0.454902 0.458824;;;;], θ=Float16[-0.8 -0.8 … -1.2 -0.8; -0.8 -0.8 … -1.2 -1.2; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.0196078 0.0196078 … 0.423529 0.411765; 0.0156863 0.0156863 … 0.427451 0.443137; … ; 0.0196078 0.0196078 … 0.439216 0.439216; 0.0196078 0.0196078 … 0.431373 0.435294;;; 0.219608 0.215686 … 0.25098 0.239216; 0.219608 0.207843 … 0.25098 0.258824; … ; 0.227451 0.227451 … 0.258824 0.258824; 0.231373 0.227451 … 0.254902 0.258824;;; 0.454902 0.447059 … 0.00784314 0.00392157; 0.454902 0.443137 … 0.00784314 0.0117647; … ; 0.458824 0.462745 … 0.00784314 0.00784314; 0.462745 0.458824 … 0.00392157 0.00784314;;;;], θ=Float16[-7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -1.2 -1.2; … ; -7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -9.2 -1.2], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.223529 0.0 … 0.160784 0.164706; 0.203922 0.0588235 … 0.156863 0.156863; … ; 0.0235294 0.00392157 … 0.00392157 0.0235294; 0.0352941 0.0117647 … 0.0 0.0;;; 0.407843 0.25098 … 0.32549 0.337255; 0.356863 0.32549 … 0.321569 0.329412; … ; 0.196078 0.215686 … 0.207843 0.235294; 0.215686 0.196078 … 0.231373 0.247059;;; 0.0823529 0.0 … 0.0431373 0.0509804; 0.054902 0.0352941 … 0.0431373 0.0431373; … ; 0.0 0.0 … 0.0 0.00784314; 0.0196078 0.0 … 0.0 0.0;;;;], θ=Float16[-5.3 -5.3 … -1.2 -0.8; -5.3 -5.3 … -1.2 -0.8; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.431373 0.454902 … 0.0156863 0.0196078; 0.411765 0.368627 … 0.0156863 0.0156863; … ; 0.14902 0.152941 … 0.152941 0.164706; 0.156863 0.137255 … 0.164706 0.164706;;; 0.25098 0.258824 … 0.219608 0.227451; 0.270588 0.278431 … 0.219608 0.219608; … ; 0.313726 0.321569 … 0.321569 0.333333; 0.32549 0.301961 … 0.333333 0.333333;;; 0.00392157 0.00784314 … 0.454902 0.462745; 0.0156863 0.0235294 … 0.454902 0.454902; … ; 0.0392157 0.0431373 … 0.0392157 0.0470588; 0.0431373 0.0313726 … 0.0470588 0.0470588;;;;], θ=Float16[-1.2 -7.7 … -7.7 -7.7; -7.7 -1.2 … -7.7 -7.7; … ; -0.8 -0.8 … -0.8 -5.3; -0.8 -1.2 … -0.8 -0.8], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.152941 0.317647 … 0.435294 0.403922; 0.376471 0.305882 … 0.431373 0.411765; … ; 0.2 0.121569 … 0.360784 0.411765; 0.0901961 0.172549 … 0.219608 0.447059;;; 0.152941 0.317647 … 0.258824 0.239216; 0.376471 0.305882 … 0.254902 0.243137; … ; 0.2 0.121569 … 0.223529 0.286275; 0.0901961 0.172549 … 0.164706 0.258824;;; 0.152941 0.317647 … 0.00784314 0.0; 0.376471 0.305882 … 0.00392157 0.0; … ; 0.2 0.121569 … 0.027451 0.109804; 0.0901961 0.172549 … 0.0862745 0.0;;;;], θ=Float16[-9.2 -9.2 … -0.8 -1.2; -9.2 -9.2 … -0.8 -1.2; … ; -9.2 -9.2 … -9.2 -1.2; -9.2 -9.2 … -9.2 -1.2], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 1]), DataSample(x=Float32[0.0235294 0.0196078 … 0.164706 0.164706; 0.0156863 0.0156863 … 0.168627 0.164706; … ; 0.168627 0.164706 … 0.431373 0.427451; 0.156863 0.160784 … 0.431373 0.427451;;; 0.227451 0.223529 … 0.333333 0.329412; 0.215686 0.207843 … 0.337255 0.329412; … ; 0.337255 0.337255 … 0.254902 0.254902; 0.329412 0.317647 … 0.254902 0.254902;;; 0.458824 0.454902 … 0.0470588 0.0470588; 0.45098 0.443137 … 0.0470588 0.0470588; … ; 0.0470588 0.0470588 … 0.00392157 0.00392157; 0.0431373 0.0392157 … 0.00392157 0.00392157;;;;], θ=Float16[-7.7 -7.7 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8; … ; -1.2 -1.2 … -9.2 -9.2; -0.8 -1.2 … -9.2 -1.2], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.407843 0.415686 … 0.380392 0.286275; 0.407843 0.411765 … 0.376471 0.309804; … ; 0.12549 0.113725 … 0.152941 0.160784; 0.129412 0.129412 … 0.156863 0.152941;;; 0.239216 0.243137 … 0.203922 0.278431; 0.239216 0.243137 … 0.278431 0.258824; … ; 0.298039 0.278431 … 0.317647 0.329412; 0.294118 0.298039 … 0.32549 0.333333;;; 0.00392157 0.00784314 … 0.0 0.0352941; 0.00392157 0.00392157 … 0.0235294 0.0235294; … ; 0.0313726 0.0196078 … 0.0392157 0.0470588; 0.027451 0.0313726 … 0.0431373 0.0470588;;;;], θ=Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -0.8 … -1.2 -1.2; … ; -0.8 -1.2 … -5.3 -5.3; -0.8 -0.8 … -0.8 -0.8], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 1]), DataSample(x=Float32[0.0156863 0.0156863 … 0.439216 0.427451; 0.0156863 0.0156863 … 0.431373 0.435294; … ; 0.337255 0.364706 … 0.168627 0.196078; 0.262745 0.368627 … 0.227451 0.341176;;; 0.207843 0.2 … 0.258824 0.25098; 0.203922 0.207843 … 0.254902 0.254902; … ; 0.239216 0.27451 … 0.34902 0.376471; 0.317647 0.243137 … 0.360784 0.247059;;; 0.443137 0.435294 … 0.0117647 0.00784314; 0.439216 0.443137 … 0.00784314 0.00784314; … ; 0.0156863 0.0196078 … 0.054902 0.0784314; 0.0470588 0.0117647 … 0.0705882 0.0196078;;;;], θ=Float16[-7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -1.2; -1.2 -1.2 … -0.8 -0.8], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.0235294 0.0196078 … 0.443137 0.439216; 0.0156863 0.0156863 … 0.435294 0.443137; … ; 0.396078 0.392157 … 0.439216 0.439216; 0.392157 0.392157 … 0.431373 0.435294;;; 0.235294 0.231373 … 0.258824 0.258824; 0.219608 0.215686 … 0.258824 0.258824; … ; 0.235294 0.231373 … 0.258824 0.258824; 0.231373 0.227451 … 0.254902 0.258824;;; 0.466667 0.462745 … 0.00784314 0.00784314; 0.454902 0.454902 … 0.00784314 0.00784314; … ; 0.0 0.0 … 0.00784314 0.00784314; 0.0 0.0 … 0.00392157 0.00784314;;;;], θ=Float16[-7.7 -7.7 … -7.7 -1.2; -7.7 -7.7 … -7.7 -1.2; … ; -1.2 -1.2 … -1.2 -9.2; -1.2 -1.2 … -1.2 -1.2], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.0 0.0235294 … 0.4 0.392157; 0.0196078 0.00392157 … 0.392157 0.376471; … ; 0.137255 0.137255 … 0.156863 0.156863; 0.12549 0.105882 … 0.168627 0.168627;;; 0.188235 0.321569 … 0.235294 0.227451; 0.270588 0.247059 … 0.231373 0.219608; … ; 0.313726 0.309804 … 0.34902 0.333333; 0.298039 0.282353 … 0.345098 0.337255;;; 0.0 0.00784314 … 0.0 0.0; 0.0156863 0.00392157 … 0.0 0.0; … ; 0.0392157 0.0352941 … 0.0588235 0.0470588; 0.0313726 0.0235294 … 0.0588235 0.0509804;;;;], θ=Float16[-5.3 -5.3 … -1.2 -1.2; -5.3 -5.3 … -1.2 -1.2; … ; -5.3 -5.3 … -1.2 -1.2; -0.8 -0.8 … -0.8 -1.2], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]) … DataSample(x=Float32[0.403922 0.431373 … 0.443137 0.435294; 0.435294 0.435294 … 0.427451 0.435294; … ; 0.0235294 0.00392157 … 0.00392157 0.0235294; 0.0352941 0.0117647 … 0.0 0.0;;; 0.235294 0.254902 … 0.258824 0.254902; 0.258824 0.254902 … 0.25098 0.254902; … ; 0.196078 0.215686 … 0.207843 0.235294; 0.215686 0.196078 … 0.231373 0.247059;;; 0.0 0.00784314 … 0.00784314 0.00784314; 0.00784314 0.00784314 … 0.00784314 0.00784314; … ; 0.0 0.0 … 0.0 0.00784314; 0.0196078 0.0 … 0.0 0.0;;;;], θ=Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -1.2; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.407843 0.415686 … 0.380392 0.286275; 0.407843 0.411765 … 0.376471 0.309804; … ; 0.164706 0.164706 … 0.443137 0.427451; 0.152941 0.160784 … 0.423529 0.419608;;; 0.239216 0.243137 … 0.203922 0.278431; 0.239216 0.243137 … 0.278431 0.258824; … ; 0.329412 0.32549 … 0.262745 0.25098; 0.321569 0.329412 … 0.25098 0.247059;;; 0.00392157 0.00784314 … 0.0 0.0352941; 0.00392157 0.00392157 … 0.0235294 0.0235294; … ; 0.0470588 0.0431373 … 0.00784314 0.00392157; 0.0392157 0.0470588 … 0.00392157 0.00784314;;;;], θ=Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -7.7 … -1.2 -1.2; … ; -0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -1.2 -1.2], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.27451 0.0196078 … 0.00392157 0.141176; 0.137255 0.0627451 … 0.0196078 0.0823529; … ; 0.427451 0.427451 … 0.396078 0.4; 0.427451 0.403922 … 0.392157 0.396078;;; 0.470588 0.231373 … 0.254902 0.376471; 0.403922 0.317647 … 0.231373 0.356863; … ; 0.25098 0.25098 … 0.235294 0.235294; 0.25098 0.235294 … 0.227451 0.231373;;; 0.141176 0.0156863 … 0.00392157 0.0745098; 0.0705882 0.0392157 … 0.00784314 0.0352941; … ; 0.00784314 0.00392157 … 0.0 0.0; 0.00392157 0.00392157 … 0.0 0.0;;;;], θ=Float16[-5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3; … ; -1.2 -0.8 … -1.2 -1.2; -1.2 -0.8 … -1.2 -1.2], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.145098 0.137255 … 0.129412 0.137255; 0.121569 0.12549 … 0.117647 0.109804; … ; 0.0235294 0.00392157 … 0.0784314 0.0117647; 0.0352941 0.0117647 … 0.0352941 0.0588235;;; 0.309804 0.309804 … 0.294118 0.305882; 0.294118 0.294118 … 0.298039 0.286275; … ; 0.196078 0.215686 … 0.301961 0.211765; 0.215686 0.196078 … 0.282353 0.301961;;; 0.0352941 0.0352941 … 0.027451 0.0313726; 0.027451 0.027451 … 0.0313726 0.027451; … ; 0.0 0.0 … 0.0352941 0.00784314; 0.0196078 0.0 … 0.0156863 0.0352941;;;;], θ=Float16[-0.8 -0.8 … -0.8 -0.8; -0.8 -0.8 … -0.8 -0.8; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.458824 0.282353 … 0.431373 0.415686; 0.356863 0.172549 … 0.427451 0.431373; … ; 0.384314 0.396078 … 0.172549 0.168627; 0.388235 0.392157 … 0.164706 0.172549;;; 0.25098 0.270588 … 0.254902 0.247059; 0.270588 0.196078 … 0.254902 0.254902; … ; 0.223529 0.231373 … 0.352941 0.356863; 0.227451 0.231373 … 0.337255 0.345098;;; 0.00392157 0.0156863 … 0.00392157 0.0; 0.00784314 0.0117647 … 0.00392157 0.00392157; … ; 0.0 0.0 … 0.0588235 0.0627451; 0.0 0.0 … 0.0509804 0.054902;;;;], θ=Float16[-7.7 -1.2 … -7.7 -1.2; -7.7 -1.2 … -1.2 -7.7; … ; -1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -0.8], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.121569 0.231373 … 0.435294 0.407843; 0.121569 0.266667 … 0.435294 0.423529; … ; 0.12549 0.12549 … 0.388235 0.45098; 0.117647 0.105882 … 0.431373 0.411765;;; 0.121569 0.231373 … 0.254902 0.239216; 0.121569 0.266667 … 0.258824 0.25098; … ; 0.301961 0.301961 … 0.262745 0.25098; 0.294118 0.282353 … 0.262745 0.266667;;; 0.121569 0.231373 … 0.00784314 0.0; 0.121569 0.266667 … 0.00392157 0.0; … ; 0.0313726 0.0313726 … 0.0156863 0.00392157; 0.0313726 0.0235294 … 0.00784314 0.0117647;;;;], θ=Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -1.2 -1.2], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 1]), DataSample(x=Float32[0.0235294 0.0196078 … 0.427451 0.423529; 0.0156863 0.0156863 … 0.435294 0.443137; … ; 0.0156863 0.0156863 … 0.113725 0.12549; 0.0156863 0.0196078 … 0.129412 0.129412;;; 0.235294 0.231373 … 0.254902 0.25098; 0.219608 0.215686 … 0.258824 0.262745; … ; 0.211765 0.211765 … 0.278431 0.298039; 0.215686 0.223529 … 0.298039 0.294118;;; 0.466667 0.462745 … 0.00392157 0.0; 0.454902 0.454902 … 0.00784314 0.0117647; … ; 0.447059 0.447059 … 0.0196078 0.0313726; 0.447059 0.454902 … 0.0313726 0.027451;;;;], θ=Float16[-7.7 -1.2 … -1.2 -1.2; -7.7 -1.2 … -0.8 -1.2; … ; -7.7 -1.2 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8], y=Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.160784 0.160784 … 0.133333 0.12549; 0.160784 0.156863 … 0.105882 0.113725; … ; 0.364706 0.388235 … 0.435294 0.439216; 0.341176 0.415686 … 0.447059 0.435294;;; 0.321569 0.329412 … 0.294118 0.294118; 0.333333 0.313726 … 0.282353 0.286275; … ; 0.270588 0.27451 … 0.258824 0.258824; 0.254902 0.254902 … 0.254902 0.262745;;; 0.0392157 0.0431373 … 0.027451 0.027451; 0.0470588 0.0352941 … 0.0235294 0.027451; … ; 0.0156863 0.0196078 … 0.00784314 0.00784314; 0.00784314 0.00784314 … 0.00392157 0.00784314;;;;], θ=Float16[-0.8 -5.3 … -0.8 -0.8; -5.3 -5.3 … -1.2 -0.8; … ; -0.8 -5.3 … -7.7 -7.7; -0.8 -0.8 … -7.7 -1.2], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.12549 0.133333 … 0.145098 0.156863; 0.113725 0.105882 … 0.137255 0.152941; … ; 0.0431373 0.0 … 0.172549 0.137255; 0.0745098 0.0117647 … 0.145098 0.164706;;; 0.294118 0.294118 … 0.313726 0.329412; 0.286275 0.282353 … 0.305882 0.317647; … ; 0.247059 0.207843 … 0.333333 0.313726; 0.321569 0.219608 … 0.313726 0.360784;;; 0.027451 0.027451 … 0.0392157 0.0431373; 0.027451 0.0235294 … 0.0352941 0.0352941; … ; 0.0117647 0.0 … 0.0470588 0.0352941; 0.027451 0.00392157 … 0.0352941 0.0627451;;;;], θ=Float16[-0.8 -0.8 … -0.8 -0.8; -0.8 -0.8 … -1.2 -0.8; … ; -5.3 -0.8 … -1.2 -1.2; -5.3 -0.8 … -1.2 -0.8], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.439216 0.443137 … 0.431373 0.415686; 0.427451 0.435294 … 0.427451 0.431373; … ; 0.0196078 0.0156863 … 0.113725 0.12549; 0.0156863 0.0156863 … 0.129412 0.129412;;; 0.254902 0.247059 … 0.254902 0.247059; 0.25098 0.258824 … 0.254902 0.254902; … ; 0.211765 0.203922 … 0.278431 0.294118; 0.203922 0.203922 … 0.298039 0.298039;;; 0.00392157 0.0 … 0.00392157 0.0; 0.0 0.00784314 … 0.00392157 0.00392157; … ; 0.443137 0.439216 … 0.0196078 0.0313726; 0.439216 0.439216 … 0.0313726 0.027451;;;;], θ=Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -1.2; … ; -7.7 -7.7 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])], DecisionFocusedLearningBenchmarks.Utils.DataSample{@NamedTuple{}, @NamedTuple{}, Array{Float32, 4}, BitMatrix, Matrix{Float16}}[DataSample(x=Float32[0.12549 0.133333 … 0.160784 0.160784; 0.113725 0.105882 … 0.388235 0.403922; … ; 0.156863 0.164706 … 0.172549 0.286275; 0.14902 0.160784 … 0.266667 0.376471;;; 0.294118 0.294118 … 0.239216 0.188235; 0.286275 0.282353 … 0.247059 0.262745; … ; 0.329412 0.337255 … 0.172549 0.211765; 0.321569 0.321569 … 0.278431 0.211765;;; 0.027451 0.027451 … 0.0941176 0.0509804; 0.027451 0.0235294 … 0.0 0.00392157; … ; 0.0470588 0.0509804 … 0.00392157 0.0117647; 0.0431373 0.0392157 … 0.027451 0.0;;;;], θ=Float16[-0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -0.8 -0.8; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.431373 0.415686 … 0.101961 0.168627; 0.454902 0.337255 … 0.219608 0.32549; … ; 0.156863 0.156863 … 0.388235 0.45098; 0.141176 0.129412 … 0.431373 0.411765;;; 0.25098 0.254902 … 0.321569 0.352941; 0.262745 0.301961 … 0.447059 0.517647; … ; 0.32549 0.32549 … 0.262745 0.25098; 0.309804 0.298039 … 0.262745 0.266667;;; 0.00392157 0.00392157 … 0.0470588 0.0823529; 0.00784314 0.0352941 … 0.0941176 0.141176; … ; 0.0431373 0.0431373 … 0.0156863 0.00392157; 0.0352941 0.027451 … 0.00784314 0.0117647;;;;], θ=Float16[-0.8 -0.8 … -5.3 -5.3; -0.8 -5.3 … -5.3 -5.3; … ; -0.8 -1.2 … -0.8 -1.2; -0.8 -0.8 … -0.8 -1.2], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.137255 0.231373 … 0.129412 0.12549; 0.384314 0.262745 … 0.133333 0.12549; … ; 0.34902 0.235294 … 0.392157 0.34902; 0.172549 0.192157 … 0.411765 0.34902;;; 0.137255 0.231373 … 0.294118 0.290196; 0.384314 0.258824 … 0.298039 0.290196; … ; 0.34902 0.235294 … 0.243137 0.262745; 0.172549 0.192157 … 0.258824 0.254902;;; 0.129412 0.235294 … 0.027451 0.027451; 0.384314 0.258824 … 0.027451 0.0235294; … ; 0.34902 0.235294 … 0.00392157 0.0156863; 0.172549 0.192157 … 0.00784314 0.0117647;;;;], θ=Float16[-9.2 -9.2 … -0.8 -0.8; -9.2 -9.2 … -0.8 -0.8; … ; -9.2 -1.2 … -0.8 -0.8; -9.2 -1.2 … -0.8 -0.8], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.188235 0.313726 … 0.423529 0.403922; 0.254902 0.329412 … 0.411765 0.4; … ; 0.411765 0.396078 … 0.121569 0.117647; 0.435294 0.384314 … 0.109804 0.129412;;; 0.188235 0.313726 … 0.25098 0.235294; 0.254902 0.329412 … 0.243137 0.235294; … ; 0.243137 0.231373 … 0.282353 0.286275; 0.254902 0.223529 … 0.286275 0.290196;;; 0.188235 0.313726 … 0.00392157 0.0; 0.254902 0.329412 … 0.00392157 0.0; … ; 0.00784314 0.0 … 0.0235294 0.0235294; 0.0117647 0.0 … 0.0235294 0.0235294;;;;], θ=Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -0.8; -1.2 -1.2 … -0.8 -0.8], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.156863 0.14902 … 0.415686 0.364706; 0.14902 0.145098 … 0.313726 0.34902; … ; 0.168627 0.164706 … 0.188235 0.235294; 0.164706 0.160784 … 0.129412 0.156863;;; 0.329412 0.313726 … 0.270588 0.243137; 0.317647 0.309804 … 0.262745 0.270588; … ; 0.333333 0.337255 … 0.188235 0.235294; 0.329412 0.321569 … 0.129412 0.156863;;; 0.0470588 0.0352941 … 0.0117647 0.0117647; 0.0392157 0.0352941 … 0.0196078 0.0235294; … ; 0.0470588 0.0470588 … 0.188235 0.235294; 0.0431373 0.0392157 … 0.129412 0.156863;;;;], θ=Float16[-0.8 -1.2 … -9.2 -1.2; -0.8 -1.2 … -9.2 -1.2; … ; -0.8 -1.2 … -9.2 -9.2; -0.8 -0.8 … -9.2 -9.2], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])])Creating a Policy
A DFLPolicy combines a statistical model (neural network) with a combinatorial optimizer. The benchmark provides utilities to generate appropriate models and optimizers:
model = generate_statistical_model(benchmark)
maximizer = generate_maximizer(benchmark; dijkstra=true)
policy = DFLPolicy(model, maximizer)DFLPolicy{Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(NNlib.relu), Vector{Float32}, Float32, Vector{Float32}}, Flux.MaxPool{2, 4}, Flux.Parallel{PartialFunctions.PartialFunction{nothing, nothing, typeof(Metalhead.addact), Tuple{typeof(NNlib.relu)}, @NamedTuple{}}, Tuple{typeof(identity), Flux.Chain{Tuple{Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}, typeof(NNlib.relu), Flux.Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, Flux.BatchNorm{typeof(identity), Vector{Float32}, Float32, Vector{Float32}}}}}}, Flux.AdaptiveMaxPool{4, 2}, typeof(DecisionFocusedLearningBenchmarks.Utils.average_tensor), typeof(DecisionFocusedLearningBenchmarks.Utils.neg_tensor), typeof(DecisionFocusedLearningBenchmarks.Utils.squeeze_last_dims)}}, typeof(DecisionFocusedLearningBenchmarks.Warcraft.dijkstra_maximizer)}(Chain(Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false), BatchNorm(64, relu), MaxPool((3, 3), pad=1, stride=2), Parallel(addact(NNlib.relu, ...), identity, Chain(Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64), relu, Conv((3, 3), 64 => 64, pad=1, bias=false), BatchNorm(64))), AdaptiveMaxPool((12, 12)), average_tensor, neg_tensor, squeeze_last_dims), DecisionFocusedLearningBenchmarks.Warcraft.dijkstra_maximizer)Configuring the Algorithm
We'll use the Perturbed Fenchel-Young Loss Imitation algorithm:
algorithm = PerturbedFenchelYoungLossImitation(;
nb_samples=100, # Number of perturbation samples for gradient estimation
ε=0.2, # Perturbation magnitude
threaded=true, # Use multi-threading for perturbations
training_optimizer=Adam(1e-3), # Flux optimizer with learning rate
seed=42, # Random seed for reproducibility
use_multiplicative_perturbation=true, # Use multiplicative perturbations
)PerturbedFenchelYoungLossImitation{Optimisers.Adam{Float64, Tuple{Float64, Float64}, Float64}, Int64}(100, 0.2, true, Optimisers.Adam(eta=0.001, beta=(0.9, 0.999), epsilon=1.0e-8), 42, true)Setting Up Metrics
We'll track several metrics during training:
Validation loss metric
val_loss_metric = FYLLossMetric(val_data, :validation_loss)FYLLossMetric{Vector{DecisionFocusedLearningBenchmarks.Utils.DataSample{@NamedTuple{}, @NamedTuple{}, Array{Float32, 4}, BitMatrix, Matrix{Float16}}}}(DecisionFocusedLearningBenchmarks.Utils.DataSample{@NamedTuple{}, @NamedTuple{}, Array{Float32, 4}, BitMatrix, Matrix{Float16}}[DataSample(x=Float32[0.12549 0.133333 … 0.160784 0.160784; 0.113725 0.105882 … 0.388235 0.403922; … ; 0.156863 0.164706 … 0.172549 0.286275; 0.14902 0.160784 … 0.266667 0.376471;;; 0.294118 0.294118 … 0.239216 0.188235; 0.286275 0.282353 … 0.247059 0.262745; … ; 0.329412 0.337255 … 0.172549 0.211765; 0.321569 0.321569 … 0.278431 0.211765;;; 0.027451 0.027451 … 0.0941176 0.0509804; 0.027451 0.0235294 … 0.0 0.00392157; … ; 0.0470588 0.0509804 … 0.00392157 0.0117647; 0.0431373 0.0392157 … 0.027451 0.0;;;;], θ=Float16[-0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -0.8 -0.8; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.431373 0.415686 … 0.101961 0.168627; 0.454902 0.337255 … 0.219608 0.32549; … ; 0.156863 0.156863 … 0.388235 0.45098; 0.141176 0.129412 … 0.431373 0.411765;;; 0.25098 0.254902 … 0.321569 0.352941; 0.262745 0.301961 … 0.447059 0.517647; … ; 0.32549 0.32549 … 0.262745 0.25098; 0.309804 0.298039 … 0.262745 0.266667;;; 0.00392157 0.00392157 … 0.0470588 0.0823529; 0.00784314 0.0352941 … 0.0941176 0.141176; … ; 0.0431373 0.0431373 … 0.0156863 0.00392157; 0.0352941 0.027451 … 0.00784314 0.0117647;;;;], θ=Float16[-0.8 -0.8 … -5.3 -5.3; -0.8 -5.3 … -5.3 -5.3; … ; -0.8 -1.2 … -0.8 -1.2; -0.8 -0.8 … -0.8 -1.2], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.137255 0.231373 … 0.129412 0.12549; 0.384314 0.262745 … 0.133333 0.12549; … ; 0.34902 0.235294 … 0.392157 0.34902; 0.172549 0.192157 … 0.411765 0.34902;;; 0.137255 0.231373 … 0.294118 0.290196; 0.384314 0.258824 … 0.298039 0.290196; … ; 0.34902 0.235294 … 0.243137 0.262745; 0.172549 0.192157 … 0.258824 0.254902;;; 0.129412 0.235294 … 0.027451 0.027451; 0.384314 0.258824 … 0.027451 0.0235294; … ; 0.34902 0.235294 … 0.00392157 0.0156863; 0.172549 0.192157 … 0.00784314 0.0117647;;;;], θ=Float16[-9.2 -9.2 … -0.8 -0.8; -9.2 -9.2 … -0.8 -0.8; … ; -9.2 -1.2 … -0.8 -0.8; -9.2 -1.2 … -0.8 -0.8], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.188235 0.313726 … 0.423529 0.403922; 0.254902 0.329412 … 0.411765 0.4; … ; 0.411765 0.396078 … 0.121569 0.117647; 0.435294 0.384314 … 0.109804 0.129412;;; 0.188235 0.313726 … 0.25098 0.235294; 0.254902 0.329412 … 0.243137 0.235294; … ; 0.243137 0.231373 … 0.282353 0.286275; 0.254902 0.223529 … 0.286275 0.290196;;; 0.188235 0.313726 … 0.00392157 0.0; 0.254902 0.329412 … 0.00392157 0.0; … ; 0.00784314 0.0 … 0.0235294 0.0235294; 0.0117647 0.0 … 0.0235294 0.0235294;;;;], θ=Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -0.8; -1.2 -1.2 … -0.8 -0.8], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.156863 0.14902 … 0.415686 0.364706; 0.14902 0.145098 … 0.313726 0.34902; … ; 0.168627 0.164706 … 0.188235 0.235294; 0.164706 0.160784 … 0.129412 0.156863;;; 0.329412 0.313726 … 0.270588 0.243137; 0.317647 0.309804 … 0.262745 0.270588; … ; 0.333333 0.337255 … 0.188235 0.235294; 0.329412 0.321569 … 0.129412 0.156863;;; 0.0470588 0.0352941 … 0.0117647 0.0117647; 0.0392157 0.0352941 … 0.0196078 0.0235294; … ; 0.0470588 0.0470588 … 0.188235 0.235294; 0.0431373 0.0392157 … 0.129412 0.156863;;;;], θ=Float16[-0.8 -1.2 … -9.2 -1.2; -0.8 -1.2 … -9.2 -1.2; … ; -0.8 -1.2 … -9.2 -9.2; -0.8 -0.8 … -9.2 -9.2], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])], LossAccumulator(:validation_loss, 0.0, 0))Validation gap metric
val_gap_metric = FunctionMetric(:val_gap, val_data) do ctx, data
compute_gap(benchmark, data, ctx.policy.statistical_model, ctx.policy.maximizer)
endFunctionMetric{Main.var"#2#3", Vector{DecisionFocusedLearningBenchmarks.Utils.DataSample{@NamedTuple{}, @NamedTuple{}, Array{Float32, 4}, BitMatrix, Matrix{Float16}}}}(Main.var"#2#3"(), :val_gap, DecisionFocusedLearningBenchmarks.Utils.DataSample{@NamedTuple{}, @NamedTuple{}, Array{Float32, 4}, BitMatrix, Matrix{Float16}}[DataSample(x=Float32[0.12549 0.133333 … 0.160784 0.160784; 0.113725 0.105882 … 0.388235 0.403922; … ; 0.156863 0.164706 … 0.172549 0.286275; 0.14902 0.160784 … 0.266667 0.376471;;; 0.294118 0.294118 … 0.239216 0.188235; 0.286275 0.282353 … 0.247059 0.262745; … ; 0.329412 0.337255 … 0.172549 0.211765; 0.321569 0.321569 … 0.278431 0.211765;;; 0.027451 0.027451 … 0.0941176 0.0509804; 0.027451 0.0235294 … 0.0 0.00392157; … ; 0.0470588 0.0509804 … 0.00392157 0.0117647; 0.0431373 0.0392157 … 0.027451 0.0;;;;], θ=Float16[-0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -0.8 -0.8; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.431373 0.415686 … 0.101961 0.168627; 0.454902 0.337255 … 0.219608 0.32549; … ; 0.156863 0.156863 … 0.388235 0.45098; 0.141176 0.129412 … 0.431373 0.411765;;; 0.25098 0.254902 … 0.321569 0.352941; 0.262745 0.301961 … 0.447059 0.517647; … ; 0.32549 0.32549 … 0.262745 0.25098; 0.309804 0.298039 … 0.262745 0.266667;;; 0.00392157 0.00392157 … 0.0470588 0.0823529; 0.00784314 0.0352941 … 0.0941176 0.141176; … ; 0.0431373 0.0431373 … 0.0156863 0.00392157; 0.0352941 0.027451 … 0.00784314 0.0117647;;;;], θ=Float16[-0.8 -0.8 … -5.3 -5.3; -0.8 -5.3 … -5.3 -5.3; … ; -0.8 -1.2 … -0.8 -1.2; -0.8 -0.8 … -0.8 -1.2], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.137255 0.231373 … 0.129412 0.12549; 0.384314 0.262745 … 0.133333 0.12549; … ; 0.34902 0.235294 … 0.392157 0.34902; 0.172549 0.192157 … 0.411765 0.34902;;; 0.137255 0.231373 … 0.294118 0.290196; 0.384314 0.258824 … 0.298039 0.290196; … ; 0.34902 0.235294 … 0.243137 0.262745; 0.172549 0.192157 … 0.258824 0.254902;;; 0.129412 0.235294 … 0.027451 0.027451; 0.384314 0.258824 … 0.027451 0.0235294; … ; 0.34902 0.235294 … 0.00392157 0.0156863; 0.172549 0.192157 … 0.00784314 0.0117647;;;;], θ=Float16[-9.2 -9.2 … -0.8 -0.8; -9.2 -9.2 … -0.8 -0.8; … ; -9.2 -1.2 … -0.8 -0.8; -9.2 -1.2 … -0.8 -0.8], y=Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.188235 0.313726 … 0.423529 0.403922; 0.254902 0.329412 … 0.411765 0.4; … ; 0.411765 0.396078 … 0.121569 0.117647; 0.435294 0.384314 … 0.109804 0.129412;;; 0.188235 0.313726 … 0.25098 0.235294; 0.254902 0.329412 … 0.243137 0.235294; … ; 0.243137 0.231373 … 0.282353 0.286275; 0.254902 0.223529 … 0.286275 0.290196;;; 0.188235 0.313726 … 0.00392157 0.0; 0.254902 0.329412 … 0.00392157 0.0; … ; 0.00784314 0.0 … 0.0235294 0.0235294; 0.0117647 0.0 … 0.0235294 0.0235294;;;;], θ=Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -0.8; -1.2 -1.2 … -0.8 -0.8], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1]), DataSample(x=Float32[0.156863 0.14902 … 0.415686 0.364706; 0.14902 0.145098 … 0.313726 0.34902; … ; 0.168627 0.164706 … 0.188235 0.235294; 0.164706 0.160784 … 0.129412 0.156863;;; 0.329412 0.313726 … 0.270588 0.243137; 0.317647 0.309804 … 0.262745 0.270588; … ; 0.333333 0.337255 … 0.188235 0.235294; 0.329412 0.321569 … 0.129412 0.156863;;; 0.0470588 0.0352941 … 0.0117647 0.0117647; 0.0392157 0.0352941 … 0.0196078 0.0235294; … ; 0.0470588 0.0470588 … 0.188235 0.235294; 0.0431373 0.0392157 … 0.129412 0.156863;;;;], θ=Float16[-0.8 -1.2 … -9.2 -1.2; -0.8 -1.2 … -9.2 -1.2; … ; -0.8 -1.2 … -9.2 -9.2; -0.8 -0.8 … -9.2 -9.2], y=Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1])])Training
Now we train the policy:
data_loader = DataLoader(train_data; batchsize=50)
history = train_policy!(
algorithm, policy, data_loader; epochs=50, metrics=(val_loss_metric, val_gap_metric)
)MVHistory{ValueHistories.History}
:training_loss => 51 elements {Int64,Float64}
:val_gap => 51 elements {Int64,Float16}
:validation_loss => 51 elements {Int64,Float64}Results Analysis
Let's examine the training progress:
Extract training history
train_loss_epochs, train_loss_values = get(history, :training_loss)
val_loss_epochs, val_loss_values = get(history, :validation_loss)
val_gap_epochs, val_gap_values = get(history, :val_gap)([0, 1, 2, 3, 4, 5, 6, 7, 8, 9 … 41, 42, 43, 44, 45, 46, 47, 48, 49, 50], Float16[0.775, 0.775, 0.3687, 0.236, 0.236, 0.236, 0.3687, 0.3687, 0.354, 0.354 … 0.08826, 0.08826, 0.08826, 0.1367, 0.1367, 0.1367, 0.09686, 0.09686, 0.09686, 0.09686])Plot training and validation loss
p1 = plot(
train_loss_epochs,
train_loss_values;
label="Training",
xlabel="Epoch",
ylabel="FYL Loss",
title="Training Progress",
linewidth=2,
)
plot!(p1, val_loss_epochs, val_loss_values; label="Validation", linewidth=2)Plot gap evolution
p2 = plot(
val_gap_epochs,
val_gap_values;
label="Validation Gap",
xlabel="Epoch",
ylabel="Gap (Regret)",
title="Decision Quality",
linewidth=2,
)This page was generated using Literate.jl.