Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 99 additions & 44 deletions clisops/core/regrid.py

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion clisops/ops/base_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from clisops.utils.common import expand_wildcards
from clisops.utils.dataset_utils import open_xr_dataset
from clisops.utils.file_namers import get_file_namer
from clisops.utils.output_utils import get_output, get_time_slices
from clisops.utils.output_utils import fix_netcdf_attrs_encoding, get_output, get_time_slices


class Operation:
Expand Down Expand Up @@ -110,6 +110,12 @@ def _remove_str_compression(self, ds):
del ds[var].encoding[en]
return ds

def _fix_netcdf_attrs_encoding(self, ds):
"""Executes output_utils.fix_netcdf_attrs_encoding for xarray.Datasets"""
if isinstance(ds, xr.Dataset):
ds = fix_netcdf_attrs_encoding(ds)
return ds

def _cap_deflate_level(self, ds):
"""
For CMOR3 / CMIP6 it was investigated which netCDF4 deflate_level should be set to optimize
Expand Down Expand Up @@ -246,6 +252,8 @@ def process(self) -> list[xr.Dataset | Path]:
processed_ds = self._remove_str_compression(processed_ds)
# cap deflate level at 1
processed_ds = self._cap_deflate_level(processed_ds)
# fix string encoding of xarray.Dataset.attrs (incl. variable attrs)
processed_ds = self._fix_netcdf_attrs_encoding(processed_ds)

# Work out how many outputs should be created based on the size
# of the array. Manage this as a list of time slices.
Expand Down
20 changes: 20 additions & 0 deletions clisops/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def require_module(
module: ModuleType,
module_name: str,
min_version: str | None = "0.0.0",
unsupported_version_range: list | None = None,
max_supported_version: str | None = None,
max_supported_warning: str | None = None,
) -> Callable:
Expand All @@ -101,6 +102,13 @@ def require_module(
The name of the module to check.
min_version : str, optional
The minimum version of the module required. Defaults to "0.0.0".
unsupported_version_range : list of str, optional
A list with two elements, with the elements marking a range of unsupported versions,
with the first element being the first unsupported and the second element being
the first supported version.
If provided, a warning will be issued if the module version falls within this range:
version_0 <= module_version < version_1
Defaults to None, meaning no unsupported version range check is performed.
max_supported_version : str, optional
The maximum supported version of the module.
If provided, a warning will be issued if the module version exceeds this.
Expand Down Expand Up @@ -131,6 +139,18 @@ def wrapper_func(*args, **kwargs): # numpydoc ignore=GL08
f"Package {module_name} version {module.__version__} "
f"is greater than the suggested version {max_supported_version}."
)

if unsupported_version_range is not None:
if not isinstance(unsupported_version_range, list) or not len(unsupported_version_range) == 2:
raise ValueError(
"The unsupported_version_range argument must be a list with two elements of type str, "
"with the elements being the minimum and maximum versions of an unsupported version range."
)
if Version(module.__version__) >= Version(unsupported_version_range[0]) and Version(
module.__version__
) < Version(unsupported_version_range[1]):
warnings.warn(max_supported_warning)

return func(*args, **kwargs)

return wrapper_func
Expand Down
15 changes: 11 additions & 4 deletions clisops/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@


def get_coord_by_type(
ds: xr.DataArray | xr.Dataset, coord_type: str, ignore_aux_coords: bool = True, return_further_matches: bool = False
ds: xr.DataArray | xr.Dataset,
coord_type: str,
ignore_aux_coords: bool = True,
return_further_matches: bool = False,
warn_if_no_main_variable: bool = True,
):
"""
Return the name of the coordinate that matches the given type.
Expand All @@ -34,9 +38,11 @@ def get_coord_by_type(
coord_type : str
Type of coordinate, e.g. 'time', 'level', 'latitude', 'longitude', 'realization'.
ignore_aux_coords : bool
Whether to ignore auxiliary coordinates.
Whether to ignore auxiliary coordinates. Default is True.
return_further_matches : bool
Whether to return further matches.
Whether to return further matches. Default is False.
warn_if_no_main_variable : bool
Whether to warn if no main variable can be identified. Default is True.

Returns
-------
Expand All @@ -62,7 +68,8 @@ def get_coord_by_type(
try:
main_var = get_main_variable(ds)
except ValueError:
warnings.warn(f"No main variable found for dataset '{ds}'.")
if warn_if_no_main_variable:
warnings.warn(f"No main variable found for dataset '{ds}'.")
main_var = None

# Loop through all (potential) coordinates to find all possible matches
Expand Down
66 changes: 66 additions & 0 deletions clisops/utils/output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,72 @@ def get_output(ds: xr.Dataset, output_type: str, output_dir: str | Path, namer:
return output_path


def _fix_str_encoding(s, encoding="utf-8"):
"""
Helper function to fix string encoding of surrogates.

Parameters
----------
s : str, byte
The string to be fixed. If the input is not of type str or bytes,
it is returned as is.
encoding : str, optional
The encoding to be used. Default is "utf-8".

Returns
-------
str
The fixed string.
"""
if isinstance(s, bytes):
# Decode directly from bytes, replacing undecodable sequences
return s.decode("utf-8", errors="replace")
elif isinstance(s, str):
try:
s.encode("utf-8") # If this works, no surrogates present
return s
except UnicodeEncodeError:
# Handle surrogate escapes
b = s.encode("utf-8", "surrogateescape")
return b.decode("utf-8", errors="replace")
return s


def fix_netcdf_attrs_encoding(ds, encoding="utf-8"):
"""
Fix strings that contain invalid chars in Dataset attrs to be safe for NetCDF writing.

Parameters
----------
ds : xarray.Dataset
The dataset with attrs to be fixed.
encoding : str, optional
The encoding to be used. Default is "utf-8".

Returns
-------
xarray.Dataset
The dataset with fixed attrs.
"""
# Work on a shallow copy so original ds is untouched
ds = ds.copy()

# Fix global attributes
for k, v in list(ds.attrs.items()):
fixed_v = _fix_str_encoding(v, encoding)
if fixed_v is not v:
ds.attrs[k] = fixed_v

# Fix variable attributes
for var in ds.variables:
for k, v in list(ds[var].attrs.items()):
fixed_v = _fix_str_encoding(v, encoding)
if fixed_v is not v:
ds[var].attrs[k] = fixed_v

return ds


class FileLock:
"""
Create and release a lockfile.
Expand Down
Loading