Skip to content

Commit f861ee5

Browse files
authored
fix: Correct hash implementations for PipeOps (#372)
1 parent 575bc85 commit f861ee5

File tree

8 files changed

+34
-1
lines changed

8 files changed

+34
-1
lines changed

R/PipeOpTorchAdaptiveAvgPool.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ PipeOpTorchAdaptiveAvgPool = R6Class("PipeOpTorchAdaptiveAvgPool",
1818
}
1919
),
2020
private = list(
21+
.additional_phash_input = function() {
22+
list(private$.d)
23+
},
2124
.shapes_out = function(shapes_in, param_vals, task) {
2225
list(adaptive_avg_output_shape(
2326
shape_in = shapes_in[[1]],

R/PipeOpTorchAvgPool.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ PipeOpTorchAvgPool = R6Class("PipeOpTorchAvgPool",
3131
}
3232
),
3333
private = list(
34+
.additional_phash_input = function() {
35+
list(private$.d)
36+
},
3437
.shapes_out = function(shapes_in, param_vals, task) {
3538
list(avg_output_shape(
3639
shape_in = shapes_in[[1]],

R/PipeOpTorchBatchNorm.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ PipeOpTorchBatchNorm = R6Class("PipeOpTorchBatchNorm",
3030
private = list(
3131
.min_dim = NULL,
3232
.max_dim = NULL,
33+
.additional_phash_input = function() {
34+
list(private$.min_dim, private$.max_dim)
35+
},
3336
.shapes_out = function(shapes_in, param_vals, task) {
3437
list(assert_numeric(shapes_in[[1]], min.len = private$.min_dim, max.len = private$.max_dim))
3538
},
@@ -63,7 +66,6 @@ PipeOpTorchBatchNorm = R6Class("PipeOpTorchBatchNorm",
6366
#' @template pipeop_torch
6467
#' @template pipeop_torch_example
6568
#'
66-
#'
6769
#' @export
6870
PipeOpTorchBatchNorm1D = R6Class("PipeOpTorchBatchNorm1D", inherit = PipeOpTorchBatchNorm,
6971
public = list(

R/PipeOpTorchConv.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ PipeOpTorchConv = R6Class("PipeOpTorchConv",
2626
}
2727
),
2828
private = list(
29+
.additional_phash_input = function() {
30+
list(private$.d)
31+
},
2932
.shapes_out = function(shapes_in, param_vals, task) {
3033
list(conv_output_shape(
3134
shape_in = shapes_in[[1]],

R/PipeOpTorchConvTranspose.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ PipeOpTorchConvTranspose = R6Class("PipeOpTorchConvTranspose",
3535
}
3636
),
3737
private = list(
38+
.additional_phash_input = function() {
39+
list(private$.d)
40+
},
3841
.shapes_out = function(shapes_in, param_vals, task) {
3942
list(conv_transpose_output_shape(
4043
shape_in = shapes_in[[1]],

R/PipeOpTorchMaxPool.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ PipeOpTorchMaxPool = R6Class("PipeOpTorchMaxPool",
3131
}
3232
),
3333
private = list(
34+
.additional_phash_input = function() {
35+
list(d = private$.d)
36+
},
3437
.shapes_out = function(shapes_in, param_vals, task) {
3538
res = list(max_output_shape(
3639
shape_in = shapes_in[[1]],

R/PipeOpTorchMerge.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ PipeOpTorchMerge = R6Class("PipeOpTorchMerge",
4545
),
4646
private = list(
4747
.innum = NULL,
48+
.additional_phash_input = function() {
49+
list(private$.innum)
50+
},
4851
.shapes_out = function(shapes_in, param_vals, task) {
4952
# note that this slightly deviates from the actual broadcasting rules implemented by torch, i.e. we don't fill
5053
# up missing dimension with 1s because the first dimension is usually the batch dimension.

tests/testthat/helper_autotest.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ expect_pipeop_torch = function(graph, id, task, module_class = id, exclude_args
2828
result = graph$train(task)
2929
md = result[[1]]
3030

31+
# PipeOp overwrites hash and phash if necessary
32+
testthat::expect_warning(po_test$hash, regexp = NA)
33+
testthat::expect_warning(po_test$phash, regexp = NA)
34+
3135
modulegraph = md$graph
3236
po_module = modulegraph$pipeops[[id]]
3337
if (is.null(po_module$module)) {
@@ -325,6 +329,10 @@ expect_pipeop_torch_preprocess = function(obj, shapes_in, exclude = character(0)
325329
if (is.null(seed)) {
326330
seed = sample.int(100000, 1)
327331
}
332+
# overwrites .additional_phash_input where necessary
333+
testthat::expect_warning(obj$hash, regexp = NA)
334+
testthat::expect_warning(obj$phash, regexp = NA)
335+
328336
expect_pipeop(obj)
329337
expect_class(obj, "PipeOpTaskPreprocTorch")
330338
# a) Check that all parameters but stages have tags train and predict (this should hold in basically all cases)
@@ -423,6 +431,11 @@ expect_pipeop_torch_preprocess = function(obj, shapes_in, exclude = character(0)
423431

424432
expect_learner_torch = function(learner, task, check_man = TRUE, check_id = TRUE) {
425433
checkmate::expect_class(learner, "LearnerTorch")
434+
435+
# overwrites .additional_phash_input where necessary
436+
testthat::expect_warning(learner$hash, regexp = NA)
437+
testthat::expect_warning(learner$phash, regexp = NA)
438+
426439
get("expect_learner", envir = .GlobalEnv)(learner)
427440
# state cloning is tested separately
428441
learner1 = learner

0 commit comments

Comments
 (0)