Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide request_class, response_class for httpx.Client #3199

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ class BaseClient:
def __init__(
self,
*,
request_class: type[Request] = Request,
response_class: type[Response] = Response,
auth: AuthTypes | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
Expand All @@ -178,6 +180,9 @@ def __init__(
) -> None:
event_hooks = {} if event_hooks is None else event_hooks

self._request_class = request_class
self._response_class = response_class

self._base_url = self._enforce_trailing_slash(URL(base_url))

self._auth = self._build_auth(auth)
Expand All @@ -195,6 +200,14 @@ def __init__(
self._default_encoding = default_encoding
self._state = ClientState.UNOPENED

@property
def request_class(self) -> type[Request]:
return self._request_class

@property
def response_class(self) -> type[Response]:
return self._response_class

@property
def is_closed(self) -> bool:
"""
Expand Down Expand Up @@ -356,7 +369,7 @@ def build_request(
else Timeout(timeout)
)
extensions = dict(**extensions, timeout=timeout.as_dict())
return Request(
return self.request_class(
method,
url,
content=content,
Expand Down Expand Up @@ -463,7 +476,7 @@ def _build_redirect_request(self, request: Request, response: Response) -> Reque
headers = self._redirect_headers(request, url, method)
stream = self._redirect_stream(request, method)
cookies = Cookies(self.cookies)
return Request(
return self.request_class(
method=method,
url=url,
headers=headers,
Expand Down Expand Up @@ -629,6 +642,8 @@ class Client(BaseClient):
def __init__(
self,
*,
request_class: type[Request] = Request,
response_class: type[Response] = Response,
auth: AuthTypes | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
Expand All @@ -652,6 +667,8 @@ def __init__(
default_encoding: str | typing.Callable[[bytes], str] = "utf-8",
) -> None:
super().__init__(
request_class=request_class,
response_class=response_class,
auth=auth,
params=params,
headers=headers,
Expand Down Expand Up @@ -748,6 +765,7 @@ def _init_transport(
http2=http2,
limits=limits,
trust_env=trust_env,
response_class=self.response_class,
)

def _init_proxy_transport(
Expand All @@ -768,6 +786,7 @@ def _init_proxy_transport(
limits=limits,
trust_env=trust_env,
proxy=proxy,
response_class=self.response_class,
)

def _transport_for_url(self, url: URL) -> BaseTransport:
Expand Down Expand Up @@ -1376,6 +1395,8 @@ class AsyncClient(BaseClient):
def __init__(
self,
*,
request_class: type[Request] = Request,
response_class: type[Response] = Response,
auth: AuthTypes | None = None,
params: QueryParamTypes | None = None,
headers: HeaderTypes | None = None,
Expand All @@ -1399,6 +1420,8 @@ def __init__(
default_encoding: str | typing.Callable[[bytes], str] = "utf-8",
) -> None:
super().__init__(
request_class=request_class,
response_class=response_class,
auth=auth,
params=params,
headers=headers,
Expand Down Expand Up @@ -1495,6 +1518,7 @@ def _init_transport(
http2=http2,
limits=limits,
trust_env=trust_env,
response_class=self.response_class,
)

def _init_proxy_transport(
Expand All @@ -1515,6 +1539,7 @@ def _init_proxy_transport(
limits=limits,
trust_env=trust_env,
proxy=proxy,
response_class=self.response_class,
)

def _transport_for_url(self, url: URL) -> AsyncBaseTransport:
Expand Down
5 changes: 4 additions & 1 deletion httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ def __setstate__(self, state: dict[str, typing.Any]) -> None:
self.stream = UnattachedStream()


_ResponseT = typing.TypeVar("_ResponseT", bound="Response")


class Response:
def __init__(
self,
Expand Down Expand Up @@ -725,7 +728,7 @@ def has_redirect_location(self) -> bool:
and "Location" in self.headers
)

def raise_for_status(self) -> Response:
def raise_for_status(self: _ResponseT) -> _ResponseT:
"""
Raise the `HTTPStatusError` if one occurred.
"""
Expand Down
10 changes: 8 additions & 2 deletions httpx/_transports/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __init__(
local_address: str | None = None,
retries: int = 0,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
response_class: type[Response] = Response,
) -> None:
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
Expand Down Expand Up @@ -201,6 +202,8 @@ def __init__(
f" but got {proxy.url.scheme!r}."
)

self._response_class = response_class

def __enter__(self: T) -> T: # Use generics for subclass support.
self._pool.__enter__()
return self
Expand Down Expand Up @@ -237,7 +240,7 @@ def handle_request(

assert isinstance(resp.stream, typing.Iterable)

return Response(
return self._response_class(
status_code=resp.status,
headers=resp.headers,
stream=ResponseStream(resp.stream),
Expand Down Expand Up @@ -276,6 +279,7 @@ def __init__(
local_address: str | None = None,
retries: int = 0,
socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
response_class: type[Response] = Response,
) -> None:
ssl_context = create_ssl_context(verify=verify, cert=cert, trust_env=trust_env)
proxy = Proxy(url=proxy) if isinstance(proxy, (str, URL)) else proxy
Expand Down Expand Up @@ -342,6 +346,8 @@ def __init__(
" but got {proxy.url.scheme!r}."
)

self._response_class = response_class

async def __aenter__(self: A) -> A: # Use generics for subclass support.
await self._pool.__aenter__()
return self
Expand Down Expand Up @@ -378,7 +384,7 @@ async def handle_async_request(

assert isinstance(resp.stream, typing.AsyncIterable)

return Response(
return self._response_class(
status_code=resp.status,
headers=resp.headers,
stream=AsyncResponseStream(resp.stream),
Expand Down
60 changes: 60 additions & 0 deletions tests/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,3 +460,63 @@ def cp1252_but_no_content_type(request):
assert response.reason_phrase == "OK"
assert response.encoding == "ISO-8859-1"
assert response.text == text


def test_client_request_class():
class Request(httpx.Request):
def __init__(self, *args, **kwargs):
kwargs["content"] = "foobar"
super().__init__(*args, **kwargs)

class Client(httpx.Client):
request_class = Request

class AsyncClient(httpx.AsyncClient):
request_class = Request

request = Client().build_request("GET", "http://www.example.com/")
assert isinstance(request, Request)
assert request.content == b"foobar"

request = AsyncClient().build_request("GET", "http://www.example.com/")
assert isinstance(request, Request)
assert request.content == b"foobar"

with httpx.Client(request_class=Request) as client:
request = client.build_request("GET", "http://www.example.com/")
assert isinstance(request, Request)
assert request.content == b"foobar"


@pytest.mark.anyio
async def test_client_response_class(server):
class Response(httpx.Response):
def iter_bytes(self, chunk_size: int | None = None) -> typing.Iterator[bytes]:
yield b"foobar"

class Client(httpx.Client):
response_class = Response

class AsyncResponse(httpx.Response):
async def aiter_bytes(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[bytes]:
yield b"foobar"

class AsyncClient(httpx.AsyncClient):
response_class = AsyncResponse

with Client() as client:
response = client.get(server.url)
assert isinstance(response, Response)
assert response.read() == b"foobar"

async with AsyncClient() as async_client:
response = await async_client.get(server.url)
assert isinstance(response, AsyncResponse)
assert await response.aread() == b"foobar"

with httpx.Client(response_class=Response) as httpx_client:
response = httpx_client.get(server.url)
assert isinstance(response, Response)
assert response.read() == b"foobar"