1+ # ' @title Deep Learning Tuning Spaces from Yandex's RTDL
2+ # '
3+ # ' @name mlr_tuning_spaces_rtdl
4+ # '
5+ # ' @description
6+ # ' Tuning spaces for deep neural network architectures from the `r cite_bib("gorishniy2021revisiting")` article.
7+ # '
8+ # ' These tuning spaces require optimizers that have a `weight_decay` parameter, such as AdamW or any of the other optimizers built into `mlr3torch`.
9+ # '
10+ # ' When the article suggests multiple ranges for a given hyperparameter, these tuning spaces choose the widest range.
11+ # '
12+ # ' The FT-Transformer tuning space disables weight decay for all bias parameters, matching the implementation provided by the authors in the rtdl-revisiting-models package.
13+ # ' However, this differs from the experiments described in the article, which states that the
14+ # '
15+ # ' For the FT-Transformer, if training is unstable, consider a combination of standardizing features, using an adaptive optimizer (e.g. Adam), reducing the learning rate,
16+ # ' and using a learning rate scheduler.
17+ # '
18+ # ' @source
19+ # ' `r format_bib("gorishniy2021revisiting")`
20+ # '
21+ # ' @aliases
22+ # ' mlr_tuning_spaces_classif.mlp.rtdl
23+ # ' mlr_tuning_spaces_classif.tab_resnet.rtdl
24+ # ' mlr_tuning_spaces_classif.ft_transformer.rtdl
25+ # ' mlr_tuning_spaces_regr.mlp.rtdl
26+ # ' mlr_tuning_spaces_regr.tab_resnet.rtdl
27+ # ' mlr_tuning_spaces_regr.ft_transformer.rtdl
28+ # '
29+ # ' @section MLP tuning space:
30+ # ' `r rd_info(lts("classif.mlp.rtdl"))`
31+ # '
32+ # ' @section Tabular ResNet tuning space:
33+ # ' `r rd_info(lts("classif.tab_resnet.rtdl"))`
34+ # '
35+ # ' @section FT-Transformer tuning space:
36+ # ' `r rd_info(lts("classif.ft_transformer.rtdl"))`
37+ # '
38+ # ' In the FT-Transformer, the validation-related parameters must still be set manually, via e.g. `lts("regr.ft_transformer.rtdl")$get_learner(validate = 0.2, measures_valid = msr("regr.rmse"))`.
39+ # '
40+ # ' @include mlr_tuning_spaces.R
41+ NULL
42+
43+ # mlp
44+ vals = list (
45+ n_layers = to_tune(1 , 16 ),
46+ neurons = to_tune(levels = 1 : 1024 ),
47+ p = to_tune(0 , 0.5 ),
48+ opt.lr = to_tune(1e-5 , 1e-2 , logscale = TRUE ),
49+ opt.weight_decay = to_tune(1e-6 , 1e-3 , logscale = TRUE ),
50+ epochs = to_tune(lower = 1L , upper = 100L , internal = TRUE ),
51+ patience = 17L
52+ )
53+
54+ add_tuning_space(
55+ id = " classif.mlp.rtdl" ,
56+ values = vals ,
57+ tags = c(" gorishniy2021" , " classification" ),
58+ learner = " classif.mlp" ,
59+ package = " mlr3torch" ,
60+ label = " Classification MLP with RTDL"
61+ )
62+
63+ add_tuning_space(
64+ id = " regr.mlp.rtdl" ,
65+ values = vals ,
66+ tags = c(" gorishniy2021" , " regression" ),
67+ learner = " regr.mlp" ,
68+ package = " mlr3torch" ,
69+ label = " Regression MLP with RTDL"
70+ )
71+
72+ # resnet
73+ vals = list (
74+ n_blocks = to_tune(1 , 16 ),
75+ d_block = to_tune(64 , 1024 ),
76+ d_hidden_multiplier = to_tune(1 , 4 ),
77+ dropout1 = to_tune(0 , 0.5 ),
78+ dropout2 = to_tune(0 , 0.5 ),
79+ opt.lr = to_tune(1e-5 , 1e-2 , logscale = TRUE ),
80+ opt.weight_decay = to_tune(1e-6 , 1e-3 , logscale = TRUE ),
81+ epochs = to_tune(lower = 1L , upper = 100L , internal = TRUE ),
82+ patience = 17L
83+ )
84+
85+ add_tuning_space(
86+ id = " classif.tab_resnet.rtdl" ,
87+ values = vals ,
88+ tags = c(" gorishniy2021" , " classification" ),
89+ learner = " classif.tab_resnet" ,
90+ package = " mlr3torch" ,
91+ label = " Classification Tabular ResNet with RTDL"
92+ )
93+
94+ add_tuning_space(
95+ id = " regr.tab_resnet.rtdl" ,
96+ values = vals ,
97+ tags = c(" gorishniy2021" , " regression" ),
98+ learner = " regr.tab_resnet" ,
99+ package = " mlr3torch" ,
100+ label = " Regression Tabular ResNet with RTDL"
101+ )
102+
103+ no_wd = function (name ) {
104+ # this will also disable weight decay for the input projection bias of the attention heads
105+ no_wd_params = c(" _normalization" , " bias" )
106+
107+ return (any(map_lgl(no_wd_params , function (pattern ) grepl(pattern , name , fixed = TRUE ))))
108+ }
109+
110+ rtdl_param_groups = function (parameters ) {
111+ split_param_names = strsplit(names(parameters ), " ." , fixed = TRUE )
112+
113+ ffn_norm_idx = grepl(" ffn_normalization" , names(parameters ), fixed = TRUE )
114+ first_ffn_norm_num_in_module_list = as.integer(split_param_names [ffn_norm_idx ][[1 ]][2 ])
115+ cls_num_in_module_list = first_ffn_norm_num_in_module_list - 1
116+ nums_in_module_list = sapply(split_param_names , function (x ) as.integer(x [2 ]))
117+ tokenizer_idx = nums_in_module_list < cls_num_in_module_list
118+
119+ # the last normalization layer is unnamed, so we need to find it based on its position in the module list
120+ last_module_num_in_module_list = as.integer(split_param_names [[length(split_param_names )]][2 ])
121+ last_norm_num_in_module_list = last_module_num_in_module_list - 2
122+ last_norm_idx = nums_in_module_list == last_norm_num_in_module_list
123+
124+ no_wd_idx = map_lgl(names(parameters ), no_wd ) | tokenizer_idx | last_norm_idx
125+ no_wd_group = parameters [no_wd_idx ]
126+
127+ main_group = parameters [! no_wd_idx ]
128+
129+ list (
130+ list (params = main_group ),
131+ list (params = no_wd_group , weight_decay = 0 )
132+ )
133+ }
134+
135+ # ft_transformer
136+ vals = list (
137+ n_blocks = to_tune(1 , 6 ),
138+ d_token = to_tune(p_int(8L , 64L , trafo = function (x ) 8L * x )),
139+ attention_n_heads = 8L ,
140+ residual_dropout = to_tune(0 , 0.2 ),
141+ attention_dropout = to_tune(0 , 0.5 ),
142+ ffn_dropout = to_tune(0 , 0.5 ),
143+ ffn_d_hidden_multiplier = to_tune(2 / 3 , 8 / 3 ),
144+ opt.lr = to_tune(1e-5 , 1e-4 , logscale = TRUE ),
145+ opt.weight_decay = to_tune(1e-6 , 1e-3 , logscale = TRUE ),
146+ opt.param_groups = rtdl_param_groups ,
147+ epochs = to_tune(lower = 1L , upper = 100L , internal = TRUE ),
148+ patience = 17L
149+ )
150+
151+ add_tuning_space(
152+ id = " classif.ft_transformer.rtdl" ,
153+ values = vals ,
154+ tags = c(" gorishniy2021" , " classification" ),
155+ learner = " classif.ft_transformer" ,
156+ package = " mlr3torch" ,
157+ label = " Classification FT-Transformer with RTDL"
158+ )
159+
160+ add_tuning_space(
161+ id = " regr.ft_transformer.rtdl" ,
162+ values = vals ,
163+ tags = c(" gorishniy2021" , " regression" ),
164+ learner = " regr.ft_transformer" ,
165+ package = " mlr3torch" ,
166+ label = " Regression FT-Transformer with RTDL"
167+ )
0 commit comments