# Copyright Iris contributors
#
# This file is part of Iris and is released under the BSD license.
# See LICENSE in the root of the repository for full licensing details.
"""Module to ensure all calls to the netCDF4 library are thread-safe.
Intention is that no other Iris module should import the netCDF4 module.
"""
from abc import ABC
from threading import Lock
import typing
import netCDF4
import numpy as np
_GLOBAL_NETCDF4_LOCK = Lock()
# Doesn't need thread protection, but this allows all netCDF4 refs to be
# replaced with thread_safe refs.
default_fillvals = netCDF4.default_fillvals
VLType = netCDF4.VLType
class _ThreadSafeWrapper(ABC):
"""Contains a netCDF4 class instance, ensuring wrapping all API calls.
Contains a netCDF4 class instance, ensuring wrapping all API calls within
_GLOBAL_NETCDF4_LOCK.
Designed to 'gate keep' all the instance's API calls, but allowing the
same API as if working directly with the instance itself.
Using a contained object instead of inheritance, as we cannot successfully
subclass or monkeypatch netCDF4 classes, because they are only wrappers for
the C-layer.
"""
# Note: this is only used to create a "contained" from passed args.
CONTAINED_CLASS = NotImplemented
# Note: this defines how we identify/check that a contained is of the expected type
# (in a duck-type way).
_DUCKTYPE_CHECK_PROPERTIES: typing.List[str] = [NotImplemented]
# Allows easy type checking, avoiding difficulties with isinstance and mocking.
THREAD_SAFE_FLAG = True
@classmethod
def is_contained_type(cls, instance):
return all(hasattr(instance, attr) for attr in cls._DUCKTYPE_CHECK_PROPERTIES)
@classmethod
def from_existing(cls, instance):
"""Routine to pass an existing instance to __init__, where it is contained."""
assert cls.is_contained_type(instance)
return cls(instance)
def __init__(self, *args, **kwargs):
"""Contain an existing instance, or generate a new one from arguments."""
if len(args) == 1 and self.is_contained_type(args[0]):
# Passed a contained-type object : Wrap ourself around that.
instance = args[0]
# We should never find ourselves "wrapping a wrapper".
assert not hasattr(instance, "THREAD_SAFE_FLAG")
else:
# Create a contained object of the intended type from passed args.
with _GLOBAL_NETCDF4_LOCK:
instance = self.CONTAINED_CLASS(*args, **kwargs)
self._contained_instance = instance
def __getattr__(self, item):
if item == "_contained_instance":
# Special behaviour when accessing the _contained_instance itself.
return object.__getattribute__(self, item)
else:
with _GLOBAL_NETCDF4_LOCK:
return getattr(self._contained_instance, item)
def __setattr__(self, key, value):
if key == "_contained_instance":
# Special behaviour when accessing the _contained_instance itself.
object.__setattr__(self, key, value)
else:
with _GLOBAL_NETCDF4_LOCK:
return setattr(self._contained_instance, key, value)
def __getitem__(self, item):
with _GLOBAL_NETCDF4_LOCK:
return self._contained_instance.__getitem__(item)
def __setitem__(self, key, value):
with _GLOBAL_NETCDF4_LOCK:
return self._contained_instance.__setitem__(key, value)
class DimensionWrapper(_ThreadSafeWrapper):
"""Accessor for a netCDF4.Dimension, always acquiring _GLOBAL_NETCDF4_LOCK.
All API calls should be identical to those for netCDF4.Dimension.
"""
CONTAINED_CLASS = netCDF4.Dimension
_DUCKTYPE_CHECK_PROPERTIES = ["isunlimited"]
class VariableWrapper(_ThreadSafeWrapper):
"""Accessor for a netCDF4.Variable, always acquiring _GLOBAL_NETCDF4_LOCK.
All API calls should be identical to those for netCDF4.Variable.
"""
CONTAINED_CLASS = netCDF4.Variable
_DUCKTYPE_CHECK_PROPERTIES = ["dimensions", "dtype"]
def setncattr(self, *args, **kwargs) -> None:
"""Call netCDF4.Variable.setncattr within _GLOBAL_NETCDF4_LOCK.
Only defined explicitly in order to get some mocks to work.
"""
with _GLOBAL_NETCDF4_LOCK:
return self._contained_instance.setncattr(*args, **kwargs)
@property
def dimensions(self) -> typing.List[str]:
"""Call netCDF4.Variable.dimensions within _GLOBAL_NETCDF4_LOCK.
Only defined explicitly in order to get some mocks to work.
"""
with _GLOBAL_NETCDF4_LOCK:
# Return value is a list of strings so no need for
# DimensionWrapper, unlike self.get_dims().
return self._contained_instance.dimensions
# All Variable API that returns Dimension(s) is wrapped to instead return
# DimensionWrapper(s).
def get_dims(self, *args, **kwargs) -> typing.Tuple[DimensionWrapper]:
"""Call netCDF4.Variable.get_dims() within _GLOBAL_NETCDF4_LOCK.
Call netCDF4.Variable.get_dims() within _GLOBAL_NETCDF4_LOCK,
returning DimensionWrappers. The original returned netCDF4.Dimensions
are simply replaced with their respective DimensionWrappers, ensuring
that downstream calls are also performed within _GLOBAL_NETCDF4_LOCK.
"""
with _GLOBAL_NETCDF4_LOCK:
dimensions_ = list(self._contained_instance.get_dims(*args, **kwargs))
return tuple([DimensionWrapper.from_existing(d) for d in dimensions_])
class GroupWrapper(_ThreadSafeWrapper):
"""Accessor for a netCDF4.Group, always acquiring _GLOBAL_NETCDF4_LOCK.
All API calls should be identical to those for netCDF4.Group.
"""
CONTAINED_CLASS = netCDF4.Group
# Note: will also accept a whole Dataset object, but that is OK.
_DUCKTYPE_CHECK_PROPERTIES = ["createVariable"]
# All Group API that returns Dimension(s) is wrapped to instead return
# DimensionWrapper(s).
@property
def dimensions(self) -> typing.Dict[str, DimensionWrapper]:
"""Call dimensions of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.
Calls dimensions of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK,
returning DimensionWrappers. The original returned netCDF4.Dimensions
are simply replaced with their respective DimensionWrappers, ensuring
that downstream calls are also performed within _GLOBAL_NETCDF4_LOCK.
"""
with _GLOBAL_NETCDF4_LOCK:
dimensions_ = self._contained_instance.dimensions
return {k: DimensionWrapper.from_existing(v) for k, v in dimensions_.items()}
def createDimension(self, *args, **kwargs) -> DimensionWrapper:
"""Call createDimension() from netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.
Call createDimension() from netCDF4.Group/Dataset within
_GLOBAL_NETCDF4_LOCK, returning DimensionWrapper. The original returned
netCDF4.Dimension is simply replaced with its respective
DimensionWrapper, ensuring that downstream calls are also performed
within _GLOBAL_NETCDF4_LOCK.
"""
with _GLOBAL_NETCDF4_LOCK:
new_dimension = self._contained_instance.createDimension(*args, **kwargs)
return DimensionWrapper.from_existing(new_dimension)
# All Group API that returns Variable(s) is wrapped to instead return
# VariableWrapper(s).
@property
def variables(self) -> typing.Dict[str, VariableWrapper]:
"""Call variables of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.
Calls variables of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK,
returning VariableWrappers. The original returned netCDF4.Variables
are simply replaced with their respective VariableWrappers, ensuring
that downstream calls are also performed within _GLOBAL_NETCDF4_LOCK.
"""
with _GLOBAL_NETCDF4_LOCK:
variables_ = self._contained_instance.variables
return {k: VariableWrapper.from_existing(v) for k, v in variables_.items()}
def createVariable(self, *args, **kwargs) -> VariableWrapper:
"""Call createVariable() from netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.
Call createVariable() from netCDF4.Group/Dataset within
_GLOBAL_NETCDF4_LOCK, returning VariableWrapper. The original
returned netCDF4.Variable is simply replaced with its respective
VariableWrapper, ensuring that downstream calls are also performed
within _GLOBAL_NETCDF4_LOCK.
"""
with _GLOBAL_NETCDF4_LOCK:
new_variable = self._contained_instance.createVariable(*args, **kwargs)
return VariableWrapper.from_existing(new_variable)
def get_variables_by_attributes(
self, *args, **kwargs
) -> typing.List[VariableWrapper]:
"""Call get_variables_by_attributes() from netCDF4.Group/Dataset.
Call get_variables_by_attributes() from netCDF4.Group/Dataset
within_GLOBAL_NETCDF4_LOCK, returning VariableWrappers.
The original returned netCDF4.Variables are simply replaced with their
respective VariableWrappers, ensuring that downstream calls are
also performed within _GLOBAL_NETCDF4_LOCK.
"""
with _GLOBAL_NETCDF4_LOCK:
variables_ = list(
self._contained_instance.get_variables_by_attributes(*args, **kwargs)
)
return [VariableWrapper.from_existing(v) for v in variables_]
# All Group API that returns Group(s) is wrapped to instead return
# GroupWrapper(s).
@property
def groups(self):
"""Call groups of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.
Calls groups of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK,
returning GroupWrappers.
The original returned netCDF4.Groups are simply replaced with their
respective GroupWrappers, ensuring that downstream calls are
also performed within _GLOBAL_NETCDF4_LOCK.
"""
with _GLOBAL_NETCDF4_LOCK:
groups_ = self._contained_instance.groups
return {k: GroupWrapper.from_existing(v) for k, v in groups_.items()}
@property
def parent(self):
"""Call parent of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK.
Calls parent of netCDF4.Group/Dataset within _GLOBAL_NETCDF4_LOCK,
returning a GroupWrapper.
The original returned netCDF4.Group is simply replaced with its
respective GroupWrapper, ensuring that downstream calls are
also performed within _GLOBAL_NETCDF4_LOCK.
"""
with _GLOBAL_NETCDF4_LOCK:
parent_ = self._contained_instance.parent
return GroupWrapper.from_existing(parent_)
def createGroup(self, *args, **kwargs):
"""Call createGroup() from netCDF4.Group/Dataset.
Call createGroup() from netCDF4.Group/Dataset within
_GLOBAL_NETCDF4_LOCK, returning GroupWrapper. The original returned
netCDF4.Group is simply replaced with its respective GroupWrapper,
ensuring that downstream calls are also performed within
_GLOBAL_NETCDF4_LOCK.
"""
with _GLOBAL_NETCDF4_LOCK:
new_group = self._contained_instance.createGroup(*args, **kwargs)
return GroupWrapper.from_existing(new_group)
class DatasetWrapper(GroupWrapper):
"""Accessor for a netCDF4.Dataset, always acquiring _GLOBAL_NETCDF4_LOCK.
All API calls should be identical to those for netCDF4.Dataset.
"""
CONTAINED_CLASS = netCDF4.Dataset
# Note: 'close' exists on Dataset but not Group (though a rather weak distinction).
_DUCKTYPE_CHECK_PROPERTIES = ["createVariable", "close"]
@classmethod
def fromcdl(cls, *args, **kwargs):
"""Call netCDF4.Dataset.fromcdl() within _GLOBAL_NETCDF4_LOCK.
Call netCDF4.Dataset.fromcdl() within _GLOBAL_NETCDF4_LOCK,
returning a DatasetWrapper. The original returned netCDF4.Dataset is
simply replaced with its respective DatasetWrapper, ensuring that
downstream calls are also performed within _GLOBAL_NETCDF4_LOCK.
"""
with _GLOBAL_NETCDF4_LOCK:
instance = cls.CONTAINED_CLASS.fromcdl(*args, **kwargs)
return cls.from_existing(instance)
[docs]
class NetCDFDataProxy:
"""A reference to the data payload of a single NetCDF file variable."""
__slots__ = ("shape", "dtype", "path", "variable_name", "fill_value")
def __init__(self, shape, dtype, path, variable_name, fill_value):
self.shape = shape
self.dtype = dtype
self.path = path
self.variable_name = variable_name
self.fill_value = fill_value
@property
def ndim(self):
# noqa: D102
return len(self.shape)
@property
def dask_meta(self):
return np.ma.array(np.empty((0,) * self.ndim, dtype=self.dtype), mask=True)
def __getitem__(self, keys):
# Using a DatasetWrapper causes problems with invalid ID's and the
# netCDF4 library, presumably because __getitem__ gets called so many
# times by Dask. Use _GLOBAL_NETCDF4_LOCK directly instead.
with _GLOBAL_NETCDF4_LOCK:
dataset = netCDF4.Dataset(self.path)
try:
variable = dataset.variables[self.variable_name]
# Get the NetCDF variable data and slice.
var = variable[keys]
finally:
dataset.close()
return np.asanyarray(var)
def __repr__(self):
fmt = (
"<{self.__class__.__name__} shape={self.shape}"
" dtype={self.dtype!r} path={self.path!r}"
" variable_name={self.variable_name!r}>"
)
return fmt.format(self=self)
def __getstate__(self):
return {attr: getattr(self, attr) for attr in self.__slots__}
def __setstate__(self, state):
for key, value in state.items():
setattr(self, key, value)
class NetCDFWriteProxy:
"""An object mimicking the data access of a netCDF4.Variable.
The "opposite" of a NetCDFDataProxy : An object mimicking the data access
of a netCDF4.Variable, but where the data is to be ***written to***.
It encapsulates the netcdf file and variable which are actually to be
written to. This opens the file each time, to enable writing the data
chunk, then closes it.
TODO: could be improved with a caching scheme, but this just about works.
"""
def __init__(self, filepath, cf_var, file_write_lock):
self.path = filepath
self.varname = cf_var.name
self.lock = file_write_lock
def __setitem__(self, keys, array_data):
# Write to the variable.
# First acquire a file-specific lock for all workers writing to this file.
self.lock.acquire()
# Open the file for writing + write to the specific file variable.
# Exactly as above, in NetCDFDataProxy : a DatasetWrapper causes problems with
# invalid ID's and the netCDF4 library, for so-far unknown reasons.
# Instead, use _GLOBAL_NETCDF4_LOCK, and netCDF4 _directly_.
with _GLOBAL_NETCDF4_LOCK:
dataset = None
try:
dataset = netCDF4.Dataset(self.path, "r+")
var = dataset.variables[self.varname]
var[keys] = array_data
finally:
try:
if dataset:
dataset.close()
finally:
# *ALWAYS* let go !
self.lock.release()
def __repr__(self):
return f"<{self.__class__.__name__} path={self.path!r} var={self.varname!r}>"