22
33import ssl
44import sys
5- from types import TracebackType
6- from typing import AsyncIterable , AsyncIterator , Iterable
5+ import types
6+ import typing
77
88from .._backends .auto import AutoBackend
99from .._backends .base import SOCKET_OPTION , AsyncNetworkBackend
1010from .._exceptions import ConnectionNotAvailable , UnsupportedProtocol
11- from .._models import Origin , Request , Response
11+ from .._models import Origin , Proxy , Request , Response
1212from .._synchronization import AsyncEvent , AsyncShieldCancellation , AsyncThreadLock
1313from .connection import AsyncHTTPConnection
1414from .interfaces import AsyncConnectionInterface , AsyncRequestInterface
@@ -48,6 +48,7 @@ class AsyncConnectionPool(AsyncRequestInterface):
4848 def __init__ (
4949 self ,
5050 ssl_context : ssl .SSLContext | None = None ,
51+ proxy : Proxy | None = None ,
5152 max_connections : int | None = 10 ,
5253 max_keepalive_connections : int | None = None ,
5354 keepalive_expiry : float | None = None ,
@@ -57,7 +58,7 @@ def __init__(
5758 local_address : str | None = None ,
5859 uds : str | None = None ,
5960 network_backend : AsyncNetworkBackend | None = None ,
60- socket_options : Iterable [SOCKET_OPTION ] | None = None ,
61+ socket_options : typing . Iterable [SOCKET_OPTION ] | None = None ,
6162 ) -> None :
6263 """
6364 A connection pool for making HTTP requests.
@@ -89,7 +90,7 @@ def __init__(
8990 in the TCP socket when the connection was established.
9091 """
9192 self ._ssl_context = ssl_context
92-
93+ self . _proxy = proxy
9394 self ._max_connections = (
9495 sys .maxsize if max_connections is None else max_connections
9596 )
@@ -125,6 +126,45 @@ def __init__(
125126 self ._optional_thread_lock = AsyncThreadLock ()
126127
127128 def create_connection (self , origin : Origin ) -> AsyncConnectionInterface :
129+ if self ._proxy is not None :
130+ if self ._proxy .url .scheme in (b"socks5" , b"socks5h" ):
131+ from .socks_proxy import AsyncSocks5Connection
132+
133+ return AsyncSocks5Connection (
134+ proxy_origin = self ._proxy .url .origin ,
135+ proxy_auth = self ._proxy .auth ,
136+ remote_origin = origin ,
137+ ssl_context = self ._ssl_context ,
138+ keepalive_expiry = self ._keepalive_expiry ,
139+ http1 = self ._http1 ,
140+ http2 = self ._http2 ,
141+ network_backend = self ._network_backend ,
142+ )
143+ elif origin .scheme == b"http" :
144+ from .http_proxy import AsyncForwardHTTPConnection
145+
146+ return AsyncForwardHTTPConnection (
147+ proxy_origin = self ._proxy .url .origin ,
148+ proxy_headers = self ._proxy .headers ,
149+ proxy_ssl_context = self ._proxy .ssl_context ,
150+ remote_origin = origin ,
151+ keepalive_expiry = self ._keepalive_expiry ,
152+ network_backend = self ._network_backend ,
153+ )
154+ from .http_proxy import AsyncTunnelHTTPConnection
155+
156+ return AsyncTunnelHTTPConnection (
157+ proxy_origin = self ._proxy .url .origin ,
158+ proxy_headers = self ._proxy .headers ,
159+ proxy_ssl_context = self ._proxy .ssl_context ,
160+ remote_origin = origin ,
161+ ssl_context = self ._ssl_context ,
162+ keepalive_expiry = self ._keepalive_expiry ,
163+ http1 = self ._http1 ,
164+ http2 = self ._http2 ,
165+ network_backend = self ._network_backend ,
166+ )
167+
128168 return AsyncHTTPConnection (
129169 origin = origin ,
130170 ssl_context = self ._ssl_context ,
@@ -217,7 +257,7 @@ async def handle_async_request(self, request: Request) -> Response:
217257
218258 # Return the response. Note that in this case we still have to manage
219259 # the point at which the response is closed.
220- assert isinstance (response .stream , AsyncIterable )
260+ assert isinstance (response .stream , typing . AsyncIterable )
221261 return Response (
222262 status = response .status ,
223263 headers = response .headers ,
@@ -319,7 +359,7 @@ async def __aexit__(
319359 self ,
320360 exc_type : type [BaseException ] | None = None ,
321361 exc_value : BaseException | None = None ,
322- traceback : TracebackType | None = None ,
362+ traceback : types . TracebackType | None = None ,
323363 ) -> None :
324364 await self .aclose ()
325365
@@ -349,7 +389,7 @@ def __repr__(self) -> str:
349389class PoolByteStream :
350390 def __init__ (
351391 self ,
352- stream : AsyncIterable [bytes ],
392+ stream : typing . AsyncIterable [bytes ],
353393 pool_request : AsyncPoolRequest ,
354394 pool : AsyncConnectionPool ,
355395 ) -> None :
@@ -358,7 +398,7 @@ def __init__(
358398 self ._pool = pool
359399 self ._closed = False
360400
361- async def __aiter__ (self ) -> AsyncIterator [bytes ]:
401+ async def __aiter__ (self ) -> typing . AsyncIterator [bytes ]:
362402 try :
363403 async for part in self ._stream :
364404 yield part
0 commit comments