|
1 | 1 | import asyncio |
| 2 | +import contextlib |
2 | 3 | import io |
3 | 4 | import os |
4 | 5 | import socket |
@@ -54,8 +55,11 @@ def __init__( |
54 | 55 | proxies_config=None, |
55 | 56 | connector_args=None, |
56 | 57 | ): |
| 58 | + self._exit_stack = contextlib.AsyncExitStack() |
| 59 | + |
57 | 60 | # TODO: handle socket_options |
58 | | - self._session: Optional[aiohttp.ClientSession] = None |
| 61 | + # keep track of sessions by proxy url (if any) |
| 62 | + self._sessions: Dict[Optional[str], aiohttp.ClientSession] = {} |
59 | 63 | self._verify = verify |
60 | 64 | self._proxy_config = ProxyConfiguration( |
61 | 65 | proxies=proxies, proxies_settings=proxies_config |
@@ -93,53 +97,17 @@ def __init__( |
93 | 97 | # it also pools by host so we don't need a manager, and can pass proxy via |
94 | 98 | # request so don't need proxy manager |
95 | 99 |
|
96 | | - ssl_context = None |
97 | | - if bool(verify): |
98 | | - if proxies: |
99 | | - proxies_settings = self._proxy_config.settings |
100 | | - ssl_context = self._setup_proxy_ssl_context(proxies_settings) |
101 | | - # TODO: add support for |
102 | | - # proxies_settings.get('proxy_use_forwarding_for_https') |
103 | | - else: |
104 | | - ssl_context = self._get_ssl_context() |
105 | | - |
106 | | - # inline self._setup_ssl_cert |
107 | | - ca_certs = get_cert_path(verify) |
108 | | - if ca_certs: |
109 | | - ssl_context.load_verify_locations(ca_certs, None, None) |
110 | | - |
111 | | - self._create_connector = lambda: aiohttp.TCPConnector( |
112 | | - limit=max_pool_connections, |
113 | | - verify_ssl=bool(verify), |
114 | | - ssl=ssl_context, |
115 | | - **self._connector_args |
116 | | - ) |
117 | | - self._connector = None |
118 | | - |
119 | 100 | async def __aenter__(self): |
120 | | - assert not self._session and not self._connector |
| 101 | + assert not self._sessions |
121 | 102 |
|
122 | | - self._connector = self._create_connector() |
123 | | - |
124 | | - self._session = aiohttp.ClientSession( |
125 | | - connector=self._connector, |
126 | | - timeout=self._timeout, |
127 | | - skip_auto_headers={'CONTENT-TYPE'}, |
128 | | - auto_decompress=False, |
129 | | - ) |
130 | 103 | return self |
131 | 104 |
|
132 | 105 | async def __aexit__(self, exc_type, exc_val, exc_tb): |
133 | | - if self._session: |
134 | | - await self._session.__aexit__(exc_type, exc_val, exc_tb) |
135 | | - self._session = None |
136 | | - self._connector = None |
| 106 | + self._sessions.clear() |
| 107 | + await self._exit_stack.aclose() |
137 | 108 |
|
138 | 109 | def _get_ssl_context(self): |
139 | | - ssl_context = create_urllib3_context() |
140 | | - if self._cert_file: |
141 | | - ssl_context.load_cert_chain(self._cert_file, self._key_file) |
142 | | - return ssl_context |
| 110 | + return create_urllib3_context() |
143 | 111 |
|
144 | 112 | def _setup_proxy_ssl_context(self, proxy_url): |
145 | 113 | proxies_settings = self._proxy_config.settings |
@@ -167,6 +135,58 @@ def _setup_proxy_ssl_context(self, proxy_url): |
167 | 135 | except (OSError, LocationParseError) as e: |
168 | 136 | raise InvalidProxiesConfigError(error=e) |
169 | 137 |
|
| 138 | + def _chunked(self, headers): |
| 139 | + transfer_encoding = headers.get('Transfer-Encoding', '') |
| 140 | + if chunked := transfer_encoding.lower() == 'chunked': |
| 141 | + # aiohttp wants chunking as a param, and not a header |
| 142 | + del headers['Transfer-Encoding'] |
| 143 | + return chunked or None |
| 144 | + |
| 145 | + def _create_connector(self, proxy_url): |
| 146 | + ssl_context = None |
| 147 | + if bool(self._verify): |
| 148 | + if proxy_url: |
| 149 | + ssl_context = self._setup_proxy_ssl_context(proxy_url) |
| 150 | + # TODO: add support for |
| 151 | + # proxies_settings.get('proxy_use_forwarding_for_https') |
| 152 | + else: |
| 153 | + ssl_context = self._get_ssl_context() |
| 154 | + |
| 155 | + if ssl_context: |
| 156 | + if self._cert_file: |
| 157 | + ssl_context.load_cert_chain( |
| 158 | + self._cert_file, |
| 159 | + self._key_file, |
| 160 | + ) |
| 161 | + |
| 162 | + # inline self._setup_ssl_cert |
| 163 | + ca_certs = get_cert_path(self._verify) |
| 164 | + if ca_certs: |
| 165 | + ssl_context.load_verify_locations(ca_certs, None, None) |
| 166 | + |
| 167 | + return aiohttp.TCPConnector( |
| 168 | + limit=self._max_pool_connections, |
| 169 | + verify_ssl=bool(self._verify), |
| 170 | + ssl=ssl_context, |
| 171 | + **self._connector_args, |
| 172 | + ) |
| 173 | + |
| 174 | + async def _get_session(self, proxy_url): |
| 175 | + if not (session := self._sessions.get(proxy_url)): |
| 176 | + connector = self._create_connector(proxy_url) |
| 177 | + self._sessions[ |
| 178 | + proxy_url |
| 179 | + ] = session = await self._exit_stack.enter_async_context( |
| 180 | + aiohttp.ClientSession( |
| 181 | + connector=connector, |
| 182 | + timeout=self._timeout, |
| 183 | + skip_auto_headers={'CONTENT-TYPE'}, |
| 184 | + auto_decompress=False, |
| 185 | + ), |
| 186 | + ) |
| 187 | + |
| 188 | + return session |
| 189 | + |
170 | 190 | async def close(self): |
171 | 191 | await self.__aexit__(None, None, None) |
172 | 192 |
|
@@ -195,20 +215,15 @@ async def send(self, request): |
195 | 215 | # https://github.com/boto/botocore/issues/1255 |
196 | 216 | headers_['Accept-Encoding'] = 'identity' |
197 | 217 |
|
198 | | - chunked = None |
199 | | - if headers_.get('Transfer-Encoding', '').lower() == 'chunked': |
200 | | - # aiohttp wants chunking as a param, and not a header |
201 | | - headers_.pop('Transfer-Encoding', '') |
202 | | - chunked = True |
203 | | - |
204 | 218 | if isinstance(data, io.IOBase): |
205 | 219 | data = _IOBaseWrapper(data) |
206 | 220 |
|
207 | 221 | url = URL(url, encoded=True) |
208 | | - response = await self._session.request( |
| 222 | + session = await self._get_session(proxy_url) |
| 223 | + response = await session.request( |
209 | 224 | request.method, |
210 | 225 | url=url, |
211 | | - chunked=chunked, |
| 226 | + chunked=self._chunked(headers_), |
212 | 227 | headers=headers_, |
213 | 228 | data=data, |
214 | 229 | proxy=proxy_url, |
|
0 commit comments