Path-finding on image maps
In this tutorial, we showcase DecisionFocusedLearningBenchmarks.jl capabilities on one of its main benchmarks: the Warcraft benchmark. This benchmark problem is a simple path-finding problem where the goal is to find the shortest path between the top left and bottom right corners of a given image map. The map is represented as a 2D image representing a 12x12 grid, each cell having an unknown travel cost depending on the terrain type.
First, let's load the package and create a benchmark object as follows:
using DecisionFocusedLearningBenchmarks
b = WarcraftBenchmark()
WarcraftBenchmark()
Dataset generation
These benchmark objects behave as generators that can generate various needed elements in order to build an algorithm to tackle the problem. First of all, all benchmarks are capable of generating datasets as needed, using the generate_dataset
method. This method takes as input the benchmark object for which the dataset is to be generated, and a second argument specifying the number of samples to generate:
dataset = generate_dataset(b, 50);
┌ Warning: Checksum not provided, add to the Datadep Registration the following hash line
│ hash = "67fa5e32ac6e567ee41b891309a776e6b57c09a7b6b37a368ac12b5a64a781fb"
└ @ DataDeps ~/.julia/packages/DataDeps/Y2lje/src/verification.jl:44
7-Zip (a) [64] 17.04 : Copyright (c) 1999-2021 Igor Pavlov : 2017-08-28
p7zip Version 17.04 (locale=C.UTF-8,Utf16=on,HugeFiles=on,64 bits,4 CPUs AMD EPYC 7763 64-Core Processor (A00F11),ASM,AES-NI)
Scanning the drive for archives:
1 file, 77643058 bytes (75 MiB)
Extracting archive: /home/runner/.julia/scratchspaces/124859b0-ceae-595e-8997-d05f6a7a8dfe/datadeps/warcraft/data.zip
--
Path = /home/runner/.julia/scratchspaces/124859b0-ceae-595e-8997-d05f6a7a8dfe/datadeps/warcraft/data.zip
Type = zip
Physical Size = 77643058
Everything is Ok
Folders: 3
Files: 15
Size: 336974348
Compressed: 77643058
We obtain a vector of DataSample
objects, containing all needed data for the problem. Subdatasets can be created through regular slicing:
train_dataset, test_dataset = dataset[1:45], dataset[46:50]
(DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}[DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.13333334 0.13333334 … 0.16470589 0.23529412; 0.1254902 0.105882354 … 0.16862746 0.3137255; … ; 0.24705882 0.24705882 … 0.019607844 0.015686275; 0.38431373 0.42352942 … 0.015686275 0.019607844;;; 0.30980393 0.29411766 … 0.34117648 0.35686275; 0.29803923 0.28235295 … 0.32941177 0.25490198; … ; 0.30588236 0.26666668 … 0.22745098 0.21960784; 0.23921569 0.23921569 … 0.21960784 0.22745098;;; 0.039215688 0.02745098 … 0.050980393 0.07058824; 0.03137255 0.023529412 … 0.047058824 0.023529412; … ; 0.043137256 0.02745098 … 0.45882353 0.45490196; 0.003921569 0.0 … 0.45490196 0.45882353;;;;], Float16[-0.8 -0.8 … -1.2 -0.8; -0.8 -0.8 … -1.2 -1.2; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.019607844 0.019607844 … 0.42352942 0.4117647; 0.015686275 0.015686275 … 0.42745098 0.44313726; … ; 0.019607844 0.019607844 … 0.4392157 0.4392157; 0.019607844 0.019607844 … 0.43137255 0.43529412;;; 0.21960784 0.21568628 … 0.2509804 0.23921569; 0.21960784 0.20784314 … 0.2509804 0.25882354; … ; 0.22745098 0.22745098 … 0.25882354 0.25882354; 0.23137255 0.22745098 … 0.25490198 0.25882354;;; 0.45490196 0.44705883 … 0.007843138 0.003921569; 0.45490196 0.44313726 … 0.007843138 0.011764706; … ; 0.45882353 0.4627451 … 0.007843138 0.007843138; 0.4627451 0.45882353 … 0.003921569 0.007843138;;;;], Float16[-7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -1.2 -1.2; … ; -7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -9.2 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.22352941 0.0 … 0.16078432 0.16470589; 0.20392157 0.05882353 … 0.15686275 0.15686275; … ; 0.023529412 0.003921569 … 0.003921569 0.023529412; 0.03529412 0.011764706 … 0.0 0.0;;; 0.40784314 0.2509804 … 0.3254902 0.3372549; 0.35686275 0.3254902 … 0.32156864 0.32941177; … ; 0.19607843 0.21568628 … 0.20784314 0.23529412; 0.21568628 0.19607843 … 0.23137255 0.24705882;;; 0.08235294 0.0 … 0.043137256 0.050980393; 0.05490196 0.03529412 … 0.043137256 0.043137256; … ; 0.0 0.0 … 0.0 0.007843138; 0.019607844 0.0 … 0.0 0.0;;;;], Float16[-5.3 -5.3 … -1.2 -0.8; -5.3 -5.3 … -1.2 -0.8; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.43137255 0.45490196 … 0.015686275 0.019607844; 0.4117647 0.36862746 … 0.015686275 0.015686275; … ; 0.14901961 0.15294118 … 0.15294118 0.16470589; 0.15686275 0.13725491 … 0.16470589 0.16470589;;; 0.2509804 0.25882354 … 0.21960784 0.22745098; 0.27058825 0.2784314 … 0.21960784 0.21960784; … ; 0.3137255 0.32156864 … 0.32156864 0.33333334; 0.3254902 0.3019608 … 0.33333334 0.33333334;;; 0.003921569 0.007843138 … 0.45490196 0.4627451; 0.015686275 0.023529412 … 0.45490196 0.45490196; … ; 0.039215688 0.043137256 … 0.039215688 0.047058824; 0.043137256 0.03137255 … 0.047058824 0.047058824;;;;], Float16[-1.2 -7.7 … -7.7 -7.7; -7.7 -1.2 … -7.7 -7.7; … ; -0.8 -0.8 … -0.8 -5.3; -0.8 -1.2 … -0.8 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.15294118 0.31764707 … 0.43529412 0.40392157; 0.3764706 0.30588236 … 0.43137255 0.4117647; … ; 0.2 0.12156863 … 0.36078432 0.4117647; 0.09019608 0.17254902 … 0.21960784 0.44705883;;; 0.15294118 0.31764707 … 0.25882354 0.23921569; 0.3764706 0.30588236 … 0.25490198 0.24313726; … ; 0.2 0.12156863 … 0.22352941 0.28627452; 0.09019608 0.17254902 … 0.16470589 0.25882354;;; 0.15294118 0.31764707 … 0.007843138 0.0; 0.3764706 0.30588236 … 0.003921569 0.0; … ; 0.2 0.12156863 … 0.02745098 0.10980392; 0.09019608 0.17254902 … 0.08627451 0.0;;;;], Float16[-9.2 -9.2 … -0.8 -1.2; -9.2 -9.2 … -0.8 -1.2; … ; -9.2 -9.2 … -9.2 -1.2; -9.2 -9.2 … -9.2 -1.2], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.023529412 0.019607844 … 0.16470589 0.16470589; 0.015686275 0.015686275 … 0.16862746 0.16470589; … ; 0.16862746 0.16470589 … 0.43137255 0.42745098; 0.15686275 0.16078432 … 0.43137255 0.42745098;;; 0.22745098 0.22352941 … 0.33333334 0.32941177; 0.21568628 0.20784314 … 0.3372549 0.32941177; … ; 0.3372549 0.3372549 … 0.25490198 0.25490198; 0.32941177 0.31764707 … 0.25490198 0.25490198;;; 0.45882353 0.45490196 … 0.047058824 0.047058824; 0.4509804 0.44313726 … 0.047058824 0.047058824; … ; 0.047058824 0.047058824 … 0.003921569 0.003921569; 0.043137256 0.039215688 … 0.003921569 0.003921569;;;;], Float16[-7.7 -7.7 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8; … ; -1.2 -1.2 … -9.2 -9.2; -0.8 -1.2 … -9.2 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.40784314 0.41568628 … 0.38039216 0.28627452; 0.40784314 0.4117647 … 0.3764706 0.30980393; … ; 0.1254902 0.11372549 … 0.15294118 0.16078432; 0.12941177 0.12941177 … 0.15686275 0.15294118;;; 0.23921569 0.24313726 … 0.20392157 0.2784314; 0.23921569 0.24313726 … 0.2784314 0.25882354; … ; 0.29803923 0.2784314 … 0.31764707 0.32941177; 0.29411766 0.29803923 … 0.3254902 0.33333334;;; 0.003921569 0.007843138 … 0.0 0.03529412; 0.003921569 0.003921569 … 0.023529412 0.023529412; … ; 0.03137255 0.019607844 … 0.039215688 0.047058824; 0.02745098 0.03137255 … 0.043137256 0.047058824;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -0.8 … -1.2 -1.2; … ; -0.8 -1.2 … -5.3 -5.3; -0.8 -0.8 … -0.8 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.015686275 0.015686275 … 0.4392157 0.42745098; 0.015686275 0.015686275 … 0.43137255 0.43529412; … ; 0.3372549 0.3647059 … 0.16862746 0.19607843; 0.2627451 0.36862746 … 0.22745098 0.34117648;;; 0.20784314 0.2 … 0.25882354 0.2509804; 0.20392157 0.20784314 … 0.25490198 0.25490198; … ; 0.23921569 0.27450982 … 0.34901962 0.3764706; 0.31764707 0.24313726 … 0.36078432 0.24705882;;; 0.44313726 0.43529412 … 0.011764706 0.007843138; 0.4392157 0.44313726 … 0.007843138 0.007843138; … ; 0.015686275 0.019607844 … 0.05490196 0.078431375; 0.047058824 0.011764706 … 0.07058824 0.019607844;;;;], Float16[-7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -1.2; -1.2 -1.2 … -0.8 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.023529412 0.019607844 … 0.44313726 0.4392157; 0.015686275 0.015686275 … 0.43529412 0.44313726; … ; 0.39607844 0.39215687 … 0.4392157 0.4392157; 0.39215687 0.39215687 … 0.43137255 0.43529412;;; 0.23529412 0.23137255 … 0.25882354 0.25882354; 0.21960784 0.21568628 … 0.25882354 0.25882354; … ; 0.23529412 0.23137255 … 0.25882354 0.25882354; 0.23137255 0.22745098 … 0.25490198 0.25882354;;; 0.46666667 0.4627451 … 0.007843138 0.007843138; 0.45490196 0.45490196 … 0.007843138 0.007843138; … ; 0.0 0.0 … 0.007843138 0.007843138; 0.0 0.0 … 0.003921569 0.007843138;;;;], Float16[-7.7 -7.7 … -7.7 -1.2; -7.7 -7.7 … -7.7 -1.2; … ; -1.2 -1.2 … -1.2 -9.2; -1.2 -1.2 … -1.2 -1.2], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.0 0.023529412 … 0.4 0.39215687; 0.019607844 0.003921569 … 0.39215687 0.3764706; … ; 0.13725491 0.13725491 … 0.15686275 0.15686275; 0.1254902 0.105882354 … 0.16862746 0.16862746;;; 0.1882353 0.32156864 … 0.23529412 0.22745098; 0.27058825 0.24705882 … 0.23137255 0.21960784; … ; 0.3137255 0.30980393 … 0.34901962 0.33333334; 0.29803923 0.28235295 … 0.34509805 0.3372549;;; 0.0 0.007843138 … 0.0 0.0; 0.015686275 0.003921569 … 0.0 0.0; … ; 0.039215688 0.03529412 … 0.05882353 0.047058824; 0.03137255 0.023529412 … 0.05882353 0.050980393;;;;], Float16[-5.3 -5.3 … -1.2 -1.2; -5.3 -5.3 … -1.2 -1.2; … ; -5.3 -5.3 … -1.2 -1.2; -0.8 -0.8 … -0.8 -1.2], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing) … DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.40392157 0.43137255 … 0.44313726 0.43529412; 0.43529412 0.43529412 … 0.42745098 0.43529412; … ; 0.023529412 0.003921569 … 0.003921569 0.023529412; 0.03529412 0.011764706 … 0.0 0.0;;; 0.23529412 0.25490198 … 0.25882354 0.25490198; 0.25882354 0.25490198 … 0.2509804 0.25490198; … ; 0.19607843 0.21568628 … 0.20784314 0.23529412; 0.21568628 0.19607843 … 0.23137255 0.24705882;;; 0.0 0.007843138 … 0.007843138 0.007843138; 0.007843138 0.007843138 … 0.007843138 0.007843138; … ; 0.0 0.0 … 0.0 0.007843138; 0.019607844 0.0 … 0.0 0.0;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -1.2; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.40784314 0.41568628 … 0.38039216 0.28627452; 0.40784314 0.4117647 … 0.3764706 0.30980393; … ; 0.16470589 0.16470589 … 0.44313726 0.42745098; 0.15294118 0.16078432 … 0.42352942 0.41960785;;; 0.23921569 0.24313726 … 0.20392157 0.2784314; 0.23921569 0.24313726 … 0.2784314 0.25882354; … ; 0.32941177 0.3254902 … 0.2627451 0.2509804; 0.32156864 0.32941177 … 0.2509804 0.24705882;;; 0.003921569 0.007843138 … 0.0 0.03529412; 0.003921569 0.003921569 … 0.023529412 0.023529412; … ; 0.047058824 0.043137256 … 0.007843138 0.003921569; 0.039215688 0.047058824 … 0.003921569 0.007843138;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -7.7 … -1.2 -1.2; … ; -0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -1.2 -1.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.27450982 0.019607844 … 0.003921569 0.14117648; 0.13725491 0.0627451 … 0.019607844 0.08235294; … ; 0.42745098 0.42745098 … 0.39607844 0.4; 0.42745098 0.40392157 … 0.39215687 0.39607844;;; 0.47058824 0.23137255 … 0.25490198 0.3764706; 0.40392157 0.31764707 … 0.23137255 0.35686275; … ; 0.2509804 0.2509804 … 0.23529412 0.23529412; 0.2509804 0.23529412 … 0.22745098 0.23137255;;; 0.14117648 0.015686275 … 0.003921569 0.07450981; 0.07058824 0.039215688 … 0.007843138 0.03529412; … ; 0.007843138 0.003921569 … 0.0 0.0; 0.003921569 0.003921569 … 0.0 0.0;;;;], Float16[-5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3; … ; -1.2 -0.8 … -1.2 -1.2; -1.2 -0.8 … -1.2 -1.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.14509805 0.13725491 … 0.12941177 0.13725491; 0.12156863 0.1254902 … 0.11764706 0.10980392; … ; 0.023529412 0.003921569 … 0.078431375 0.011764706; 0.03529412 0.011764706 … 0.03529412 0.05882353;;; 0.30980393 0.30980393 … 0.29411766 0.30588236; 0.29411766 0.29411766 … 0.29803923 0.28627452; … ; 0.19607843 0.21568628 … 0.3019608 0.21176471; 0.21568628 0.19607843 … 0.28235295 0.3019608;;; 0.03529412 0.03529412 … 0.02745098 0.03137255; 0.02745098 0.02745098 … 0.03137255 0.02745098; … ; 0.0 0.0 … 0.03529412 0.007843138; 0.019607844 0.0 … 0.015686275 0.03529412;;;;], Float16[-0.8 -0.8 … -0.8 -0.8; -0.8 -0.8 … -0.8 -0.8; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.45882353 0.28235295 … 0.43137255 0.41568628; 0.35686275 0.17254902 … 0.42745098 0.43137255; … ; 0.38431373 0.39607844 … 0.17254902 0.16862746; 0.3882353 0.39215687 … 0.16470589 0.17254902;;; 0.2509804 0.27058825 … 0.25490198 0.24705882; 0.27058825 0.19607843 … 0.25490198 0.25490198; … ; 0.22352941 0.23137255 … 0.3529412 0.35686275; 0.22745098 0.23137255 … 0.3372549 0.34509805;;; 0.003921569 0.015686275 … 0.003921569 0.0; 0.007843138 0.011764706 … 0.003921569 0.003921569; … ; 0.0 0.0 … 0.05882353 0.0627451; 0.0 0.0 … 0.050980393 0.05490196;;;;], Float16[-7.7 -1.2 … -7.7 -1.2; -7.7 -1.2 … -1.2 -7.7; … ; -1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.12156863 0.23137255 … 0.43529412 0.40784314; 0.12156863 0.26666668 … 0.43529412 0.42352942; … ; 0.1254902 0.1254902 … 0.3882353 0.4509804; 0.11764706 0.105882354 … 0.43137255 0.4117647;;; 0.12156863 0.23137255 … 0.25490198 0.23921569; 0.12156863 0.26666668 … 0.25882354 0.2509804; … ; 0.3019608 0.3019608 … 0.2627451 0.2509804; 0.29411766 0.28235295 … 0.2627451 0.26666668;;; 0.12156863 0.23137255 … 0.007843138 0.0; 0.12156863 0.26666668 … 0.003921569 0.0; … ; 0.03137255 0.03137255 … 0.015686275 0.003921569; 0.03137255 0.023529412 … 0.007843138 0.011764706;;;;], Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -1.2 -1.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.023529412 0.019607844 … 0.42745098 0.42352942; 0.015686275 0.015686275 … 0.43529412 0.44313726; … ; 0.015686275 0.015686275 … 0.11372549 0.1254902; 0.015686275 0.019607844 … 0.12941177 0.12941177;;; 0.23529412 0.23137255 … 0.25490198 0.2509804; 0.21960784 0.21568628 … 0.25882354 0.2627451; … ; 0.21176471 0.21176471 … 0.2784314 0.29803923; 0.21568628 0.22352941 … 0.29803923 0.29411766;;; 0.46666667 0.4627451 … 0.003921569 0.0; 0.45490196 0.45490196 … 0.007843138 0.011764706; … ; 0.44705883 0.44705883 … 0.019607844 0.03137255; 0.44705883 0.45490196 … 0.03137255 0.02745098;;;;], Float16[-7.7 -1.2 … -1.2 -1.2; -7.7 -1.2 … -0.8 -1.2; … ; -7.7 -1.2 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.16078432 0.16078432 … 0.13333334 0.1254902; 0.16078432 0.15686275 … 0.105882354 0.11372549; … ; 0.3647059 0.3882353 … 0.43529412 0.4392157; 0.34117648 0.41568628 … 0.44705883 0.43529412;;; 0.32156864 0.32941177 … 0.29411766 0.29411766; 0.33333334 0.3137255 … 0.28235295 0.28627452; … ; 0.27058825 0.27450982 … 0.25882354 0.25882354; 0.25490198 0.25490198 … 0.25490198 0.2627451;;; 0.039215688 0.043137256 … 0.02745098 0.02745098; 0.047058824 0.03529412 … 0.023529412 0.02745098; … ; 0.015686275 0.019607844 … 0.007843138 0.007843138; 0.007843138 0.007843138 … 0.003921569 0.007843138;;;;], Float16[-0.8 -5.3 … -0.8 -0.8; -5.3 -5.3 … -1.2 -0.8; … ; -0.8 -5.3 … -7.7 -7.7; -0.8 -0.8 … -7.7 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1254902 0.13333334 … 0.14509805 0.15686275; 0.11372549 0.105882354 … 0.13725491 0.15294118; … ; 0.043137256 0.0 … 0.17254902 0.13725491; 0.07450981 0.011764706 … 0.14509805 0.16470589;;; 0.29411766 0.29411766 … 0.3137255 0.32941177; 0.28627452 0.28235295 … 0.30588236 0.31764707; … ; 0.24705882 0.20784314 … 0.33333334 0.3137255; 0.32156864 0.21960784 … 0.3137255 0.36078432;;; 0.02745098 0.02745098 … 0.039215688 0.043137256; 0.02745098 0.023529412 … 0.03529412 0.03529412; … ; 0.011764706 0.0 … 0.047058824 0.03529412; 0.02745098 0.003921569 … 0.03529412 0.0627451;;;;], Float16[-0.8 -0.8 … -0.8 -0.8; -0.8 -0.8 … -1.2 -0.8; … ; -5.3 -0.8 … -1.2 -1.2; -5.3 -0.8 … -1.2 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.4392157 0.44313726 … 0.43137255 0.41568628; 0.42745098 0.43529412 … 0.42745098 0.43137255; … ; 0.019607844 0.015686275 … 0.11372549 0.1254902; 0.015686275 0.015686275 … 0.12941177 0.12941177;;; 0.25490198 0.24705882 … 0.25490198 0.24705882; 0.2509804 0.25882354 … 0.25490198 0.25490198; … ; 0.21176471 0.20392157 … 0.2784314 0.29411766; 0.20392157 0.20392157 … 0.29803923 0.29803923;;; 0.003921569 0.0 … 0.003921569 0.0; 0.0 0.007843138 … 0.003921569 0.003921569; … ; 0.44313726 0.4392157 … 0.019607844 0.03137255; 0.4392157 0.4392157 … 0.03137255 0.02745098;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -1.2; … ; -7.7 -7.7 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)], DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}[DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1254902 0.13333334 … 0.16078432 0.16078432; 0.11372549 0.105882354 … 0.3882353 0.40392157; … ; 0.15686275 0.16470589 … 0.17254902 0.28627452; 0.14901961 0.16078432 … 0.26666668 0.3764706;;; 0.29411766 0.29411766 … 0.23921569 0.1882353; 0.28627452 0.28235295 … 0.24705882 0.2627451; … ; 0.32941177 0.3372549 … 0.17254902 0.21176471; 0.32156864 0.32156864 … 0.2784314 0.21176471;;; 0.02745098 0.02745098 … 0.09411765 0.050980393; 0.02745098 0.023529412 … 0.0 0.003921569; … ; 0.047058824 0.050980393 … 0.003921569 0.011764706; 0.043137256 0.039215688 … 0.02745098 0.0;;;;], Float16[-0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -0.8 -0.8; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.43137255 0.41568628 … 0.101960786 0.16862746; 0.45490196 0.3372549 … 0.21960784 0.3254902; … ; 0.15686275 0.15686275 … 0.3882353 0.4509804; 0.14117648 0.12941177 … 0.43137255 0.4117647;;; 0.2509804 0.25490198 … 0.32156864 0.3529412; 0.2627451 0.3019608 … 0.44705883 0.5176471; … ; 0.3254902 0.3254902 … 0.2627451 0.2509804; 0.30980393 0.29803923 … 0.2627451 0.26666668;;; 0.003921569 0.003921569 … 0.047058824 0.08235294; 0.007843138 0.03529412 … 0.09411765 0.14117648; … ; 0.043137256 0.043137256 … 0.015686275 0.003921569; 0.03529412 0.02745098 … 0.007843138 0.011764706;;;;], Float16[-0.8 -0.8 … -5.3 -5.3; -0.8 -5.3 … -5.3 -5.3; … ; -0.8 -1.2 … -0.8 -1.2; -0.8 -0.8 … -0.8 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.13725491 0.23137255 … 0.12941177 0.1254902; 0.38431373 0.2627451 … 0.13333334 0.1254902; … ; 0.34901962 0.23529412 … 0.39215687 0.34901962; 0.17254902 0.19215687 … 0.4117647 0.34901962;;; 0.13725491 0.23137255 … 0.29411766 0.2901961; 0.38431373 0.25882354 … 0.29803923 0.2901961; … ; 0.34901962 0.23529412 … 0.24313726 0.2627451; 0.17254902 0.19215687 … 0.25882354 0.25490198;;; 0.12941177 0.23529412 … 0.02745098 0.02745098; 0.38431373 0.25882354 … 0.02745098 0.023529412; … ; 0.34901962 0.23529412 … 0.003921569 0.015686275; 0.17254902 0.19215687 … 0.007843138 0.011764706;;;;], Float16[-9.2 -9.2 … -0.8 -0.8; -9.2 -9.2 … -0.8 -0.8; … ; -9.2 -1.2 … -0.8 -0.8; -9.2 -1.2 … -0.8 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1882353 0.3137255 … 0.42352942 0.40392157; 0.25490198 0.32941177 … 0.4117647 0.4; … ; 0.4117647 0.39607844 … 0.12156863 0.11764706; 0.43529412 0.38431373 … 0.10980392 0.12941177;;; 0.1882353 0.3137255 … 0.2509804 0.23529412; 0.25490198 0.32941177 … 0.24313726 0.23529412; … ; 0.24313726 0.23137255 … 0.28235295 0.28627452; 0.25490198 0.22352941 … 0.28627452 0.2901961;;; 0.1882353 0.3137255 … 0.003921569 0.0; 0.25490198 0.32941177 … 0.003921569 0.0; … ; 0.007843138 0.0 … 0.023529412 0.023529412; 0.011764706 0.0 … 0.023529412 0.023529412;;;;], Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -0.8; -1.2 -1.2 … -0.8 -0.8], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.15686275 0.14901961 … 0.41568628 0.3647059; 0.14901961 0.14509805 … 0.3137255 0.34901962; … ; 0.16862746 0.16470589 … 0.1882353 0.23529412; 0.16470589 0.16078432 … 0.12941177 0.15686275;;; 0.32941177 0.3137255 … 0.27058825 0.24313726; 0.31764707 0.30980393 … 0.2627451 0.27058825; … ; 0.33333334 0.3372549 … 0.1882353 0.23529412; 0.32941177 0.32156864 … 0.12941177 0.15686275;;; 0.047058824 0.03529412 … 0.011764706 0.011764706; 0.039215688 0.03529412 … 0.019607844 0.023529412; … ; 0.047058824 0.047058824 … 0.1882353 0.23529412; 0.043137256 0.039215688 … 0.12941177 0.15686275;;;;], Float16[-0.8 -1.2 … -9.2 -1.2; -0.8 -1.2 … -9.2 -1.2; … ; -0.8 -1.2 … -9.2 -9.2; -0.8 -0.8 … -9.2 -9.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)])
And getting an individual sample will return a DataSample
with four fields: x
, instance
, θ
, and y
:
sample = test_dataset[1]
DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1254902 0.13333334 … 0.16078432 0.16078432; 0.11372549 0.105882354 … 0.3882353 0.40392157; … ; 0.15686275 0.16470589 … 0.17254902 0.28627452; 0.14901961 0.16078432 … 0.26666668 0.3764706;;; 0.29411766 0.29411766 … 0.23921569 0.1882353; 0.28627452 0.28235295 … 0.24705882 0.2627451; … ; 0.32941177 0.3372549 … 0.17254902 0.21176471; 0.32156864 0.32156864 … 0.2784314 0.21176471;;; 0.02745098 0.02745098 … 0.09411765 0.050980393; 0.02745098 0.023529412 … 0.0 0.003921569; … ; 0.047058824 0.050980393 … 0.003921569 0.011764706; 0.043137256 0.039215688 … 0.02745098 0.0;;;;], Float16[-0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -0.8 -0.8; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)
x
correspond to the input features, i.e. the input image (3D array) in the Warcraft benchmark case:
x = sample.x
96×96×3×1 Array{Float32, 4}:
[:, :, 1, 1] =
0.12549 0.133333 0.117647 0.109804 … 0.176471 0.160784 0.160784
0.113725 0.105882 0.117647 0.145098 0.411765 0.388235 0.403922
0.101961 0.137255 0.14902 0.141176 0.439216 0.443137 0.439216
0.121569 0.137255 0.156863 0.137255 0.435294 0.435294 0.435294
0.113725 0.152941 0.156863 0.156863 0.435294 0.439216 0.431373
0.137255 0.145098 0.156863 0.164706 … 0.435294 0.435294 0.435294
0.156863 0.164706 0.160784 0.164706 0.427451 0.431373 0.427451
0.14902 0.160784 0.168627 0.156863 0.454902 0.4 0.364706
0.160784 0.164706 0.160784 0.152941 0.372549 0.427451 0.294118
0.156863 0.156863 0.156863 0.164706 0.286275 0.301961 0.231373
⋮ ⋱ ⋮
0.156863 0.133333 0.156863 0.160784 0.0196078 0.0196078 0.0
0.141176 0.14902 0.141176 0.156863 0.0117647 0.027451 0.0431373
0.117647 0.145098 0.141176 0.168627 0.0 0.105882 0.168627
0.109804 0.160784 0.141176 0.152941 … 0.0509804 0.0235294 0.00784314
0.117647 0.129412 0.152941 0.14902 0.0588235 0.0313726 0.0666667
0.113725 0.152941 0.156863 0.156863 0.0352941 0.113725 0.211765
0.137255 0.145098 0.156863 0.164706 0.0588235 0.196078 0.258824
0.156863 0.164706 0.160784 0.164706 0.133333 0.172549 0.286275
0.14902 0.160784 0.168627 0.156863 … 0.172549 0.266667 0.376471
[:, :, 2, 1] =
0.294118 0.294118 0.290196 0.286275 … 0.231373 0.239216 0.188235
0.286275 0.282353 0.286275 0.305882 0.247059 0.247059 0.262745
0.282353 0.309804 0.313726 0.305882 0.254902 0.254902 0.254902
0.294118 0.305882 0.32549 0.301961 0.258824 0.258824 0.258824
0.286275 0.321569 0.317647 0.32549 0.254902 0.258824 0.254902
0.313726 0.313726 0.329412 0.333333 … 0.254902 0.258824 0.258824
0.329412 0.337255 0.32549 0.337255 0.254902 0.25098 0.25098
0.321569 0.321569 0.337255 0.32549 0.254902 0.266667 0.231373
0.329412 0.337255 0.32549 0.317647 0.262745 0.231373 0.262745
0.32549 0.317647 0.321569 0.333333 0.223529 0.254902 0.337255
⋮ ⋱ ⋮
0.32549 0.298039 0.329412 0.32549 0.211765 0.219608 0.192157
0.317647 0.32549 0.313726 0.321569 0.203922 0.203922 0.184314
0.290196 0.309804 0.309804 0.333333 0.219608 0.2 0.176471
0.286275 0.32549 0.309804 0.321569 … 0.227451 0.172549 0.172549
0.294118 0.301961 0.32549 0.309804 0.180392 0.172549 0.172549
0.286275 0.321569 0.317647 0.32549 0.141176 0.168627 0.2
0.313726 0.313726 0.329412 0.333333 0.145098 0.168627 0.301961
0.329412 0.337255 0.32549 0.337255 0.188235 0.172549 0.211765
0.321569 0.321569 0.337255 0.32549 … 0.235294 0.278431 0.211765
[:, :, 3, 1] =
0.027451 0.027451 0.0235294 … 0.0745098 0.0941176 0.0509804
0.027451 0.0235294 0.0235294 0.0 0.0 0.00392157
0.0235294 0.0352941 0.0352941 0.00784314 0.00784314 0.00392157
0.0313726 0.0313726 0.0431373 0.00784314 0.00784314 0.00784314
0.027451 0.0392157 0.0392157 0.00392157 0.00784314 0.00392157
0.0392157 0.0392157 0.0431373 … 0.00392157 0.00392157 0.00784314
0.0470588 0.0509804 0.0431373 0.00392157 0.00392157 0.00392157
0.0431373 0.0392157 0.0509804 0.00784314 0.0117647 0.00392157
0.0470588 0.0509804 0.0431373 0.0196078 0.0 0.0235294
0.0392157 0.0392157 0.0392157 0.0117647 0.027451 0.0588235
⋮ ⋱ ⋮
0.0431373 0.0313726 0.0470588 0.447059 0.45098 0.439216
0.0392157 0.0431373 0.0392157 0.443137 0.411765 0.356863
0.027451 0.0313726 0.0392157 0.478431 0.290196 0.172549
0.027451 0.0431373 0.0352941 … 0.407843 0.364706 0.407843
0.027451 0.0313726 0.0431373 0.317647 0.360784 0.262745
0.027451 0.0392157 0.0392157 0.294118 0.235294 0.0196078
0.0392157 0.0392157 0.0431373 0.286275 0.00392157 0.0196078
0.0470588 0.0509804 0.0431373 0.239216 0.00392157 0.0117647
0.0431373 0.0392157 0.0509804 … 0.0313726 0.027451 0.0
θ_true
correspond to the true unknown terrain weights. We use the opposite of the true weights in order to formulate the optimization problem as a maximization problem:
θ_true = sample.θ_true
12×12 Matrix{Float16}:
-0.8 -0.8 -0.8 -0.8 -0.8 -1.2 -7.7 -1.2 -1.2 -1.2 -1.2 -1.2
-0.8 -0.8 -0.8 -0.8 -0.8 -1.2 -7.7 -1.2 -0.8 -0.8 -0.8 -0.8
-0.8 -0.8 -0.8 -0.8 -0.8 -1.2 -7.7 -1.2 -1.2 -0.8 -5.3 -0.8
-0.8 -0.8 -1.2 -1.2 -1.2 -1.2 -7.7 -1.2 -0.8 -0.8 -5.3 -0.8
-0.8 -0.8 -1.2 -7.7 -7.7 -7.7 -7.7 -1.2 -0.8 -0.8 -0.8 -0.8
-0.8 -1.2 -1.2 -7.7 -7.7 -7.7 -7.7 -1.2 -0.8 -1.2 -0.8 -1.2
-0.8 -0.8 -1.2 -1.2 -1.2 -1.2 -1.2 -1.2 -1.2 -1.2 -1.2 -1.2
-0.8 -1.2 -1.2 -9.2 -1.2 -1.2 -7.7 -7.7 -7.7 -7.7 -7.7 -7.7
-0.8 -0.8 -1.2 -9.2 -1.2 -1.2 -1.2 -1.2 -7.7 -7.7 -7.7 -7.7
-0.8 -0.8 -1.2 -9.2 -9.2 -1.2 -0.8 -1.2 -7.7 -7.7 -7.7 -7.7
-0.8 -0.8 -1.2 -9.2 -9.2 -1.2 -0.8 -1.2 -7.7 -7.7 -7.7 -7.7
-0.8 -0.8 -1.2 -1.2 -1.2 -1.2 -0.8 -1.2 -1.2 -1.2 -7.7 -7.7
y_true
correspond to the optimal shortest path, encoded as a binary matrix:
y_true = sample.y_true
12×12 BitMatrix:
1 0 0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 0 0 0 0 0 0
0 1 0 0 0 0 0 0 0 0 0 0
0 0 1 0 0 0 0 0 0 0 0 0
0 0 0 1 0 0 0 0 0 0 0 0
0 0 0 0 1 0 0 0 0 0 0 0
0 0 0 0 0 1 0 0 0 0 0 0
0 0 0 0 0 0 1 0 0 0 0 0
0 0 0 0 0 0 0 1 0 0 1 0
0 0 0 0 0 0 0 0 1 1 0 1
instance
is not used in this benchmark, therefore set to nothing:
isnothing(sample.instance)
true
For some benchmarks, we provide the following plotting method plot_data
to visualize the data:
plot_data(b, sample)
We can see here the terrain image, the true terrain weights, and the true shortest path avoiding the high cost cells.
Building a pipeline
DecisionFocusedLearningBenchmarks also provides methods to build an hybrid machine learning and combinatorial optimization pipeline for the benchmark. First, the generate_statistical_model
method generates a machine learning predictor to predict cell weights from the input image:
model = generate_statistical_model(b)
Chain(
Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false), # 9_408 parameters
BatchNorm(64, relu), # 128 parameters, plus 128
MaxPool((3, 3), pad=1, stride=2),
Parallel(
PartialFunction(
"",
Metalhead.addact,
(NNlib.relu,),
NamedTuple(),
),
identity,
Chain(
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64), # 128 parameters, plus 128
NNlib.relu,
Conv((3, 3), 64 => 64, pad=1, bias=false), # 36_864 parameters
BatchNorm(64), # 128 parameters, plus 128
),
),
AdaptiveMaxPool((12, 12)),
DecisionFocusedLearningBenchmarks.Utils.average_tensor,
DecisionFocusedLearningBenchmarks.Utils.neg_tensor,
DecisionFocusedLearningBenchmarks.Utils.squeeze_last_dims,
) # Total: 9 trainable arrays, 83_520 parameters,
# plus 6 non-trainable, 384 parameters, summarysize 328.945 KiB.
In the case of the Warcraft benchmark, the model is a convolutional neural network built using the Flux.jl package.
θ = model(x)
12×12 Matrix{Float32}:
-0.721526 -0.718137 -0.717626 … -0.731732 -0.7317 -0.736122
-0.718858 -0.714253 -0.713772 -0.726469 -0.725278 -0.731283
-0.718467 -0.715424 -0.721073 -0.726078 -0.720642 -0.726064
-0.718754 -0.719024 -0.728877 -0.723631 -0.721364 -0.727132
-0.719706 -0.722873 -0.729793 -0.722864 -0.725774 -0.727845
-0.722395 -0.728315 -0.729841 … -0.728819 -0.729245 -0.733978
-0.72076 -0.72767 -0.729481 -0.730372 -0.731709 -0.736237
-0.722009 -0.728514 -0.73406 -0.731785 -0.734136 -0.736757
-0.72086 -0.727749 -0.736777 -0.719128 -0.717915 -0.723152
-0.718745 -0.721057 -0.734956 -0.717407 -0.714833 -0.72134
-0.718924 -0.720848 -0.73593 … -0.724723 -0.720396 -0.722206
-0.72244 -0.724106 -0.737705 -0.741514 -0.738982 -0.729798
Note that the model is not trained yet, and its parameters are randomly initialized.
Finally, the generate_maximizer
method can be used to generate a combinatorial optimization algorithm that takes the predicted cell weights as input and returns the corresponding shortest path:
maximizer = generate_maximizer(b; dijkstra=true)
dijkstra_maximizer (generic function with 1 method)
In the case o fthe Warcraft benchmark, the method has an additional keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm.
y = maximizer(θ)
12×12 Matrix{Int64}:
1 0 0 0 0 0 0 0 0 0 0 0
0 1 0 0 0 0 0 0 0 0 0 0
0 0 1 0 0 0 0 0 0 0 0 0
0 0 0 1 0 0 0 0 0 0 0 0
0 0 0 0 1 0 0 0 0 0 0 0
0 0 0 0 0 1 0 0 0 0 0 0
0 0 0 0 0 0 1 0 0 0 0 0
0 0 0 0 0 0 0 1 0 0 0 0
0 0 0 0 0 0 0 0 1 0 0 0
0 0 0 0 0 0 0 0 0 1 0 0
0 0 0 0 0 0 0 0 0 0 1 0
0 0 0 0 0 0 0 0 0 0 0 1
As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path.
plot_data(b, DataSample(; x, θ_true=θ, y_true=y))
We can evaluate the current pipeline performance using the optimality gap metric:
starting_gap = compute_gap(b, test_dataset, model, maximizer)
Float16(0.775)
Using a learning algorithm
We can now train the model using the InferOpt.jl package:
using InferOpt
using Flux
using Plots
perturbed_maximizer = PerturbedMultiplicative(maximizer; ε=0.2, nb_samples=100)
loss = FenchelYoungLoss(perturbed_maximizer)
starting_gap = compute_gap(b, test_dataset, model, maximizer)
opt_state = Flux.setup(Adam(1e-3), model)
loss_history = Float64[]
for epoch in 1:50
val, grads = Flux.withgradient(model) do m
sum(loss(m(x), y_true) for (; x, y_true) in train_dataset) / length(train_dataset)
end
Flux.update!(opt_state, model, grads[1])
push!(loss_history, val)
end
plot(loss_history; xlabel="Epoch", ylabel="Loss", title="Training loss")
final_gap = compute_gap(b, test_dataset, model, maximizer)
Float16(0.02493)
θ = model(x)
y = maximizer(θ)
plot_data(b, DataSample(; x, θ_true=θ, y_true=y))
This page was generated using Literate.jl.