|
5 | 5 | AwsChunkedWrapper, |
6 | 6 | FlexibleChecksumError, |
7 | 7 | _apply_request_header_checksum, |
8 | | - _handle_streaming_response, |
9 | 8 | base64, |
10 | 9 | conditionally_calculate_md5, |
11 | 10 | determine_content_length, |
12 | 11 | logger, |
13 | 12 | ) |
14 | 13 |
|
15 | 14 | from aiobotocore._helpers import resolve_awaitable |
| 15 | +from aiobotocore.response import StreamingBody |
16 | 16 |
|
17 | 17 |
|
18 | 18 | class AioAwsChunkedWrapper(AwsChunkedWrapper): |
@@ -44,6 +44,30 @@ async def __anext__(self): |
44 | 44 | raise StopAsyncIteration() |
45 | 45 |
|
46 | 46 |
|
| 47 | +# unfortunately we can't inherit from botocore's StreamingChecksumBody due to |
| 48 | +# subclassing |
| 49 | +class StreamingChecksumBody(StreamingBody): |
| 50 | + def __init__(self, raw_stream, content_length, checksum, expected): |
| 51 | + super().__init__(raw_stream, content_length) |
| 52 | + self._checksum = checksum |
| 53 | + self._expected = expected |
| 54 | + |
| 55 | + async def read(self, amt=None): |
| 56 | + chunk = await super().read(amt=amt) |
| 57 | + self._checksum.update(chunk) |
| 58 | + if amt is None or (not chunk and amt > 0): |
| 59 | + self._validate_checksum() |
| 60 | + return chunk |
| 61 | + |
| 62 | + def _validate_checksum(self): |
| 63 | + if self._checksum.digest() != base64.b64decode(self._expected): |
| 64 | + error_msg = ( |
| 65 | + f"Expected checksum {self._expected} did not match calculated " |
| 66 | + f"checksum: {self._checksum.b64digest()}" |
| 67 | + ) |
| 68 | + raise FlexibleChecksumError(error_msg=error_msg) |
| 69 | + |
| 70 | + |
47 | 71 | async def handle_checksum_body( |
48 | 72 | http_response, response, context, operation_model |
49 | 73 | ): |
@@ -87,6 +111,17 @@ async def handle_checksum_body( |
87 | 111 | ) |
88 | 112 |
|
89 | 113 |
|
| 114 | +def _handle_streaming_response(http_response, response, algorithm): |
| 115 | + checksum_cls = _CHECKSUM_CLS.get(algorithm) |
| 116 | + header_name = "x-amz-checksum-%s" % algorithm |
| 117 | + return StreamingChecksumBody( |
| 118 | + http_response.raw, |
| 119 | + response["headers"].get("content-length"), |
| 120 | + checksum_cls(), |
| 121 | + response["headers"][header_name], |
| 122 | + ) |
| 123 | + |
| 124 | + |
90 | 125 | async def _handle_bytes_response(http_response, response, algorithm): |
91 | 126 | body = await http_response.content |
92 | 127 | header_name = "x-amz-checksum-%s" % algorithm |
|
0 commit comments