Skip to content

Commit addf4ee

Browse files
authored
better error message for nn_head (#369)
1 parent 4600702 commit addf4ee

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

R/PipeOpTorchHead.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ PipeOpTorchHead = R6Class("PipeOpTorchHead",
3636
),
3737
private = list(
3838
.shapes_out = function(shapes_in, param_vals, task) {
39-
assert_true(length(shapes_in[[1]]) == 2L)
39+
if (length(shapes_in[[1]]) != 2L) {
40+
stopf("PipeOpTorchHead expects 2D input, but got %s.", shape_to_str(shapes_in))
41+
}
4042
d = get_nout(task)
4143
list(c(shapes_in[[1]][[1]], d))
4244
},

tests/testthat/test_PipeOpTorchHead.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,10 @@ test_that("PipeOpTorchHead paramtest", {
1212
res = expect_paramset(po_test, torch::nn_linear, exclude = c("out_features", "in_features"))
1313
expect_paramtest(res)
1414
})
15+
16+
17+
test_that("correct error message", {
18+
task = nano_imagenet()
19+
graph = po("torch_ingress_ltnsr") %>>% po("nn_head")
20+
expect_error(graph$train(task), "expects 2D input")
21+
})

0 commit comments

Comments
 (0)