Skip to content

Commit d1f3df2

Browse files
committed
feat(resample): feature flag for sorting based order col
1 parent 94f8e48 commit d1f3df2

File tree

3 files changed

+42
-41
lines changed

3 files changed

+42
-41
lines changed

R/ResamplingForecastCV.R

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,6 @@ ResamplingForecastCV = R6Class("ResamplingForecastCV",
8282

8383
private = list(
8484
.sample = function(ids, task, ...) {
85-
# if (length(task$col_roles$order) == 0L) {
86-
# stopf(
87-
# "Resampling '%s' requires an ordered task, but Task '%s' has no order.",
88-
# self$id, task$id
89-
# )
90-
# }
91-
9285
pars = self$param_set$get_values()
9386
ids = sort(ids)
9487
train_end = ids[ids <= (max(ids) - pars$horizon) & ids >= pars$window_size]

R/ResamplingForecastHoldout.R

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,6 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
6868

6969
private = list(
7070
.sample = function(ids, task, ...) {
71-
# if (length(task$col_roles$order) == 0L) {
72-
# stopf(
73-
# "Resampling '%s' requires an ordered task, but Task '%s' has no order.",
74-
# self$id, task$id
75-
# )
76-
# }
77-
7871
pars = self$param_set$get_values()
7972
ratio = pars$ratio
8073
n = pars$n
@@ -91,8 +84,23 @@ ResamplingForecastHoldout = R6Class("ResamplingForecastHoldout",
9184
} else {
9285
nr = max(n_obs + n, 0L)
9386
}
94-
ii = ids[1:nr]
95-
list(train = ii, test = ids[(nr + 1L):n_obs])
87+
88+
if (TRUE) {
89+
ids = sort(ids)
90+
ii = ids[1:nr]
91+
list(train = ii, test = ids[(nr + 1L):n_obs])
92+
} else {
93+
# check when this is even needed
94+
order = row_id = NULL
95+
order_cols = private$.col_roles$order
96+
tab = task$backend$data(rows = ids, cols = c(task$backend$primary_key, order_cols))
97+
setnames(tab, c("row_id", "order"))
98+
setorder(tab, order)
99+
list(
100+
train = tab[1:nr, row_id],
101+
test = tab[(nr + 1L):n_obs, row_id]
102+
)
103+
}
96104
},
97105

98106
.get_train = function(i) {

README.md

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,32 @@ prediction = ff$predict_newdata(newdata, task)
4545
prediction
4646
#> <PredictionRegr> for 3 observations:
4747
#> row_ids truth response
48-
#> 1 NA 446.9409
49-
#> 2 NA 477.9439
50-
#> 3 NA 480.5694
48+
#> 1 NA 448.8710
49+
#> 2 NA 475.2456
50+
#> 3 NA 480.5179
5151
prediction = ff$predict(task, 142:144)
5252
prediction
5353
#> <PredictionRegr> for 3 observations:
5454
#> row_ids truth response
55-
#> 1 461 459.4145
56-
#> 2 390 411.2457
57-
#> 3 432 400.4514
55+
#> 1 461 456.4968
56+
#> 2 390 411.1712
57+
#> 3 432 393.9585
5858
prediction$score(measure)
5959
#> regr.rmse
60-
#> 21.97883
60+
#> 25.26957
6161

6262
ff = Forecaster$new(lrn("regr.ranger"), 1:3)
6363
resampling = rsmp("forecast_holdout", ratio = 0.8)
6464
rr = resample(task, ff, resampling)
6565
rr$aggregate(measure)
6666
#> regr.rmse
67-
#> 105.0997
67+
#> 105.8215
6868

6969
resampling = rsmp("forecast_cv")
7070
rr = resample(task, ff, resampling)
7171
rr$aggregate(measure)
7272
#> regr.rmse
73-
#> 54.93903
73+
#> 54.28352
7474
```
7575

7676
### Multivariate
@@ -90,34 +90,34 @@ ff = Forecaster$new(lrn("regr.ranger"), 1:3)$train(new_task)
9090
prediction = ff$predict(new_task, 142:144)
9191
prediction$score(measure)
9292
#> regr.rmse
93-
#> 17.55705
93+
#> 17.0878
9494

9595
row_ids = new_task$nrow - 0:2
9696
ff$predict_newdata(new_task$data(rows = row_ids), new_task)
9797
#> <PredictionRegr> for 3 observations:
9898
#> row_ids truth response
99-
#> 1 432 405.2216
100-
#> 2 390 388.3066
101-
#> 3 461 385.6412
99+
#> 1 432 405.5814
100+
#> 2 390 388.3657
101+
#> 3 461 390.9778
102102
newdata = new_task$data(rows = row_ids, cols = new_task$feature_names)
103103
ff$predict_newdata(newdata, new_task)
104104
#> <PredictionRegr> for 3 observations:
105105
#> row_ids truth response
106-
#> 1 NA 405.2216
107-
#> 2 NA 388.3066
108-
#> 3 NA 385.6412
106+
#> 1 NA 405.5814
107+
#> 2 NA 388.3657
108+
#> 3 NA 390.9778
109109

110110
resampling = rsmp("forecast_holdout", ratio = 0.8)
111111
rr = resample(new_task, ff, resampling)
112112
rr$aggregate(measure)
113113
#> regr.rmse
114-
#> 82.35283
114+
#> 81.91252
115115

116116
resampling = rsmp("forecast_cv")
117117
rr = resample(new_task, ff, resampling)
118118
rr$aggregate(measure)
119119
#> regr.rmse
120-
#> 45.54337
120+
#> 41.87113
121121
```
122122

123123
### mlr3pipelines integration
@@ -128,7 +128,7 @@ glrn = as_learner(graph %>>% ff)$train(task)
128128
prediction = glrn$predict(task, 142:144)
129129
prediction$score(measure)
130130
#> regr.rmse
131-
#> 34.29322
131+
#> 33.74039
132132
```
133133

134134
### Example: Forecasting electricity demand
@@ -166,11 +166,11 @@ prediction = glrn$predict_newdata(newdata, task)
166166
prediction
167167
#> <PredictionRegr> for 14 observations:
168168
#> row_ids truth response
169-
#> 1 NA 187.9399
170-
#> 2 NA 190.5695
171-
#> 3 NA 184.2617
169+
#> 1 NA 187.6208
170+
#> 2 NA 191.8121
171+
#> 3 NA 183.6753
172172
#> --- --- ---
173-
#> 12 NA 214.6350
174-
#> 13 NA 218.8392
175-
#> 14 NA 221.4170
173+
#> 12 NA 213.8759
174+
#> 13 NA 218.4198
175+
#> 14 NA 218.8139
176176
```

0 commit comments

Comments
 (0)