Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Introduce HolidayBase methods for proper serialization
  • Loading branch information
KJhellico committed Mar 6, 2025
commit 9b3eb72385529e32692d2e66b7c6f7b49b0505f2
69 changes: 41 additions & 28 deletions holidays/holiday_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,35 +365,8 @@ def __init__(
self.observed = observed
self.subdiv = subdiv
self.weekend_workdays = getattr(self, "weekend_workdays", set())

supported_languages = set(self.supported_languages)
if self._entity_code is not None:
fallback = language not in supported_languages
languages = [language] if language in supported_languages else None
locale_directory = str(Path(__file__).with_name("locale"))

# Add entity native content translations.
entity_translation = translation(
self._entity_code,
fallback=fallback,
languages=languages,
localedir=locale_directory,
)
# Add a fallback if entity has parent translations.
if parent_entity := self.parent_entity:
entity_translation.add_fallback(
translation(
parent_entity.country or parent_entity.market,
fallback=fallback,
languages=languages,
localedir=locale_directory,
)
)
self.tr = entity_translation.gettext
else:
self.tr = gettext

self.years = _normalize_arguments(int, years)
self._init_translation()

# Populate holidays.
for year in self.years:
Expand Down Expand Up @@ -592,6 +565,12 @@ def __getitem__(self, key: DateLike) -> Any:

return dict.__getitem__(self, self.__keytransform__(key))

def __getstate__(self) -> dict[str, Any]:
"""Return the object's state for serialization."""
state = self.__dict__.copy()
del state["tr"]
return state

def __keytransform__(self, key: DateLike) -> date:
"""Transforms the date from one of the following types:

Expand Down Expand Up @@ -699,6 +678,11 @@ def __setitem__(self, key: DateLike, value: str) -> None:

dict.__setitem__(self, self.__keytransform__(key), value)

def __setstate__(self, state: dict[str, Any]) -> None:
"""Restores the object's state after deserialization."""
self.__dict__.update(state)
self._init_translation()

def __str__(self) -> str:
if self:
return super().__str__()
Expand Down Expand Up @@ -750,6 +734,35 @@ def get_subdivision_aliases(cls) -> dict[str, list]:

return subdivision_aliases

def _init_translation(self) -> None:
"""Initialize translation function based on language settings."""
supported_languages = set(self.supported_languages)
if self._entity_code is not None:
fallback = self.language not in supported_languages
languages = [self.language] if self.language in supported_languages else None
locale_directory = str(Path(__file__).with_name("locale"))

# Add entity native content translations.
entity_translation = translation(
self._entity_code,
fallback=fallback,
languages=languages,
localedir=locale_directory,
)
# Add a fallback if entity has parent translations.
if parent_entity := self.parent_entity:
entity_translation.add_fallback(
translation(
parent_entity.country or parent_entity.market,
fallback=fallback,
languages=languages,
localedir=locale_directory,
)
)
self.tr = entity_translation.gettext
else:
self.tr = gettext

def _is_leap_year(self) -> bool:
"""
Returns True if the year is leap. Returns False otherwise.
Expand Down
9 changes: 9 additions & 0 deletions tests/test_holiday_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,15 @@ def test_pickle(self):
self.assertEqual(loaded_holidays, self.hb)
self.assertIn(dt, self.hb)

def test_pickle_localized_entity(self):
ua = UA(language="uk")
dt = "2021-01-01"
self.assertIn(dt, self.hb)

loaded_ua = pickle.loads(pickle.dumps(ua))
self.assertEqual(loaded_ua, ua)
self.assertIn(dt, loaded_ua)


class TestSpecialHolidays(unittest.TestCase):
def setUp(self):
Expand Down
Loading