Path-finding on image maps

In this tutorial, we showcase DecisionFocusedLearningBenchmarks.jl capabilities on one of its main benchmarks: the Warcraft benchmark. This benchmark problem is a simple path-finding problem where the goal is to find the shortest path between the top left and bottom right corners of a given image map. The map is represented as a 2D image representing a 12x12 grid, each cell having an unknown travel cost depending on the terrain type.

First, let's load the package and create a benchmark object as follows:

using DecisionFocusedLearningBenchmarks
b = WarcraftBenchmark()
WarcraftBenchmark()

Dataset generation

These benchmark objects behave as generators that can generate various needed elements in order to build an algorithm to tackle the problem. First of all, all benchmarks are capable of generating datasets as needed, using the generate_dataset method. This method takes as input the benchmark object for which the dataset is to be generated, and a second argument specifying the number of samples to generate:

dataset = generate_dataset(b, 50);
┌ Warning: Checksum not provided, add to the Datadep Registration the following hash line
│   hash = "67fa5e32ac6e567ee41b891309a776e6b57c09a7b6b37a368ac12b5a64a781fb"
└ @ DataDeps ~/.julia/packages/DataDeps/Y2lje/src/verification.jl:44

7-Zip (a) [64] 17.04 : Copyright (c) 1999-2021 Igor Pavlov : 2017-08-28
p7zip Version 17.04 (locale=C.UTF-8,Utf16=on,HugeFiles=on,64 bits,4 CPUs AMD EPYC 7763 64-Core Processor                 (A00F11),ASM,AES-NI)

Scanning the drive for archives:
1 file, 77643058 bytes (75 MiB)

Extracting archive: /home/runner/.julia/scratchspaces/124859b0-ceae-595e-8997-d05f6a7a8dfe/datadeps/warcraft/data.zip
--
Path = /home/runner/.julia/scratchspaces/124859b0-ceae-595e-8997-d05f6a7a8dfe/datadeps/warcraft/data.zip
Type = zip
Physical Size = 77643058

Everything is Ok

Folders: 3
Files: 15
Size:       336974348
Compressed: 77643058

We obtain a vector of DataSample objects, containing all needed data for the problem. Subdatasets can be created through regular slicing:

train_dataset, test_dataset = dataset[1:45], dataset[46:50]
(DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}[DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.13333334 0.13333334 … 0.16470589 0.23529412; 0.1254902 0.105882354 … 0.16862746 0.3137255; … ; 0.24705882 0.24705882 … 0.019607844 0.015686275; 0.38431373 0.42352942 … 0.015686275 0.019607844;;; 0.30980393 0.29411766 … 0.34117648 0.35686275; 0.29803923 0.28235295 … 0.32941177 0.25490198; … ; 0.30588236 0.26666668 … 0.22745098 0.21960784; 0.23921569 0.23921569 … 0.21960784 0.22745098;;; 0.039215688 0.02745098 … 0.050980393 0.07058824; 0.03137255 0.023529412 … 0.047058824 0.023529412; … ; 0.043137256 0.02745098 … 0.45882353 0.45490196; 0.003921569 0.0 … 0.45490196 0.45882353;;;;], Float16[-0.8 -0.8 … -1.2 -0.8; -0.8 -0.8 … -1.2 -1.2; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.019607844 0.019607844 … 0.42352942 0.4117647; 0.015686275 0.015686275 … 0.42745098 0.44313726; … ; 0.019607844 0.019607844 … 0.4392157 0.4392157; 0.019607844 0.019607844 … 0.43137255 0.43529412;;; 0.21960784 0.21568628 … 0.2509804 0.23921569; 0.21960784 0.20784314 … 0.2509804 0.25882354; … ; 0.22745098 0.22745098 … 0.25882354 0.25882354; 0.23137255 0.22745098 … 0.25490198 0.25882354;;; 0.45490196 0.44705883 … 0.007843138 0.003921569; 0.45490196 0.44313726 … 0.007843138 0.011764706; … ; 0.45882353 0.4627451 … 0.007843138 0.007843138; 0.4627451 0.45882353 … 0.003921569 0.007843138;;;;], Float16[-7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -1.2 -1.2; … ; -7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -9.2 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.22352941 0.0 … 0.16078432 0.16470589; 0.20392157 0.05882353 … 0.15686275 0.15686275; … ; 0.023529412 0.003921569 … 0.003921569 0.023529412; 0.03529412 0.011764706 … 0.0 0.0;;; 0.40784314 0.2509804 … 0.3254902 0.3372549; 0.35686275 0.3254902 … 0.32156864 0.32941177; … ; 0.19607843 0.21568628 … 0.20784314 0.23529412; 0.21568628 0.19607843 … 0.23137255 0.24705882;;; 0.08235294 0.0 … 0.043137256 0.050980393; 0.05490196 0.03529412 … 0.043137256 0.043137256; … ; 0.0 0.0 … 0.0 0.007843138; 0.019607844 0.0 … 0.0 0.0;;;;], Float16[-5.3 -5.3 … -1.2 -0.8; -5.3 -5.3 … -1.2 -0.8; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.43137255 0.45490196 … 0.015686275 0.019607844; 0.4117647 0.36862746 … 0.015686275 0.015686275; … ; 0.14901961 0.15294118 … 0.15294118 0.16470589; 0.15686275 0.13725491 … 0.16470589 0.16470589;;; 0.2509804 0.25882354 … 0.21960784 0.22745098; 0.27058825 0.2784314 … 0.21960784 0.21960784; … ; 0.3137255 0.32156864 … 0.32156864 0.33333334; 0.3254902 0.3019608 … 0.33333334 0.33333334;;; 0.003921569 0.007843138 … 0.45490196 0.4627451; 0.015686275 0.023529412 … 0.45490196 0.45490196; … ; 0.039215688 0.043137256 … 0.039215688 0.047058824; 0.043137256 0.03137255 … 0.047058824 0.047058824;;;;], Float16[-1.2 -7.7 … -7.7 -7.7; -7.7 -1.2 … -7.7 -7.7; … ; -0.8 -0.8 … -0.8 -5.3; -0.8 -1.2 … -0.8 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.15294118 0.31764707 … 0.43529412 0.40392157; 0.3764706 0.30588236 … 0.43137255 0.4117647; … ; 0.2 0.12156863 … 0.36078432 0.4117647; 0.09019608 0.17254902 … 0.21960784 0.44705883;;; 0.15294118 0.31764707 … 0.25882354 0.23921569; 0.3764706 0.30588236 … 0.25490198 0.24313726; … ; 0.2 0.12156863 … 0.22352941 0.28627452; 0.09019608 0.17254902 … 0.16470589 0.25882354;;; 0.15294118 0.31764707 … 0.007843138 0.0; 0.3764706 0.30588236 … 0.003921569 0.0; … ; 0.2 0.12156863 … 0.02745098 0.10980392; 0.09019608 0.17254902 … 0.08627451 0.0;;;;], Float16[-9.2 -9.2 … -0.8 -1.2; -9.2 -9.2 … -0.8 -1.2; … ; -9.2 -9.2 … -9.2 -1.2; -9.2 -9.2 … -9.2 -1.2], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.023529412 0.019607844 … 0.16470589 0.16470589; 0.015686275 0.015686275 … 0.16862746 0.16470589; … ; 0.16862746 0.16470589 … 0.43137255 0.42745098; 0.15686275 0.16078432 … 0.43137255 0.42745098;;; 0.22745098 0.22352941 … 0.33333334 0.32941177; 0.21568628 0.20784314 … 0.3372549 0.32941177; … ; 0.3372549 0.3372549 … 0.25490198 0.25490198; 0.32941177 0.31764707 … 0.25490198 0.25490198;;; 0.45882353 0.45490196 … 0.047058824 0.047058824; 0.4509804 0.44313726 … 0.047058824 0.047058824; … ; 0.047058824 0.047058824 … 0.003921569 0.003921569; 0.043137256 0.039215688 … 0.003921569 0.003921569;;;;], Float16[-7.7 -7.7 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8; … ; -1.2 -1.2 … -9.2 -9.2; -0.8 -1.2 … -9.2 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.40784314 0.41568628 … 0.38039216 0.28627452; 0.40784314 0.4117647 … 0.3764706 0.30980393; … ; 0.1254902 0.11372549 … 0.15294118 0.16078432; 0.12941177 0.12941177 … 0.15686275 0.15294118;;; 0.23921569 0.24313726 … 0.20392157 0.2784314; 0.23921569 0.24313726 … 0.2784314 0.25882354; … ; 0.29803923 0.2784314 … 0.31764707 0.32941177; 0.29411766 0.29803923 … 0.3254902 0.33333334;;; 0.003921569 0.007843138 … 0.0 0.03529412; 0.003921569 0.003921569 … 0.023529412 0.023529412; … ; 0.03137255 0.019607844 … 0.039215688 0.047058824; 0.02745098 0.03137255 … 0.043137256 0.047058824;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -0.8 … -1.2 -1.2; … ; -0.8 -1.2 … -5.3 -5.3; -0.8 -0.8 … -0.8 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.015686275 0.015686275 … 0.4392157 0.42745098; 0.015686275 0.015686275 … 0.43137255 0.43529412; … ; 0.3372549 0.3647059 … 0.16862746 0.19607843; 0.2627451 0.36862746 … 0.22745098 0.34117648;;; 0.20784314 0.2 … 0.25882354 0.2509804; 0.20392157 0.20784314 … 0.25490198 0.25490198; … ; 0.23921569 0.27450982 … 0.34901962 0.3764706; 0.31764707 0.24313726 … 0.36078432 0.24705882;;; 0.44313726 0.43529412 … 0.011764706 0.007843138; 0.4392157 0.44313726 … 0.007843138 0.007843138; … ; 0.015686275 0.019607844 … 0.05490196 0.078431375; 0.047058824 0.011764706 … 0.07058824 0.019607844;;;;], Float16[-7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -1.2; -1.2 -1.2 … -0.8 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.023529412 0.019607844 … 0.44313726 0.4392157; 0.015686275 0.015686275 … 0.43529412 0.44313726; … ; 0.39607844 0.39215687 … 0.4392157 0.4392157; 0.39215687 0.39215687 … 0.43137255 0.43529412;;; 0.23529412 0.23137255 … 0.25882354 0.25882354; 0.21960784 0.21568628 … 0.25882354 0.25882354; … ; 0.23529412 0.23137255 … 0.25882354 0.25882354; 0.23137255 0.22745098 … 0.25490198 0.25882354;;; 0.46666667 0.4627451 … 0.007843138 0.007843138; 0.45490196 0.45490196 … 0.007843138 0.007843138; … ; 0.0 0.0 … 0.007843138 0.007843138; 0.0 0.0 … 0.003921569 0.007843138;;;;], Float16[-7.7 -7.7 … -7.7 -1.2; -7.7 -7.7 … -7.7 -1.2; … ; -1.2 -1.2 … -1.2 -9.2; -1.2 -1.2 … -1.2 -1.2], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.0 0.023529412 … 0.4 0.39215687; 0.019607844 0.003921569 … 0.39215687 0.3764706; … ; 0.13725491 0.13725491 … 0.15686275 0.15686275; 0.1254902 0.105882354 … 0.16862746 0.16862746;;; 0.1882353 0.32156864 … 0.23529412 0.22745098; 0.27058825 0.24705882 … 0.23137255 0.21960784; … ; 0.3137255 0.30980393 … 0.34901962 0.33333334; 0.29803923 0.28235295 … 0.34509805 0.3372549;;; 0.0 0.007843138 … 0.0 0.0; 0.015686275 0.003921569 … 0.0 0.0; … ; 0.039215688 0.03529412 … 0.05882353 0.047058824; 0.03137255 0.023529412 … 0.05882353 0.050980393;;;;], Float16[-5.3 -5.3 … -1.2 -1.2; -5.3 -5.3 … -1.2 -1.2; … ; -5.3 -5.3 … -1.2 -1.2; -0.8 -0.8 … -0.8 -1.2], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)  …  DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.40392157 0.43137255 … 0.44313726 0.43529412; 0.43529412 0.43529412 … 0.42745098 0.43529412; … ; 0.023529412 0.003921569 … 0.003921569 0.023529412; 0.03529412 0.011764706 … 0.0 0.0;;; 0.23529412 0.25490198 … 0.25882354 0.25490198; 0.25882354 0.25490198 … 0.2509804 0.25490198; … ; 0.19607843 0.21568628 … 0.20784314 0.23529412; 0.21568628 0.19607843 … 0.23137255 0.24705882;;; 0.0 0.007843138 … 0.007843138 0.007843138; 0.007843138 0.007843138 … 0.007843138 0.007843138; … ; 0.0 0.0 … 0.0 0.007843138; 0.019607844 0.0 … 0.0 0.0;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -1.2; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.40784314 0.41568628 … 0.38039216 0.28627452; 0.40784314 0.4117647 … 0.3764706 0.30980393; … ; 0.16470589 0.16470589 … 0.44313726 0.42745098; 0.15294118 0.16078432 … 0.42352942 0.41960785;;; 0.23921569 0.24313726 … 0.20392157 0.2784314; 0.23921569 0.24313726 … 0.2784314 0.25882354; … ; 0.32941177 0.3254902 … 0.2627451 0.2509804; 0.32156864 0.32941177 … 0.2509804 0.24705882;;; 0.003921569 0.007843138 … 0.0 0.03529412; 0.003921569 0.003921569 … 0.023529412 0.023529412; … ; 0.047058824 0.043137256 … 0.007843138 0.003921569; 0.039215688 0.047058824 … 0.003921569 0.007843138;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -7.7 … -1.2 -1.2; … ; -0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -1.2 -1.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.27450982 0.019607844 … 0.003921569 0.14117648; 0.13725491 0.0627451 … 0.019607844 0.08235294; … ; 0.42745098 0.42745098 … 0.39607844 0.4; 0.42745098 0.40392157 … 0.39215687 0.39607844;;; 0.47058824 0.23137255 … 0.25490198 0.3764706; 0.40392157 0.31764707 … 0.23137255 0.35686275; … ; 0.2509804 0.2509804 … 0.23529412 0.23529412; 0.2509804 0.23529412 … 0.22745098 0.23137255;;; 0.14117648 0.015686275 … 0.003921569 0.07450981; 0.07058824 0.039215688 … 0.007843138 0.03529412; … ; 0.007843138 0.003921569 … 0.0 0.0; 0.003921569 0.003921569 … 0.0 0.0;;;;], Float16[-5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3; … ; -1.2 -0.8 … -1.2 -1.2; -1.2 -0.8 … -1.2 -1.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.14509805 0.13725491 … 0.12941177 0.13725491; 0.12156863 0.1254902 … 0.11764706 0.10980392; … ; 0.023529412 0.003921569 … 0.078431375 0.011764706; 0.03529412 0.011764706 … 0.03529412 0.05882353;;; 0.30980393 0.30980393 … 0.29411766 0.30588236; 0.29411766 0.29411766 … 0.29803923 0.28627452; … ; 0.19607843 0.21568628 … 0.3019608 0.21176471; 0.21568628 0.19607843 … 0.28235295 0.3019608;;; 0.03529412 0.03529412 … 0.02745098 0.03137255; 0.02745098 0.02745098 … 0.03137255 0.02745098; … ; 0.0 0.0 … 0.03529412 0.007843138; 0.019607844 0.0 … 0.015686275 0.03529412;;;;], Float16[-0.8 -0.8 … -0.8 -0.8; -0.8 -0.8 … -0.8 -0.8; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.45882353 0.28235295 … 0.43137255 0.41568628; 0.35686275 0.17254902 … 0.42745098 0.43137255; … ; 0.38431373 0.39607844 … 0.17254902 0.16862746; 0.3882353 0.39215687 … 0.16470589 0.17254902;;; 0.2509804 0.27058825 … 0.25490198 0.24705882; 0.27058825 0.19607843 … 0.25490198 0.25490198; … ; 0.22352941 0.23137255 … 0.3529412 0.35686275; 0.22745098 0.23137255 … 0.3372549 0.34509805;;; 0.003921569 0.015686275 … 0.003921569 0.0; 0.007843138 0.011764706 … 0.003921569 0.003921569; … ; 0.0 0.0 … 0.05882353 0.0627451; 0.0 0.0 … 0.050980393 0.05490196;;;;], Float16[-7.7 -1.2 … -7.7 -1.2; -7.7 -1.2 … -1.2 -7.7; … ; -1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.12156863 0.23137255 … 0.43529412 0.40784314; 0.12156863 0.26666668 … 0.43529412 0.42352942; … ; 0.1254902 0.1254902 … 0.3882353 0.4509804; 0.11764706 0.105882354 … 0.43137255 0.4117647;;; 0.12156863 0.23137255 … 0.25490198 0.23921569; 0.12156863 0.26666668 … 0.25882354 0.2509804; … ; 0.3019608 0.3019608 … 0.2627451 0.2509804; 0.29411766 0.28235295 … 0.2627451 0.26666668;;; 0.12156863 0.23137255 … 0.007843138 0.0; 0.12156863 0.26666668 … 0.003921569 0.0; … ; 0.03137255 0.03137255 … 0.015686275 0.003921569; 0.03137255 0.023529412 … 0.007843138 0.011764706;;;;], Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -1.2 -1.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.023529412 0.019607844 … 0.42745098 0.42352942; 0.015686275 0.015686275 … 0.43529412 0.44313726; … ; 0.015686275 0.015686275 … 0.11372549 0.1254902; 0.015686275 0.019607844 … 0.12941177 0.12941177;;; 0.23529412 0.23137255 … 0.25490198 0.2509804; 0.21960784 0.21568628 … 0.25882354 0.2627451; … ; 0.21176471 0.21176471 … 0.2784314 0.29803923; 0.21568628 0.22352941 … 0.29803923 0.29411766;;; 0.46666667 0.4627451 … 0.003921569 0.0; 0.45490196 0.45490196 … 0.007843138 0.011764706; … ; 0.44705883 0.44705883 … 0.019607844 0.03137255; 0.44705883 0.45490196 … 0.03137255 0.02745098;;;;], Float16[-7.7 -1.2 … -1.2 -1.2; -7.7 -1.2 … -0.8 -1.2; … ; -7.7 -1.2 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.16078432 0.16078432 … 0.13333334 0.1254902; 0.16078432 0.15686275 … 0.105882354 0.11372549; … ; 0.3647059 0.3882353 … 0.43529412 0.4392157; 0.34117648 0.41568628 … 0.44705883 0.43529412;;; 0.32156864 0.32941177 … 0.29411766 0.29411766; 0.33333334 0.3137255 … 0.28235295 0.28627452; … ; 0.27058825 0.27450982 … 0.25882354 0.25882354; 0.25490198 0.25490198 … 0.25490198 0.2627451;;; 0.039215688 0.043137256 … 0.02745098 0.02745098; 0.047058824 0.03529412 … 0.023529412 0.02745098; … ; 0.015686275 0.019607844 … 0.007843138 0.007843138; 0.007843138 0.007843138 … 0.003921569 0.007843138;;;;], Float16[-0.8 -5.3 … -0.8 -0.8; -5.3 -5.3 … -1.2 -0.8; … ; -0.8 -5.3 … -7.7 -7.7; -0.8 -0.8 … -7.7 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1254902 0.13333334 … 0.14509805 0.15686275; 0.11372549 0.105882354 … 0.13725491 0.15294118; … ; 0.043137256 0.0 … 0.17254902 0.13725491; 0.07450981 0.011764706 … 0.14509805 0.16470589;;; 0.29411766 0.29411766 … 0.3137255 0.32941177; 0.28627452 0.28235295 … 0.30588236 0.31764707; … ; 0.24705882 0.20784314 … 0.33333334 0.3137255; 0.32156864 0.21960784 … 0.3137255 0.36078432;;; 0.02745098 0.02745098 … 0.039215688 0.043137256; 0.02745098 0.023529412 … 0.03529412 0.03529412; … ; 0.011764706 0.0 … 0.047058824 0.03529412; 0.02745098 0.003921569 … 0.03529412 0.0627451;;;;], Float16[-0.8 -0.8 … -0.8 -0.8; -0.8 -0.8 … -1.2 -0.8; … ; -5.3 -0.8 … -1.2 -1.2; -5.3 -0.8 … -1.2 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.4392157 0.44313726 … 0.43137255 0.41568628; 0.42745098 0.43529412 … 0.42745098 0.43137255; … ; 0.019607844 0.015686275 … 0.11372549 0.1254902; 0.015686275 0.015686275 … 0.12941177 0.12941177;;; 0.25490198 0.24705882 … 0.25490198 0.24705882; 0.2509804 0.25882354 … 0.25490198 0.25490198; … ; 0.21176471 0.20392157 … 0.2784314 0.29411766; 0.20392157 0.20392157 … 0.29803923 0.29803923;;; 0.003921569 0.0 … 0.003921569 0.0; 0.0 0.007843138 … 0.003921569 0.003921569; … ; 0.44313726 0.4392157 … 0.019607844 0.03137255; 0.4392157 0.4392157 … 0.03137255 0.02745098;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -1.2; … ; -7.7 -7.7 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)], DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}[DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1254902 0.13333334 … 0.16078432 0.16078432; 0.11372549 0.105882354 … 0.3882353 0.40392157; … ; 0.15686275 0.16470589 … 0.17254902 0.28627452; 0.14901961 0.16078432 … 0.26666668 0.3764706;;; 0.29411766 0.29411766 … 0.23921569 0.1882353; 0.28627452 0.28235295 … 0.24705882 0.2627451; … ; 0.32941177 0.3372549 … 0.17254902 0.21176471; 0.32156864 0.32156864 … 0.2784314 0.21176471;;; 0.02745098 0.02745098 … 0.09411765 0.050980393; 0.02745098 0.023529412 … 0.0 0.003921569; … ; 0.047058824 0.050980393 … 0.003921569 0.011764706; 0.043137256 0.039215688 … 0.02745098 0.0;;;;], Float16[-0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -0.8 -0.8; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.43137255 0.41568628 … 0.101960786 0.16862746; 0.45490196 0.3372549 … 0.21960784 0.3254902; … ; 0.15686275 0.15686275 … 0.3882353 0.4509804; 0.14117648 0.12941177 … 0.43137255 0.4117647;;; 0.2509804 0.25490198 … 0.32156864 0.3529412; 0.2627451 0.3019608 … 0.44705883 0.5176471; … ; 0.3254902 0.3254902 … 0.2627451 0.2509804; 0.30980393 0.29803923 … 0.2627451 0.26666668;;; 0.003921569 0.003921569 … 0.047058824 0.08235294; 0.007843138 0.03529412 … 0.09411765 0.14117648; … ; 0.043137256 0.043137256 … 0.015686275 0.003921569; 0.03529412 0.02745098 … 0.007843138 0.011764706;;;;], Float16[-0.8 -0.8 … -5.3 -5.3; -0.8 -5.3 … -5.3 -5.3; … ; -0.8 -1.2 … -0.8 -1.2; -0.8 -0.8 … -0.8 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.13725491 0.23137255 … 0.12941177 0.1254902; 0.38431373 0.2627451 … 0.13333334 0.1254902; … ; 0.34901962 0.23529412 … 0.39215687 0.34901962; 0.17254902 0.19215687 … 0.4117647 0.34901962;;; 0.13725491 0.23137255 … 0.29411766 0.2901961; 0.38431373 0.25882354 … 0.29803923 0.2901961; … ; 0.34901962 0.23529412 … 0.24313726 0.2627451; 0.17254902 0.19215687 … 0.25882354 0.25490198;;; 0.12941177 0.23529412 … 0.02745098 0.02745098; 0.38431373 0.25882354 … 0.02745098 0.023529412; … ; 0.34901962 0.23529412 … 0.003921569 0.015686275; 0.17254902 0.19215687 … 0.007843138 0.011764706;;;;], Float16[-9.2 -9.2 … -0.8 -0.8; -9.2 -9.2 … -0.8 -0.8; … ; -9.2 -1.2 … -0.8 -0.8; -9.2 -1.2 … -0.8 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1882353 0.3137255 … 0.42352942 0.40392157; 0.25490198 0.32941177 … 0.4117647 0.4; … ; 0.4117647 0.39607844 … 0.12156863 0.11764706; 0.43529412 0.38431373 … 0.10980392 0.12941177;;; 0.1882353 0.3137255 … 0.2509804 0.23529412; 0.25490198 0.32941177 … 0.24313726 0.23529412; … ; 0.24313726 0.23137255 … 0.28235295 0.28627452; 0.25490198 0.22352941 … 0.28627452 0.2901961;;; 0.1882353 0.3137255 … 0.003921569 0.0; 0.25490198 0.32941177 … 0.003921569 0.0; … ; 0.007843138 0.0 … 0.023529412 0.023529412; 0.011764706 0.0 … 0.023529412 0.023529412;;;;], Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -0.8; -1.2 -1.2 … -0.8 -0.8], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.15686275 0.14901961 … 0.41568628 0.3647059; 0.14901961 0.14509805 … 0.3137255 0.34901962; … ; 0.16862746 0.16470589 … 0.1882353 0.23529412; 0.16470589 0.16078432 … 0.12941177 0.15686275;;; 0.32941177 0.3137255 … 0.27058825 0.24313726; 0.31764707 0.30980393 … 0.2627451 0.27058825; … ; 0.33333334 0.3372549 … 0.1882353 0.23529412; 0.32941177 0.32156864 … 0.12941177 0.15686275;;; 0.047058824 0.03529412 … 0.011764706 0.011764706; 0.039215688 0.03529412 … 0.019607844 0.023529412; … ; 0.047058824 0.047058824 … 0.1882353 0.23529412; 0.043137256 0.039215688 … 0.12941177 0.15686275;;;;], Float16[-0.8 -1.2 … -9.2 -1.2; -0.8 -1.2 … -9.2 -1.2; … ; -0.8 -1.2 … -9.2 -9.2; -0.8 -0.8 … -9.2 -9.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)])

And getting an individual sample will return a DataSample with four fields: x, instance, θ, and y:

sample = test_dataset[1]
DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1254902 0.13333334 … 0.16078432 0.16078432; 0.11372549 0.105882354 … 0.3882353 0.40392157; … ; 0.15686275 0.16470589 … 0.17254902 0.28627452; 0.14901961 0.16078432 … 0.26666668 0.3764706;;; 0.29411766 0.29411766 … 0.23921569 0.1882353; 0.28627452 0.28235295 … 0.24705882 0.2627451; … ; 0.32941177 0.3372549 … 0.17254902 0.21176471; 0.32156864 0.32156864 … 0.2784314 0.21176471;;; 0.02745098 0.02745098 … 0.09411765 0.050980393; 0.02745098 0.023529412 … 0.0 0.003921569; … ; 0.047058824 0.050980393 … 0.003921569 0.011764706; 0.043137256 0.039215688 … 0.02745098 0.0;;;;], Float16[-0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -0.8 -0.8; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)

x correspond to the input features, i.e. the input image (3D array) in the Warcraft benchmark case:

x = sample.x
96×96×3×1 Array{Float32, 4}:
[:, :, 1, 1] =
 0.12549   0.133333  0.117647  0.109804  …  0.176471   0.160784   0.160784
 0.113725  0.105882  0.117647  0.145098     0.411765   0.388235   0.403922
 0.101961  0.137255  0.14902   0.141176     0.439216   0.443137   0.439216
 0.121569  0.137255  0.156863  0.137255     0.435294   0.435294   0.435294
 0.113725  0.152941  0.156863  0.156863     0.435294   0.439216   0.431373
 0.137255  0.145098  0.156863  0.164706  …  0.435294   0.435294   0.435294
 0.156863  0.164706  0.160784  0.164706     0.427451   0.431373   0.427451
 0.14902   0.160784  0.168627  0.156863     0.454902   0.4        0.364706
 0.160784  0.164706  0.160784  0.152941     0.372549   0.427451   0.294118
 0.156863  0.156863  0.156863  0.164706     0.286275   0.301961   0.231373
 ⋮                                       ⋱                        ⋮
 0.156863  0.133333  0.156863  0.160784     0.0196078  0.0196078  0.0
 0.141176  0.14902   0.141176  0.156863     0.0117647  0.027451   0.0431373
 0.117647  0.145098  0.141176  0.168627     0.0        0.105882   0.168627
 0.109804  0.160784  0.141176  0.152941  …  0.0509804  0.0235294  0.00784314
 0.117647  0.129412  0.152941  0.14902      0.0588235  0.0313726  0.0666667
 0.113725  0.152941  0.156863  0.156863     0.0352941  0.113725   0.211765
 0.137255  0.145098  0.156863  0.164706     0.0588235  0.196078   0.258824
 0.156863  0.164706  0.160784  0.164706     0.133333   0.172549   0.286275
 0.14902   0.160784  0.168627  0.156863  …  0.172549   0.266667   0.376471

[:, :, 2, 1] =
 0.294118  0.294118  0.290196  0.286275  …  0.231373  0.239216  0.188235
 0.286275  0.282353  0.286275  0.305882     0.247059  0.247059  0.262745
 0.282353  0.309804  0.313726  0.305882     0.254902  0.254902  0.254902
 0.294118  0.305882  0.32549   0.301961     0.258824  0.258824  0.258824
 0.286275  0.321569  0.317647  0.32549      0.254902  0.258824  0.254902
 0.313726  0.313726  0.329412  0.333333  …  0.254902  0.258824  0.258824
 0.329412  0.337255  0.32549   0.337255     0.254902  0.25098   0.25098
 0.321569  0.321569  0.337255  0.32549      0.254902  0.266667  0.231373
 0.329412  0.337255  0.32549   0.317647     0.262745  0.231373  0.262745
 0.32549   0.317647  0.321569  0.333333     0.223529  0.254902  0.337255
 ⋮                                       ⋱                      ⋮
 0.32549   0.298039  0.329412  0.32549      0.211765  0.219608  0.192157
 0.317647  0.32549   0.313726  0.321569     0.203922  0.203922  0.184314
 0.290196  0.309804  0.309804  0.333333     0.219608  0.2       0.176471
 0.286275  0.32549   0.309804  0.321569  …  0.227451  0.172549  0.172549
 0.294118  0.301961  0.32549   0.309804     0.180392  0.172549  0.172549
 0.286275  0.321569  0.317647  0.32549      0.141176  0.168627  0.2
 0.313726  0.313726  0.329412  0.333333     0.145098  0.168627  0.301961
 0.329412  0.337255  0.32549   0.337255     0.188235  0.172549  0.211765
 0.321569  0.321569  0.337255  0.32549   …  0.235294  0.278431  0.211765

[:, :, 3, 1] =
 0.027451   0.027451   0.0235294  …  0.0745098   0.0941176   0.0509804
 0.027451   0.0235294  0.0235294     0.0         0.0         0.00392157
 0.0235294  0.0352941  0.0352941     0.00784314  0.00784314  0.00392157
 0.0313726  0.0313726  0.0431373     0.00784314  0.00784314  0.00784314
 0.027451   0.0392157  0.0392157     0.00392157  0.00784314  0.00392157
 0.0392157  0.0392157  0.0431373  …  0.00392157  0.00392157  0.00784314
 0.0470588  0.0509804  0.0431373     0.00392157  0.00392157  0.00392157
 0.0431373  0.0392157  0.0509804     0.00784314  0.0117647   0.00392157
 0.0470588  0.0509804  0.0431373     0.0196078   0.0         0.0235294
 0.0392157  0.0392157  0.0392157     0.0117647   0.027451    0.0588235
 ⋮                                ⋱                          ⋮
 0.0431373  0.0313726  0.0470588     0.447059    0.45098     0.439216
 0.0392157  0.0431373  0.0392157     0.443137    0.411765    0.356863
 0.027451   0.0313726  0.0392157     0.478431    0.290196    0.172549
 0.027451   0.0431373  0.0352941  …  0.407843    0.364706    0.407843
 0.027451   0.0313726  0.0431373     0.317647    0.360784    0.262745
 0.027451   0.0392157  0.0392157     0.294118    0.235294    0.0196078
 0.0392157  0.0392157  0.0431373     0.286275    0.00392157  0.0196078
 0.0470588  0.0509804  0.0431373     0.239216    0.00392157  0.0117647
 0.0431373  0.0392157  0.0509804  …  0.0313726   0.027451    0.0

θ_true correspond to the true unknown terrain weights. We use the opposite of the true weights in order to formulate the optimization problem as a maximization problem:

θ_true = sample.θ_true
12×12 Matrix{Float16}:
 -0.8  -0.8  -0.8  -0.8  -0.8  -1.2  -7.7  -1.2  -1.2  -1.2  -1.2  -1.2
 -0.8  -0.8  -0.8  -0.8  -0.8  -1.2  -7.7  -1.2  -0.8  -0.8  -0.8  -0.8
 -0.8  -0.8  -0.8  -0.8  -0.8  -1.2  -7.7  -1.2  -1.2  -0.8  -5.3  -0.8
 -0.8  -0.8  -1.2  -1.2  -1.2  -1.2  -7.7  -1.2  -0.8  -0.8  -5.3  -0.8
 -0.8  -0.8  -1.2  -7.7  -7.7  -7.7  -7.7  -1.2  -0.8  -0.8  -0.8  -0.8
 -0.8  -1.2  -1.2  -7.7  -7.7  -7.7  -7.7  -1.2  -0.8  -1.2  -0.8  -1.2
 -0.8  -0.8  -1.2  -1.2  -1.2  -1.2  -1.2  -1.2  -1.2  -1.2  -1.2  -1.2
 -0.8  -1.2  -1.2  -9.2  -1.2  -1.2  -7.7  -7.7  -7.7  -7.7  -7.7  -7.7
 -0.8  -0.8  -1.2  -9.2  -1.2  -1.2  -1.2  -1.2  -7.7  -7.7  -7.7  -7.7
 -0.8  -0.8  -1.2  -9.2  -9.2  -1.2  -0.8  -1.2  -7.7  -7.7  -7.7  -7.7
 -0.8  -0.8  -1.2  -9.2  -9.2  -1.2  -0.8  -1.2  -7.7  -7.7  -7.7  -7.7
 -0.8  -0.8  -1.2  -1.2  -1.2  -1.2  -0.8  -1.2  -1.2  -1.2  -7.7  -7.7

y_true correspond to the optimal shortest path, encoded as a binary matrix:

y_true = sample.y_true
12×12 BitMatrix:
 1  0  0  0  0  0  0  0  0  0  0  0
 1  0  0  0  0  0  0  0  0  0  0  0
 1  0  0  0  0  0  0  0  0  0  0  0
 1  0  0  0  0  0  0  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  0  0  0
 0  0  1  0  0  0  0  0  0  0  0  0
 0  0  0  1  0  0  0  0  0  0  0  0
 0  0  0  0  1  0  0  0  0  0  0  0
 0  0  0  0  0  1  0  0  0  0  0  0
 0  0  0  0  0  0  1  0  0  0  0  0
 0  0  0  0  0  0  0  1  0  0  1  0
 0  0  0  0  0  0  0  0  1  1  0  1

instance is not used in this benchmark, therefore set to nothing:

isnothing(sample.instance)
true

For some benchmarks, we provide the following plotting method plot_data to visualize the data:

plot_data(b, sample)
Example block output

We can see here the terrain image, the true terrain weights, and the true shortest path avoiding the high cost cells.

Building a pipeline

DecisionFocusedLearningBenchmarks also provides methods to build an hybrid machine learning and combinatorial optimization pipeline for the benchmark. First, the generate_statistical_model method generates a machine learning predictor to predict cell weights from the input image:

model = generate_statistical_model(b)
Chain(
  Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
  BatchNorm(64, relu),                  # 128 parameters, plus 128
  MaxPool((3, 3), pad=1, stride=2),
  Parallel(
    PartialFunction(
      "",
      Metalhead.addact,
      (NNlib.relu,),
      NamedTuple(),
    ),
    identity,
    Chain(
      Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
      BatchNorm(64),                    # 128 parameters, plus 128
      NNlib.relu,
      Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
      BatchNorm(64),                    # 128 parameters, plus 128
    ),
  ),
  AdaptiveMaxPool((12, 12)),
  DecisionFocusedLearningBenchmarks.Utils.average_tensor,
  DecisionFocusedLearningBenchmarks.Utils.neg_tensor,
  DecisionFocusedLearningBenchmarks.Utils.squeeze_last_dims,
)         # Total: 9 trainable arrays, 83_520 parameters,
          # plus 6 non-trainable, 384 parameters, summarysize 328.945 KiB.

In the case of the Warcraft benchmark, the model is a convolutional neural network built using the Flux.jl package.

θ = model(x)
12×12 Matrix{Float32}:
 -0.721526  -0.718137  -0.717626  …  -0.731732  -0.7317    -0.736122
 -0.718858  -0.714253  -0.713772     -0.726469  -0.725278  -0.731283
 -0.718467  -0.715424  -0.721073     -0.726078  -0.720642  -0.726064
 -0.718754  -0.719024  -0.728877     -0.723631  -0.721364  -0.727132
 -0.719706  -0.722873  -0.729793     -0.722864  -0.725774  -0.727845
 -0.722395  -0.728315  -0.729841  …  -0.728819  -0.729245  -0.733978
 -0.72076   -0.72767   -0.729481     -0.730372  -0.731709  -0.736237
 -0.722009  -0.728514  -0.73406      -0.731785  -0.734136  -0.736757
 -0.72086   -0.727749  -0.736777     -0.719128  -0.717915  -0.723152
 -0.718745  -0.721057  -0.734956     -0.717407  -0.714833  -0.72134
 -0.718924  -0.720848  -0.73593   …  -0.724723  -0.720396  -0.722206
 -0.72244   -0.724106  -0.737705     -0.741514  -0.738982  -0.729798

Note that the model is not trained yet, and its parameters are randomly initialized.

Finally, the generate_maximizer method can be used to generate a combinatorial optimization algorithm that takes the predicted cell weights as input and returns the corresponding shortest path:

maximizer = generate_maximizer(b; dijkstra=true)
dijkstra_maximizer (generic function with 1 method)

In the case o fthe Warcraft benchmark, the method has an additional keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm.

y = maximizer(θ)
12×12 Matrix{Int64}:
 1  0  0  0  0  0  0  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  0  0  0
 0  0  1  0  0  0  0  0  0  0  0  0
 0  0  0  1  0  0  0  0  0  0  0  0
 0  0  0  0  1  0  0  0  0  0  0  0
 0  0  0  0  0  1  0  0  0  0  0  0
 0  0  0  0  0  0  1  0  0  0  0  0
 0  0  0  0  0  0  0  1  0  0  0  0
 0  0  0  0  0  0  0  0  1  0  0  0
 0  0  0  0  0  0  0  0  0  1  0  0
 0  0  0  0  0  0  0  0  0  0  1  0
 0  0  0  0  0  0  0  0  0  0  0  1

As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path.

plot_data(b, DataSample(; x, θ_true=θ, y_true=y))
Example block output

We can evaluate the current pipeline performance using the optimality gap metric:

starting_gap = compute_gap(b, test_dataset, model, maximizer)
Float16(0.775)

Using a learning algorithm

We can now train the model using the InferOpt.jl package:

using InferOpt
using Flux
using Plots

perturbed_maximizer = PerturbedMultiplicative(maximizer; ε=0.2, nb_samples=100)
loss = FenchelYoungLoss(perturbed_maximizer)

starting_gap = compute_gap(b, test_dataset, model, maximizer)

opt_state = Flux.setup(Adam(1e-3), model)
loss_history = Float64[]
for epoch in 1:50
    val, grads = Flux.withgradient(model) do m
        sum(loss(m(x), y_true) for (; x, y_true) in train_dataset) / length(train_dataset)
    end
    Flux.update!(opt_state, model, grads[1])
    push!(loss_history, val)
end

plot(loss_history; xlabel="Epoch", ylabel="Loss", title="Training loss")
Example block output
final_gap = compute_gap(b, test_dataset, model, maximizer)
Float16(0.02493)
θ = model(x)
y = maximizer(θ)
plot_data(b, DataSample(; x, θ_true=θ, y_true=y))
Example block output

This page was generated using Literate.jl.