Skip to content

Commit cc9cdfd

Browse files
committed
feat: add properties to forecast learners
1 parent 92789ca commit cc9cdfd

File tree

6 files changed

+35
-25
lines changed

6 files changed

+35
-25
lines changed

R/LearnerRegrArfima.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ LearnerFcstArfima = R6Class("LearnerFcstArfima",
2727
param_set = param_set,
2828
predict_types = c("response", "quantiles"),
2929
feature_types = c("Date", "logical", "integer", "numeric"),
30+
properties = c("univariate", "exogenous", "missings"),
3031
packages = c("mlr3forecast", "forecast"),
3132
label = "ARFIMA",
3233
man = "mlr3forecast::mlr_learners_fcst.arfima"

R/LearnerRegrArima.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ LearnerFcstArima = R6Class("LearnerFcstArima",
3939
param_set = param_set,
4040
predict_types = c("response", "quantiles"),
4141
feature_types = c("Date", "logical", "integer", "numeric"),
42+
properties = c("univariate", "exogenous", "missings"),
4243
packages = c("mlr3forecast", "forecast"),
4344
label = "ARIMA",
4445
man = "mlr3forecast::mlr_learners_fcst.arima"

R/LearnerRegrAutoArima.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#' @title ARIMA
1+
#' @title Auto ARIMA
22
#'
33
#' @name mlr_learners_fcst.auto_arima
44
#'
@@ -44,6 +44,7 @@ LearnerFcstAutoArima = R6Class("LearnerFcstAutoArima",
4444
param_set = param_set,
4545
predict_types = c("response", "quantiles"),
4646
feature_types = c("Date", "logical", "integer", "numeric"),
47+
properties = c("univariate", "exogenous", "missings"),
4748
packages = c("mlr3forecast", "forecast"),
4849
label = "Auto ARIMA",
4950
man = "mlr3forecast::mlr_learners_fcst.arima"

R/LearnerRegrEts.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ LearnerFcstEts = R6Class("LearnerFcstEts",
4646
param_set = param_set,
4747
predict_types = c("response", "quantiles"),
4848
feature_types = c("Date", "logical", "integer", "numeric"),
49+
properties = c("univariate", "missings"),
4950
packages = c("mlr3forecast", "forecast"),
5051
label = "ETS",
5152
man = "mlr3forecast::mlr_learners_fcst.ets"

R/zzz.R

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mlr3forecast_learners = new.env()
1212
mlr3forecast_measures = new.env()
1313
mlr3forecast_feature_types = c(dte = "Date")
1414
mlr3forecast_col_roles = "key"
15+
mlr3forecast_learner_properties = c("univariate", "multivariate", "exogenous", "missings")
1516

1617
named_union = function(x, y) set_names(union(x, y), union(names(x), names(y)))
1718

@@ -36,7 +37,12 @@ register_mlr3 = function() {
3637
"fcst", "mlr3forecast", "TaskRegr", "LearnerRegr", "PredictionFcst", "PredictionDataFcst", "MeasureFcst" # nolint
3738
), fill = TRUE), "type")
3839
mlr_reflections$learner_predict_types$fcst = mlr_reflections$learner_predict_types$regr
39-
mlr_reflections$learner_properties$fcst = mlr_reflections$learner_properties$regr
40+
mlr_reflections$learner_properties$fcst = union(
41+
mlr_reflections$learner_properties$regr, mlr3forecast_learner_properties
42+
)
43+
mlr_reflections$learner_properties$regr = union(
44+
mlr_reflections$learner_properties$regr, mlr3forecast_learner_properties
45+
)
4046
mlr_reflections$task_col_roles$fcst = mlr_reflections$task_col_roles$regr
4147
mlr_reflections$task_col_roles$regr = union(
4248
mlr_reflections$task_col_roles$regr, mlr3forecast_col_roles

README.md

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -44,32 +44,32 @@ prediction = flrn$predict_newdata(newdata, task)
4444
prediction
4545
#> <PredictionRegr> for 3 observations:
4646
#> row_ids truth response
47-
#> 1 NA 434.9580
48-
#> 2 NA 438.3391
49-
#> 3 NA 456.1386
47+
#> 1 NA 438.0155
48+
#> 2 NA 439.6429
49+
#> 3 NA 457.8119
5050
prediction = flrn$predict(task, 142:144)
5151
prediction
5252
#> <PredictionRegr> for 3 observations:
5353
#> row_ids truth response
54-
#> 1 461 456.7837
55-
#> 2 390 412.4510
56-
#> 3 432 434.0057
54+
#> 1 461 459.7825
55+
#> 2 390 417.8945
56+
#> 3 432 435.8002
5757
prediction$score(msr("regr.rmse"))
5858
#> regr.rmse
59-
#> 13.23942
59+
#> 16.26886
6060

6161
flrn = ForecastLearner$new(lrn("regr.ranger"), 1:12)
6262
resampling = rsmp("forecast_holdout", ratio = 0.9)
6363
rr = resample(task, flrn, resampling)
6464
rr$aggregate(msr("regr.rmse"))
6565
#> regr.rmse
66-
#> 47.49655
66+
#> 49.22924
6767

6868
resampling = rsmp("forecast_cv")
6969
rr = resample(task, flrn, resampling)
7070
rr$aggregate(msr("regr.rmse"))
7171
#> regr.rmse
72-
#> 25.05562
72+
#> 25.57887
7373
```
7474

7575
### Multivariate
@@ -89,34 +89,34 @@ flrn = ForecastLearner$new(lrn("regr.ranger"), 1:12)$train(new_task)
8989
prediction = flrn$predict(new_task, 142:144)
9090
prediction$score(msr("regr.rmse"))
9191
#> regr.rmse
92-
#> 13.13675
92+
#> 12.35229
9393

9494
row_ids = new_task$nrow - 0:2
9595
flrn$predict_newdata(new_task$data(rows = row_ids), new_task)
9696
#> <PredictionRegr> for 3 observations:
9797
#> row_ids truth response
98-
#> 1 432 435.4188
99-
#> 2 390 434.2164
100-
#> 3 461 457.9859
98+
#> 1 432 431.1496
99+
#> 2 390 429.3616
100+
#> 3 461 455.1301
101101
newdata = new_task$data(rows = row_ids, cols = new_task$feature_names)
102102
flrn$predict_newdata(newdata, new_task)
103103
#> <PredictionRegr> for 3 observations:
104104
#> row_ids truth response
105-
#> 1 NA 435.4188
106-
#> 2 NA 434.2164
107-
#> 3 NA 457.9859
105+
#> 1 NA 431.1496
106+
#> 2 NA 429.3616
107+
#> 3 NA 455.1301
108108

109109
resampling = rsmp("forecast_holdout", ratio = 0.9)
110110
rr = resample(new_task, flrn, resampling)
111111
rr$aggregate(msr("regr.rmse"))
112112
#> regr.rmse
113-
#> 47.83001
113+
#> 48.2451
114114

115115
resampling = rsmp("forecast_cv")
116116
rr = resample(new_task, flrn, resampling)
117117
rr$aggregate(msr("regr.rmse"))
118118
#> regr.rmse
119-
#> 27.82117
119+
#> 26.80115
120120
```
121121

122122
### mlr3pipelines integration
@@ -131,7 +131,7 @@ glrn = as_learner(graph %>>% flrn)$train(task)
131131
prediction = glrn$predict(task, 142:144)
132132
prediction$score(msr("regr.rmse"))
133133
#> regr.rmse
134-
#> 11.0778
134+
#> 12.39349
135135
```
136136

137137
### Example: Forecasting electricity demand
@@ -205,14 +205,14 @@ flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)$train(task)
205205
prediction = flrn$predict(task, 4460:4464)
206206
prediction$score(msr("regr.rmse"))
207207
#> regr.rmse
208-
#> 25052.84
208+
#> 23543.57
209209

210210
flrn = ForecastLearner$new(lrn("regr.ranger"), 1:3)
211211
resampling = rsmp("forecast_holdout", ratio = 0.9)
212212
rr = resample(task, flrn, resampling)
213213
rr$aggregate(msr("regr.rmse"))
214214
#> regr.rmse
215-
#> 93423.87
215+
#> 92241
216216
```
217217

218218
### Example: Global vs Local Forecasting
@@ -247,7 +247,7 @@ row_ids = tab[year >= 2015, row_id]
247247
prediction = flrn$predict(task, row_ids)
248248
prediction$score(msr("regr.rmse"))
249249
#> regr.rmse
250-
#> 29931.59
250+
#> 32967.19
251251

252252
# global forecasting
253253
task = tsibbledata::aus_livestock |>
@@ -268,7 +268,7 @@ row_ids = tab[year >= 2015 & state == "Western Australia", row_id]
268268
prediction = flrn$predict(task, row_ids)
269269
prediction$score(msr("regr.rmse"))
270270
#> regr.rmse
271-
#> 31607.32
271+
#> 31955.28
272272
```
273273

274274
### Example: generate new data

0 commit comments

Comments
 (0)