Skip to content

Commit

Permalink
Add type overloads for recv and recv_streaming.
Browse files Browse the repository at this point in the history
Fix #1578.
  • Loading branch information
aaugustin committed Jan 23, 2025
1 parent 8f12d8f commit 4e30662
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 4 deletions.
26 changes: 25 additions & 1 deletion src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import uuid
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterable, Mapping
from types import TracebackType
from typing import Any, cast
from typing import Any, Literal, cast, overload

from ..exceptions import (
ConcurrencyError,
Expand Down Expand Up @@ -243,6 +243,15 @@ async def __aiter__(self) -> AsyncIterator[Data]:
except ConnectionClosedOK:
return

@overload
async def recv(self, decode: Literal[True] = True) -> str: ...

@overload
async def recv(self, decode: Literal[False] = False) -> bytes: ...

@overload
async def recv(self, decode: bool | None = None) -> Data: ...

async def recv(self, decode: bool | None = None) -> Data:
"""
Receive the next message.
Expand Down Expand Up @@ -312,6 +321,21 @@ async def recv(self, decode: bool | None = None) -> Data:
await asyncio.shield(self.connection_lost_waiter)
raise self.protocol.close_exc from self.recv_exc

@overload
async def recv_streaming(
self, decode: Literal[True] = True
) -> AsyncIterator[str]: ...

@overload
async def recv_streaming(
self, decode: Literal[False] = False
) -> AsyncIterator[bytes]: ...

@overload
async def recv_streaming(
self, decode: bool | None = None
) -> AsyncIterator[Data]: ...

async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]:
"""
Receive the next message frame by frame.
Expand Down
22 changes: 21 additions & 1 deletion src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import codecs
import collections
from collections.abc import AsyncIterator, Iterable
from typing import Any, Callable, Generic, TypeVar
from typing import Any, Callable, Generic, Literal, TypeVar, overload

from ..exceptions import ConcurrencyError
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
Expand Down Expand Up @@ -116,6 +116,15 @@ def __init__( # pragma: no cover
# This flag marks the end of the connection.
self.closed = False

@overload
async def get(self, decode: Literal[True] = True) -> str: ...

@overload
async def get(self, decode: Literal[False] = False) -> bytes: ...

@overload
async def get(self, decode: bool | None = None) -> Data: ...

async def get(self, decode: bool | None = None) -> Data:
"""
Read the next message.
Expand Down Expand Up @@ -176,6 +185,17 @@ async def get(self, decode: bool | None = None) -> Data:
else:
return data

@overload
async def get_iter(self, decode: Literal[True] = True) -> AsyncIterator[str]: ...

@overload
async def get_iter(
self, decode: Literal[False] = False
) -> AsyncIterator[bytes]: ...

@overload
async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ...

async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
"""
Stream the next message.
Expand Down
26 changes: 25 additions & 1 deletion src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import uuid
from collections.abc import Iterable, Iterator, Mapping
from types import TracebackType
from typing import Any
from typing import Any, Literal, overload

from ..exceptions import (
ConcurrencyError,
Expand Down Expand Up @@ -241,6 +241,21 @@ def __iter__(self) -> Iterator[Data]:
except ConnectionClosedOK:
return

@overload
def recv(
self, timeout: float | None = None, decode: Literal[True] = True
) -> str: ...

@overload
def recv(
self, timeout: float | None = None, decode: Literal[False] = False
) -> bytes: ...

@overload
def recv(
self, timeout: float | None = None, decode: bool | None = None
) -> Data: ...

def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data:
"""
Receive the next message.
Expand Down Expand Up @@ -311,6 +326,15 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data
self.recv_events_thread.join()
raise self.protocol.close_exc from self.recv_exc

@overload
def recv_streaming(self, decode: Literal[True] = True) -> Iterator[str]: ...

@overload
def recv_streaming(self, decode: Literal[False] = False) -> Iterator[bytes]: ...

@overload
def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: ...

def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]:
"""
Receive the next message frame by frame.
Expand Down
24 changes: 23 additions & 1 deletion src/websockets/sync/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import codecs
import queue
import threading
from typing import Any, Callable, Iterable, Iterator
from typing import Any, Callable, Iterable, Iterator, Literal, overload

from ..exceptions import ConcurrencyError
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
Expand Down Expand Up @@ -110,6 +110,19 @@ def reset_queue(self, frames: Iterable[Frame]) -> None:
for frame in queued: # pragma: no cover
self.frames.put(frame)

@overload
def get(
self, timeout: float | None = None, decode: Literal[True] = True
) -> str: ...

@overload
def get(
self, timeout: float | None = None, decode: Literal[False] = False
) -> bytes: ...

@overload
def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: ...

def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
"""
Read the next message.
Expand Down Expand Up @@ -181,6 +194,15 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
else:
return data

@overload
def get_iter(self, decode: Literal[True] = True) -> Iterator[str]: ...

@overload
def get_iter(self, decode: Literal[False] = False) -> Iterator[bytes]: ...

@overload
def get_iter(self, decode: bool | None = None) -> Iterator[Data]: ...

def get_iter(self, decode: bool | None = None) -> Iterator[Data]:
"""
Stream the next message.
Expand Down

0 comments on commit 4e30662

Please sign in to comment.