diff --git a/docs/index.rst b/docs/index.rst index f929607..8d3b245 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -69,8 +69,12 @@ iterators .. autofunction:: azip .. autofunction:: azip_longest +exceptions +---------- +.. autofunction:: multi_error_defer_to +.. autofunction:: defer_to_cancelled + miscellaneous ------------- .. autoclass:: TaskStats :show-inheritance: -.. autofunction:: defer_to_cancelled diff --git a/src/trio_util/__init__.py b/src/trio_util/__init__.py index 0a954d5..e608dba 100644 --- a/src/trio_util/__init__.py +++ b/src/trio_util/__init__.py @@ -4,7 +4,7 @@ from ._async_itertools import azip, azip_longest from ._async_value import AsyncValue from ._awaitables import wait_all, wait_any -from ._exceptions import defer_to_cancelled +from ._exceptions import defer_to_cancelled, multi_error_defer_to from ._periodic import periodic from ._repeated_event import UnqueuedRepeatedEvent, MailboxRepeatedEvent from ._task_stats import TaskStats diff --git a/src/trio_util/_exceptions.py b/src/trio_util/_exceptions.py index c00e6f7..90cdc86 100644 --- a/src/trio_util/_exceptions.py +++ b/src/trio_util/_exceptions.py @@ -1,7 +1,8 @@ +from collections import defaultdict from contextlib import _GeneratorContextManager from functools import wraps from inspect import iscoroutinefunction -from typing import Type +from typing import Type, Dict, List import trio @@ -37,7 +38,6 @@ def helper(*args, **kwargs): return helper -@async_friendly_contextmanager def defer_to_cancelled(*args: Type[Exception]): """Context manager which defers MultiError exceptions to Cancelled. @@ -51,6 +51,8 @@ def defer_to_cancelled(*args: Type[Exception]): unhandled exception will occur. Often what is desired in this case is for the Cancelled exception alone to propagate to the cancel scope. + Equivalent to multi_error_defer_to(trio.Cancelled, *args). + :param args: One or more exception types which will defer to trio.Cancelled. By default, all exception types will be filtered. @@ -65,28 +67,84 @@ def defer_to_cancelled(*args: Type[Exception]): except Obstacle: # handle API exception (unless Cancelled raised simultaneously) ... + """ + return multi_error_defer_to(trio.Cancelled, *args) - TODO: Support consolidation of simultaneous user API exceptions - (i.e. MultiError without Cancelled). This would work by prioritized list - of exceptions to defer to. E.g. given:: - - [Cancelled, WheelObstruction, RangeObstruction] - - then:: - Cancelled + RangeObstruction => Cancelled - WheelObstruction + RangeObstruction => WheelObstruction +@async_friendly_contextmanager +def multi_error_defer_to(*privileged_types: Type[BaseException], + propagate_multi_error=True, + strict=True): + """ + Defer a trio.MultiError exception to a single, privileged exception + + In the scope of this context manager, a raised MultiError will be coalesced + into a single exception with the highest privilege if the following + criteria is met: + 1. every exception in the MultiError is an instance of one of the given + privileged types + additionally, by default with strict=True: + 2. there is a single candidate at the highest privilege after grouping + the exceptions by repr(). For example, this test fails if both + ValueError('foo') and ValueError('bar') are the most privileged. + + If the criteria are not met, by default the original MultiError is + propagated. Use propagate_multi_error=False to instead raise a + RuntimeError in these cases. + + Examples: + multi_error_defer_to(trio.Cancelled, MyException) + MultiError([Cancelled(), MyException()]) -> Cancelled() + MultiError([Cancelled(), MyException(), + MultiError([Cancelled(), Cancelled())]]) -> Cancelled() + MultiError([Cancelled(), MyException(), ValueError()]) -> *no change* + MultiError([MyException('foo'), MyException('foo')]) -> MyException('foo') + MultiError([MyException('foo'), MyException('bar')]) -> *no change* + + multi_error_defer_to(MyImportantException, trio.Cancelled, MyBaseException) + # where isinstance(MyDerivedException, MyBaseException) + # and isinstance(MyImportantException, MyBaseException) + MultiError([Cancelled(), MyDerivedException()]) -> Cancelled() + MultiError([MyImportantException(), Cancelled()]) -> MyImportantException() + + :param privileged_types: exception types from highest priority to lowest + :param propagate_multi_error: if false, raise a RuntimeError where a + MultiError would otherwise be leaked + :param strict: propagate MultiError if there are multiple output exceptions + to chose from (i.e. multiple exceptions objects with differing repr() + are instances of the privileged type). When combined with + propagate_multi_error=False, this case will raise a RuntimeError. """ try: yield - except trio.MultiError as e: - exceptions = e.exceptions - if not any(isinstance(exc, trio.Cancelled) for exc in exceptions): - raise - if not args: - raise trio.MultiError.filter( - lambda exc: exc if isinstance(exc, trio.Cancelled) else None, - e) - raise trio.MultiError.filter( - lambda exc: None if isinstance(exc, args) else exc, - e) + except trio.MultiError as root_multi_error: + # flatten the exceptions in the MultiError, grouping by repr() + multi_errors = [root_multi_error] + errors_by_repr = {} # exception_repr -> exception_object + while multi_errors: + multi_error = multi_errors.pop() + for e in multi_error.exceptions: + if isinstance(e, trio.MultiError): + multi_errors.append(e) + continue + if not isinstance(e, privileged_types): + # not in privileged list + if propagate_multi_error: + raise + raise RuntimeError('Unhandled trio.MultiError') + errors_by_repr[repr(e)] = e + # group the resulting errors by index in the privileged type list + # priority_index -> exception_object + errors_by_priority: Dict[int, List[BaseException]] = defaultdict(list) + for e in errors_by_repr.values(): + for priority, privileged_type in enumerate(privileged_types): + if isinstance(e, privileged_type): + errors_by_priority[priority].append(e) + # the error (or one of the errors) of the most privileged type wins + priority_errors = errors_by_priority[min(errors_by_priority)] + if strict and len(priority_errors) > 1: + # multiple unique exception objects at the same priority + if propagate_multi_error: + raise + raise RuntimeError('Unhandled trio.MultiError') + raise priority_errors[0] diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 8cf2ccd..dcbef75 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -1,7 +1,7 @@ import pytest import trio -from trio_util import defer_to_cancelled +from trio_util import defer_to_cancelled, multi_error_defer_to async def test_defer_to_cancelled_simple_exception(): @@ -30,6 +30,15 @@ async def test_defer_to_cancelled_deferred_multiple(): KeyError()]) +async def test_defer_to_cancelled_deferred_nested_multi_error(): + with pytest.raises(trio.Cancelled): + with defer_to_cancelled(ValueError): + raise trio.MultiError([ + ValueError(), + trio.MultiError([trio.Cancelled._create(), trio.Cancelled._create()]) + ]) + + async def test_defer_to_cancelled_not_deferred(): with pytest.raises(trio.MultiError): with defer_to_cancelled(ValueError): @@ -43,3 +52,88 @@ async def foo(): with pytest.raises(trio.Cancelled): await foo() + + +# TODO: parameterize tests + +async def test_multi_error_defer_simple_exception(): + with pytest.raises(ValueError): + with multi_error_defer_to(trio.Cancelled, ValueError): + raise ValueError + + +async def test_multi_error_defer_simple_cancel(): + with trio.move_on_after(1) as cancel_scope: + with multi_error_defer_to(trio.Cancelled, ValueError): + cancel_scope.cancel() + await trio.sleep(0) + + +async def test_multi_error_defer(): + with pytest.raises(trio.Cancelled): + with multi_error_defer_to(trio.Cancelled, ValueError): + raise trio.MultiError([trio.Cancelled._create(), ValueError()]) + + +async def test_multi_error_defer_nested(): + with pytest.raises(trio.Cancelled): + with multi_error_defer_to(trio.Cancelled, ValueError): + raise trio.MultiError([ + ValueError(), + trio.MultiError([trio.Cancelled._create(), trio.Cancelled._create()]) + ]) + + +async def test_multi_error_defer_derived(): + class MyExceptionBase(Exception): + pass + class MyException(MyExceptionBase): + pass + with pytest.raises(MyException): + with multi_error_defer_to(MyExceptionBase, trio.Cancelled): + raise trio.MultiError([trio.Cancelled._create(), MyException()]) + + +async def test_multi_error_defer_deferred_same_repr_strict(): + with pytest.raises(ValueError): + with multi_error_defer_to(ValueError, trio.Cancelled): + raise trio.MultiError([ValueError(), ValueError(), trio.Cancelled._create()]) + + +async def test_multi_error_defer_deferred_different_repr_strict(): + with pytest.raises(trio.MultiError): + with multi_error_defer_to(ValueError, trio.Cancelled): + raise trio.MultiError([ValueError('foo'), ValueError('bar'), trio.Cancelled._create()]) + + +async def test_multi_error_defer_deferred_different_repr_strict_no_propagate(): + with pytest.raises(RuntimeError): + with multi_error_defer_to(ValueError, trio.Cancelled, propagate_multi_error=False): + raise trio.MultiError([ValueError('foo'), ValueError('bar'), trio.Cancelled._create()]) + + +async def test_multi_error_defer_deferred_different_repr_no_strict(): + with pytest.raises(ValueError): + with multi_error_defer_to(ValueError, trio.Cancelled, strict=False): + raise trio.MultiError([ValueError('foo'), ValueError('bar'), trio.Cancelled._create()]) + + +async def test_multi_error_defer_no_match(): + with pytest.raises(trio.MultiError): + with multi_error_defer_to(trio.Cancelled, ValueError): + raise trio.MultiError([trio.Cancelled._create(), KeyError()]) + + +async def test_multi_error_defer_no_match_no_propagate(): + with pytest.raises(RuntimeError): + with multi_error_defer_to(trio.Cancelled, ValueError, propagate_multi_error=False): + raise trio.MultiError([trio.Cancelled._create(), KeyError()]) + + +async def test_multi_error_defer_decorating_async(): + @multi_error_defer_to(trio.Cancelled, ValueError) + async def foo(): + raise trio.MultiError([trio.Cancelled._create(), ValueError()]) + + with pytest.raises(trio.Cancelled): + await foo()