Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy rectilinear interpolator #6084

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b77af1d
lazy interpolation using map_complete_blocks
fnattino Jul 25, 2024
0021cb0
pre-commit fixes
fnattino Jul 25, 2024
18908ae
replace test on interpolation with lazy data
fnattino Aug 27, 2024
47a8599
Merge branch 'main' into lazy-rectilinearinterpolator-2
fnattino Sep 11, 2024
555f3c7
Update lib/iris/analysis/_interpolation.py
fnattino Sep 13, 2024
7a08108
Update lib/iris/analysis/_interpolation.py
fnattino Sep 13, 2024
c453e01
Merge branch 'lazy-rectilinearinterpolator-2' of github.com:fnattino/…
fnattino Sep 13, 2024
3814383
resume local import
fnattino Sep 17, 2024
0c5dc9a
add entry to latest.rst
fnattino Sep 20, 2024
3481e46
add author name to list
fnattino Sep 20, 2024
63714f0
drop duplicated method
fnattino Nov 28, 2024
7aaefd8
Merge branch 'main' into lazy-rectilinearinterpolator-2
fnattino Nov 28, 2024
98143f2
new signature of map_complete_blocks
fnattino Nov 28, 2024
948c75e
update docstrings on lazy data
fnattino Nov 28, 2024
09974f3
Merge branch 'main' into lazy-rectilinearinterpolator-2
fnattino Nov 28, 2024
18c6e7a
Merge branch 'main' into lazy-rectilinearinterpolator-2
fnattino Dec 6, 2024
d190a8b
update userguide with lazy interpolator
fnattino Dec 6, 2024
3b210e4
Merge branch 'main' into lazy-rectilinearinterpolator-2
trexfeathers Dec 18, 2024
47bce0e
the unstructured NN regridder does not support lazy data
fnattino Dec 22, 2024
609c75a
remove caching an interpolator
fnattino Dec 22, 2024
dd14caa
update what's new entry
fnattino Dec 22, 2024
48f47eb
Merge branch 'main' into lazy-rectilinearinterpolator-2
fnattino Dec 22, 2024
2fb8fc3
remove links to docs section about caching interpolators
fnattino Dec 22, 2024
a4774b8
Merge branch 'main' into lazy-rectilinearinterpolator-2
fnattino Jan 7, 2025
fb8cffd
Merge branch 'main' into lazy-rectilinearinterpolator-2
fnattino Jan 17, 2025
6dc3ec9
Merge branch 'main' into lazy-rectilinearinterpolator-2
trexfeathers Jan 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 81 additions & 53 deletions lib/iris/analysis/_interpolation.py
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from numpy.lib.stride_tricks import as_strided
import numpy.ma as ma

from iris._lazy_data import map_complete_blocks
from iris.analysis._scipy_interpolate import _RegularGridInterpolator
fnattino marked this conversation as resolved.
Show resolved Hide resolved
from iris.coords import AuxCoord, DimCoord
import iris.util

Expand Down Expand Up @@ -163,6 +165,15 @@ def snapshot_grid(cube):
return x.copy(), y.copy()


def _interpolated_dtype(dtype, method):
"""Determine the minimum base dtype required by the underlying interpolator."""
if method == "nearest":
result = dtype
else:
result = np.result_type(_DEFAULT_DTYPE, dtype)
return result
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved


class RectilinearInterpolator:
"""Provide support for performing nearest-neighbour or linear interpolation.

Expand Down Expand Up @@ -200,13 +211,8 @@ def __init__(self, src_cube, coords, method, extrapolation_mode):
set to NaN.

"""
# Trigger any deferred loading of the source cube's data and snapshot
# its state to ensure that the interpolator is impervious to external
# changes to the original source cube. The data is loaded to prevent
# the snapshot having lazy data, avoiding the potential for the
# same data to be loaded again and again.
if src_cube.has_lazy_data():
src_cube.data
# Snapshot the cube state to ensure that the interpolator is impervious
# to external changes to the original source cube.
self._src_cube = src_cube.copy()
# Coordinates defining the dimensions to be interpolated.
self._src_coords = [self._src_cube.coord(coord) for coord in coords]
Expand Down Expand Up @@ -277,17 +283,27 @@ def _account_for_inverted(self, data):
data = data[tuple(dim_slices)]
return data

def _interpolate(self, data, interp_points):
@staticmethod
def _interpolate(
data,
src_points,
interp_points,
interp_shape,
method="linear",
extrapolation_mode="nanmask",
):
"""Interpolate a data array over N dimensions.

Create and cache the underlying interpolator instance before invoking
it to perform interpolation over the data at the given coordinate point
values.
Create the interpolator instance before invoking it to perform
interpolation over the data at the given coordinate point values.

Parameters
----------
data : ndarray
A data array, to be interpolated in its first 'N' dimensions.
src_points :
The point values defining the dimensions to be interpolated.
(len(src_points) should be N).
interp_points : ndarray
An array of interpolation coordinate values.
Its shape is (..., N) where N is the number of interpolation
Expand All @@ -296,44 +312,51 @@ def _interpolate(self, data, interp_points):
coordinate, which is mapped to the i'th data dimension.
The other (leading) dimensions index over the different required
sample points.
interp_shape :
The shape of the interpolated array in its first 'N' dimensions
(len(interp_shape) should be N).
method : str
Interpolation method (see :class:`iris.analysis._interpolation.RectilinearInterpolator`).
extrapolation_mode : str
Extrapolation mode (see :class:`iris.analysis._interpolation.RectilinearInterpolator`).

Returns
-------
:class:`np.ndarray`.
Its shape is "points_shape + extra_shape",
Its shape is "interp_shape + extra_shape",
where "extra_shape" is the remaining non-interpolated dimensions of
the data array (i.e. 'data.shape[N:]'), and "points_shape" is the
leading dimensions of interp_points,
(i.e. 'interp_points.shape[:-1]').

the data array (i.e. 'data.shape[N:]').
"""
from iris.analysis._scipy_interpolate import _RegularGridInterpolator

dtype = self._interpolated_dtype(data.dtype)
dtype = _interpolated_dtype(data.dtype, method)
if data.dtype != dtype:
# Perform dtype promotion.
data = data.astype(dtype)

mode = EXTRAPOLATION_MODES[self._mode]
if self._interpolator is None:
# Cache the interpolator instance.
# NB. The constructor of the _RegularGridInterpolator class does
# some unnecessary checks on the fill_value parameter,
# so we set it afterwards instead. Sneaky. ;-)
self._interpolator = _RegularGridInterpolator(
self._src_points,
data,
method=self.method,
bounds_error=mode.bounds_error,
fill_value=None,
)
else:
self._interpolator.values = data
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
# Determine the shape of the interpolated result.
ndims_interp = len(interp_shape)
extra_shape = data.shape[ndims_interp:]
final_shape = [*interp_shape, *extra_shape]
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved

mode = EXTRAPOLATION_MODES[extrapolation_mode]
_data = np.ma.getdata(data)
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
# NB. The constructor of the _RegularGridInterpolator class does
# some unnecessary checks on the fill_value parameter,
# so we set it afterwards instead. Sneaky. ;-)
interpolator = _RegularGridInterpolator(
src_points,
_data,
method=method,
bounds_error=mode.bounds_error,
fill_value=None,
)
interpolator.fill_value = mode.fill_value
result = interpolator(interp_points)

# We may be re-using a cached interpolator, so ensure the fill
# value is set appropriately for extrapolating data values.
self._interpolator.fill_value = mode.fill_value
result = self._interpolator(interp_points)
# The interpolated result has now shape "points_shape + extra_shape"
# where "points_shape" is the leading dimension of "interp_points"
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
# (i.e. 'interp_points.shape[:-1]'). We reshape it to match the shape
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
# of the interpolated dimensions.
result = result.reshape(final_shape)
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved

if result.dtype != data.dtype:
# Cast the data dtype to be as expected. Note that, the dtype
Expand All @@ -346,13 +369,11 @@ def _interpolate(self, data, interp_points):
# `data` is not a masked array.
src_mask = np.ma.getmaskarray(data)
# Switch the extrapolation to work with mask values.
self._interpolator.fill_value = mode.mask_fill_value
self._interpolator.values = src_mask
mask_fraction = self._interpolator(interp_points)
interpolator.fill_value = mode.mask_fill_value
interpolator.values = src_mask
mask_fraction = interpolator(interp_points)
new_mask = mask_fraction > 0
if ma.isMaskedArray(data) or np.any(new_mask):
result = np.ma.MaskedArray(result, new_mask)

result = np.ma.MaskedArray(result, new_mask)
fnattino marked this conversation as resolved.
Show resolved Hide resolved
return result

def _resample_coord(self, sample_points, coord, coord_dims):
Expand Down Expand Up @@ -530,7 +551,7 @@ def _points(self, sample_points, data, data_dims=None):
_, src_order = zip(*sorted(dmap.items(), key=operator.itemgetter(0)))

# Prepare the sample points for interpolation and calculate the
# shape of the interpolated result.
# shape of the interpolated dimensions.
interp_points = []
interp_shape = []
for index, points in enumerate(sample_points):
Expand All @@ -539,10 +560,6 @@ def _points(self, sample_points, data, data_dims=None):
interp_points.append(points)
interp_shape.append(points.size)

interp_shape.extend(
length for dim, length in enumerate(data.shape) if dim not in di
)

# Convert the interpolation points into a cross-product array
# with shape (n_cross_points, n_dims)
interp_points = np.asarray([pts for pts in product(*interp_points)])
Expand All @@ -554,9 +571,20 @@ def _points(self, sample_points, data, data_dims=None):
# Transpose data in preparation for interpolation.
data = np.transpose(data, interp_order)

# Interpolate and reshape the data ...
result = self._interpolate(data, interp_points)
result = result.reshape(interp_shape)
# Interpolate the data, merging the chunks in the interpolated
# dimensions.
dims_merge_chunks = [dmap[d] for d in di]
fnattino marked this conversation as resolved.
Show resolved Hide resolved
result = map_complete_blocks(
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
data,
self._interpolate,
dims=dims_merge_chunks,
out_sizes=interp_shape,
src_points=self._src_points,
interp_points=interp_points,
interp_shape=interp_shape,
method=self._method,
extrapolation_mode=self._mode,
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
)

if src_order != dims:
# Restore the interpolated result to the original
Expand Down Expand Up @@ -592,7 +620,7 @@ def __call__(self, sample_points, collapse_scalar=True):

sample_points = _canonical_sample_points(self._src_coords, sample_points)

data = self._src_cube.data
data = self._src_cube.core_data()
# Interpolate the cube payload.
interpolated_data = self._points(sample_points, data)

Expand Down
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -499,24 +499,37 @@ def test_orthogonal_cube_squash(self):
self.assertEqual(result_cube, non_collapsed_cube[0, ...])


class Test___call___real_data(ThreeDimCube):
def test_src_cube_data_loaded(self):
# If the source cube has real data when the interpolator is
# instantiated, then the interpolated result should also have
# real data.
self.assertFalse(self.cube.has_lazy_data())

# Perform interpolation and check the data is real.
interpolator = RectilinearInterpolator(
self.cube, ["latitude"], LINEAR, EXTRAPOLATE
)
res = interpolator([[1.5]])
self.assertFalse(res.has_lazy_data())


class Test___call___lazy_data(ThreeDimCube):
def test_src_cube_data_loaded(self):
trexfeathers marked this conversation as resolved.
Show resolved Hide resolved
# RectilinearInterpolator operates using a snapshot of the source cube.
# If the source cube has lazy data when the interpolator is
# instantiated we want to make sure the source cube's data is
# loaded as a consequence of interpolation to avoid the risk
# of loading it again and again.
# instantiated, then the interpolated result should also have
# lazy data.

# Modify self.cube to have lazy data.
self.cube.data = as_lazy_data(self.data)
self.assertTrue(self.cube.has_lazy_data())

# Perform interpolation and check the data has been loaded.
# Perform interpolation and check the data is lazy..
interpolator = RectilinearInterpolator(
self.cube, ["latitude"], LINEAR, EXTRAPOLATE
)
interpolator([[1.5]])
self.assertFalse(self.cube.has_lazy_data())
res = interpolator([[1.5]])
self.assertTrue(res.has_lazy_data())


class Test___call___time(tests.IrisTest):
Expand Down
Loading