Skip to content

Commit

Permalink
Improve previous commit.
Browse files Browse the repository at this point in the history
* Require fullmatch instead of match — this avoids a vulnerability.
* Shorten code and tweak to match my preferred style.
* Add changelog.
  • Loading branch information
aaugustin committed Jan 19, 2025
1 parent 7e617b2 commit 7de24bd
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 34 deletions.
6 changes: 6 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ notice.

*In development*

New features
............

* Added support for regular expressions in the ``origins`` argument of
:func:`~asyncio.server.serve`.

Bug fixes
.........

Expand Down
9 changes: 5 additions & 4 deletions src/websockets/asyncio/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,10 +600,11 @@ def handler(websocket):
See :meth:`~asyncio.loop.create_server` for details.
port: TCP port the server listens on.
See :meth:`~asyncio.loop.create_server` for details.
origins: Acceptable values of the ``Origin`` header, including regular
expressions, for defending against Cross-Site WebSocket Hijacking
attacks. Include :obj:`None` in the list if the lack of an origin
is acceptable.
origins: Acceptable values of the ``Origin`` header, for defending
against Cross-Site WebSocket Hijacking attacks. Values can be
:class:`str` to test for an exact match or regular expressions
compiled by :func:`re.compile` to test against a pattern. Include
:obj:`None` in the list if the lack of an origin is acceptable.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
Expand Down
25 changes: 12 additions & 13 deletions src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,11 @@ class ServerProtocol(Protocol):
Sans-I/O implementation of a WebSocket server connection.
Args:
origins: Acceptable values of the ``Origin`` header, including regular
expressions; include :obj:`None` in the list if the lack of an origin
is acceptable. This is useful for defending against Cross-Site WebSocket
origins: Acceptable values of the ``Origin`` header. Values can be
:class:`str` to test for an exact match or regular expressions
compiled by :func:`re.compile` to test against a pattern. Include
:obj:`None` in the list if the lack of an origin is acceptable.
This is useful for defending against Cross-Site WebSocket
Hijacking attacks.
extensions: List of supported extensions, in order in which they
should be tried.
Expand Down Expand Up @@ -310,17 +312,14 @@ def process_origin(self, headers: Headers) -> Origin | None:
if origin is not None:
origin = cast(Origin, origin)
if self.origins is not None:
valid = False
for acceptable_origin_or_regex in self.origins:
if isinstance(acceptable_origin_or_regex, re.Pattern):
# `str(origin)` is needed for compatibility
# between `Pattern.match(string=...)` and `origin`.
valid = acceptable_origin_or_regex.match(str(origin)) is not None
else:
valid = acceptable_origin_or_regex == origin
if valid:
for origin_or_regex in self.origins:
if origin_or_regex == origin or (
isinstance(origin_or_regex, re.Pattern)
and origin is not None
and origin_or_regex.fullmatch(origin) is not None
):
break
if not valid:
else:
raise InvalidOrigin(origin)
return origin

Expand Down
9 changes: 5 additions & 4 deletions src/websockets/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,11 @@ def handler(websocket):
You may call :func:`socket.create_server` to create a suitable TCP
socket.
ssl: Configuration for enabling TLS on the connection.
origins: Acceptable values of the ``Origin`` header, including regular
expressions, for defending against Cross-Site WebSocket Hijacking
attacks. Include :obj:`None` in the list if the lack of an origin
is acceptable.
origins: Acceptable values of the ``Origin`` header, for defending
against Cross-Site WebSocket Hijacking attacks. Values can be
:class:`str` to test for an exact match or regular expressions
compiled by :func:`re.compile` to test against a pattern. Include
:obj:`None` in the list if the lack of an origin is acceptable.
extensions: List of supported extensions, in order in which they
should be negotiated and run.
subprotocols: List of supported subprotocols, in order of decreasing
Expand Down
37 changes: 24 additions & 13 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def test_supported_origin(self):
self.assertEqual(server.origin, "https://other.example.com")

def test_unsupported_origin(self):
"""Handshake succeeds when checking origins and the origin is unsupported."""
"""Handshake fails when checking origins and the origin is unsupported."""
server = ServerProtocol(
origins=["https://example.com", "https://other.example.com"]
)
Expand All @@ -624,13 +624,10 @@ def test_unsupported_origin(self):
"invalid Origin header: https://original.example.com",
)

def test_supported_origin_by_regex(self):
"""
Handshake succeeds when checking origins and the origin is supported
by a regular expression.
"""
def test_supported_origin_regex(self):
"""Handshake succeeds when checking origins and the origin is supported."""
server = ServerProtocol(
origins=["https://example.com", re.compile(r"https://other.*")]
origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")]
)
request = make_request()
request.headers["Origin"] = "https://other.example.com"
Expand All @@ -640,13 +637,10 @@ def test_supported_origin_by_regex(self):
self.assertHandshakeSuccess(server)
self.assertEqual(server.origin, "https://other.example.com")

def test_unsupported_origin_by_regex(self):
"""
Handshake succeeds when checking origins and the origin is unsupported
by a regular expression.
"""
def test_unsupported_origin_regex(self):
"""Handshake fails when checking origins and the origin is unsupported."""
server = ServerProtocol(
origins=["https://example.com", re.compile(r"https://other.*")]
origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")]
)
request = make_request()
request.headers["Origin"] = "https://original.example.com"
Expand All @@ -660,6 +654,23 @@ def test_unsupported_origin_by_regex(self):
"invalid Origin header: https://original.example.com",
)

def test_partial_match_origin_regex(self):
"""Handshake fails when checking origins and the origin a partial match."""
server = ServerProtocol(
origins=[re.compile(r"https://(?!original)[a-z]+\.example\.com")]
)
request = make_request()
request.headers["Origin"] = "https://other.example.com.hacked"
response = server.accept(request)
server.send_response(response)

self.assertEqual(response.status_code, 403)
self.assertHandshakeError(
server,
InvalidOrigin,
"invalid Origin header: https://other.example.com.hacked",
)

def test_no_origin_accepted(self):
"""Handshake succeeds when the lack of an origin is accepted."""
server = ServerProtocol(origins=[None])
Expand Down

0 comments on commit 7de24bd

Please sign in to comment.