-
-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
The solution of question 4 in the mlr3book fails (https://github.com/mlr-org/mlr3book/blob/97ce18dced94d7589e61d2b66d96a4c3c6e3b081/book/chapters/appendices/solutions.qmd#L2156-L2173).
devtools::load_all("../mlr3")
devtools::load_all(".")
set.seed(8)
# load data
tsk_adult_train = tsk("adult_train")
tsk_adult_test = tsk("adult_test")
tsk_adult_train
# train and predict
learner = lrn("classif.rpart")
learner$train(tsk_adult_train)
prediction = learner$predict(tsk_adult_test)
# set protected attribute
tsk_adult_train$set_col_roles("race", add_to = "pta")
tsk_adult_test$set_col_roles("race", add_to = "pta")
# create groupwise metrics
msr_3 = groupwise_metrics(msr("classif.fomr"), tsk_adult_train)
unname(sapply(msr_3, function(x) x$id))
# subset test task
adult_subset = tsk_adult_test$clone()
df = adult_subset$data()
rows = seq_len(nrow(df))[df$race %in% c("Black", "White") & df$sex %in% c("Female")]
adult_subset$filter(rows)
adult_subset$set_col_roles("race", add_to = "pta")
# adult subset contains 4731 rows
# tsk_adult_test and prediction contains 15315 rows
# in get_pta() we call task$data(rows = prediction$row_ids) (helpers.R Line 132)
# which ignores the filter on the task
# in mlr3 1.0.0 this is an error
#
# in MeasureSubgroup we assume that the return of get_pta() is of length 15315
# see groups[, row_ids := prediction$row_ids] in MeasureSubgroup.R Line 75
# do we need both tasks?
prediction$score(msr_3, task = adult_subset)The new mlr3 version checks that task$data(rows) only accesses unfiltered rows. It looks to me that MeasureSubgroup wants the unfiltered task for at least some sub steps. Does the Measure perhaps need tsk_adult_test and adult_subset.
Metadata
Metadata
Assignees
Labels
No labels