Skip to content

Commit e8addaa

Browse files
authored
feat(callbacks): unfreeze parameters (#303)
1 parent 79894bc commit e8addaa

25 files changed

+592
-0
lines changed

.Rbuildignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,4 @@ data-raw
3434
^doc$
3535
^Meta$
3636
^CRAN-SUBMISSION$
37+
^paper$

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,7 @@ inst/doc
1414
/doc/
1515
/Meta/
1616
CRAN-SUBMISSION
17+
paper/data
18+
.idea/
19+
.vsc/
1720
paper/data

DESCRIPTION

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ Collate:
8787
'CallbackSetHistory.R'
8888
'CallbackSetProgress.R'
8989
'CallbackSetTB.R'
90+
'CallbackSetUnfreeze.R'
9091
'ContextTorch.R'
9192
'DataBackendLazy.R'
9293
'utils.R'
@@ -125,6 +126,7 @@ Collate:
125126
'PipeOpTorchOptimizer.R'
126127
'PipeOpTorchReshape.R'
127128
'PipeOpTorchSoftmax.R'
129+
'Select.R'
128130
'TaskClassif_lazy_iris.R'
129131
'TaskClassif_melanoma.R'
130132
'TaskClassif_mnist.R'

NAMESPACE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ S3method(materialize,data.frame)
3737
S3method(materialize,lazy_tensor)
3838
S3method(materialize,list)
3939
S3method(print,ModelDescriptor)
40+
S3method(print,Select)
4041
S3method(print,TorchIngressToken)
4142
S3method(print,lazy_tensor)
4243
S3method(rep,lazy_tensor)
@@ -64,6 +65,7 @@ export(CallbackSetCheckpoint)
6465
export(CallbackSetHistory)
6566
export(CallbackSetProgress)
6667
export(CallbackSetTB)
68+
export(CallbackSetUnfreeze)
6769
export(ContextTorch)
6870
export(DataBackendLazy)
6971
export(DataDescriptor)
@@ -176,6 +178,11 @@ export(nn_squeeze)
176178
export(nn_unsqueeze)
177179
export(pipeop_preproc_torch)
178180
export(replace_head)
181+
export(select_all)
182+
export(select_grep)
183+
export(select_invert)
184+
export(select_name)
185+
export(select_none)
179186
export(t_clbk)
180187
export(t_clbks)
181188
export(t_loss)

R/CallbackSetUnfreeze.R

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#' @title Unfreezing Weights Callback
2+
#'
3+
#' @name mlr_callback_set.unfreeze
4+
#'
5+
#' @description
6+
#' Unfreeze some weights (parameters of the network) after some number of steps or epochs.
7+
#'
8+
#' @param starting_weights (`Select`)\cr
9+
#' A `Select` denoting the weights that are trainable from the start.
10+
#' @param unfreeze (`data.table`)\cr
11+
#' A `data.table` with a column `weights` (a list column of `Select`s) and a column `epoch` or `batch`.
12+
#' The selector indicates which parameters to unfreeze, while the `epoch` or `batch` column indicates when to do so.
13+
#'
14+
#' @family Callback
15+
#' @export
16+
#' @include CallbackSet.R
17+
CallbackSetUnfreeze = R6Class("CallbackSetUnfreeze",
18+
inherit = CallbackSet,
19+
lock_objects = FALSE,
20+
public = list(
21+
#' @description
22+
#' Creates a new instance of this [R6][R6::R6Class] class.
23+
initialize = function(starting_weights, unfreeze) {
24+
self$starting_weights = starting_weights
25+
self$unfreeze = unfreeze
26+
private$.batchwise = "batch" %in% names(self$unfreeze)
27+
},
28+
#' @description
29+
#' Sets the starting weights
30+
on_begin = function() {
31+
trainable_weights = self$starting_weights(names(self$ctx$network$parameters))
32+
walk(self$ctx$network$parameters[trainable_weights], function(param) param$requires_grad_(TRUE))
33+
frozen_weights = select_invert(self$starting_weights)(names(self$ctx$network$parameters))
34+
walk(self$ctx$network$parameters[frozen_weights], function(param) param$requires_grad_(FALSE))
35+
36+
frozen_weights_str = paste(trainable_weights, collapse = ", ")
37+
lg$info(paste0("Training the following weights at the start: ", trainable_weights))
38+
},
39+
#' @description
40+
#' Unfreezes weights if the training is at the correct epoch
41+
on_epoch_begin = function() {
42+
if (!private$.batchwise) {
43+
if (self$ctx$epoch %in% self$unfreeze$epoch) {
44+
weights = (self$unfreeze[get("epoch") == self$ctx$epoch]$weights)[[1]](names(self$ctx$network$parameters))
45+
if (!length(weights)) {
46+
lg$warn(paste0("No weights unfrozen at epoch ", self$ctx$epoch, " , check the specification of the Selector"))
47+
} else {
48+
walk(self$ctx$network$parameters[weights], function(param) param$requires_grad_(TRUE))
49+
weights_str = paste(weights, collapse = ", ")
50+
lg$info(paste0("Unfreezing at epoch ", self$ctx$epoch, ": ", weights_str))
51+
}
52+
53+
}
54+
}
55+
},
56+
#' @description
57+
#' Unfreezes weights if the training is at the correct batch
58+
on_batch_begin = function() {
59+
if (private$.batchwise) {
60+
batch_num = (self$ctx$epoch - 1) * length(self$ctx$loader_train) + self$ctx$step
61+
if (batch_num %in% self$unfreeze$batch) {
62+
weights = (self$unfreeze[get("batch") == batch_num]$weights)[[1]](names(self$ctx$network$parameters))
63+
if (!length(weights)) {
64+
lg$warn(paste0("No weights unfrozen at batch ", batch_num, " , check the specification of the Selector"))
65+
} else {
66+
walk(self$ctx$network$parameters[weights], function(param) param$requires_grad_(TRUE))
67+
weights_str = paste(weights, collapse = ", ")
68+
lg$info(paste0("Unfreezing at batch ", batch_num, ": ", weights_str))
69+
}
70+
}
71+
}
72+
}
73+
)
74+
)
75+
76+
#' @include TorchCallback.R
77+
mlr3torch_callbacks$add("unfreeze", function() {
78+
TorchCallback$new(
79+
callback_generator = CallbackSetUnfreeze,
80+
param_set = ps(
81+
starting_weights = p_uty(
82+
tags = c("train", "required"),
83+
custom_check = function(input) check_class(input, "Select")
84+
),
85+
unfreeze = p_uty(
86+
tags = c("train", "required"),
87+
custom_check = check_unfreeze_dt
88+
)
89+
),
90+
id = "unfreeze",
91+
label = "Unfreeze",
92+
man = "mlr3torch::mlr_callback_set.unfreeze"
93+
)
94+
})
95+
96+
check_unfreeze_dt = function(x) {
97+
if (is.null(x) || (is.data.table(x) && nrow(x) == 0)) {
98+
return(TRUE)
99+
}
100+
if (!test_class(x, "data.table")) {
101+
return("`unfreeze` must be a data.table()")
102+
}
103+
if (!test_names(names(x), must.include = "weights")) {
104+
return("Must contain 2 columns: `weights` and (epoch or batch)")
105+
}
106+
if (!xor("epoch" %in% names(x), "batch" %in% names(x))) {
107+
return("Exactly one of the columns must be named 'epoch' or 'batch'")
108+
}
109+
xs = x[["epoch"]] %??% x[["batch"]]
110+
if (!test_integerish(xs, lower = 0L) || anyDuplicated(xs)) {
111+
return("Column batch/epoch must be a positive integerish vector without duplicates.")
112+
}
113+
if (!test_list(x$weights)) {
114+
return("The `weights` column should be a list")
115+
}
116+
if (some(x$weights, function(input) !test_class(input, classes = "Select"))) {
117+
return("The `weights` column should be a list of Selects")
118+
}
119+
return(TRUE)
120+
}

R/Select.R

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
#' @title Selector Functions for Character Vectors
2+
#'
3+
#' @name Select
4+
#'
5+
#' @description
6+
#' A [`Select`] function subsets a character vector. They are used by the callback `CallbackSetUnfreeze` to select parameters to freeze or unfreeze during training.
7+
NULL
8+
9+
make_select = function(fun, description, ...) {
10+
structure(fun,
11+
repr = sprintf(description, ...),
12+
class = c("Select", "function")
13+
)
14+
}
15+
16+
#' @describeIn Select `select_all` selects all elements
17+
#' @export
18+
#' @examples
19+
#' select_all()(c("a", "b"))
20+
select_all = function() {
21+
make_select(function(param_names) {
22+
param_names
23+
}, "select_all()")
24+
}
25+
26+
#' @describeIn Select `select_none` selects no elements
27+
#' @export
28+
#' @examples
29+
#' select_none()(c("a", "b"))
30+
select_none = function() {
31+
make_select(function(param_names) {
32+
character(0)
33+
}, "select_none()")
34+
}
35+
36+
#' @describeIn Select `select_grep` selects elements with names matching a regular expression
37+
#' @param pattern See `grep()`
38+
#' @param ignore.case See `grep()`
39+
#' @param perl See `grep()`
40+
#' @param fixed See `grep()`
41+
#' @export
42+
#' @examples
43+
#' select_grep("b$")(c("ab", "ac"))
44+
select_grep = function(pattern, ignore.case = FALSE, perl = FALSE, fixed = FALSE) {
45+
assert_character(pattern)
46+
assert_flag(ignore.case)
47+
assert_flag(perl)
48+
assert_flag(fixed)
49+
str_ignore_case = if (ignore.case) ", ignore.case = TRUE" else ""
50+
str_perl = if (perl) ", perl = TRUE" else ""
51+
str_fixed = if (fixed) ", fixed = TRUE" else ""
52+
make_select(function(param_names) {
53+
grep(pattern, param_names, ignore.case = ignore.case, perl = perl, fixed = fixed, value = TRUE)
54+
}, "selector_grep(%s%s%s%s)", pattern, str_ignore_case, str_perl, str_fixed)
55+
}
56+
57+
#' @describeIn Select `select_name` selects elements with names matching the given names
58+
#' @param param_names The names of the parameters that you want to select
59+
#' @param assert_present Whether to check that `param_names` is a subset of the full vector of names
60+
#' @export
61+
#' @examples
62+
#' select_name("a")(c("a", "b"))
63+
select_name = function(param_names, assert_present = TRUE) {
64+
assert_character(param_names, any.missing = FALSE)
65+
assert_flag(assert_present)
66+
str_assert_present = if (assert_present) ", assert_present = TRUE" else ""
67+
make_select(function(full_names) {
68+
if (assert_present) {
69+
assert_subset(param_names, full_names)
70+
}
71+
intersect(full_names, param_names)
72+
}, "select_name(%s%s)", char_repr(param_names), str_assert_present)
73+
}
74+
75+
#' @describeIn Select `select_invert` selects the elements NOT selected by the given selector
76+
#' @param select A `Select`
77+
#' @export
78+
#' @examples
79+
#' select_invert(select_all())(c("a", "b"))
80+
select_invert = function(select) {
81+
assert_function(select)
82+
make_select(function(full_names) {
83+
setdiff(full_names, select(full_names))
84+
}, "select_invert(%s)", select_repr(select))
85+
}
86+
87+
# copied from mlr3pipelines
88+
# Representation of character vector
89+
# letters[1] --> '"a"'
90+
# letters[1:2] --> 'c("a", "b")'
91+
char_repr = function(x) {
92+
output = str_collapse(x, sep = ", ", quote = '"')
93+
if (length(x) == 0) {
94+
"character(0)"
95+
} else if (length(x) == 1) {
96+
output
97+
} else {
98+
sprintf("c(%s)", output)
99+
}
100+
}
101+
102+
# copied from mlr3pipelines
103+
# Representation for a function that may or may not be a `Select`.
104+
# If it is not, we just use deparse(), otherwise we use the repr as
105+
# reported by that selector.
106+
select_repr = function(select) {
107+
if (test_string(attr(select, "repr"))) {
108+
attr(select, "repr")
109+
} else {
110+
str_collapse(deparse(select), sep = "\n")
111+
}
112+
}
113+
114+
# copied from mlr3pipelines
115+
#' @export
116+
print.Select = function(x, ...) {
117+
if (inherits(x, "R6")) return(NextMethod("print"))
118+
cat(paste0(attr(x, "repr"), "\n"))
119+
invisible(x)
120+
}

attic/try-CallbackSetUnfreeze.R

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
devtools::load_all()
2+
3+
task = tsk("iris")
4+
5+
mlp = lrn("classif.mlp",
6+
epochs = 10, batch_size = 150, neurons = c(100, 200, 300)
7+
)
8+
9+
# sela = selector_all()
10+
# sela(mlp$network$modules)
11+
12+
mlp$train(task)
13+
14+
# do this for each element in the parameters list
15+
mlp$model$network$modules[["9"]]$parameters[[1]]$requires_grad_(TRUE)
16+
mlp$model$network$modules[["9"]]$parameters[[2]]$requires_grad_(TRUE)
17+
18+
19+
# construct a NN as a graph
20+
module_1 = nn_linear(in_features = 3, out_features = 4, bias = TRUE)
21+
activation = nn_sigmoid()
22+
module_2 = nn_linear(4, 3, bias = TRUE)
23+
softmax = nn_softmax(2)
24+
25+
po_module_1 = po("module_1", module = module_1)
26+
po_activation = po("module", id = "activation", activation)
27+
po_module_2 = po("module_2", module = module_2)
28+
po_softmax = po("module", id = "softmax", module = softmax)
29+
30+
module_graph = po_module_1 %>>%
31+
po_activation %>>%
32+
po_module_2 %>>%
33+
po_softmax
34+
35+
module_graph$plot(html = TRUE)
36+
37+
module_graph

attic/try-Select.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
n_epochs = 10
3+
4+
task = tsk("iris")
5+
6+
mlp = lrn("classif.mlp",
7+
epochs = 10, batch_size = 150, neurons = c(100, 200, 300)
8+
)
9+
mlp$train(task)
10+
11+
names(mlp$network$parameters)
12+
13+
sela = select_all()
14+
sela(names(mlp$network$parameters))
15+
16+
selg = select_grep("weight")
17+
selg(names(mlp$network$parameters))
18+
19+
seln = select_name("0.weight")
20+
seln(names(mlp$network$parameters))
21+
22+
seli = select_invert(select_name("0.weight"))
23+
seli(names(mlp$network$parameters))
24+
25+
seln = select_none()
26+
seln(names(mlp$network$parameters))
27+

0 commit comments

Comments
 (0)