Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support proper custom class reflexive operator applied to xarray objects #9944

Open
Li9htmare opened this issue Jan 13, 2025 · 3 comments
Open

Comments

@Li9htmare
Copy link

Is your feature request related to a problem?

I would like to implement reflexive operator on a custom class applied to xarray objects.

Following is a demo snippet:

import numpy as np
import xarray as xr


class DemoObj:
    def __add__(self, other):
        print(f'__add__ call: type={other.__class__}, value={other}')
        return other

    def __radd__(self, other):
        print(f'__radd__ call: type={other.__class__}, value={other}')
        return other


obj = DemoObj()
da = xr.DataArray(np.arange(8))

print('#### Test __add__ ####')
obj + da
print('\n')

print('#### Test __radd__ ####')
da + obj

Actual Output:

#### Test __add__ ####
__add__ call: type=<class 'xarray.core.dataarray.DataArray'>, value=<xarray.DataArray (dim_0: 8)>
array([0, 1, 2, 3, 4, 5, 6, 7])
Dimensions without coordinates: dim_0

#### Test __radd__ ####
__radd__ call: type=<class 'int'>, value=0
__radd__ call: type=<class 'int'>, value=1
__radd__ call: type=<class 'int'>, value=2
__radd__ call: type=<class 'int'>, value=3
__radd__ call: type=<class 'int'>, value=4
__radd__ call: type=<class 'int'>, value=5
__radd__ call: type=<class 'int'>, value=6
__radd__ call: type=<class 'int'>, value=7

We can see __add__ got called once and received xr.DataArray obj but __radd__ got called 8 times and received ints. This causes 2 problems;

  • Performance issue on large xr.DataArray
  • No access to xr.DataArray coords which is needed in a more realistic use case

Describe the solution you'd like

I would like to have a mechanism so that DemoObj.__radd__ got called only once and received xr.DataArray instance in the above example.

Describe alternatives you've considered

Option 1:

The most naive approach to workaround this is to call obj.__radd__(da) to achieve da + obj which defeats the purpose of implementing the reflexive operator and not offer good readability.

Option 2:

As xr.DataArray._binary_op replies on numpy's operator resolving mechanism under the hood, I could improve the situation by setting __array_ufunc__ = None on my class, e.g.:

class DemoObj:
    __array_ufunc__ = None

    def __add__(self, other):
        ...

    def __radd__(self, other):
        ...

This will make __radd__ get called once with np.ndarray instead of 8 times with ints. This solves the potential perf concern, however, it still doesn't cover the case if xr.Dataarray.coords is needed.

Additional context

Considering xr.DataArray._binary_op has already returned NoImplemented for a list of classes:
https://github.com/pydata/xarray/blob/v2025.01.1/xarray/core/dataarray.py#L4808-L4809

I'm wondering whether we should do the same for classes has __array_ufunc__ = None, i.e.:

def _binary_op(
    self: T_DataArray,
    other: Any,
    f: Callable,
    reflexive: bool = False,
) -> T_DataArray:
    if hasattr(other, '__array_ufunc__') and other.__array_ufunc__ is None:
        return NotImplementd
    ...

I'm happy with a similar property if you prefer to make it xarray specific. I'm happy to make the PR as well once you confirmed the mechanism / property name you preferred.

Many thanks in advance!

Copy link

welcome bot commented Jan 13, 2025

Thanks for opening your first issue here at xarray! Be sure to follow the issue template!
If you have an idea for a solution, we would really welcome a Pull Request with proposed changes.
See the Contributing Guide for more.
It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better.
Thank you!

@shoyer
Copy link
Member

shoyer commented Jan 17, 2025

Indeed, currently Xarray very aggressively attempts to take control of all binary arithmetic operations (by applying them to the wrapped .data of the xarray object). I agree that this is definitely not ideal.

Xarray should only attempt to do this for objects with an API that work like multi-dimensional arrays. I see at least two ways to determine this:

  1. As you suggest, we could use __array_ufunc__ = None like NumPy to indicate that an object explicitly does not have an API like NumPy arrays.
  2. Alternatively, we return NotImplemented except for types that explicitly indicate that they do work like NumPy arrays, which in principle should be the same set of types that are valid when wrapped inside xarray objects, because they implement one of two generations of NumPy compatibility APIs (__array_ufunc__/__array_function__ or __array_namespace). Here is where the current code to check for compatibility with these objects lives:
    hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__")

My inclination would be to try the second solution first (I think it's a little cleaner / more comprehensive) but if that doesn't work I would be OK to fall back to the first one.

@Li9htmare
Copy link
Author

Thanks for looking into this Stephen, I agree that the second solution is cleaner / more comprehensive. However, as xarray is an already well-known library with a big user base, the second solution could be more likely to break codebase for existing users.

Considering the following example (I admit that people unlikely to do this, just for the sake of a potential problematic scenario):

import numpy as np
import xarray as xr

class DemoObj:
    def __add__(self, other):
        print(f'__add__ call: type={other.__class__}, value={other}')
        if not isinstance(other, np.ndarray):
            return NotImplemented
        return 1 + other


obj = DemoObj()
da = xr.DataArray(np.arange(8))

print('result: ', obj + da)
__add__ call: type=<class 'xarray.core.dataarray.DataArray'>, value=<xarray.DataArray (dim_0: 8)>
array([0, 1, 2, 3, 4, 5, 6, 7])
Dimensions without coordinates: dim_0

__add__ call: type=<class 'xarray.core.variable.Variable'>, value=<xarray.Variable (dim_0: 8)>
array([0, 1, 2, 3, 4, 5, 6, 7])

__add__ call: type=<class 'numpy.ndarray'>, value=[0 1 2 3 4 5 6 7]

result:  <xarray.DataArray (dim_0: 8)>
array([1, 2, 3, 4, 5, 6, 7, 8])
Dimensions without coordinates: dim_0

It "works" now but would complain if xarray starts to check __array_ufunc__.

Also the checks could be more involved than just __array_function__ and __array_namespace__. Following are few more examples:

Example 1:

from numpy.lib.mixins import NDArrayOperatorsMixin


class DemoObj(NDArrayOperatorsMixin):
    def __array_ufunc__(
        self,
        ufunc,
        method,
        *inputs,
        **kwargs,
    ):
        if method == '__call__':
            print(f'calling {ufunc}: {inputs=}, {kwargs=}')

obj = DemoObj()
da = xr.DataArray(np.arange(8))

print('result: ', obj + da)

for name in ['__array_ufunc__', '__array_function__', '__array_namespace__']:
    print(f'hasattr({name}):', hasattr(obj, name))
calling <ufunc 'add'>: inputs=(<__main__.DemoObj object at 0x7f02926c9de0>, <xarray.DataArray (dim_0: 8)>
array([0, 1, 2, 3, 4, 5, 6, 7])
Dimensions without coordinates: dim_0), kwargs={}

result:  None

hasattr(__array_ufunc__): True
hasattr(__array_function__): False
hasattr(__array_namespace__): False

Example 2:

class DemoObj:
    def __array__(self):
        return np.array(10)

obj = DemoObj()
da = xr.DataArray(np.arange(8))

print('result: ', obj + da)

for name in ['__array_ufunc__', '__array_function__', '__array_namespace__']:
    print(f'hasattr({name}):', hasattr(obj, name))
result:  <xarray.DataArray (dim_0: 8)>
array([10, 11, 12, 13, 14, 15, 16, 17])
Dimensions without coordinates: dim_0

hasattr(__array_ufunc__): False
hasattr(__array_function__): False
hasattr(__array_namespace__): False

Besides, I think approach 1 (check __array_ufunc__ is None) is something xarray should respect even if xarray ends up offering approach 2. i.e. The condidtion is:

obj_supports_xarray_binary_op = (
    ...
    or (getattr(obj, '__array_ufunc__', None) is not None)
    or ...
)

If you agree with the above, I will make a PR with __array_ufunc__ is None check only to offer xarray users a way to bypass its binary ops if needed. Then you and other xarray maintaners can work out a futher plan if you decided to go for the "explicitly opt-in approach" assuming you are gonna emit warnings for a few releases before putting in the actual changes.

Please let me know what's your thought, many thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants