Skip to content

Commit a528a47

Browse files
committed
feat: allow no of obs in holdout resampling
1 parent d6c22d8 commit a528a47

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

R/ResamplingForecastHoldout.R

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,10 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
3939
#' @description
4040
#' Creates a new instance of this [R6][R6::R6Class] class.
4141
initialize = function() {
42-
param_set = ps(ratio = p_dbl(0, 1, tags = "required"))
43-
param_set$set_values(ratio = 0.8)
42+
param_set = ps(
43+
ratio = p_dbl(0, 1),
44+
n = p_int()
45+
)
4446

4547
super$initialize(
4648
id = "forecast_holdout",
@@ -63,7 +65,17 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
6365
.sample = function(ids, ...) {
6466
pars = self$param_set$get_values()
6567
n = length(ids)
66-
nr = round(n * pars$ratio)
68+
has_ratio = !is.null(pars$ratio)
69+
if (!xor(!has_ratio, is.null(pars$n))) {
70+
stopf("Either parameter `ratio` (x)or `n` must be provided.")
71+
}
72+
if (has_ratio) {
73+
nr = round(n * pars$ratio)
74+
} else if (pars$n > 0L) {
75+
nr = min(n, pars$n)
76+
} else {
77+
nr = max(n + pars$n, 0L)
78+
}
6779
ii = ids[1:nr]
6880
list(train = ii, test = ids[(nr + 1L):n])
6981
},

tests/testthat/test_ResamplingForecastHoldout.R

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,26 @@ test_that("forecast_holdout basic properties", {
99
expect_error(resampling$train_set(2L))
1010
expect_error(resampling$test_set(2L))
1111
expect_false(resampling$duplicated_ids)
12+
13+
resampling = rsmp("forecast_holdout", ratio = 0.5)$instantiate(task)
14+
expect_length(resampling$train_set(1L), task$nrow / 2)
15+
expect_length(resampling$test_set(1L), task$nrow / 2)
16+
17+
resampling = rsmp("forecast_holdout", n = 10L)$instantiate(task)
18+
expect_length(resampling$train_set(1L), 10L)
19+
expect_length(resampling$test_set(1L), task$nrow - 10L)
20+
21+
resampling = rsmp("forecast_holdout", n = -10L)$instantiate(task)
22+
expect_length(resampling$train_set(1L), task$nrow - 10L)
23+
expect_length(resampling$test_set(1L), 10L)
1224
})
1325

1426
test_that("forecast_holdout works", {
1527
skip_if_not_installed("tsbox")
1628
dt = tsbox::ts_dt(AirPassengers)
1729
dt[, time := NULL]
1830
task = as_task_regr(dt, target = "value")
19-
resampling = rsmp("forecast_holdout")
31+
resampling = rsmp("forecast_holdout", ratio = 0.8)
2032
resampling$instantiate(task)
2133
expect_identical(resampling$train_set(1L), 1:115)
2234
expect_identical(resampling$test_set(1L), 116:144)

0 commit comments

Comments
 (0)