Skip to content

Commit e995217

Browse files
committed
feat: task train draft for lag pipeop
1 parent e0ef76c commit e995217

File tree

4 files changed

+111
-40
lines changed

4 files changed

+111
-40
lines changed

R/ForecastLearner.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ ForecastLearner = R6::R6Class("ForecastLearner",
8181

8282
.lag_transform = function(dt, target) {
8383
lag = self$lag
84-
nms = sprintf("%s_lag_%s", target, lag)
84+
nms = sprintf("%s_lag_%i", target, lag)
8585
dt = copy(dt)
86-
key_coles = private$.task$col_roles$key
87-
if (length(key_coles) > 0L) {
88-
dt[, (nms) := shift(.SD, n = lag, type = "lag"), by = key_coles, .SDcols = target]
86+
key_cols = private$.task$col_roles$key
87+
if (length(key_cols) > 0L) {
88+
dt[, (nms) := shift(.SD, n = lag, type = "lag"), by = key_cols, .SDcols = target]
8989
} else {
9090
dt[, (nms) := shift(.SD, n = lag, type = "lag"), .SDcols = target]
9191
}

R/PipeOpTargetLags.R

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ PipeOpLags = R6Class("PipeOpLags",
33
public = list(
44
#' @description Initializes a new instance of this Class.
55
#' @param id (`character(1)`)\cr
6-
#' Identifier of resulting object, default `"fda.cor"`.
6+
#' Identifier of resulting object, default `"fcst.lags"`.
77
#' @param param_vals (named `list()`)\cr
88
#' List of hyperparameter settings, overwriting the hyperparameter settings that would
99
#' otherwise be set during construction. Default `list()`.
10-
initialize = function(id = "fcsts.lags", param_vals = list()) {
10+
initialize = function(id = "fcst.lags", param_vals = list()) {
1111
param_set = ps(
1212
lag = p_uty(tags = c("train", "predict"), custom_check = check_integerish)
1313
)
14+
param_set$set_values(lag = 1L)
1415

1516
super$initialize(
1617
id = id,
@@ -21,14 +22,74 @@ PipeOpLags = R6Class("PipeOpLags",
2122
)
2223
}
2324
),
25+
2426
private = list(
27+
.train_task = function(task) {
28+
pv = self$param_set$get_values()
29+
lag = pv$lag
30+
target = task$target_names
31+
key_cols = task$col_roles$key
32+
order_cols = task$col_roles$order
33+
dt = task$data()
34+
nms = sprintf("%s_lag_%i", target, lag)
35+
self$state = list(dt = dt)
36+
if (length(key_cols) > 0L) {
37+
setorderv(dt, c(key_cols, order_cols))
38+
dt[, (nms) := shift(.SD, n = lag, type = "lag"), by = key_cols, .SDcols = target]
39+
} else {
40+
setorderv(dt, order_cols)
41+
dt[, (nms) := shift(.SD, n = lag, type = "lag"), .SDcols = target]
42+
}
43+
task$select(task$feature_names)$cbind(dt)
44+
},
45+
46+
.predict_task = function(task) {
47+
.NotYetImplemented()
48+
},
49+
50+
.predict_task_original = function(task) {
51+
cols = self$state$dt_columns
52+
if (!length(cols)) {
53+
return(task)
54+
}
55+
dt = task$data(cols = cols)
56+
dt = as.data.table(private$.predict_dt(dt, task$levels(cols)))
57+
task$select(setdiff(task$feature_names, cols))$cbind(dt)
58+
},
59+
60+
.train_task_original = function(task) {
61+
dt_columns = private$.select_cols(task)
62+
cols = dt_columns
63+
if (!length(cols)) {
64+
self$state = list(dt_columns = dt_columns)
65+
return(task)
66+
}
67+
dt = task$data(cols = cols)
68+
69+
dt = if (test_r6(task, classes = "TaskSupervised")) {
70+
as.data.table(private$.train_dt(dt, task$levels(cols), task$truth()))
71+
} else {
72+
as.data.table(private$.train_dt(dt, task$levels(cols)))
73+
}
74+
75+
self$state$dt_columns = dt_columns
76+
task$select(setdiff(task$feature_names, cols))$cbind(dt)
77+
},
2578

2679
.train_dt = function(dt, levels, target) {
27-
browser()
80+
# this wouldn't allow sorting since we don't get the task here,
81+
# as well as getting the target name
82+
pv = self$param_set$get_values()
83+
lag = pv$lag
84+
nms = sprintf("target_lag_%i", lag)
85+
dt[, target := target]
86+
dt[, (nms) := shift(.SD, n = lag, type = "lag"), .SDcols = "target"]
87+
dt[, target := NULL]
88+
dt
2889
},
2990

3091
.predict_dt = function(dt, levels) {
31-
..NotYetImplemented()
92+
.NotYetImplemented()
3293
}
3394
)
3495
)

README.Rmd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,10 @@ learner$predict_newdata(newdata, task)
339339

340340
### Custom PipeOps
341341

342-
```{r, eval = FALSE}
342+
```{r}
343343
library(mlr3pipelines)
344344
345345
task = tsk("airpassengers")
346-
pop = po("fcst.lags")
346+
pop = po("fcst.lags", lag = 1:12)
347347
pop$train(list(task))[[1L]]
348348
```

README.md

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,32 +44,32 @@ prediction = flrn$predict_newdata(newdata, task)
4444
prediction
4545
#> <PredictionRegr> for 3 observations:
4646
#> row_ids truth response
47-
#> 1 NA 436.4899
48-
#> 2 NA 436.6391
49-
#> 3 NA 456.0920
47+
#> 1 NA 433.6001
48+
#> 2 NA 438.1410
49+
#> 3 NA 457.1800
5050
prediction = flrn$predict(task, 142:144)
5151
prediction
5252
#> <PredictionRegr> for 3 observations:
5353
#> row_ids truth response
54-
#> 1 461 456.6918
55-
#> 2 390 411.1894
56-
#> 3 432 431.1121
54+
#> 1 461 456.5852
55+
#> 2 390 411.2524
56+
#> 3 432 431.9528
5757
prediction$score(msr("regr.rmse"))
5858
#> regr.rmse
59-
#> 12.49451
59+
#> 12.53208
6060

6161
flrn = ForecastLearner$new(lrn("regr.ranger"), 1:12)
6262
resampling = rsmp("forecast_holdout", ratio = 0.9)
6363
rr = resample(task, flrn, resampling)
6464
rr$aggregate(msr("regr.rmse"))
6565
#> regr.rmse
66-
#> 48.87653
66+
#> 47.88555
6767

6868
resampling = rsmp("forecast_cv")
6969
rr = resample(task, flrn, resampling)
7070
rr$aggregate(msr("regr.rmse"))
7171
#> regr.rmse
72-
#> 25.25769
72+
#> 24.16737
7373
```
7474

7575
### Multivariate
@@ -89,34 +89,34 @@ flrn = ForecastLearner$new(lrn("regr.ranger"), 1:12)$train(new_task)
8989
prediction = flrn$predict(new_task, 142:144)
9090
prediction$score(msr("regr.rmse"))
9191
#> regr.rmse
92-
#> 13.44279
92+
#> 13.08595
9393

9494
row_ids = new_task$nrow - 0:2
9595
flrn$predict_newdata(new_task$data(rows = row_ids), new_task)
9696
#> <PredictionRegr> for 3 observations:
9797
#> row_ids truth response
98-
#> 1 432 434.0366
99-
#> 2 390 436.9707
100-
#> 3 461 458.7455
98+
#> 1 432 433.6868
99+
#> 2 390 430.1164
100+
#> 3 461 453.4341
101101
newdata = new_task$data(rows = row_ids, cols = new_task$feature_names)
102102
flrn$predict_newdata(newdata, new_task)
103103
#> <PredictionRegr> for 3 observations:
104104
#> row_ids truth response
105-
#> 1 NA 434.0366
106-
#> 2 NA 436.9707
107-
#> 3 NA 458.7455
105+
#> 1 NA 433.6868
106+
#> 2 NA 430.1164
107+
#> 3 NA 453.4341
108108

109109
resampling = rsmp("forecast_holdout", ratio = 0.9)
110110
rr = resample(new_task, flrn, resampling)
111111
rr$aggregate(msr("regr.rmse"))
112112
#> regr.rmse
113-
#> 50.14024
113+
#> 51.17934
114114

115115
resampling = rsmp("forecast_cv")
116116
rr = resample(new_task, flrn, resampling)
117117
rr$aggregate(msr("regr.rmse"))
118118
#> regr.rmse
119-
#> 26.23039
119+
#> 27.53512
120120
```
121121

122122
### mlr3pipelines integration
@@ -131,7 +131,7 @@ glrn = as_learner(graph %>>% flrn)$train(task)
131131
prediction = glrn$predict(task, 142:144)
132132
prediction$score(msr("regr.rmse"))
133133
#> regr.rmse
134-
#> 13.82398
134+
#> 16.0287
135135
```
136136

137137
### Example: Forecasting electricity demand
@@ -174,13 +174,13 @@ prediction = glrn$predict_newdata(newdata, task)
174174
prediction
175175
#> <PredictionRegr> for 14 observations:
176176
#> row_ids truth response
177-
#> 1 NA 186.6940
178-
#> 2 NA 190.8129
179-
#> 3 NA 183.0273
177+
#> 1 NA 186.9874
178+
#> 2 NA 191.3284
179+
#> 3 NA 183.5836
180180
#> --- --- ---
181-
#> 12 NA 214.4948
182-
#> 13 NA 218.4061
183-
#> 14 NA 220.0571
181+
#> 12 NA 216.9396
182+
#> 13 NA 221.4096
183+
#> 14 NA 222.3596
184184
```
185185

186186
### Global Forecasting
@@ -213,14 +213,14 @@ flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)$train(task)
213213
prediction = flrn$predict(task, 4460:4464)
214214
prediction$score(msr("regr.rmse"))
215215
#> regr.rmse
216-
#> 22058.4
216+
#> 22494.87
217217

218218
flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)
219219
resampling = rsmp("forecast_holdout", ratio = 0.9)
220220
rr = resample(task, flrn, resampling)
221221
rr$aggregate(msr("regr.rmse"))
222222
#> regr.rmse
223-
#> 94136.08
223+
#> 91483.84
224224
```
225225

226226
### Example: Global vs Local Forecasting
@@ -255,7 +255,7 @@ row_ids = tab[year >= 2015, row_id]
255255
prediction = flrn$predict(task, row_ids)
256256
prediction$score(msr("regr.rmse"))
257257
#> regr.rmse
258-
#> 33009.95
258+
#> 32875.1
259259

260260
# global forecasting
261261
task = tsibbledata::aus_livestock |>
@@ -276,7 +276,7 @@ row_ids = tab[year >= 2015 & state == "Western Australia", row_id]
276276
prediction = flrn$predict(task, row_ids)
277277
prediction$score(msr("regr.rmse"))
278278
#> regr.rmse
279-
#> 30965.86
279+
#> 31399.84
280280
```
281281

282282
### Example: generate new data
@@ -470,6 +470,16 @@ learner$predict_newdata(newdata, task)
470470
library(mlr3pipelines)
471471

472472
task = tsk("airpassengers")
473-
pop = po("fcst.lags")
473+
pop = po("fcst.lags", lag = 1:12)
474474
pop$train(list(task))[[1L]]
475+
#> <TaskRegr:airpassengers> (144 x 14): Monthly Airline Passenger Numbers 1949-1960
476+
#> * Target: passengers
477+
#> * Properties: ordered
478+
#> * Features (13):
479+
#> - dbl (12): passengers_lag_1, passengers_lag_10, passengers_lag_11,
480+
#> passengers_lag_12, passengers_lag_2, passengers_lag_3,
481+
#> passengers_lag_4, passengers_lag_5, passengers_lag_6,
482+
#> passengers_lag_7, passengers_lag_8, passengers_lag_9
483+
#> - dte (1): date
484+
#> * Order by: date
475485
```

0 commit comments

Comments
 (0)