Skip to content

Commit 9147e5b

Browse files
authored
feat(dl_tuning_spaces): Add tuning spaces from Gorishniy 2021 "Revisiting" article (#61)
* init * readme * Formatting * looks ok * correct mlr3torch package version * decrement version: should fail * manually add rtdl to _pkgdown.yml * uncomment helper * as.data.table
1 parent 9371c47 commit 9147e5b

File tree

10 files changed

+315
-30
lines changed

10 files changed

+315
-30
lines changed

DESCRIPTION

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ Suggests:
3535
ranger (>= 0.12.1),
3636
rpart (>= 4.1-15),
3737
testthat (>= 3.0.0),
38-
xgboost (>= 1.4.1.1)
38+
xgboost (>= 1.4.1.1),
39+
torch (>= 0.15.0),
40+
mlr3torch (>= 0.3)
3941
Config/testthat/edition: 3
4042
Encoding: UTF-8
4143
Roxygen: list(markdown = TRUE)
@@ -48,4 +50,5 @@ Collate:
4850
'tuning_spaces_default.R'
4951
'tuning_spaces_rbv1.R'
5052
'tuning_spaces_rbv2.R'
53+
'tuning_spaces_rtdl.R'
5154
'zzz.R'

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# mlr3tuningspaces (development version)
2+
* feat: Added tuning spaces for deep neural networks from the Gorishniy, Rubachev, Khrulkov, Babenko (2021) article.
23

34
# mlr3tuningspaces 0.6.0
45

R/TuningSpace.R

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,20 +169,37 @@ add_tuning_space = function(id, values, tags, learner, package = character(), la
169169
mlr_tuning_spaces$add(id, tuning_space)
170170
}
171171

172+
# rd_format_misc = function(lower, upper) {
173+
# if (is.na(lower) || is.na(upper) || is.null(lower) || is.null(upper)) return("-")
174+
175+
# str = sprintf("%s%s, %s%s",
176+
# if (is.finite(lower)) "[" else "(",
177+
# if (is.finite(lower)) c(lower, lower) else c("-\\infty", "-Inf"),
178+
# if (is.finite(upper)) c(upper, upper) else c("\\infty", "Inf"),
179+
# if (is.finite(upper)) "]" else ")")
180+
# paste0("\\eqn{", str[1L], "}{", str[2L], "}")
181+
# }
182+
172183
#' @export
173184
rd_info.TuningSpace = function(obj, ...) { # nolint
174185
require_namespaces(obj$package)
175186
ps = lrn(obj$learner)$param_set
176187
x = c("",
177188
imap_chr(obj$values, function(space, name) {
178-
switch(ps$params[name, , on = "id"]$cls,
179-
"ParamLgl" = sprintf("* %s \\[%s\\]", name, as_short_string(space$content$levels[[1]])),
180-
"ParamFct" = sprintf("* %s \\[%s\\]", name, rd_format_string(space$content$levels[[1]])),
181-
{lower = c(space$content$param$lower, space$content$lower) # one is NULL
182-
upper = c(space$content$upper, space$content$param$upper)
183-
logscale = if (is.null(space$content$logscale) || !space$content$logscale) character(1) else "Logscale"
184-
sprintf("* %s %s %s", name, rd_format_range(lower, upper), logscale)}
185-
)
189+
if (is.atomic(space)) {
190+
sprintf("* %s %s", name, space)
191+
} else if ("TuneToken" %nin% class(space)) {
192+
sprintf("* %s -", name)
193+
} else {
194+
switch(ps$params[name, , on = "id"]$cls,
195+
"ParamLgl" = sprintf("* %s \\[%s\\]", name, as_short_string(space$content$levels[[1]])),
196+
"ParamFct" = sprintf("* %s \\[%s\\]", name, rd_format_string(space$content$levels[[1]])),
197+
{lower = c(space$content$param$lower, space$content$lower) # one is NULL
198+
upper = c(space$content$upper, space$content$param$upper)
199+
logscale = if (is.null(space$content$logscale) || !space$content$logscale) character(1) else "Logscale"
200+
sprintf("* %s %s %s", name, rd_format_range(lower, upper), logscale)}
201+
)
202+
}
186203
})
187204
)
188205
paste(x, collapse = "\n")
@@ -199,6 +216,10 @@ as.data.table.TuningSpace = function(x, ...) {
199216
if (test_class(value, "ObjectTuneToken")) {
200217
# old paradox: value$content$param
201218
as.data.table(value$content$param %??% value$content)[, c("lower", "upper", "levels")]
219+
} else if (is.atomic(value)) {
220+
data.table(lower = NA, upper = NA, levels = NA, logscale = FALSE)
221+
} else if (is.function(value)) {
222+
data.table(lower = NA, upper = NA, levels = NA, logscale = FALSE)
202223
} else {
203224
as.data.table(value$content)
204225
}
@@ -207,4 +228,4 @@ as.data.table.TuningSpace = function(x, ...) {
207228
setcolorder(tab, intersect(c("id", "lower", "upper", "levels", "logscale"), names(tab)))
208229
if ("logscale" %in% names(tab)) tab[is.na(get("logscale")), "logscale" := FALSE]
209230
tab[]
210-
}
231+
}

R/bibentries.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,13 @@ bibentries = c(
3131
booktitle = "Proceedings of the 7th ICML Workshop on Automated Machine Learning (AutoML 2020)",
3232
date = "2020",
3333
url = "https://www.automl.org/wp-content/uploads/2020/07/AutoML_2020_paper_63.pdf"
34+
),
35+
36+
gorishniy2021revisiting = bibentry("article",
37+
title = "Revisiting Deep Learning for Tabular Data",
38+
author = "Yury Gorishniy and Ivan Rubachev and Valentin Khrulkov and Artem Babenko",
39+
journal = "arXiv",
40+
volume = "2106.11959",
41+
year = "2021",
3442
)
35-
)
43+
)

R/tuning_spaces_rtdl.R

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#' @title Deep Learning Tuning Spaces from Yandex's RTDL
2+
#'
3+
#' @name mlr_tuning_spaces_rtdl
4+
#'
5+
#' @description
6+
#' Tuning spaces for deep neural network architectures from the `r cite_bib("gorishniy2021revisiting")` article.
7+
#'
8+
#' These tuning spaces require optimizers that have a `weight_decay` parameter, such as AdamW or any of the other optimizers built into `mlr3torch`.
9+
#'
10+
#' When the article suggests multiple ranges for a given hyperparameter, these tuning spaces choose the widest range.
11+
#'
12+
#' The FT-Transformer tuning space disables weight decay for all bias parameters, matching the implementation provided by the authors in the rtdl-revisiting-models package.
13+
#' However, this differs from the experiments described in the article, which states that the
14+
#'
15+
#' For the FT-Transformer, if training is unstable, consider a combination of standardizing features, using an adaptive optimizer (e.g. Adam), reducing the learning rate,
16+
#' and using a learning rate scheduler.
17+
#'
18+
#' @source
19+
#' `r format_bib("gorishniy2021revisiting")`
20+
#'
21+
#' @aliases
22+
#' mlr_tuning_spaces_classif.mlp.rtdl
23+
#' mlr_tuning_spaces_classif.tab_resnet.rtdl
24+
#' mlr_tuning_spaces_classif.ft_transformer.rtdl
25+
#' mlr_tuning_spaces_regr.mlp.rtdl
26+
#' mlr_tuning_spaces_regr.tab_resnet.rtdl
27+
#' mlr_tuning_spaces_regr.ft_transformer.rtdl
28+
#'
29+
#' @section MLP tuning space:
30+
#' `r rd_info(lts("classif.mlp.rtdl"))`
31+
#'
32+
#' @section Tabular ResNet tuning space:
33+
#' `r rd_info(lts("classif.tab_resnet.rtdl"))`
34+
#'
35+
#' @section FT-Transformer tuning space:
36+
#' `r rd_info(lts("classif.ft_transformer.rtdl"))`
37+
#'
38+
#' In the FT-Transformer, the validation-related parameters must still be set manually, via e.g. `lts("regr.ft_transformer.rtdl")$get_learner(validate = 0.2, measures_valid = msr("regr.rmse"))`.
39+
#'
40+
#' @include mlr_tuning_spaces.R
41+
NULL
42+
43+
# mlp
44+
vals = list(
45+
n_layers = to_tune(1, 16),
46+
neurons = to_tune(levels = 1:1024),
47+
p = to_tune(0, 0.5),
48+
opt.lr = to_tune(1e-5, 1e-2, logscale = TRUE),
49+
opt.weight_decay = to_tune(1e-6, 1e-3, logscale = TRUE),
50+
epochs = to_tune(lower = 1L, upper = 100L, internal = TRUE),
51+
patience = 17L
52+
)
53+
54+
add_tuning_space(
55+
id = "classif.mlp.rtdl",
56+
values = vals,
57+
tags = c("gorishniy2021", "classification"),
58+
learner = "classif.mlp",
59+
package = "mlr3torch",
60+
label = "Classification MLP with RTDL"
61+
)
62+
63+
add_tuning_space(
64+
id = "regr.mlp.rtdl",
65+
values = vals,
66+
tags = c("gorishniy2021", "regression"),
67+
learner = "regr.mlp",
68+
package = "mlr3torch",
69+
label = "Regression MLP with RTDL"
70+
)
71+
72+
# resnet
73+
vals = list(
74+
n_blocks = to_tune(1, 16),
75+
d_block = to_tune(64, 1024),
76+
d_hidden_multiplier = to_tune(1, 4),
77+
dropout1 = to_tune(0, 0.5),
78+
dropout2 = to_tune(0, 0.5),
79+
opt.lr = to_tune(1e-5, 1e-2, logscale = TRUE),
80+
opt.weight_decay = to_tune(1e-6, 1e-3, logscale = TRUE),
81+
epochs = to_tune(lower = 1L, upper = 100L, internal = TRUE),
82+
patience = 17L
83+
)
84+
85+
add_tuning_space(
86+
id = "classif.tab_resnet.rtdl",
87+
values = vals,
88+
tags = c("gorishniy2021", "classification"),
89+
learner = "classif.tab_resnet",
90+
package = "mlr3torch",
91+
label = "Classification Tabular ResNet with RTDL"
92+
)
93+
94+
add_tuning_space(
95+
id = "regr.tab_resnet.rtdl",
96+
values = vals,
97+
tags = c("gorishniy2021", "regression"),
98+
learner = "regr.tab_resnet",
99+
package = "mlr3torch",
100+
label = "Regression Tabular ResNet with RTDL"
101+
)
102+
103+
no_wd = function(name) {
104+
# this will also disable weight decay for the input projection bias of the attention heads
105+
no_wd_params = c("_normalization", "bias")
106+
107+
return(any(map_lgl(no_wd_params, function(pattern) grepl(pattern, name, fixed = TRUE))))
108+
}
109+
110+
rtdl_param_groups = function(parameters) {
111+
split_param_names = strsplit(names(parameters), ".", fixed = TRUE)
112+
113+
ffn_norm_idx = grepl("ffn_normalization", names(parameters), fixed = TRUE)
114+
first_ffn_norm_num_in_module_list = as.integer(split_param_names[ffn_norm_idx][[1]][2])
115+
cls_num_in_module_list = first_ffn_norm_num_in_module_list - 1
116+
nums_in_module_list = sapply(split_param_names, function(x) as.integer(x[2]))
117+
tokenizer_idx = nums_in_module_list < cls_num_in_module_list
118+
119+
# the last normalization layer is unnamed, so we need to find it based on its position in the module list
120+
last_module_num_in_module_list = as.integer(split_param_names[[length(split_param_names)]][2])
121+
last_norm_num_in_module_list = last_module_num_in_module_list - 2
122+
last_norm_idx = nums_in_module_list == last_norm_num_in_module_list
123+
124+
no_wd_idx = map_lgl(names(parameters), no_wd) | tokenizer_idx | last_norm_idx
125+
no_wd_group = parameters[no_wd_idx]
126+
127+
main_group = parameters[!no_wd_idx]
128+
129+
list(
130+
list(params = main_group),
131+
list(params = no_wd_group, weight_decay = 0)
132+
)
133+
}
134+
135+
# ft_transformer
136+
vals = list(
137+
n_blocks = to_tune(1, 6),
138+
d_token = to_tune(p_int(8L, 64L, trafo = function(x) 8L * x)),
139+
attention_n_heads = 8L,
140+
residual_dropout = to_tune(0, 0.2),
141+
attention_dropout = to_tune(0, 0.5),
142+
ffn_dropout = to_tune(0, 0.5),
143+
ffn_d_hidden_multiplier = to_tune(2 / 3, 8 / 3),
144+
opt.lr = to_tune(1e-5, 1e-4, logscale = TRUE),
145+
opt.weight_decay = to_tune(1e-6, 1e-3, logscale = TRUE),
146+
opt.param_groups = rtdl_param_groups,
147+
epochs = to_tune(lower = 1L, upper = 100L, internal = TRUE),
148+
patience = 17L
149+
)
150+
151+
add_tuning_space(
152+
id = "classif.ft_transformer.rtdl",
153+
values = vals,
154+
tags = c("gorishniy2021", "classification"),
155+
learner = "classif.ft_transformer",
156+
package = "mlr3torch",
157+
label = "Classification FT-Transformer with RTDL"
158+
)
159+
160+
add_tuning_space(
161+
id = "regr.ft_transformer.rtdl",
162+
values = vals,
163+
tags = c("gorishniy2021", "regression"),
164+
learner = "regr.ft_transformer",
165+
package = "mlr3torch",
166+
label = "Regression FT-Transformer with RTDL"
167+
)

README.md

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,29 @@ Status](https://www.r-pkg.org/badges/version-ago/mlr3tuningspaces)](https://cran
1717
optimization in the [mlr3](https://github.com/mlr-org/mlr3/) ecosystem.
1818
It features ready-to-use search spaces for many popular machine learning
1919
algorithms. The search spaces are from scientific articles and work for
20-
a wide range of data sets. Currently, we offer tuning spaces from three
20+
a wide range of data sets. Currently, we offer tuning spaces from four
2121
publications.
2222

23-
| Publication | Learner | n Hyperparameter |
24-
|--------------------------------------|---------|------------------|
25-
| Bischl et al. (2023) | glmnet | 2 |
26-
| | ranger | 4 |
27-
| | rpart | 3 |
28-
| | svm | 4 |
29-
| | xgboost | 8 |
30-
| Kuehn et al. (2018) | glmnet | 2 |
31-
| | ranger | 8 |
32-
| | rpart | 4 |
33-
| | svm | 5 |
34-
| | xgboost | 13 |
35-
| Binder, Pfisterer, and Bischl (2020) | glmnet | 2 |
36-
| | ranger | 6 |
37-
| | rpart | 4 |
38-
| | svm | 4 |
39-
| | xgboost | 10 |
23+
| Publication | Learner | n Hyperparameter |
24+
|----------------------------------------------------|----------------|-------------------|
25+
| Bischl et al. (2023) | glmnet | 2 |
26+
| | ranger | 4 |
27+
| | rpart | 3 |
28+
| | svm | 4 |
29+
| | xgboost | 8 |
30+
| Kuehn et al. (2018) | glmnet | 2 |
31+
| | ranger | 8 |
32+
| | rpart | 4 |
33+
| | svm | 5 |
34+
| | xgboost | 13 |
35+
| Binder, Pfisterer, and Bischl (2020) | glmnet | 2 |
36+
| | ranger | 6 |
37+
| | rpart | 4 |
38+
| | svm | 4 |
39+
| | xgboost | 10 |
40+
| Gorishniy, Rubachev, Khrulkov, and Babenko (2021) | mlp | 7 |
41+
| | tab_resnet | 9 |
42+
| | ft_transformer | 12 |
4043

4144
## Resources
4245

man/mlr_tuning_spaces_rtdl.Rd

Lines changed: 77 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)