Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add/complete type annotations #193

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ lint:
$(PYTHON) -m pylint trio_websocket/ tests/ autobahn/ examples/

typecheck:
$(PYTHON) -m mypy --explicit-package-bases trio_websocket tests autobahn examples
$(PYTHON) -m mypy

publish:
rm -fr build dist .egg trio_websocket.egg-info
Expand Down
17 changes: 10 additions & 7 deletions autobahn/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@
logger = logging.getLogger('client')


async def get_case_count(url):
async def get_case_count(url: str) -> int:
url = url + '/getCaseCount'
async with open_websocket_url(url) as conn:
case_count = await conn.get_message()
logger.info('Case count=%s', case_count)
return int(case_count)


async def get_case_info(url, case):
async def get_case_info(url: str, case: str) -> object:
url = f'{url}/getCaseInfo?case={case}'
async with open_websocket_url(url) as conn:
return json.loads(await conn.get_message())


async def run_case(url, case):
async def run_case(url: str, case: str) -> None:
url = f'{url}/runCase?case={case}&agent={AGENT}'
try:
async with open_websocket_url(url, max_message_size=MAX_MESSAGE_SIZE) as conn:
Expand All @@ -42,15 +42,15 @@ async def run_case(url, case):
pass


async def update_reports(url):
async def update_reports(url: str) -> None:
url = f'{url}/updateReports?agent={AGENT}'
async with open_websocket_url(url) as conn:
# This command runs as soon as we connect to it, so we don't need to
# send any messages.
pass


async def run_tests(args):
async def run_tests(args: argparse.Namespace) -> None:
logger = logging.getLogger('trio-websocket')
if args.debug_cases:
# Don't fetch case count when debugging a subset of test cases. It adds
Expand All @@ -62,7 +62,10 @@ async def run_tests(args):
test_cases = list(range(1, case_count + 1))
exception_cases = []
for case in test_cases:
case_id = (await get_case_info(args.url, case))['id']
result = await get_case_info(args.url, case)
assert isinstance(result, dict)
case_id = result['id']
assert isinstance(case_id, int)
if case_count:
logger.info("Running test case %s (%d of %d)", case_id, case, case_count)
else:
Expand All @@ -82,7 +85,7 @@ async def run_tests(args):
sys.exit(1)


def parse_args():
def parse_args() -> argparse.Namespace:
''' Parse command line arguments. '''
parser = argparse.ArgumentParser(description='Autobahn client for'
' trio-websocket')
Expand Down
6 changes: 3 additions & 3 deletions autobahn/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
connection_count = 0


async def main():
async def main() -> None:
''' Main entry point. '''
logger.info('Starting websocket server on ws://%s:%d', BIND_IP, BIND_PORT)
await serve_websocket(handler, BIND_IP, BIND_PORT, ssl_context=None,
max_message_size=MAX_MESSAGE_SIZE)


async def handler(request: WebSocketRequest):
async def handler(request: WebSocketRequest) -> None:
''' Reverse incoming websocket messages and send them back. '''
global connection_count # pylint: disable=global-statement
connection_count += 1
Expand All @@ -46,7 +46,7 @@ async def handler(request: WebSocketRequest):
logger.exception(' runtime exception handling connection #%d', connection_count)


def parse_args():
def parse_args() -> argparse.Namespace:
''' Parse command line arguments. '''
parser = argparse.ArgumentParser(description='Autobahn server for'
' trio-websocket')
Expand Down
30 changes: 19 additions & 11 deletions examples/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,23 @@
import ssl
import sys
import urllib.parse
from typing import NoReturn

import trio
from trio_websocket import open_websocket_url, ConnectionClosed, HandshakeError
from trio_websocket import (
open_websocket_url,
ConnectionClosed,
HandshakeError,
WebSocketConnection,
CloseReason,
)


logging.basicConfig(level=logging.DEBUG)
here = pathlib.Path(__file__).parent


def commands():
def commands() -> None:
''' Print the supported commands. '''
print('Commands: ')
print('send <MESSAGE> -> send message')
Expand All @@ -29,7 +36,7 @@ def commands():
print()


def parse_args():
def parse_args() -> argparse.Namespace:
''' Parse command line arguments. '''
parser = argparse.ArgumentParser(description='Example trio-websocket client')
parser.add_argument('--heartbeat', action='store_true',
Expand All @@ -38,7 +45,7 @@ def parse_args():
return parser.parse_args()


async def main(args):
async def main(args: argparse.Namespace) -> bool:
''' Main entry point, returning False in the case of logged error. '''
if urllib.parse.urlsplit(args.url).scheme == 'wss':
# Configure SSL context to handle our self-signed certificate. Most
Expand All @@ -59,9 +66,10 @@ async def main(args):
except HandshakeError as e:
logging.error('Connection attempt failed: %s', e)
return False
return True


async def handle_connection(ws, use_heartbeat):
async def handle_connection(ws: WebSocketConnection, use_heartbeat: bool) -> None:
''' Handle the connection. '''
logging.debug('Connected!')
try:
Expand All @@ -71,11 +79,12 @@ async def handle_connection(ws, use_heartbeat):
nursery.start_soon(get_commands, ws)
nursery.start_soon(get_messages, ws)
except ConnectionClosed as cc:
assert isinstance(cc.reason, CloseReason)
reason = '<no reason>' if cc.reason.reason is None else f'"{cc.reason.reason}"'
print(f'Closed: {cc.reason.code}/{cc.reason.name} {reason}')


async def heartbeat(ws, timeout, interval):
async def heartbeat(ws: WebSocketConnection, timeout: float, interval: float) -> NoReturn:
'''
Send periodic pings on WebSocket ``ws``.

Expand All @@ -99,11 +108,10 @@ async def heartbeat(ws, timeout, interval):
await trio.sleep(interval)


async def get_commands(ws):
async def get_commands(ws: WebSocketConnection) -> None:
''' In a loop: get a command from the user and execute it. '''
while True:
cmd = await trio.to_thread.run_sync(input, 'cmd> ',
cancellable=True)
cmd = await trio.to_thread.run_sync(input, 'cmd> ')
if cmd.startswith('ping'):
payload = cmd[5:].encode('utf8') or None
await ws.ping(payload)
Expand All @@ -123,11 +131,11 @@ async def get_commands(ws):
await trio.sleep(0.25)


async def get_messages(ws):
async def get_messages(ws: WebSocketConnection) -> None:
''' In a loop: get a WebSocket message and print it out. '''
while True:
message = await ws.get_message()
print(f'message: {message}')
print(f'message: {message!r}')


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/generate-cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import trustme

def main():
def main() -> None:
here = pathlib.Path(__file__).parent
ca_path = here / 'fake.ca.pem'
server_path = here / 'fake.server.pem'
Expand Down
8 changes: 4 additions & 4 deletions examples/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
import ssl

import trio
from trio_websocket import serve_websocket, ConnectionClosed
from trio_websocket import serve_websocket, ConnectionClosed, WebSocketRequest


logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
here = pathlib.Path(__file__).parent


def parse_args():
def parse_args() -> argparse.Namespace:
''' Parse command line arguments. '''
parser = argparse.ArgumentParser(description='Example trio-websocket client')
parser.add_argument('--ssl', action='store_true', help='Use SSL')
Expand All @@ -32,7 +32,7 @@ def parse_args():
return parser.parse_args()


async def main(args):
async def main(args: argparse.Namespace) -> None:
''' Main entry point. '''
logging.info('Starting websocket server…')
if args.ssl:
Expand All @@ -48,7 +48,7 @@ async def main(args):
await serve_websocket(handler, host, args.port, ssl_context)


async def handler(request):
async def handler(request: WebSocketRequest) -> None:
''' Reverse incoming websocket messages and send them back. '''
logging.info('Handler starting on path "%s"', request.path)
ws = await request.accept()
Expand Down
13 changes: 13 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[tool.mypy]
explicit_package_bases = true
files = ["trio_websocket", "tests", "autobahn", "examples"]
show_column_numbers = true
show_error_codes = true
show_traceback = true
disallow_any_decorated = true
disallow_any_unimported = true
ignore_missing_imports = true
local_partial_types = true
no_implicit_optional = true
strict = true
warn_unreachable = true
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
'Programming Language :: Python :: 3.12',
'Programming Language :: Python :: Implementation :: CPython',
'Programming Language :: Python :: Implementation :: PyPy',
'Typing :: Typed',
],
python_requires=">=3.8",
keywords='websocket client server trio',
packages=find_packages(exclude=['docs', 'examples', 'tests']),
package_data={"trio-websocket": ["py.typed"]},
install_requires=[
'exceptiongroup; python_version<"3.11"',
'trio>=0.11',
Expand Down
Loading
Loading