In the paper, we derive lower bounds of the linearized Laplace approximation to the marginal likelihood that enable SGD-based hyperparameter optimization. The corresponding estimators and experiments are available in this repository.
Stochastic Marginal Likelihood Gradients using Neural Tangent Kernels.
Alexander Immer, Tycho F.A. van der Ouderaa, Mark van der Wilk, Gunnar Rätsch, Bernhard Schölkopf.
In proceedings of ICML 2023.
| Existing parametric bounds | NTK-based stochastic bounds |
|---|---|
We use python>=3.9 and rely on pytorch for the experiments.
The basic dependencies are in requirements.txt but might have to be adjusted depending on GPU or CUDA support in the case of torch.
The proposed marginal likelihood estimators are implemented in dependencies/laplace and dependencies/asdl and are forks of the respective packages laplace-torch and asdl with modifications for the NTK and lower-bound linearized Laplace marginal likelihood approximations as well as differentiability in asdl.
To install these, move into dependencies/laplace and /asdl and install locally with pip install ..
The experiments, with the exception for the illustrated bounds, rely on wandb for tracking and collecting results and might have to be set up separately (see bottom of main runner classification_image.py).
The commands to reproduce individual experiments are:
scripts/bound_grid_commands.shcontains commands to compute the slack of bounds for different subset (minibatch) sizesscripts/generate_bound_commands.py > scripts/bound_commands.shgenerates all online visualizations of the bound displayed in the appendix as well es the timing commands displayed in the Pareto figurescripts/generate_cifar_commands.py > scripts/cifar_commands.shgenerates the commands for the CIFAR-10 and -100 table without invariance learningscripts/cifar_commands_lila.share the commands for CIFAR with invariance learning (lila)scripts/generate_tiny_commands.py> scripts/tiny_commands.shgenerates the commands for the TinyImageNet experiments
To produce plots, we download the results from wandb so line 15 in generate_illustration_figures.py needs to be adjusted to the individual wandb account.
The commands in the main function can be used selectively to produce plots and, by default, produce all of them given that all results are present in wandb.