diff --git a/openai-java-core/src/main/kotlin/com/openai/core/handlers/ErrorHandler.kt b/openai-java-core/src/main/kotlin/com/openai/core/handlers/ErrorHandler.kt index 6d6ee571c..00a517b34 100644 --- a/openai-java-core/src/main/kotlin/com/openai/core/handlers/ErrorHandler.kt +++ b/openai-java-core/src/main/kotlin/com/openai/core/handlers/ErrorHandler.kt @@ -44,50 +44,45 @@ internal fun errorHandler( errorBodyHandler: Handler> ): Handler = object : Handler { - override fun handle(response: HttpResponse): HttpResponse = - when (val statusCode = response.statusCode()) { - in 200..299 -> response - 400 -> - throw BadRequestException.builder() - .headers(response.headers()) - .error(errorBodyHandler.handle(response)) - .build() - 401 -> - throw UnauthorizedException.builder() - .headers(response.headers()) - .error(errorBodyHandler.handle(response)) - .build() - 403 -> - throw PermissionDeniedException.builder() - .headers(response.headers()) - .error(errorBodyHandler.handle(response)) - .build() - 404 -> - throw NotFoundException.builder() - .headers(response.headers()) - .error(errorBodyHandler.handle(response)) - .build() - 422 -> - throw UnprocessableEntityException.builder() - .headers(response.headers()) - .error(errorBodyHandler.handle(response)) - .build() - 429 -> - throw RateLimitException.builder() - .headers(response.headers()) - .error(errorBodyHandler.handle(response)) - .build() - in 500..599 -> - throw InternalServerException.builder() - .statusCode(statusCode) - .headers(response.headers()) - .error(errorBodyHandler.handle(response)) - .build() - else -> - throw UnexpectedStatusCodeException.builder() - .statusCode(statusCode) - .headers(response.headers()) - .error(errorBodyHandler.handle(response)) - .build() + override fun handle(response: HttpResponse): HttpResponse { + val statusCode = response.statusCode() + if (statusCode in 200..299) { + return response } + + return response.use { + val headers = it.headers() + val error = errorBodyHandler.handle(it) + + when (statusCode) { + 400 -> throw BadRequestException.builder().headers(headers).error(error).build() + 401 -> + throw UnauthorizedException.builder().headers(headers).error(error).build() + 403 -> + throw PermissionDeniedException.builder() + .headers(headers) + .error(error) + .build() + 404 -> throw NotFoundException.builder().headers(headers).error(error).build() + 422 -> + throw UnprocessableEntityException.builder() + .headers(headers) + .error(error) + .build() + 429 -> throw RateLimitException.builder().headers(headers).error(error).build() + in 500..599 -> + throw InternalServerException.builder() + .statusCode(statusCode) + .headers(headers) + .error(error) + .build() + else -> + throw UnexpectedStatusCodeException.builder() + .statusCode(statusCode) + .headers(headers) + .error(error) + .build() + } + } + } } diff --git a/openai-java-core/src/test/kotlin/com/openai/core/handlers/ErrorHandlerTest.kt b/openai-java-core/src/test/kotlin/com/openai/core/handlers/ErrorHandlerTest.kt new file mode 100644 index 000000000..8e4e4867c --- /dev/null +++ b/openai-java-core/src/test/kotlin/com/openai/core/handlers/ErrorHandlerTest.kt @@ -0,0 +1,149 @@ +// File generated from our OpenAPI spec by Stainless. + +package com.openai.core.handlers + +import com.openai.core.JsonField +import com.openai.core.JsonValue +import com.openai.core.http.Headers +import com.openai.core.http.HttpResponse +import com.openai.core.http.HttpResponse.Handler +import com.openai.core.jsonMapper +import com.openai.errors.BadRequestException +import com.openai.errors.InternalServerException +import com.openai.errors.UnexpectedStatusCodeException +import com.openai.models.ErrorObject +import java.io.InputStream +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows + +internal class ErrorHandlerTest { + + @Test + fun `should not close response if status is successful`() { + val response = RecordingHttpResponse(statusCode = 204) + val handler = + errorHandler( + object : Handler> { + override fun handle(response: HttpResponse): JsonField { + error("Error body handler should not be called for successful responses") + } + } + ) + + assertThat(handler.handle(response)).isSameAs(response) + assertThat(response.closed).isFalse() + } + + @Test + fun `should close response if bad request exception is thrown`() { + val response = RecordingHttpResponse(statusCode = 400) + + val e = assertThrows { defaultErrorHandler().handle(response) } + + assertThat(response.closed).isTrue() + assertThat(e.statusCode()).isEqualTo(400) + assertThat(e.headers().values(HEADER_NAME)).containsExactly(HEADER_VALUE) + assertThat(e.body()).isEqualTo(ERROR_BODY) + } + + @Test + fun `should close response if internal server exception is thrown`() { + val response = RecordingHttpResponse(statusCode = 503) + + val e = assertThrows { defaultErrorHandler().handle(response) } + + assertThat(response.closed).isTrue() + assertThat(e.statusCode()).isEqualTo(503) + assertThat(e.headers().values(HEADER_NAME)).containsExactly(HEADER_VALUE) + assertThat(e.body()).isEqualTo(ERROR_BODY) + } + + @Test + fun `should close response if unexpected status exception is thrown`() { + val response = RecordingHttpResponse(statusCode = 999) + + val e = + assertThrows { defaultErrorHandler().handle(response) } + + assertThat(response.closed).isTrue() + assertThat(e.statusCode()).isEqualTo(999) + assertThat(e.headers().values(HEADER_NAME)).containsExactly(HEADER_VALUE) + assertThat(e.body()).isEqualTo(ERROR_BODY) + } + + @Test + fun `should close response if error body handler throws`() { + val response = RecordingHttpResponse(statusCode = 400) + val expected = IllegalStateException("boom") + val handler = + errorHandler( + object : Handler> { + override fun handle(response: HttpResponse): JsonField { + throw expected + } + } + ) + + val e = assertThrows { handler.handle(response) } + + assertThat(response.closed).isTrue() + assertThat(e).isSameAs(expected) + } + + private fun defaultErrorHandler(): Handler = + errorHandler(errorBodyHandler(jsonMapper())) + + private class RecordingHttpResponse( + private val statusCode: Int, + private val bodyBytes: ByteArray = ERROR_JSON_BYTES, + private val headers: Headers = Headers.builder().put(HEADER_NAME, HEADER_VALUE).build(), + ) : HttpResponse { + + var closed: Boolean = false + private set + + override fun statusCode(): Int = statusCode + + override fun headers(): Headers = headers + + override fun body(): InputStream = bodyBytes.inputStream() + + override fun close() { + closed = true + } + } + + companion object { + + private const val HEADER_NAME = "Error-Header" + + private const val HEADER_VALUE = "42" + + private val ERROR_BODY: JsonValue = + JsonValue.from( + mapOf( + "code" to "code", + "message" to "message", + "param" to "param", + "type" to "type", + ) + ) + + private val ERROR_JSON_BYTES: ByteArray = + jsonMapper() + .writeValueAsBytes( + JsonValue.from( + mapOf( + "error" to + mapOf( + "code" to "code", + "message" to "message", + "param" to "param", + "type" to "type", + ) + ) + ) + ) + } +}