Skip to content

Commit df64840

Browse files
KJhellicoarkid15r
andauthored
Introduce HolidayBase methods for proper serialization (#2333)
Signed-off-by: ~Jhellico <[email protected]> Co-authored-by: Arkadii Yakovets <[email protected]>
1 parent 9abd47f commit df64840

File tree

2 files changed

+53
-28
lines changed

2 files changed

+53
-28
lines changed

holidays/holiday_base.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -365,36 +365,11 @@ def __init__(
365365
self.observed = observed
366366
self.subdiv = subdiv
367367
self.weekend_workdays = getattr(self, "weekend_workdays", set())
368-
369-
supported_languages = set(self.supported_languages)
370-
if self._entity_code is not None:
371-
fallback = language not in supported_languages
372-
languages = [language] if language in supported_languages else None
373-
locale_directory = str(Path(__file__).with_name("locale"))
374-
375-
# Add entity native content translations.
376-
entity_translation = translation(
377-
self._entity_code,
378-
fallback=fallback,
379-
languages=languages,
380-
localedir=locale_directory,
381-
)
382-
# Add a fallback if entity has parent translations.
383-
if parent_entity := self.parent_entity:
384-
entity_translation.add_fallback(
385-
translation(
386-
parent_entity.country or parent_entity.market,
387-
fallback=fallback,
388-
languages=languages,
389-
localedir=locale_directory,
390-
)
391-
)
392-
self.tr = entity_translation.gettext
393-
else:
394-
self.tr = gettext
395-
396368
self.years = _normalize_arguments(int, years)
397369

370+
# Configure l10n related attributes.
371+
self._init_translation()
372+
398373
# Populate holidays.
399374
for year in self.years:
400375
self._populate(year)
@@ -592,6 +567,12 @@ def __getitem__(self, key: DateLike) -> Any:
592567

593568
return dict.__getitem__(self, self.__keytransform__(key))
594569

570+
def __getstate__(self) -> dict[str, Any]:
571+
"""Return the object's state for serialization."""
572+
state = self.__dict__.copy()
573+
state.pop("tr", None)
574+
return state
575+
595576
def __keytransform__(self, key: DateLike) -> date:
596577
"""Transforms the date from one of the following types:
597578
@@ -699,6 +680,11 @@ def __setitem__(self, key: DateLike, value: str) -> None:
699680

700681
dict.__setitem__(self, self.__keytransform__(key), value)
701682

683+
def __setstate__(self, state: dict[str, Any]) -> None:
684+
"""Restore the object's state after deserialization."""
685+
self.__dict__.update(state)
686+
self._init_translation()
687+
702688
def __str__(self) -> str:
703689
if self:
704690
return super().__str__()
@@ -750,6 +736,35 @@ def get_subdivision_aliases(cls) -> dict[str, list]:
750736

751737
return subdivision_aliases
752738

739+
def _init_translation(self) -> None:
740+
"""Initialize translation function based on language settings."""
741+
supported_languages = set(self.supported_languages)
742+
if self._entity_code is not None:
743+
fallback = self.language not in supported_languages
744+
languages = [self.language] if self.language in supported_languages else None
745+
locale_directory = str(Path(__file__).with_name("locale"))
746+
747+
# Add entity native content translations.
748+
entity_translation = translation(
749+
self._entity_code,
750+
fallback=fallback,
751+
languages=languages,
752+
localedir=locale_directory,
753+
)
754+
# Add a fallback if entity has parent translations.
755+
if parent_entity := self.parent_entity:
756+
entity_translation.add_fallback(
757+
translation(
758+
parent_entity.country or parent_entity.market,
759+
fallback=fallback,
760+
languages=languages,
761+
localedir=locale_directory,
762+
)
763+
)
764+
self.tr = entity_translation.gettext
765+
else:
766+
self.tr = gettext
767+
753768
def _is_leap_year(self) -> bool:
754769
"""
755770
Returns True if the year is leap. Returns False otherwise.

tests/test_holiday_base.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,16 @@ def test_pickle(self):
867867
self.assertEqual(loaded_holidays, self.hb)
868868
self.assertIn(dt, self.hb)
869869

870+
def test_pickle_localized_entity(self):
871+
for lang in ("uk", "en_US", None):
872+
ua = UA(language=lang)
873+
dt = "2021-01-01"
874+
self.assertIn(dt, self.hb)
875+
876+
loaded_ua = pickle.loads(pickle.dumps(ua))
877+
self.assertEqual(loaded_ua, ua)
878+
self.assertIn(dt, loaded_ua)
879+
870880

871881
class TestSpecialHolidays(unittest.TestCase):
872882
def setUp(self):

0 commit comments

Comments
 (0)