Contextual Stochastic Argmax
Select the best item from a set of n items with stochastic utilities: each scenario draws a different utility vector, but utilities depend on observable context features. This is a toy benchmark designed so that a linear model can exactly recover the optimal context-to-utility mapping.
using DecisionFocusedLearningBenchmarks
using Plots
b = ContextualStochasticArgmaxBenchmark()ContextualStochasticArgmaxBenchmark{Matrix{Float32}}(10, 5, Float32[1.4298681 -1.194263 … -0.8662849 1.0613704; 1.1314367 -1.6595181 … -0.8292305 -1.1606448; … ; 0.10203307 0.35854763 … -1.5413508 0.20959873; 0.03215548 0.22359875 … -0.19313496 -0.45557913], 0.1f0)By default, generate_dataset returns unlabeled samples (y = nothing) for this benchmark. A target_policy must be provided to attach labels. Here we use the anticipative oracle: it returns the item with the highest realized utility for each scenario, giving one labeled sample per scenario per instance.
anticipative = generate_anticipative_solver(b)
policy =
(ctx, scenarios) ->
[DataSample(ctx; y=anticipative(ξ), extra=(; scenario=ξ)) for ξ in scenarios]
dataset = generate_dataset(b, 20; target_policy=policy, seed=0)
sample = first(dataset)DataSample(x=Float32[0.409521, 0.951421, 0.972709, 0.189452, 0.582645, 0.987582, 0.751793, 0.615489, 0.0283986, 0.865268, -1.17424, -0.859058, 0.816649, -1.3611, 1.12722], y=Float32[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], c_base=Float32[0.409521, 0.951421, 0.972709, 0.189452, 0.582645, 0.987582, 0.751793, 0.615489, 0.0283986, 0.865268], x_raw=Float32[-1.17424, -0.859058, 0.816649, -1.3611, 1.12722], scenario=Float32[2.53292, 0.357118, 0.951917, 3.99695, -3.18122, -2.24389, 2.82202, -1.02555, 1.79007, 0.236807])Observable input
At inference time the model observes x = [c_base; x_raw]. plot_context shows both components: base utilities c_base (left) and context features x_raw (right):
plot_context(b, sample)A training sample
Stochastic benchmarks have no single ground-truth label: the optimal item depends on which utility scenario is realized. We label each sample with the anticipative oracle, which returns the best item given the realized scenario ξ.
Each labeled sample contains:
x: feature vector[c_base; x_raw](observable at train and test time)y: optimal item for the realized scenario ξ (one-hot; anticipative oracle label)extra.scenario: realized utility vector ξ (available only during training)
Top: feature vector x. Bottom: realized scenario ξ acting as the cost vector, with the anticipative-optimal item in red:
plot_sample(b, DataSample(sample; θ=sample.scenario))Untrained policy
A DFL policy chains two components: a statistical model predicting expected item utilities:
model = generate_statistical_model(b) # linear map: features → predicted expected utilitiesDense(15 => 10; bias=false) # 150 parametersand a maximizer selecting the item with the highest predicted utility:
maximizer = generate_maximizer(b) # one-hot argmaxone_hot_argmax (generic function with 1 method)A randomly initialized policy selects items with no relation to their expected utilities. Top: feature vector x. Bottom: predicted utilities θ̂ with the selected item in red:
θ_pred = model(sample.x)
y_pred = maximizer(θ_pred)
plot_sample(b, DataSample(sample; θ=θ_pred, y=y_pred))Problem Description
Overview
In the Contextual Stochastic Argmax benchmark, $n$ items have random utilities that depend on observable context. Per instance:
- $c_\text{base} \sim U[0,1]^n$: base utilities (stored in
context) - $x_\text{raw} \sim \mathcal{N}(0, I_d)$: observable context features
- Full features: $x = [c_\text{base}; x_\text{raw}] \in \mathbb{R}^{n+d}$
The realized utility (scenario) is drawn as:
\[\xi = c_\text{base} + W \, x_\text{raw} + \varepsilon, \quad \varepsilon \sim \mathcal{N}(0, \sigma^2 I)\]
where $W \in \mathbb{R}^{n \times d}$ is a fixed unknown perturbation matrix.
The task is to select the item with the highest realized utility:
\[y^* = \mathrm{argmax}(\xi)\]
A linear model $\theta = [I \mid W] \cdot x$ can exactly recover the optimal solution in expectation.
Key Parameters
| Parameter | Description | Default |
|---|---|---|
n | Number of items | 10 |
d | Context feature dimension | 5 |
noise_std | Noise standard deviation σ | 0.1 |
Baseline Policies
- SAA: selects the item with highest mean utility over available scenarios.
DFL Policy
\[\xrightarrow[\text{Features}]{x = [c_\text{base}; x_\text{raw}]} \fbox{Linear model} \xrightarrow{\theta \in \mathbb{R}^n} \fbox{argmax} \xrightarrow{y}\]
Model: Dense(n+d → n; bias=false): can in principle recover the exact mapping $[I \mid W]$ from training data.
Maximizer: one_hot_argmax.
This page was generated using Literate.jl.