Skip to content

Commit aff5f76

Browse files
committed
refactor: rename lag arg to lags
1 parent 2a037f2 commit aff5f76

File tree

6 files changed

+59
-59
lines changed

6 files changed

+59
-59
lines changed

R/ForecastLearner.R

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,18 @@ ForecastLearner = R6::R6Class("ForecastLearner",
88
#' The learner
99
learner = NULL,
1010

11-
#' @field lag (`integer()`)\cr
12-
#' The lag
13-
lag = NULL,
11+
#' @field lags (`integer()`)\cr
12+
#' The lags
13+
lags = NULL,
1414

1515
#' @description
1616
#' Creates a new instance of this [R6][R6::R6Class] class.
1717
#' @param task ([Task])\cr
1818
#' @param learner ([Learner])\cr
19-
#' @param lag (`integer(1)`)\cr
20-
initialize = function(learner, lag) {
19+
#' @param lags (`integer(1)`)\cr
20+
initialize = function(learner, lags) {
2121
self$learner = assert_learner(as_learner(learner, clone = TRUE))
22-
self$lag = assert_integerish(lag, lower = 1L, any.missing = FALSE, coerce = TRUE)
22+
self$lags = assert_integerish(lags, lower = 1L, any.missing = FALSE, coerce = TRUE)
2323

2424
super$initialize(
2525
id = learner$id,
@@ -80,14 +80,14 @@ ForecastLearner = R6::R6Class("ForecastLearner",
8080
},
8181

8282
.lag_transform = function(dt, target) {
83-
lag = self$lag
84-
nms = sprintf("%s_lag_%i", target, lag)
83+
lags = self$lags
84+
nms = sprintf("%s_lag_%i", target, lags)
8585
dt = copy(dt)
8686
key_cols = private$.task$col_roles$key
8787
if (length(key_cols) > 0L) {
88-
dt[, (nms) := shift(.SD, lag), by = key_cols, .SDcols = target]
88+
dt[, (nms) := shift(.SD, lags), by = key_cols, .SDcols = target]
8989
} else {
90-
dt[, (nms) := shift(.SD, lag), .SDcols = target]
90+
dt[, (nms) := shift(.SD, lags), .SDcols = target]
9191
}
9292
dt
9393
},

R/PipeOpFcstLag.R

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#' @section Parameters:
88
#' The parameters are the parameters inherited from [mlr3pipelines::PipeOpTaskPreproc],
99
#' as well as the following parameters:
10-
#' * `lag` :: `integer()`\cr
10+
#' * `lags` :: `integer()`\cr
1111
#' The lags to create.
1212
#'
1313
#' @export
@@ -24,9 +24,9 @@ PipeOpFcstLag = R6Class("PipeOpFcstLag",
2424
#' otherwise be set during construction. Default `list()`.
2525
initialize = function(id = "fcst.lags", param_vals = list()) {
2626
param_set = ps(
27-
lag = p_uty(tags = c("train", "predict"), custom_check = check_integerish)
27+
lags = p_uty(tags = c("train", "predict"), custom_check = check_integerish)
2828
)
29-
param_set$set_values(lag = 1L)
29+
param_set$set_values(lags = 1L)
3030

3131
super$initialize(
3232
id = id,
@@ -41,37 +41,37 @@ PipeOpFcstLag = R6Class("PipeOpFcstLag",
4141
private = list(
4242
.train_task = function(task) {
4343
pv = self$param_set$get_values(tags = "train")
44-
lag = pv$lag
44+
lags = pv$lags
4545
target = task$target_names
4646
key_cols = task$col_roles$key
4747
order_cols = task$col_roles$order
4848
dt = task$data()
49-
self$state = list(dt = dt[(.N - max(lag)):.N])
50-
nms = sprintf("%s_lag_%i", target, lag)
49+
self$state = list(dt = dt[(.N - max(lags)):.N])
50+
nms = sprintf("%s_lag_%i", target, lags)
5151
if (length(key_cols) > 0L) {
5252
setorderv(dt, c(key_cols, order_cols))
53-
dt[, (nms) := shift(.SD, lag), by = key_cols, .SDcols = target]
53+
dt[, (nms) := shift(.SD, lags), by = key_cols, .SDcols = target]
5454
} else {
5555
setorderv(dt, order_cols)
56-
dt[, (nms) := shift(.SD, lag), .SDcols = target]
56+
dt[, (nms) := shift(.SD, lags), .SDcols = target]
5757
}
5858
task$select(task$feature_names)$cbind(dt)
5959
},
6060

6161
.predict_task = function(task) {
6262
pv = self$param_set$get_values(tags = "predict")
63-
lag = pv$lag
63+
lags = pv$lags
6464
target = task$target_names
6565
key_cols = task$col_roles$key
6666
order_cols = task$col_roles$order
6767
dt = rbind(self$state$dt, task$data())
68-
nms = sprintf("%s_lag_%i", target, lag)
68+
nms = sprintf("%s_lag_%i", target, lags)
6969
if (length(key_cols) > 0L) {
7070
setorderv(dt, c(key_cols, order_cols))
71-
dt[, (nms) := shift(.SD, lag), by = key_cols, .SDcols = target]
71+
dt[, (nms) := shift(.SD, lags), by = key_cols, .SDcols = target]
7272
} else {
7373
setorderv(dt, order_cols)
74-
dt[, (nms) := shift(.SD, lag), .SDcols = target]
74+
dt[, (nms) := shift(.SD, lags), .SDcols = target]
7575
}
7676
dt = dt[(.N - task$nrow + 1L):.N]
7777
task$select(task$feature_names)$cbind(dt)
@@ -110,10 +110,10 @@ PipeOpFcstLag = R6Class("PipeOpFcstLag",
110110
# this wouldn't allow sorting since we don't get the task here,
111111
# as well as getting the target name
112112
pv = self$param_set$get_values()
113-
lag = pv$lag
114-
nms = sprintf("target_lag_%i", lag)
113+
lags = pv$lags
114+
nms = sprintf("target_lag_%i", lags)
115115
dt[, target := target]
116-
dt[, (nms) := shift(.SD, lag), .SDcols = "target"]
116+
dt[, (nms) := shift(.SD, lags), .SDcols = "target"]
117117
dt[, target := NULL]
118118
dt
119119
},
@@ -125,4 +125,4 @@ PipeOpFcstLag = R6Class("PipeOpFcstLag",
125125
)
126126

127127
#' @include zzz.R
128-
register_po("fcst.lag", PipeOpFcstLag)
128+
register_po("fcst.lags", PipeOpFcstLag)

README.Rmd

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,12 +254,12 @@ library(mlr3learners)
254254
library(mlr3pipelines)
255255
256256
task = tsk("airpassengers")
257-
pop = po("fcst.lag", lag = 1:12)
257+
pop = po("fcst.lag", lags = 1:12)
258258
new_task = pop$train(list(task))[[1L]]
259259
new_task$data()
260260
261261
task = tsk("airpassengers")
262-
graph = po("fcst.lag", lag = 1:12) %>>%
262+
graph = po("fcst.lag", lags = 1:12) %>>%
263263
ppl("convert_types", "Date", "POSIXct") %>>%
264264
po("datefeatures",
265265
param_vals = list(
@@ -295,7 +295,7 @@ trafo = po("targetmutate",
295295
)
296296
)
297297
298-
graph = po("fcst.lag", lag = 1:12) %>>%
298+
graph = po("fcst.lag", lags = 1:12) %>>%
299299
ppl("convert_types", "Date", "POSIXct") %>>%
300300
po("datefeatures",
301301
param_vals = list(
@@ -315,7 +315,7 @@ prediction$score(msr("regr.rmse"))
315315
```
316316

317317
```{r, eval = FALSE}
318-
graph = po("fcst.lag", lag = 1:12) %>>%
318+
graph = po("fcst.lag", lags = 1:12) %>>%
319319
ppl("convert_types", "Date", "POSIXct") %>>%
320320
po("datefeatures",
321321
param_vals = list(
@@ -328,7 +328,7 @@ graph = po("fcst.lag", lag = 1:12) %>>%
328328
task = tsk("airpassengers")
329329
flrn = ForecastRecursiveLearner$new(lrn("regr.ranger"))
330330
glrn = as_learner(graph %>>% flrn)
331-
trafo = po("fcst.targetdiff", lag = 12L)
331+
trafo = po("fcst.targetdiff", lags = 12L)
332332
pipeline = ppl("targettrafo", graph = glrn, trafo_pipeop = trafo)
333333
glrn = as_learner(pipeline)$train(task)
334334
prediction = glrn$predict(task, 142:144)

README.md

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -99,32 +99,32 @@ prediction = flrn$predict_newdata(newdata, task)
9999
prediction
100100
#> <PredictionRegr> for 3 observations:
101101
#> row_ids truth response
102-
#> 1 NA 436.1867
103-
#> 2 NA 437.4089
104-
#> 3 NA 456.8410
102+
#> 1 NA 438.6738
103+
#> 2 NA 438.2207
104+
#> 3 NA 457.2237
105105
prediction = flrn$predict(task, 142:144)
106106
prediction
107107
#> <PredictionRegr> for 3 observations:
108108
#> row_ids truth response
109-
#> 1 461 459.1495
110-
#> 2 390 414.8433
111-
#> 3 432 430.2693
109+
#> 1 461 456.8032
110+
#> 2 390 412.9617
111+
#> 3 432 432.0672
112112
prediction$score(msr("regr.rmse"))
113113
#> regr.rmse
114-
#> 14.41767
114+
#> 13.4766
115115

116116
flrn = ForecastLearner$new(lrn("regr.ranger"), 1:12)
117117
resampling = rsmp("forecast_holdout", ratio = 0.9)
118118
rr = resample(task, flrn, resampling)
119119
rr$aggregate(msr("regr.rmse"))
120120
#> regr.rmse
121-
#> 48.97126
121+
#> 48.4789
122122

123123
resampling = rsmp("forecast_cv")
124124
rr = resample(task, flrn, resampling)
125125
rr$aggregate(msr("regr.rmse"))
126126
#> regr.rmse
127-
#> 25.19211
127+
#> 25.08963
128128
```
129129

130130
Or with some feature engineering using mlr3pipelines:
@@ -146,7 +146,7 @@ glrn = as_learner(graph %>>% flrn)$train(task)
146146
prediction = glrn$predict(task, 142:144)
147147
prediction$score(msr("regr.rmse"))
148148
#> regr.rmse
149-
#> 15.58057
149+
#> 14.22429
150150
```
151151

152152
### Example: forecasting electricity demand
@@ -176,13 +176,13 @@ prediction = glrn$predict_newdata(newdata, task)
176176
prediction
177177
#> <PredictionRegr> for 14 observations:
178178
#> row_ids truth response
179-
#> 1 NA 187595.7
180-
#> 2 NA 196608.6
181-
#> 3 NA 189152.0
179+
#> 1 NA 189375.9
180+
#> 2 NA 199550.0
181+
#> 3 NA 188647.1
182182
#> --- --- ---
183-
#> 12 NA 222400.3
184-
#> 13 NA 226494.8
185-
#> 14 NA 226568.4
183+
#> 12 NA 221192.0
184+
#> 13 NA 225456.5
185+
#> 14 NA 227090.1
186186
```
187187

188188
### Example: global forecasting (longitudinal data)
@@ -220,14 +220,14 @@ flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)$train(task)
220220
prediction = flrn$predict(task, 4460:4464)
221221
prediction$score(msr("regr.rmse"))
222222
#> regr.rmse
223-
#> 22604.48
223+
#> 22055.26
224224

225225
flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)
226226
resampling = rsmp("forecast_holdout", ratio = 0.9)
227227
rr = resample(task, flrn, resampling)
228228
rr$aggregate(msr("regr.rmse"))
229229
#> regr.rmse
230-
#> 92125.26
230+
#> 92992
231231
```
232232

233233
### Example: global vs local forecasting
@@ -293,12 +293,12 @@ library(mlr3learners)
293293
library(mlr3pipelines)
294294

295295
task = tsk("airpassengers")
296-
pop = po("fcst.lag", lag = 1:12)
296+
pop = po("fcst.lag", lags = 1:12)
297297
new_task = pop$train(list(task))[[1L]]
298298
new_task$data()
299299

300300
task = tsk("airpassengers")
301-
graph = po("fcst.lag", lag = 1:12) %>>%
301+
graph = po("fcst.lag", lags = 1:12) %>>%
302302
ppl("convert_types", "Date", "POSIXct") %>>%
303303
po("datefeatures",
304304
param_vals = list(
@@ -338,7 +338,7 @@ trafo = po("targetmutate",
338338
)
339339
)
340340

341-
graph = po("fcst.lag", lag = 1:12) %>>%
341+
graph = po("fcst.lag", lags = 1:12) %>>%
342342
ppl("convert_types", "Date", "POSIXct") %>>%
343343
po("datefeatures",
344344
param_vals = list(
@@ -358,7 +358,7 @@ prediction$score(msr("regr.rmse"))
358358
```
359359

360360
``` r
361-
graph = po("fcst.lag", lag = 1:12) %>>%
361+
graph = po("fcst.lag", lags = 1:12) %>>%
362362
ppl("convert_types", "Date", "POSIXct") %>>%
363363
po("datefeatures",
364364
param_vals = list(
@@ -371,7 +371,7 @@ graph = po("fcst.lag", lag = 1:12) %>>%
371371
task = tsk("airpassengers")
372372
flrn = ForecastRecursiveLearner$new(lrn("regr.ranger"))
373373
glrn = as_learner(graph %>>% flrn)
374-
trafo = po("fcst.targetdiff", lag = 12L)
374+
trafo = po("fcst.targetdiff", lags = 12L)
375375
pipeline = ppl("targettrafo", graph = glrn, trafo_pipeop = trafo)
376376
glrn = as_learner(pipeline)$train(task)
377377
prediction = glrn$predict(task, 142:144)

man/ForecastLearner.Rd

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/mlr_pipeops_fcst.lag.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)