Skip to content

Commit 726c563

Browse files
committed
refactor: rename transformers and validators for clarity
1 parent 3d4cf63 commit 726c563

File tree

2 files changed

+27
-27
lines changed

2 files changed

+27
-27
lines changed

moderndid/core/preprocess/transformers.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def transform(self, data: pd.DataFrame, config: BasePreprocessConfig) -> pd.Data
151151
return data
152152

153153

154-
class NeverTreatedHandler(BaseTransformer):
155-
"""Never treated handler."""
154+
class ControlGroupCreator(BaseTransformer):
155+
"""Control group creator."""
156156

157157
def transform(self, data: pd.DataFrame, config: BasePreprocessConfig) -> pd.DataFrame:
158158
"""Transform data."""
@@ -258,8 +258,8 @@ def transform(self, data: pd.DataFrame, config: BasePreprocessConfig) -> pd.Data
258258
return data
259259

260260

261-
class GroupFilter(BaseTransformer):
262-
"""Group filter."""
261+
class EarlyTreatmentGroupFilter(BaseTransformer):
262+
"""Early treatment group filter."""
263263

264264
def transform(self, data: pd.DataFrame, config: BasePreprocessConfig) -> pd.DataFrame:
265265
"""Transform data."""
@@ -370,8 +370,8 @@ def update(data: pd.DataFrame, config: BasePreprocessConfig) -> None:
370370
config.cband = False
371371

372372

373-
class TwoPeriodColumnSelector(BaseTransformer):
374-
"""Two-period column selector."""
373+
class PrePostColumnSelector(BaseTransformer):
374+
"""Pre-post column selector."""
375375

376376
def transform(self, data: pd.DataFrame, config: BasePreprocessConfig | TwoPeriodDIDConfig) -> pd.DataFrame:
377377
"""Transform data."""
@@ -397,8 +397,8 @@ def transform(self, data: pd.DataFrame, config: BasePreprocessConfig | TwoPeriod
397397
return data[cols_to_keep].copy()
398398

399399

400-
class TwoPeriodCovariateProcessor(BaseTransformer):
401-
"""Two-period covariate processor."""
400+
class PrePostCovariateProcessor(BaseTransformer):
401+
"""Pre-post covariate processor."""
402402

403403
def transform(self, data: pd.DataFrame, config: BasePreprocessConfig | TwoPeriodDIDConfig) -> pd.DataFrame:
404404
"""Transform data."""
@@ -432,8 +432,8 @@ def transform(self, data: pd.DataFrame, config: BasePreprocessConfig | TwoPeriod
432432
return data_processed
433433

434434

435-
class TwoPeriodPanelBalancer(BaseTransformer):
436-
"""Two-period panel balancer."""
435+
class PrePostPanelBalancer(BaseTransformer):
436+
"""Pre-post panel balancer."""
437437

438438
def transform(self, data: pd.DataFrame, config: BasePreprocessConfig | TwoPeriodDIDConfig) -> pd.DataFrame:
439439
"""Transform data."""
@@ -450,8 +450,8 @@ def transform(self, data: pd.DataFrame, config: BasePreprocessConfig | TwoPeriod
450450
return data[data[config.idname].isin(ids_to_keep)].copy()
451451

452452

453-
class TwoPeriodTimeInvarianceChecker(BaseTransformer):
454-
"""Two-period time invariance checker."""
453+
class PrePostInvarianceChecker(BaseTransformer):
454+
"""Pre-post invariance checker."""
455455

456456
def transform(self, data: pd.DataFrame, config: BasePreprocessConfig | TwoPeriodDIDConfig) -> pd.DataFrame:
457457
"""Transform data."""
@@ -498,7 +498,7 @@ def get_did_pipeline() -> "DataTransformerPipeline":
498498
WeightNormalizer(),
499499
TreatmentEncoder(),
500500
EarlyTreatmentFilter(),
501-
NeverTreatedHandler(),
501+
ControlGroupCreator(),
502502
PanelBalancer(),
503503
RepeatedCrossSectionHandler(),
504504
DataSorter(),
@@ -514,7 +514,7 @@ def get_cont_did_pipeline() -> "DataTransformerPipeline":
514514
MissingDataHandler(),
515515
WeightNormalizer(),
516516
TimePeriodRecoder(),
517-
GroupFilter(),
517+
EarlyTreatmentGroupFilter(),
518518
DoseValidator(),
519519
PanelBalancer(),
520520
DataSorter(),
@@ -526,12 +526,12 @@ def get_two_period_pipeline() -> "DataTransformerPipeline":
526526
"""Get two-period pipeline."""
527527
return DataTransformerPipeline(
528528
[
529-
TwoPeriodColumnSelector(),
529+
PrePostColumnSelector(),
530530
MissingDataHandler(),
531-
TwoPeriodCovariateProcessor(),
531+
PrePostCovariateProcessor(),
532532
WeightNormalizer(),
533-
TwoPeriodPanelBalancer(),
534-
TwoPeriodTimeInvarianceChecker(),
533+
PrePostPanelBalancer(),
534+
PrePostInvarianceChecker(),
535535
]
536536
)
537537

moderndid/core/preprocess/validators.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,8 @@ def _create_result(errors: list[str] | None = None, warnings: list[str] | None =
264264
return ValidationResult(is_valid=len(errors) == 0, errors=errors, warnings=warnings)
265265

266266

267-
class TwoPeriodColumnValidator(BaseValidator):
268-
"""Two-period column validator."""
267+
class PrePostColumnValidator(BaseValidator):
268+
"""Pre-post column validator."""
269269

270270
def validate(self, data: pd.DataFrame, config: BasePreprocessConfig | TwoPeriodDIDConfig) -> ValidationResult:
271271
"""Validate data."""
@@ -313,8 +313,8 @@ def _create_result(errors: list[str] | None = None, warnings: list[str] | None =
313313
return ValidationResult(is_valid=len(errors) == 0, errors=errors, warnings=warnings)
314314

315315

316-
class TwoPeriodValidator(BaseValidator):
317-
"""Two-period validator."""
316+
class PrePostDataValidator(BaseValidator):
317+
"""Pre-post data validator."""
318318

319319
def validate(self, data: pd.DataFrame, config: BasePreprocessConfig | TwoPeriodDIDConfig) -> ValidationResult:
320320
"""Validate data."""
@@ -342,8 +342,8 @@ def _create_result(errors: list[str] | None = None, warnings: list[str] | None =
342342
return ValidationResult(is_valid=len(errors) == 0, errors=errors, warnings=warnings)
343343

344344

345-
class TwoPeriodPanelValidator(BaseValidator):
346-
"""Two-period panel validator."""
345+
class PrePostPanelValidator(BaseValidator):
346+
"""Pre-post panel validator."""
347347

348348
def validate(self, data: pd.DataFrame, config: BasePreprocessConfig | TwoPeriodDIDConfig) -> ValidationResult:
349349
"""Validate data."""
@@ -386,9 +386,9 @@ def _get_default_validators(config_type: str = "did") -> list[BaseValidator]:
386386
"""Get default validators."""
387387
if config_type == "two_period":
388388
return [
389-
TwoPeriodColumnValidator(),
390-
TwoPeriodValidator(),
391-
TwoPeriodPanelValidator(),
389+
PrePostColumnValidator(),
390+
PrePostDataValidator(),
391+
PrePostPanelValidator(),
392392
]
393393

394394
common_validators = [

0 commit comments

Comments
 (0)