Skip to content

Commit 495d7bd

Browse files
committed
feat: add partition implementation
1 parent 47870b9 commit 495d7bd

16 files changed

+71
-40
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ URL: https://mlr3forecast.mlr-org.com,
1010
https://github.com/mlr-org/mlr3forecast
1111
BugReports: https://github.com/mlr-org/mlr3forecast/issues
1212
Depends:
13-
mlr3 (>= 1.1.0),
13+
mlr3 (>= 1.1.0.9000),
1414
R (>= 3.3.0)
1515
Imports:
1616
backports,
@@ -38,6 +38,7 @@ Suggests:
3838
xts,
3939
zoo
4040
Remotes:
41+
mlr-org/mlr3,
4142
mlr-org/mlr3pipelines
4243
Config/testthat/edition: 3
4344
Encoding: UTF-8
@@ -76,6 +77,7 @@ Collate:
7677
'assertions.R'
7778
'autoplot.R'
7879
'bibentries.R'
80+
'partition.R'
7981
'reexports.R'
8082
'tsf.R'
8183
'utils.R'

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ S3method(as_task_fcst,zoo)
1111
S3method(as_tasks_fcst,default)
1212
S3method(as_tasks_fcst,list)
1313
S3method(autoplot,TaskFcst)
14+
S3method(partition,TaskFcst)
1415
S3method(plot,TaskFcst)
1516
export(ForecastLearner)
1617
export(ForecastLearnerManual)

R/ResamplingFcstHoldout.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,12 @@ ResamplingFcstHoldout = R6Class(
9090
nr = max(n_obs + n, 0L)
9191
}
9292

93-
order_cols = task$col_roles$order
94-
key_cols = task$col_roles$key
93+
col_roles = task$col_roles
94+
order_cols = col_roles$order
95+
key_cols = col_roles$key
9596
has_key_cols = length(key_cols) > 0L
9697
dt = task$backend$data(rows = ids, cols = c(task$backend$primary_key, order_cols, key_cols))
98+
9799
if (has_key_cols) {
98100
setnames(dt, "..row_id", "row_id")
99101
setorderv(dt, c(key_cols, order_cols))
@@ -106,7 +108,10 @@ ResamplingFcstHoldout = R6Class(
106108
} else {
107109
setnames(dt, c("row_id", "order"))
108110
setorderv(dt, "order")
109-
list(train = dt[1:nr, row_id], test = dt[(nr + 1L):.N, row_id])
111+
list(
112+
train = dt[1:nr, row_id],
113+
test = if (nrow(dt) > nr) dt[(nr + 1L):.N, row_id] else integer()
114+
)
110115
}
111116
},
112117

R/partition.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#' @export
2+
partition.TaskFcst = function(task, ratio = 0.67) {
3+
task = task$clone(deep = TRUE)
4+
if (sum(ratio) >= 1) {
5+
stopf("Sum of 'ratio' must be smaller than 1")
6+
}
7+
8+
if (length(ratio) == 1L) {
9+
ratio[2L] = 1 - ratio
10+
} else {
11+
ratio[3L] = 1 - (ratio[1L] + ratio[2L])
12+
}
13+
r1 = rsmp("fcst.holdout", ratio = ratio[1L])$instantiate(task)
14+
15+
task$row_roles$use = r1$test_set(1L)
16+
r2 = rsmp("fcst.holdout", ratio = ratio[2L] / (1 - ratio[1L]))$instantiate(task)
17+
18+
list(
19+
train = r1$train_set(1L),
20+
test = r2$train_set(1L),
21+
validation = r2$test_set(1L)
22+
)
23+
}

man-roxygen/example.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
#' task = tsk("airpassengers")
88
#'
99
#' # Create train and test set
10-
#' resampling = rsmp("fcst.holdout", ratio = 0.7)$instantiate(task)
10+
#' ids = partition(task)
1111
#'
1212
#' # Train the learner on the training ids
13-
#' learner$train(task, row_ids = resampling$train_set(1))
13+
#' learner$train(task, row_ids = ids$train)
1414
#'
1515
#' # Print the model
1616
#' print(learner$model)
@@ -19,7 +19,7 @@
1919
#' if ("importance" %in% learner$properties) print(learner$importance)
2020
#'
2121
#' # Make predictions for the test rows
22-
#' predictions = learner$predict(task, row_ids = resampling$test_set(1))
22+
#' predictions = learner$predict(task, row_ids = ids$test)
2323
#'
2424
#' # Score the predictions
2525
#' predictions$score()

man/mlr_learners_fcst.adam.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_fcst.arfima.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_fcst.arima.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_fcst.auto_adam.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_learners_fcst.auto_arima.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)