Path-finding on image maps

In this tutorial, we showcase DecisionFocusedLearningBenchmarks.jl capabilities on one of its main benchmarks: the Warcraft benchmark. This benchmark problem is a simple path-finding problem where the goal is to find the shortest path between the top left and bottom right corners of a given image map. The map is represented as a 2D image representing a 12x12 grid, each cell having an unknown travel cost depending on the terrain type.

First, let's load the package and create a benchmark object as follows:

using DecisionFocusedLearningBenchmarks
b = WarcraftBenchmark()
WarcraftBenchmark()

Dataset generation

These benchmark objects behave as generators that can generate various needed elements in order to build an algorithm to tackle the problem. First of all, all benchmarks are capable of generating datasets as needed, using the generate_dataset method. This method takes as input the benchmark object for which the dataset is to be generated, and a second argument specifying the number of samples to generate:

dataset = generate_dataset(b, 50);

We obtain a vector of DataSample objects, containing all needed data for the problem. Subdatasets can be created through regular slicing:

train_dataset, test_dataset = dataset[1:45], dataset[46:50]
(DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}[DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.13333334 0.13333334 … 0.16470589 0.23529412; 0.1254902 0.105882354 … 0.16862746 0.3137255; … ; 0.24705882 0.24705882 … 0.019607844 0.015686275; 0.38431373 0.42352942 … 0.015686275 0.019607844;;; 0.30980393 0.29411766 … 0.34117648 0.35686275; 0.29803923 0.28235295 … 0.32941177 0.25490198; … ; 0.30588236 0.26666668 … 0.22745098 0.21960784; 0.23921569 0.23921569 … 0.21960784 0.22745098;;; 0.039215688 0.02745098 … 0.050980393 0.07058824; 0.03137255 0.023529412 … 0.047058824 0.023529412; … ; 0.043137256 0.02745098 … 0.45882353 0.45490196; 0.003921569 0.0 … 0.45490196 0.45882353;;;;], Float16[-0.8 -0.8 … -1.2 -0.8; -0.8 -0.8 … -1.2 -1.2; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.019607844 0.019607844 … 0.42352942 0.4117647; 0.015686275 0.015686275 … 0.42745098 0.44313726; … ; 0.019607844 0.019607844 … 0.4392157 0.4392157; 0.019607844 0.019607844 … 0.43137255 0.43529412;;; 0.21960784 0.21568628 … 0.2509804 0.23921569; 0.21960784 0.20784314 … 0.2509804 0.25882354; … ; 0.22745098 0.22745098 … 0.25882354 0.25882354; 0.23137255 0.22745098 … 0.25490198 0.25882354;;; 0.45490196 0.44705883 … 0.007843138 0.003921569; 0.45490196 0.44313726 … 0.007843138 0.011764706; … ; 0.45882353 0.4627451 … 0.007843138 0.007843138; 0.4627451 0.45882353 … 0.003921569 0.007843138;;;;], Float16[-7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -1.2 -1.2; … ; -7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -9.2 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.22352941 0.0 … 0.16078432 0.16470589; 0.20392157 0.05882353 … 0.15686275 0.15686275; … ; 0.023529412 0.003921569 … 0.003921569 0.023529412; 0.03529412 0.011764706 … 0.0 0.0;;; 0.40784314 0.2509804 … 0.3254902 0.3372549; 0.35686275 0.3254902 … 0.32156864 0.32941177; … ; 0.19607843 0.21568628 … 0.20784314 0.23529412; 0.21568628 0.19607843 … 0.23137255 0.24705882;;; 0.08235294 0.0 … 0.043137256 0.050980393; 0.05490196 0.03529412 … 0.043137256 0.043137256; … ; 0.0 0.0 … 0.0 0.007843138; 0.019607844 0.0 … 0.0 0.0;;;;], Float16[-5.3 -5.3 … -1.2 -0.8; -5.3 -5.3 … -1.2 -0.8; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.43137255 0.45490196 … 0.015686275 0.019607844; 0.4117647 0.36862746 … 0.015686275 0.015686275; … ; 0.14901961 0.15294118 … 0.15294118 0.16470589; 0.15686275 0.13725491 … 0.16470589 0.16470589;;; 0.2509804 0.25882354 … 0.21960784 0.22745098; 0.27058825 0.2784314 … 0.21960784 0.21960784; … ; 0.3137255 0.32156864 … 0.32156864 0.33333334; 0.3254902 0.3019608 … 0.33333334 0.33333334;;; 0.003921569 0.007843138 … 0.45490196 0.4627451; 0.015686275 0.023529412 … 0.45490196 0.45490196; … ; 0.039215688 0.043137256 … 0.039215688 0.047058824; 0.043137256 0.03137255 … 0.047058824 0.047058824;;;;], Float16[-1.2 -7.7 … -7.7 -7.7; -7.7 -1.2 … -7.7 -7.7; … ; -0.8 -0.8 … -0.8 -5.3; -0.8 -1.2 … -0.8 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.15294118 0.31764707 … 0.43529412 0.40392157; 0.3764706 0.30588236 … 0.43137255 0.4117647; … ; 0.2 0.12156863 … 0.36078432 0.4117647; 0.09019608 0.17254902 … 0.21960784 0.44705883;;; 0.15294118 0.31764707 … 0.25882354 0.23921569; 0.3764706 0.30588236 … 0.25490198 0.24313726; … ; 0.2 0.12156863 … 0.22352941 0.28627452; 0.09019608 0.17254902 … 0.16470589 0.25882354;;; 0.15294118 0.31764707 … 0.007843138 0.0; 0.3764706 0.30588236 … 0.003921569 0.0; … ; 0.2 0.12156863 … 0.02745098 0.10980392; 0.09019608 0.17254902 … 0.08627451 0.0;;;;], Float16[-9.2 -9.2 … -0.8 -1.2; -9.2 -9.2 … -0.8 -1.2; … ; -9.2 -9.2 … -9.2 -1.2; -9.2 -9.2 … -9.2 -1.2], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 1; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.023529412 0.019607844 … 0.16470589 0.16470589; 0.015686275 0.015686275 … 0.16862746 0.16470589; … ; 0.16862746 0.16470589 … 0.43137255 0.42745098; 0.15686275 0.16078432 … 0.43137255 0.42745098;;; 0.22745098 0.22352941 … 0.33333334 0.32941177; 0.21568628 0.20784314 … 0.3372549 0.32941177; … ; 0.3372549 0.3372549 … 0.25490198 0.25490198; 0.32941177 0.31764707 … 0.25490198 0.25490198;;; 0.45882353 0.45490196 … 0.047058824 0.047058824; 0.4509804 0.44313726 … 0.047058824 0.047058824; … ; 0.047058824 0.047058824 … 0.003921569 0.003921569; 0.043137256 0.039215688 … 0.003921569 0.003921569;;;;], Float16[-7.7 -7.7 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8; … ; -1.2 -1.2 … -9.2 -9.2; -0.8 -1.2 … -9.2 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.40784314 0.41568628 … 0.38039216 0.28627452; 0.40784314 0.4117647 … 0.3764706 0.30980393; … ; 0.1254902 0.11372549 … 0.15294118 0.16078432; 0.12941177 0.12941177 … 0.15686275 0.15294118;;; 0.23921569 0.24313726 … 0.20392157 0.2784314; 0.23921569 0.24313726 … 0.2784314 0.25882354; … ; 0.29803923 0.2784314 … 0.31764707 0.32941177; 0.29411766 0.29803923 … 0.3254902 0.33333334;;; 0.003921569 0.007843138 … 0.0 0.03529412; 0.003921569 0.003921569 … 0.023529412 0.023529412; … ; 0.03137255 0.019607844 … 0.039215688 0.047058824; 0.02745098 0.03137255 … 0.043137256 0.047058824;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -0.8 … -1.2 -1.2; … ; -0.8 -1.2 … -5.3 -5.3; -0.8 -0.8 … -0.8 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.015686275 0.015686275 … 0.4392157 0.42745098; 0.015686275 0.015686275 … 0.43137255 0.43529412; … ; 0.3372549 0.3647059 … 0.16862746 0.19607843; 0.2627451 0.36862746 … 0.22745098 0.34117648;;; 0.20784314 0.2 … 0.25882354 0.2509804; 0.20392157 0.20784314 … 0.25490198 0.25490198; … ; 0.23921569 0.27450982 … 0.34901962 0.3764706; 0.31764707 0.24313726 … 0.36078432 0.24705882;;; 0.44313726 0.43529412 … 0.011764706 0.007843138; 0.4392157 0.44313726 … 0.007843138 0.007843138; … ; 0.015686275 0.019607844 … 0.05490196 0.078431375; 0.047058824 0.011764706 … 0.07058824 0.019607844;;;;], Float16[-7.7 -7.7 … -1.2 -1.2; -7.7 -7.7 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -1.2; -1.2 -1.2 … -0.8 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.023529412 0.019607844 … 0.44313726 0.4392157; 0.015686275 0.015686275 … 0.43529412 0.44313726; … ; 0.39607844 0.39215687 … 0.4392157 0.4392157; 0.39215687 0.39215687 … 0.43137255 0.43529412;;; 0.23529412 0.23137255 … 0.25882354 0.25882354; 0.21960784 0.21568628 … 0.25882354 0.25882354; … ; 0.23529412 0.23137255 … 0.25882354 0.25882354; 0.23137255 0.22745098 … 0.25490198 0.25882354;;; 0.46666667 0.4627451 … 0.007843138 0.007843138; 0.45490196 0.45490196 … 0.007843138 0.007843138; … ; 0.0 0.0 … 0.007843138 0.007843138; 0.0 0.0 … 0.003921569 0.007843138;;;;], Float16[-7.7 -7.7 … -7.7 -1.2; -7.7 -7.7 … -7.7 -1.2; … ; -1.2 -1.2 … -1.2 -9.2; -1.2 -1.2 … -1.2 -1.2], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.0 0.023529412 … 0.4 0.39215687; 0.019607844 0.003921569 … 0.39215687 0.3764706; … ; 0.13725491 0.13725491 … 0.15686275 0.15686275; 0.1254902 0.105882354 … 0.16862746 0.16862746;;; 0.1882353 0.32156864 … 0.23529412 0.22745098; 0.27058825 0.24705882 … 0.23137255 0.21960784; … ; 0.3137255 0.30980393 … 0.34901962 0.33333334; 0.29803923 0.28235295 … 0.34509805 0.3372549;;; 0.0 0.007843138 … 0.0 0.0; 0.015686275 0.003921569 … 0.0 0.0; … ; 0.039215688 0.03529412 … 0.05882353 0.047058824; 0.03137255 0.023529412 … 0.05882353 0.050980393;;;;], Float16[-5.3 -5.3 … -1.2 -1.2; -5.3 -5.3 … -1.2 -1.2; … ; -5.3 -5.3 … -1.2 -1.2; -0.8 -0.8 … -0.8 -1.2], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)  …  DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.40392157 0.43137255 … 0.44313726 0.43529412; 0.43529412 0.43529412 … 0.42745098 0.43529412; … ; 0.023529412 0.003921569 … 0.003921569 0.023529412; 0.03529412 0.011764706 … 0.0 0.0;;; 0.23529412 0.25490198 … 0.25882354 0.25490198; 0.25882354 0.25490198 … 0.2509804 0.25490198; … ; 0.19607843 0.21568628 … 0.20784314 0.23529412; 0.21568628 0.19607843 … 0.23137255 0.24705882;;; 0.0 0.007843138 … 0.007843138 0.007843138; 0.007843138 0.007843138 … 0.007843138 0.007843138; … ; 0.0 0.0 … 0.0 0.007843138; 0.019607844 0.0 … 0.0 0.0;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -1.2; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.40784314 0.41568628 … 0.38039216 0.28627452; 0.40784314 0.4117647 … 0.3764706 0.30980393; … ; 0.16470589 0.16470589 … 0.44313726 0.42745098; 0.15294118 0.16078432 … 0.42352942 0.41960785;;; 0.23921569 0.24313726 … 0.20392157 0.2784314; 0.23921569 0.24313726 … 0.2784314 0.25882354; … ; 0.32941177 0.3254902 … 0.2627451 0.2509804; 0.32156864 0.32941177 … 0.2509804 0.24705882;;; 0.003921569 0.007843138 … 0.0 0.03529412; 0.003921569 0.003921569 … 0.023529412 0.023529412; … ; 0.047058824 0.043137256 … 0.007843138 0.003921569; 0.039215688 0.047058824 … 0.003921569 0.007843138;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -7.7 … -1.2 -1.2; … ; -0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -1.2 -1.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.27450982 0.019607844 … 0.003921569 0.14117648; 0.13725491 0.0627451 … 0.019607844 0.08235294; … ; 0.42745098 0.42745098 … 0.39607844 0.4; 0.42745098 0.40392157 … 0.39215687 0.39607844;;; 0.47058824 0.23137255 … 0.25490198 0.3764706; 0.40392157 0.31764707 … 0.23137255 0.35686275; … ; 0.2509804 0.2509804 … 0.23529412 0.23529412; 0.2509804 0.23529412 … 0.22745098 0.23137255;;; 0.14117648 0.015686275 … 0.003921569 0.07450981; 0.07058824 0.039215688 … 0.007843138 0.03529412; … ; 0.007843138 0.003921569 … 0.0 0.0; 0.003921569 0.003921569 … 0.0 0.0;;;;], Float16[-5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3; … ; -1.2 -0.8 … -1.2 -1.2; -1.2 -0.8 … -1.2 -1.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.14509805 0.13725491 … 0.12941177 0.13725491; 0.12156863 0.1254902 … 0.11764706 0.10980392; … ; 0.023529412 0.003921569 … 0.078431375 0.011764706; 0.03529412 0.011764706 … 0.03529412 0.05882353;;; 0.30980393 0.30980393 … 0.29411766 0.30588236; 0.29411766 0.29411766 … 0.29803923 0.28627452; … ; 0.19607843 0.21568628 … 0.3019608 0.21176471; 0.21568628 0.19607843 … 0.28235295 0.3019608;;; 0.03529412 0.03529412 … 0.02745098 0.03137255; 0.02745098 0.02745098 … 0.03137255 0.02745098; … ; 0.0 0.0 … 0.03529412 0.007843138; 0.019607844 0.0 … 0.015686275 0.03529412;;;;], Float16[-0.8 -0.8 … -0.8 -0.8; -0.8 -0.8 … -0.8 -0.8; … ; -5.3 -5.3 … -5.3 -5.3; -5.3 -5.3 … -5.3 -5.3], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.45882353 0.28235295 … 0.43137255 0.41568628; 0.35686275 0.17254902 … 0.42745098 0.43137255; … ; 0.38431373 0.39607844 … 0.17254902 0.16862746; 0.3882353 0.39215687 … 0.16470589 0.17254902;;; 0.2509804 0.27058825 … 0.25490198 0.24705882; 0.27058825 0.19607843 … 0.25490198 0.25490198; … ; 0.22352941 0.23137255 … 0.3529412 0.35686275; 0.22745098 0.23137255 … 0.3372549 0.34509805;;; 0.003921569 0.015686275 … 0.003921569 0.0; 0.007843138 0.011764706 … 0.003921569 0.003921569; … ; 0.0 0.0 … 0.05882353 0.0627451; 0.0 0.0 … 0.050980393 0.05490196;;;;], Float16[-7.7 -1.2 … -7.7 -1.2; -7.7 -1.2 … -1.2 -7.7; … ; -1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.12156863 0.23137255 … 0.43529412 0.40784314; 0.12156863 0.26666668 … 0.43529412 0.42352942; … ; 0.1254902 0.1254902 … 0.3882353 0.4509804; 0.11764706 0.105882354 … 0.43137255 0.4117647;;; 0.12156863 0.23137255 … 0.25490198 0.23921569; 0.12156863 0.26666668 … 0.25882354 0.2509804; … ; 0.3019608 0.3019608 … 0.2627451 0.2509804; 0.29411766 0.28235295 … 0.2627451 0.26666668;;; 0.12156863 0.23137255 … 0.007843138 0.0; 0.12156863 0.26666668 … 0.003921569 0.0; … ; 0.03137255 0.03137255 … 0.015686275 0.003921569; 0.03137255 0.023529412 … 0.007843138 0.011764706;;;;], Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -1.2 -1.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 0 0; 0 0 … 1 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.023529412 0.019607844 … 0.42745098 0.42352942; 0.015686275 0.015686275 … 0.43529412 0.44313726; … ; 0.015686275 0.015686275 … 0.11372549 0.1254902; 0.015686275 0.019607844 … 0.12941177 0.12941177;;; 0.23529412 0.23137255 … 0.25490198 0.2509804; 0.21960784 0.21568628 … 0.25882354 0.2627451; … ; 0.21176471 0.21176471 … 0.2784314 0.29803923; 0.21568628 0.22352941 … 0.29803923 0.29411766;;; 0.46666667 0.4627451 … 0.003921569 0.0; 0.45490196 0.45490196 … 0.007843138 0.011764706; … ; 0.44705883 0.44705883 … 0.019607844 0.03137255; 0.44705883 0.45490196 … 0.03137255 0.02745098;;;;], Float16[-7.7 -1.2 … -1.2 -1.2; -7.7 -1.2 … -0.8 -1.2; … ; -7.7 -1.2 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8], Bool[1 0 … 0 0; 0 1 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.16078432 0.16078432 … 0.13333334 0.1254902; 0.16078432 0.15686275 … 0.105882354 0.11372549; … ; 0.3647059 0.3882353 … 0.43529412 0.4392157; 0.34117648 0.41568628 … 0.44705883 0.43529412;;; 0.32156864 0.32941177 … 0.29411766 0.29411766; 0.33333334 0.3137255 … 0.28235295 0.28627452; … ; 0.27058825 0.27450982 … 0.25882354 0.25882354; 0.25490198 0.25490198 … 0.25490198 0.2627451;;; 0.039215688 0.043137256 … 0.02745098 0.02745098; 0.047058824 0.03529412 … 0.023529412 0.02745098; … ; 0.015686275 0.019607844 … 0.007843138 0.007843138; 0.007843138 0.007843138 … 0.003921569 0.007843138;;;;], Float16[-0.8 -5.3 … -0.8 -0.8; -5.3 -5.3 … -1.2 -0.8; … ; -0.8 -5.3 … -7.7 -7.7; -0.8 -0.8 … -7.7 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1254902 0.13333334 … 0.14509805 0.15686275; 0.11372549 0.105882354 … 0.13725491 0.15294118; … ; 0.043137256 0.0 … 0.17254902 0.13725491; 0.07450981 0.011764706 … 0.14509805 0.16470589;;; 0.29411766 0.29411766 … 0.3137255 0.32941177; 0.28627452 0.28235295 … 0.30588236 0.31764707; … ; 0.24705882 0.20784314 … 0.33333334 0.3137255; 0.32156864 0.21960784 … 0.3137255 0.36078432;;; 0.02745098 0.02745098 … 0.039215688 0.043137256; 0.02745098 0.023529412 … 0.03529412 0.03529412; … ; 0.011764706 0.0 … 0.047058824 0.03529412; 0.02745098 0.003921569 … 0.03529412 0.0627451;;;;], Float16[-0.8 -0.8 … -0.8 -0.8; -0.8 -0.8 … -1.2 -0.8; … ; -5.3 -0.8 … -1.2 -1.2; -5.3 -0.8 … -1.2 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.4392157 0.44313726 … 0.43137255 0.41568628; 0.42745098 0.43529412 … 0.42745098 0.43137255; … ; 0.019607844 0.015686275 … 0.11372549 0.1254902; 0.015686275 0.015686275 … 0.12941177 0.12941177;;; 0.25490198 0.24705882 … 0.25490198 0.24705882; 0.2509804 0.25882354 … 0.25490198 0.25490198; … ; 0.21176471 0.20392157 … 0.2784314 0.29411766; 0.20392157 0.20392157 … 0.29803923 0.29803923;;; 0.003921569 0.0 … 0.003921569 0.0; 0.0 0.007843138 … 0.003921569 0.003921569; … ; 0.44313726 0.4392157 … 0.019607844 0.03137255; 0.4392157 0.4392157 … 0.03137255 0.02745098;;;;], Float16[-1.2 -1.2 … -1.2 -1.2; -1.2 -1.2 … -1.2 -1.2; … ; -7.7 -7.7 … -0.8 -0.8; -7.7 -7.7 … -0.8 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)], DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}[DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1254902 0.13333334 … 0.16078432 0.16078432; 0.11372549 0.105882354 … 0.3882353 0.40392157; … ; 0.15686275 0.16470589 … 0.17254902 0.28627452; 0.14901961 0.16078432 … 0.26666668 0.3764706;;; 0.29411766 0.29411766 … 0.23921569 0.1882353; 0.28627452 0.28235295 … 0.24705882 0.2627451; … ; 0.32941177 0.3372549 … 0.17254902 0.21176471; 0.32156864 0.32156864 … 0.2784314 0.21176471;;; 0.02745098 0.02745098 … 0.09411765 0.050980393; 0.02745098 0.023529412 … 0.0 0.003921569; … ; 0.047058824 0.050980393 … 0.003921569 0.011764706; 0.043137256 0.039215688 … 0.02745098 0.0;;;;], Float16[-0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -0.8 -0.8; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.43137255 0.41568628 … 0.101960786 0.16862746; 0.45490196 0.3372549 … 0.21960784 0.3254902; … ; 0.15686275 0.15686275 … 0.3882353 0.4509804; 0.14117648 0.12941177 … 0.43137255 0.4117647;;; 0.2509804 0.25490198 … 0.32156864 0.3529412; 0.2627451 0.3019608 … 0.44705883 0.5176471; … ; 0.3254902 0.3254902 … 0.2627451 0.2509804; 0.30980393 0.29803923 … 0.2627451 0.26666668;;; 0.003921569 0.003921569 … 0.047058824 0.08235294; 0.007843138 0.03529412 … 0.09411765 0.14117648; … ; 0.043137256 0.043137256 … 0.015686275 0.003921569; 0.03529412 0.02745098 … 0.007843138 0.011764706;;;;], Float16[-0.8 -0.8 … -5.3 -5.3; -0.8 -5.3 … -5.3 -5.3; … ; -0.8 -1.2 … -0.8 -1.2; -0.8 -0.8 … -0.8 -1.2], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.13725491 0.23137255 … 0.12941177 0.1254902; 0.38431373 0.2627451 … 0.13333334 0.1254902; … ; 0.34901962 0.23529412 … 0.39215687 0.34901962; 0.17254902 0.19215687 … 0.4117647 0.34901962;;; 0.13725491 0.23137255 … 0.29411766 0.2901961; 0.38431373 0.25882354 … 0.29803923 0.2901961; … ; 0.34901962 0.23529412 … 0.24313726 0.2627451; 0.17254902 0.19215687 … 0.25882354 0.25490198;;; 0.12941177 0.23529412 … 0.02745098 0.02745098; 0.38431373 0.25882354 … 0.02745098 0.023529412; … ; 0.34901962 0.23529412 … 0.003921569 0.015686275; 0.17254902 0.19215687 … 0.007843138 0.011764706;;;;], Float16[-9.2 -9.2 … -0.8 -0.8; -9.2 -9.2 … -0.8 -0.8; … ; -9.2 -1.2 … -0.8 -0.8; -9.2 -1.2 … -0.8 -0.8], Bool[1 1 … 0 0; 0 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1882353 0.3137255 … 0.42352942 0.40392157; 0.25490198 0.32941177 … 0.4117647 0.4; … ; 0.4117647 0.39607844 … 0.12156863 0.11764706; 0.43529412 0.38431373 … 0.10980392 0.12941177;;; 0.1882353 0.3137255 … 0.2509804 0.23529412; 0.25490198 0.32941177 … 0.24313726 0.23529412; … ; 0.24313726 0.23137255 … 0.28235295 0.28627452; 0.25490198 0.22352941 … 0.28627452 0.2901961;;; 0.1882353 0.3137255 … 0.003921569 0.0; 0.25490198 0.32941177 … 0.003921569 0.0; … ; 0.007843138 0.0 … 0.023529412 0.023529412; 0.011764706 0.0 … 0.023529412 0.023529412;;;;], Float16[-9.2 -9.2 … -1.2 -1.2; -9.2 -9.2 … -1.2 -1.2; … ; -1.2 -1.2 … -0.8 -0.8; -1.2 -1.2 … -0.8 -0.8], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing), DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.15686275 0.14901961 … 0.41568628 0.3647059; 0.14901961 0.14509805 … 0.3137255 0.34901962; … ; 0.16862746 0.16470589 … 0.1882353 0.23529412; 0.16470589 0.16078432 … 0.12941177 0.15686275;;; 0.32941177 0.3137255 … 0.27058825 0.24313726; 0.31764707 0.30980393 … 0.2627451 0.27058825; … ; 0.33333334 0.3372549 … 0.1882353 0.23529412; 0.32941177 0.32156864 … 0.12941177 0.15686275;;; 0.047058824 0.03529412 … 0.011764706 0.011764706; 0.039215688 0.03529412 … 0.019607844 0.023529412; … ; 0.047058824 0.047058824 … 0.1882353 0.23529412; 0.043137256 0.039215688 … 0.12941177 0.15686275;;;;], Float16[-0.8 -1.2 … -9.2 -1.2; -0.8 -1.2 … -9.2 -1.2; … ; -0.8 -1.2 … -9.2 -9.2; -0.8 -0.8 … -9.2 -9.2], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)])

And getting an individual sample will return a DataSample with four fields: x, instance, θ, and y:

sample = test_dataset[1]
DataSample{Nothing, Array{Float32, 4}, BitMatrix, Matrix{Float16}}(Float32[0.1254902 0.13333334 … 0.16078432 0.16078432; 0.11372549 0.105882354 … 0.3882353 0.40392157; … ; 0.15686275 0.16470589 … 0.17254902 0.28627452; 0.14901961 0.16078432 … 0.26666668 0.3764706;;; 0.29411766 0.29411766 … 0.23921569 0.1882353; 0.28627452 0.28235295 … 0.24705882 0.2627451; … ; 0.32941177 0.3372549 … 0.17254902 0.21176471; 0.32156864 0.32156864 … 0.2784314 0.21176471;;; 0.02745098 0.02745098 … 0.09411765 0.050980393; 0.02745098 0.023529412 … 0.0 0.003921569; … ; 0.047058824 0.050980393 … 0.003921569 0.011764706; 0.043137256 0.039215688 … 0.02745098 0.0;;;;], Float16[-0.8 -0.8 … -1.2 -1.2; -0.8 -0.8 … -0.8 -0.8; … ; -0.8 -0.8 … -7.7 -7.7; -0.8 -0.8 … -7.7 -7.7], Bool[1 0 … 0 0; 1 0 … 0 0; … ; 0 0 … 1 0; 0 0 … 0 1], nothing)

x correspond to the input features, i.e. the input image (3D array) in the Warcraft benchmark case:

x = sample.x
96×96×3×1 Array{Float32, 4}:
[:, :, 1, 1] =
 0.12549   0.133333  0.117647  0.109804  …  0.176471   0.160784   0.160784
 0.113725  0.105882  0.117647  0.145098     0.411765   0.388235   0.403922
 0.101961  0.137255  0.14902   0.141176     0.439216   0.443137   0.439216
 0.121569  0.137255  0.156863  0.137255     0.435294   0.435294   0.435294
 0.113725  0.152941  0.156863  0.156863     0.435294   0.439216   0.431373
 0.137255  0.145098  0.156863  0.164706  …  0.435294   0.435294   0.435294
 0.156863  0.164706  0.160784  0.164706     0.427451   0.431373   0.427451
 0.14902   0.160784  0.168627  0.156863     0.454902   0.4        0.364706
 0.160784  0.164706  0.160784  0.152941     0.372549   0.427451   0.294118
 0.156863  0.156863  0.156863  0.164706     0.286275   0.301961   0.231373
 ⋮                                       ⋱                        ⋮
 0.156863  0.133333  0.156863  0.160784     0.0196078  0.0196078  0.0
 0.141176  0.14902   0.141176  0.156863     0.0117647  0.027451   0.0431373
 0.117647  0.145098  0.141176  0.168627     0.0        0.105882   0.168627
 0.109804  0.160784  0.141176  0.152941  …  0.0509804  0.0235294  0.00784314
 0.117647  0.129412  0.152941  0.14902      0.0588235  0.0313726  0.0666667
 0.113725  0.152941  0.156863  0.156863     0.0352941  0.113725   0.211765
 0.137255  0.145098  0.156863  0.164706     0.0588235  0.196078   0.258824
 0.156863  0.164706  0.160784  0.164706     0.133333   0.172549   0.286275
 0.14902   0.160784  0.168627  0.156863  …  0.172549   0.266667   0.376471

[:, :, 2, 1] =
 0.294118  0.294118  0.290196  0.286275  …  0.231373  0.239216  0.188235
 0.286275  0.282353  0.286275  0.305882     0.247059  0.247059  0.262745
 0.282353  0.309804  0.313726  0.305882     0.254902  0.254902  0.254902
 0.294118  0.305882  0.32549   0.301961     0.258824  0.258824  0.258824
 0.286275  0.321569  0.317647  0.32549      0.254902  0.258824  0.254902
 0.313726  0.313726  0.329412  0.333333  …  0.254902  0.258824  0.258824
 0.329412  0.337255  0.32549   0.337255     0.254902  0.25098   0.25098
 0.321569  0.321569  0.337255  0.32549      0.254902  0.266667  0.231373
 0.329412  0.337255  0.32549   0.317647     0.262745  0.231373  0.262745
 0.32549   0.317647  0.321569  0.333333     0.223529  0.254902  0.337255
 ⋮                                       ⋱                      ⋮
 0.32549   0.298039  0.329412  0.32549      0.211765  0.219608  0.192157
 0.317647  0.32549   0.313726  0.321569     0.203922  0.203922  0.184314
 0.290196  0.309804  0.309804  0.333333     0.219608  0.2       0.176471
 0.286275  0.32549   0.309804  0.321569  …  0.227451  0.172549  0.172549
 0.294118  0.301961  0.32549   0.309804     0.180392  0.172549  0.172549
 0.286275  0.321569  0.317647  0.32549      0.141176  0.168627  0.2
 0.313726  0.313726  0.329412  0.333333     0.145098  0.168627  0.301961
 0.329412  0.337255  0.32549   0.337255     0.188235  0.172549  0.211765
 0.321569  0.321569  0.337255  0.32549   …  0.235294  0.278431  0.211765

[:, :, 3, 1] =
 0.027451   0.027451   0.0235294  …  0.0745098   0.0941176   0.0509804
 0.027451   0.0235294  0.0235294     0.0         0.0         0.00392157
 0.0235294  0.0352941  0.0352941     0.00784314  0.00784314  0.00392157
 0.0313726  0.0313726  0.0431373     0.00784314  0.00784314  0.00784314
 0.027451   0.0392157  0.0392157     0.00392157  0.00784314  0.00392157
 0.0392157  0.0392157  0.0431373  …  0.00392157  0.00392157  0.00784314
 0.0470588  0.0509804  0.0431373     0.00392157  0.00392157  0.00392157
 0.0431373  0.0392157  0.0509804     0.00784314  0.0117647   0.00392157
 0.0470588  0.0509804  0.0431373     0.0196078   0.0         0.0235294
 0.0392157  0.0392157  0.0392157     0.0117647   0.027451    0.0588235
 ⋮                                ⋱                          ⋮
 0.0431373  0.0313726  0.0470588     0.447059    0.45098     0.439216
 0.0392157  0.0431373  0.0392157     0.443137    0.411765    0.356863
 0.027451   0.0313726  0.0392157     0.478431    0.290196    0.172549
 0.027451   0.0431373  0.0352941  …  0.407843    0.364706    0.407843
 0.027451   0.0313726  0.0431373     0.317647    0.360784    0.262745
 0.027451   0.0392157  0.0392157     0.294118    0.235294    0.0196078
 0.0392157  0.0392157  0.0431373     0.286275    0.00392157  0.0196078
 0.0470588  0.0509804  0.0431373     0.239216    0.00392157  0.0117647
 0.0431373  0.0392157  0.0509804  …  0.0313726   0.027451    0.0

θ_true correspond to the true unknown terrain weights. We use the opposite of the true weights in order to formulate the optimization problem as a maximization problem:

θ_true = sample.θ_true
12×12 Matrix{Float16}:
 -0.8  -0.8  -0.8  -0.8  -0.8  -1.2  -7.7  -1.2  -1.2  -1.2  -1.2  -1.2
 -0.8  -0.8  -0.8  -0.8  -0.8  -1.2  -7.7  -1.2  -0.8  -0.8  -0.8  -0.8
 -0.8  -0.8  -0.8  -0.8  -0.8  -1.2  -7.7  -1.2  -1.2  -0.8  -5.3  -0.8
 -0.8  -0.8  -1.2  -1.2  -1.2  -1.2  -7.7  -1.2  -0.8  -0.8  -5.3  -0.8
 -0.8  -0.8  -1.2  -7.7  -7.7  -7.7  -7.7  -1.2  -0.8  -0.8  -0.8  -0.8
 -0.8  -1.2  -1.2  -7.7  -7.7  -7.7  -7.7  -1.2  -0.8  -1.2  -0.8  -1.2
 -0.8  -0.8  -1.2  -1.2  -1.2  -1.2  -1.2  -1.2  -1.2  -1.2  -1.2  -1.2
 -0.8  -1.2  -1.2  -9.2  -1.2  -1.2  -7.7  -7.7  -7.7  -7.7  -7.7  -7.7
 -0.8  -0.8  -1.2  -9.2  -1.2  -1.2  -1.2  -1.2  -7.7  -7.7  -7.7  -7.7
 -0.8  -0.8  -1.2  -9.2  -9.2  -1.2  -0.8  -1.2  -7.7  -7.7  -7.7  -7.7
 -0.8  -0.8  -1.2  -9.2  -9.2  -1.2  -0.8  -1.2  -7.7  -7.7  -7.7  -7.7
 -0.8  -0.8  -1.2  -1.2  -1.2  -1.2  -0.8  -1.2  -1.2  -1.2  -7.7  -7.7

y_true correspond to the optimal shortest path, encoded as a binary matrix:

y_true = sample.y_true
12×12 BitMatrix:
 1  0  0  0  0  0  0  0  0  0  0  0
 1  0  0  0  0  0  0  0  0  0  0  0
 1  0  0  0  0  0  0  0  0  0  0  0
 1  0  0  0  0  0  0  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  0  0  0
 0  0  1  0  0  0  0  0  0  0  0  0
 0  0  0  1  0  0  0  0  0  0  0  0
 0  0  0  0  1  0  0  0  0  0  0  0
 0  0  0  0  0  1  0  0  0  0  0  0
 0  0  0  0  0  0  1  0  0  0  0  0
 0  0  0  0  0  0  0  1  0  0  1  0
 0  0  0  0  0  0  0  0  1  1  0  1

instance is not used in this benchmark, therefore set to nothing:

isnothing(sample.instance)
true

For some benchmarks, we provide the following plotting method plot_data to visualize the data:

plot_data(b, sample)
Example block output

We can see here the terrain image, the true terrain weights, and the true shortest path avoiding the high cost cells.

Building a pipeline

DecisionFocusedLearningBenchmarks also provides methods to build an hybrid machine learning and combinatorial optimization pipeline for the benchmark. First, the generate_statistical_model method generates a machine learning predictor to predict cell weights from the input image:

model = generate_statistical_model(b)
Chain(
  Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
  BatchNorm(64, relu),                  # 128 parameters, plus 128
  MaxPool((3, 3), pad=1, stride=2),
  Parallel(
    PartialFunction(
      "",
      Metalhead.addact,
      (NNlib.relu,),
      NamedTuple(),
    ),
    identity,
    Chain(
      Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
      BatchNorm(64),                    # 128 parameters, plus 128
      NNlib.relu,
      Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
      BatchNorm(64),                    # 128 parameters, plus 128
    ),
  ),
  AdaptiveMaxPool((12, 12)),
  DecisionFocusedLearningBenchmarks.Utils.average_tensor,
  DecisionFocusedLearningBenchmarks.Utils.neg_tensor,
  DecisionFocusedLearningBenchmarks.Utils.squeeze_last_dims,
)         # Total: 9 trainable arrays, 83_520 parameters,
          # plus 6 non-trainable, 384 parameters, summarysize 328.945 KiB.

In the case of the Warcraft benchmark, the model is a convolutional neural network built using the Flux.jl package.

θ = model(x)
12×12 Matrix{Float32}:
 -0.721526  -0.718137  -0.717626  …  -0.731732  -0.7317    -0.736122
 -0.718858  -0.714253  -0.713772     -0.726469  -0.725278  -0.731283
 -0.718467  -0.715424  -0.721073     -0.726078  -0.720642  -0.726064
 -0.718754  -0.719024  -0.728877     -0.723631  -0.721364  -0.727132
 -0.719706  -0.722873  -0.729793     -0.722864  -0.725774  -0.727845
 -0.722395  -0.728315  -0.729841  …  -0.728819  -0.729245  -0.733978
 -0.72076   -0.72767   -0.729481     -0.730372  -0.731709  -0.736237
 -0.722009  -0.728514  -0.73406      -0.731785  -0.734136  -0.736757
 -0.72086   -0.727749  -0.736777     -0.719128  -0.717915  -0.723152
 -0.718745  -0.721057  -0.734956     -0.717407  -0.714833  -0.72134
 -0.718924  -0.720848  -0.73593   …  -0.724723  -0.720396  -0.722206
 -0.72244   -0.724106  -0.737705     -0.741514  -0.738982  -0.729798

Note that the model is not trained yet, and its parameters are randomly initialized.

Finally, the generate_maximizer method can be used to generate a combinatorial optimization algorithm that takes the predicted cell weights as input and returns the corresponding shortest path:

maximizer = generate_maximizer(b; dijkstra=true)
dijkstra_maximizer (generic function with 1 method)

In the case o fthe Warcraft benchmark, the method has an additional keyword argument to chose the algorithm to use: Dijkstra's algorithm or Bellman-Ford algorithm.

y = maximizer(θ)
12×12 Matrix{Int64}:
 1  0  0  0  0  0  0  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  0  0  0
 0  0  1  0  0  0  0  0  0  0  0  0
 0  0  0  1  0  0  0  0  0  0  0  0
 0  0  0  0  1  0  0  0  0  0  0  0
 0  0  0  0  0  1  0  0  0  0  0  0
 0  0  0  0  0  0  1  0  0  0  0  0
 0  0  0  0  0  0  0  1  0  0  0  0
 0  0  0  0  0  0  0  0  1  0  0  0
 0  0  0  0  0  0  0  0  0  1  0  0
 0  0  0  0  0  0  0  0  0  0  1  0
 0  0  0  0  0  0  0  0  0  0  0  1

As we can see, currently the pipeline predicts random noise as cell weights, and therefore the maximizer returns a straight line path.

plot_data(b, DataSample(; x, θ_true=θ, y_true=y))
Example block output

We can evaluate the current pipeline performance using the optimality gap metric:

starting_gap = compute_gap(b, test_dataset, model, maximizer)
Float16(0.775)

Using a learning algorithm

We can now train the model using the InferOpt.jl package:

using InferOpt
using Flux
using Plots

perturbed_maximizer = PerturbedMultiplicative(maximizer; ε=0.2, nb_samples=100)
loss = FenchelYoungLoss(perturbed_maximizer)

starting_gap = compute_gap(b, test_dataset, model, maximizer)

opt_state = Flux.setup(Adam(1e-3), model)
loss_history = Float64[]
for epoch in 1:50
    val, grads = Flux.withgradient(model) do m
        sum(loss(m(x), y_true) for (; x, y_true) in train_dataset) / length(train_dataset)
    end
    Flux.update!(opt_state, model, grads[1])
    push!(loss_history, val)
end

plot(loss_history; xlabel="Epoch", ylabel="Loss", title="Training loss")
Example block output
final_gap = compute_gap(b, test_dataset, model, maximizer)
Float16(0.0)
θ = model(x)
y = maximizer(θ)
plot_data(b, DataSample(; x, θ_true=θ, y_true=y))
Example block output

This page was generated using Literate.jl.