Base Client
31 removals
710 lines
31 additions
711 lines
from __future__ import annotations
from __future__ import annotations
import sys
import json
import json
import time
import time
import uuid
import uuid
import email
import email
import asyncio
import asyncio
import inspect
import inspect
import logging
import logging
import platform
import platform
import warnings
import warnings
import email.utils
import email.utils
from types import TracebackType
from types import TracebackType
from random import random
from random import random
from typing import (
from typing import (
TYPE_CHECKING,
TYPE_CHECKING,
Any,
Any,
Dict,
Dict,
Type,
Type,
Union,
Union,
Generic,
Generic,
Mapping,
Mapping,
TypeVar,
TypeVar,
Iterable,
Iterable,
Iterator,
Iterator,
Optional,
Optional,
Generator,
Generator,
AsyncIterator,
AsyncIterator,
cast,
cast,
overload,
overload,
)
)
from typing_extensions import Literal, override, get_origin
from typing_extensions import Literal, override, get_origin
import anyio
import anyio
import httpx
import httpx
import distro
import distro
import pydantic
import pydantic
from httpx import URL, Limits
from httpx import URL, Limits
from pydantic import PrivateAttr
from pydantic import PrivateAttr
from . import _exceptions
from . import _exceptions
from ._qs import Querystring
from ._qs import Querystring
from ._files import to_httpx_files, async_to_httpx_files
from ._files import to_httpx_files, async_to_httpx_files
from ._types import (
from ._types import (
NOT_GIVEN,
NOT_GIVEN,
Body,
Body,
Omit,
Omit,
Query,
Query,
Headers,
Headers,
Timeout,
Timeout,
NotGiven,
NotGiven,
ResponseT,
ResponseT,
Transport,
Transport,
AnyMapping,
AnyMapping,
PostParser,
PostParser,
ProxiesTypes,
ProxiesTypes,
RequestFiles,
RequestFiles,
HttpxSendArgs,
HttpxSendArgs,
AsyncTransport,
AsyncTransport,
RequestOptions,
RequestOptions,
HttpxRequestFiles,
HttpxRequestFiles,
ModelBuilderProtocol,
ModelBuilderProtocol,
)
)
from ._utils import is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
from ._utils import SensitiveHeadersFilter, is_dict, is_list, asyncify, is_given, lru_cache, is_mapping
from ._compat import model_copy, model_dump
from ._compat import model_copy, model_dump
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type
from ._response import (
from ._response import (
APIResponse,
APIResponse,
BaseAPIResponse,
BaseAPIResponse,
AsyncAPIResponse,
AsyncAPIResponse,
extract_response_type,
extract_response_type,
)
)
from ._constants import (
from ._constants import (
DEFAULT_TIMEOUT,
DEFAULT_TIMEOUT,
MAX_RETRY_DELAY,
MAX_RETRY_DELAY,
DEFAULT_MAX_RETRIES,
DEFAULT_MAX_RETRIES,
INITIAL_RETRY_DELAY,
INITIAL_RETRY_DELAY,
RAW_RESPONSE_HEADER,
RAW_RESPONSE_HEADER,
OVERRIDE_CAST_TO_HEADER,
OVERRIDE_CAST_TO_HEADER,
DEFAULT_CONNECTION_LIMITS,
DEFAULT_CONNECTION_LIMITS,
)
)
from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder
from ._streaming import Stream, SSEDecoder, AsyncStream, SSEBytesDecoder
from ._exceptions import (
from ._exceptions import (
APIStatusError,
APIStatusError,
APITimeoutError,
APITimeoutError,
APIConnectionError,
APIConnectionError,
APIResponseValidationError,
APIResponseValidationError,
)
)
from ._legacy_response import LegacyAPIResponse
log: logging.Logger = logging.getLogger(__name__)
log: logging.Logger = logging.getLogger(__name__)
log.addFilter(SensitiveHeadersFilter())
# TODO: make base page type vars covariant
# TODO: make base page type vars covariant
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
SyncPageT = TypeVar("SyncPageT", bound="BaseSyncPage[Any]")
AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
AsyncPageT = TypeVar("AsyncPageT", bound="BaseAsyncPage[Any]")
_T = TypeVar("_T")
_T = TypeVar("_T")
_T_co = TypeVar("_T_co", covariant=True)
_T_co = TypeVar("_T_co", covariant=True)
_StreamT = TypeVar("_StreamT", bound=Stream[Any])
_StreamT = TypeVar("_StreamT", bound=Stream[Any])
_AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any])
_AsyncStreamT = TypeVar("_AsyncStreamT", bound=AsyncStream[Any])
if TYPE_CHECKING:
if TYPE_CHECKING:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
else:
else:
try:
try:
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
from httpx._config import DEFAULT_TIMEOUT_CONFIG as HTTPX_DEFAULT_TIMEOUT
except ImportError:
except ImportError:
# taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366
# taken from https://github.com/encode/httpx/blob/3ba5fe0d7ac70222590e759c31442b1cab263791/httpx/_config.py#L366
HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)
HTTPX_DEFAULT_TIMEOUT = Timeout(5.0)
class PageInfo:
class PageInfo:
"""Stores the necessary information to build the request to retrieve the next page.
"""Stores the necessary information to build the request to retrieve the next page.
Either `url` or `params` must be set.
Either `url` or `params` must be set.
"""
"""
url: URL | NotGiven
url: URL | NotGiven
params: Query | NotGiven
params: Query | NotGiven
@overload
@overload
def __init__(
def __init__(
self,
self,
*,
*,
url: URL,
url: URL,
) -> None: ...
) -> None: ...
@overload
@overload
def __init__(
def __init__(
self,
self,
*,
*,
params: Query,
params: Query,
) -> None: ...
) -> None: ...
def __init__(
def __init__(
self,
self,
*,
*,
url: URL | NotGiven = NOT_GIVEN,
url: URL | NotGiven = NOT_GIVEN,
params: Query | NotGiven = NOT_GIVEN,
params: Query | NotGiven = NOT_GIVEN,
) -> None:
) -> None:
self.url = url
self.url = url
self.params = params
self.params = params
@override
def __repr__(self) -> str:
if self.url:
return f"{self.__class__.__name__}(url={self.url})"
return f"{self.__class__.__name__}(params={self.params})"
class BasePage(GenericModel, Generic[_T]):
class BasePage(GenericModel, Generic[_T]):
"""
"""
Defines the core interface for pagination.
Defines the core interface for pagination.
Type Args:
Type Args:
ModelT: The pydantic model that represents an item in the response.
ModelT: The pydantic model that represents an item in the response.
Methods:
Methods:
has_next_page(): Check if there is another page available
has_next_page(): Check if there is another page available
next_page_info(): Get the necessary information to make a request for the next page
next_page_info(): Get the necessary information to make a request for the next page
"""
"""
_options: FinalRequestOptions = PrivateAttr()
_options: FinalRequestOptions = PrivateAttr()
_model: Type[_T] = PrivateAttr()
_model: Type[_T] = PrivateAttr()
def has_next_page(self) -> bool:
def has_next_page(self) -> bool:
items = self._get_page_items()
items = self._get_page_items()
if not items:
if not items:
return False
return False
return self.next_page_info() is not None
return self.next_page_info() is not None
def next_page_info(self) -> Optional[PageInfo]: ...
def next_page_info(self) -> Optional[PageInfo]: ...
def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]
def _get_page_items(self) -> Iterable[_T]: # type: ignore[empty-body]
...
...
def _params_from_url(self, url: URL) -> httpx.QueryParams:
def _params_from_url(self, url: URL) -> httpx.QueryParams:
# TODO: do we have to preprocess params here?
# TODO: do we have to preprocess params here?
return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params)
return httpx.QueryParams(cast(Any, self._options.params)).merge(url.params)
def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
def _info_to_options(self, info: PageInfo) -> FinalRequestOptions:
options = model_copy(self._options)
options = model_copy(self._options)
options._strip_raw_response_header()
options._strip_raw_response_header()
if not isinstance(info.params, NotGiven):
if not isinstance(info.params, NotGiven):
options.params = {**options.params, **info.params}
options.params = {**options.params, **info.params}
return options
return options
if not isinstance(info.url, NotGiven):
if not isinstance(info.url, NotGiven):
params = self._params_from_url(info.url)
params = self._params_from_url(info.url)
url = info.url.copy_with(params=params)
url = info.url.copy_with(params=params)
options.params = dict(url.params)
options.params = dict(url.params)
options.url = str(url)
options.url = str(url)
return options
return options
raise ValueError("Unexpected PageInfo state")
raise ValueError("Unexpected PageInfo state")
class BaseSyncPage(BasePage[_T], Generic[_T]):
class BaseSyncPage(BasePage[_T], Generic[_T]):
_client: SyncAPIClient = pydantic.PrivateAttr()
_client: SyncAPIClient = pydantic.PrivateAttr()
def _set_private_attributes(
def _set_private_attributes(
self,
self,
client: SyncAPIClient,
client: SyncAPIClient,
model: Type[_T],
model: Type[_T],
options: FinalRequestOptions,
options: FinalRequestOptions,
) -> None:
) -> None:
self._model = model
self._model = model
self._client = client
self._client = client
self._options = options
self._options = options
# Pydantic uses a custom `__iter__` method to support casting BaseModels
# Pydantic uses a custom `__iter__` method to support casting BaseModels
# to dictionaries. e.g. dict(model).
# to dictionaries. e.g. dict(model).
# As we want to support `for item in page`, this is inherently incompatible
# As we want to support `for item in page`, this is inherently incompatible
# with the default pydantic behaviour. It is not possible to support both
# with the default pydantic behaviour. It is not possible to support both
# use cases at once. Fortunately, this is not a big deal as all other pydantic
# use cases at once. Fortunately, this is not a big deal as all other pydantic
# methods should continue to work as expected as there is an alternative method
# methods should continue to work as expected as there is an alternative method
# to cast a model to a dictionary, model.dict(), which is used internally
# to cast a model to a dictionary, model.dict(), which is used internally
# by pydantic.
# by pydantic.
def __iter__(self) -> Iterator[_T]: # type: ignore
def __iter__(self) -> Iterator[_T]: # type: ignore
for page in self.iter_pages():
for page in self.iter_pages():
for item in page._get_page_items():
for item in page._get_page_items():
yield item
yield item
def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]:
def iter_pages(self: SyncPageT) -> Iterator[SyncPageT]:
page = self
page = self
while True:
while True:
yield page
yield page
if page.has_next_page():
if page.has_next_page():
page = page.get_next_page()
page = page.get_next_page()
else:
else:
return
return
def get_next_page(self: SyncPageT) -> SyncPageT:
def get_next_page(self: SyncPageT) -> SyncPageT:
info = self.next_page_info()
info = self.next_page_info()
if not info:
if not info:
raise RuntimeError(
raise RuntimeError(
"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
)
)
options = self._info_to_options(info)
options = self._info_to_options(info)
return self._client._request_api_list(self._model, page=self.__class__, options=options)
return self._client._request_api_list(self._model, page=self.__class__, options=options)
class AsyncPaginator(Generic[_T, AsyncPageT]):
class AsyncPaginator(Generic[_T, AsyncPageT]):
def __init__(
def __init__(
self,
self,
client: AsyncAPIClient,
client: AsyncAPIClient,
options: FinalRequestOptions,
options: FinalRequestOptions,
page_cls: Type[AsyncPageT],
page_cls: Type[AsyncPageT],
model: Type[_T],
model: Type[_T],
) -> None:
) -> None:
self._model = model
self._model = model
self._client = client
self._client = client
self._options = options
self._options = options
self._page_cls = page_cls
self._page_cls = page_cls
def __await__(self) -> Generator[Any, None, AsyncPageT]:
def __await__(self) -> Generator[Any, None, AsyncPageT]:
return self._get_page().__await__()
return self._get_page().__await__()
async def _get_page(self) -> AsyncPageT:
async def _get_page(self) -> AsyncPageT:
def _parser(resp: AsyncPageT) -> AsyncPageT:
def _parser(resp: AsyncPageT) -> AsyncPageT:
resp._set_private_attributes(
resp._set_private_attributes(
model=self._model,
model=self._model,
options=self._options,
options=self._options,
client=self._client,
client=self._client,
)
)
return resp
return resp
self._options.post_parser = _parser
self._options.post_parser = _parser
return await self._client.request(self._page_cls, self._options)
return await self._client.request(self._page_cls, self._options)
async def __aiter__(self) -> AsyncIterator[_T]:
async def __aiter__(self) -> AsyncIterator[_T]:
# https://github.com/microsoft/pyright/issues/3464
# https://github.com/microsoft/pyright/issues/3464
page = cast(
page = cast(
AsyncPageT,
AsyncPageT,
await self, # type: ignore
await self, # type: ignore
)
)
async for item in page:
async for item in page:
yield item
yield item
class BaseAsyncPage(BasePage[_T], Generic[_T]):
class BaseAsyncPage(BasePage[_T], Generic[_T]):
_client: AsyncAPIClient = pydantic.PrivateAttr()
_client: AsyncAPIClient = pydantic.PrivateAttr()
def _set_private_attributes(
def _set_private_attributes(
self,
self,
model: Type[_T],
model: Type[_T],
client: AsyncAPIClient,
client: AsyncAPIClient,
options: FinalRequestOptions,
options: FinalRequestOptions,
) -> None:
) -> None:
self._model = model
self._model = model
self._client = client
self._client = client
self._options = options
self._options = options
async def __aiter__(self) -> AsyncIterator[_T]:
async def __aiter__(self) -> AsyncIterator[_T]:
async for page in self.iter_pages():
async for page in self.iter_pages():
for item in page._get_page_items():
for item in page._get_page_items():
yield item
yield item
async def iter_pages(self: AsyncPageT) -> AsyncIterator[AsyncPageT]:
async def iter_pages(self: AsyncPageT) -> AsyncIterator[AsyncPageT]:
page = self
page = self
while True:
while True:
yield page
yield page
if page.has_next_page():
if page.has_next_page():
page = await page.get_next_page()
page = await page.get_next_page()
else:
else:
return
return
async def get_next_page(self: AsyncPageT) -> AsyncPageT:
async def get_next_page(self: AsyncPageT) -> AsyncPageT:
info = self.next_page_info()
info = self.next_page_info()
if not info:
if not info:
raise RuntimeError(
raise RuntimeError(
"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
"No next page expected; please check `.has_next_page()` before calling `.get_next_page()`."
)
)
options = self._info_to_options(info)
options = self._info_to_options(info)
return await self._client._request_api_list(self._model, page=self.__class__, options=options)
return await self._client._request_api_list(self._model, page=self.__class__, options=options)
_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
_HttpxClientT = TypeVar("_HttpxClientT", bound=Union[httpx.Client, httpx.AsyncClient])
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
_DefaultStreamT = TypeVar("_DefaultStreamT", bound=Union[Stream[Any], AsyncStream[Any]])
class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
_client: _HttpxClientT
_client: _HttpxClientT
_version: str
_version: str
_base_url: URL
_base_url: URL
max_retries: int
max_retries: int
timeout: Union[float, Timeout, None]
timeout: Union[float, Timeout, None]
_limits: httpx.Limits
_limits: httpx.Limits
_proxies: ProxiesTypes | None
_proxies: ProxiesTypes | None
_transport: Transport | AsyncTransport | None
_transport: Transport | AsyncTransport | None
_strict_response_validation: bool
_strict_response_validation: bool
_idempotency_header: str | None
_idempotency_header: str | None
_default_stream_cls: type[_DefaultStreamT] | None = None
_default_stream_cls: type[_DefaultStreamT] | None = None
def __init__(
def __init__(
self,
self,
*,
*,
version: str,
version: str,
base_url: str | URL,
base_url: str | URL,
_strict_response_validation: bool,
_strict_response_validation: bool,
max_retries: int = DEFAULT_MAX_RETRIES,
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
timeout: float | Timeout | None = DEFAULT_TIMEOUT,
limits: httpx.Limits,
limits: httpx.Limits,
transport: Transport | AsyncTransport | None,
transport: Transport | AsyncTransport | None,
proxies: ProxiesTypes | None,
proxies: ProxiesTypes | None,
custom_headers: Mapping[str, str] | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
custom_query: Mapping[str, object] | None = None,
) -> None:
) -> None:
self._version = version
self._version = version
self._base_url = self._enforce_trailing_slash(URL(base_url))
self._base_url = self._enforce_trailing_slash(URL(base_url))
self.max_retries = max_retries
self.max_retries = max_retries
self.timeout = timeout
self.timeout = timeout
self._limits = limits
self._limits = limits
self._proxies = proxies
self._proxies = proxies
self._transport = transport
self._transport = transport
self._custom_headers = custom_headers or {}
self._custom_headers = custom_headers or {}
self._custom_query = custom_query or {}
self._custom_query = custom_query or {}
self._strict_response_validation = _strict_response_validation
self._strict_response_validation = _strict_response_validation
self._idempotency_header = None
self._idempotency_header = None
self._platform: Platform | None = None
self._platform: Platform | None = None
if max_retries is None: # pyright: ignore[reportUnnecessaryComparison]
if max_retries is None: # pyright: ignore[reportUnnecessaryComparison]
raise TypeError(
raise TypeError(
"max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `krutrim_cloud.DEFAULT_MAX_RETRIES`"
"max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `openai.DEFAULT_MAX_RETRIES`"
)
)
def _enforce_trailing_slash(self, url: URL) -> URL:
def _enforce_trailing_slash(self, url: URL) -> URL:
if url.raw_path.endswith(b"/"):
if url.raw_path.endswith(b"/"):
return url
return url
return url.copy_with(raw_path=url.raw_path + b"/")
return url.copy_with(raw_path=url.raw_path + b"/")
def _make_status_error_from_response(
def _make_status_error_from_response(
self,
self,
response: httpx.Response,
response: httpx.Response,
) -> APIStatusError:
) -> APIStatusError:
if response.is_closed and not response.is_stream_consumed:
if response.is_closed and not response.is_stream_consumed:
# We can't read the response body as it has been closed
# We can't read the response body as it has been closed
# before it was read. This can happen if an event hook
# before it was read. This can happen if an event hook
# raises a status error.
# raises a status error.
body = None
body = None
err_msg = f"Error code: {response.status_code}"
err_msg = f"Error code: {response.status_code}"
else:
else:
err_text = response.text.strip()
err_text = response.text.strip()
body = err_text
body = err_text
try:
try:
body = json.loads(err_text)
body = json.loads(err_text)
err_msg = f"Error code: {response.status_code} - {body}"
err_msg = f"Error code: {response.status_code} - {body}"
except Exception:
except Exception:
err_msg = err_text or f"Error code: {response.status_code}"
err_msg = err_text or f"Error code: {response.status_code}"
return self._make_status_error(err_msg, body=body, response=response)
return self._make_status_error(err_msg, body=body, response=response)
def _make_status_error(
def _make_status_error(
self,
self,
err_msg: str,
err_msg: str,
*,
*,
body: object,
body: object,
response: httpx.Response,
response: httpx.Response,
) -> _exceptions.APIStatusError:
) -> _exceptions.APIStatusError:
raise NotImplementedError()
raise NotImplementedError()
def _remaining_retries(
def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0) -> httpx.Headers:
self,
remaining_retries: Optional[int],
options: FinalRequestOptions,
) -> int:
return remaining_retries if remaining_retries is not None else options.get_max_retries(self.max_retries)
def _build_headers(self, options: FinalRequestOptions) -> httpx.Headers:
custom_headers = options.headers or {}
custom_headers = options.headers or {}
headers_dict = _merge_mappings(self.default_headers, custom_headers)
headers_dict = _merge_mappings(self.default_headers, custom_headers)
self._validate_headers(headers_dict, custom_headers)
self._validate_headers(headers_dict, custom_headers)
# headers are case-insensitive while dictionaries are not.
# headers are case-insensitive while dictionaries are not.
headers = httpx.Headers(headers_dict)
headers = httpx.Headers(headers_dict)
idempotency_header = self._idempotency_header
idempotency_header = self._idempotency_header
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()
# Don't set the retry count header if it was already set or removed by the caller. We check
# `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case.
if "x-stainless-retry-count" not in (header.lower() for header in custom_headers):
headers["x-stainless-retry-count"] = str(retries_taken)
return headers
return headers
def _prepare_url(self, url: str) -> URL:
def _prepare_url(self, url: str) -> URL:
"""
"""
Merge a URL argument together with any 'base_url' on the client,
Merge a URL argument together with any 'base_url' on the client,
to create the URL used for the outgoing request.
to create the URL used for the outgoing request.
"""
"""
# Copied from httpx's `_merge_url` method.
# Copied from httpx's `_merge_url` method.
merge_url = URL(url)
merge_url = URL(url)
if merge_url.is_relative_url:
if merge_url.is_relative_url:
merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b"/")
merge_raw_path = self.base_url.raw_path + merge_url.raw_path.lstrip(b"/")
return self.base_url.copy_with(raw_path=merge_raw_path)
return self.base_url.copy_with(raw_path=merge_raw_path)
return merge_url
return merge_url
def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
def _make_sse_decoder(self) -> SSEDecoder | SSEBytesDecoder:
return SSEDecoder()
return SSEDecoder()
def _build_request(
def _build_request(
self,
self,
options: FinalRequestOptions,
options: FinalRequestOptions,
*,
retries_taken: int = 0,
) -> httpx.Request:
) -> httpx.Request:
if log.isEnabledFor(logging.DEBUG):
if log.isEnabledFor(logging.DEBUG):
log.debug("Request options: %s", model_dump(options, exclude_unset=True))
log.debug("Request options: %s", model_dump(options, exclude_unset=True))
kwargs: dict[str, Any] = {}
kwargs: dict[str, Any] = {}
json_data = options.json_data
json_data = options.json_data
if options.extra_json is not None:
if options.extra_json is not None:
if json_data is None:
if json_data is None:
json_data = cast(Body, options.extra_json)
json_data = cast(Body, options.extra_json)
elif is_mapping(json_data):
elif is_mapping(json_data):
json_data = _merge_mappings(json_data, options.extra_json)
json_data = _merge_mappings(json_data, options.extra_json)
else:
else:
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
headers = self._build_headers(options)
headers = self._build_headers(options, retries_taken=retries_taken)
params = _merge_mappings(self.default_query, options.params)
params = _merge_mappings(self.default_query, options.params)
content_type = headers.get("Content-Type")
content_type = headers.get("Content-Type")
files = options.files
files = options.files
# If the given Content-Type header is multipart/form-data then it
# If the given Content-Type header is multipart/form-data then it
# has to be removed so that httpx can generate the header with
# has to be removed so that httpx can generate the header with
# additional information for us as it has to be in this form
# additional information for us as it has to be in this form
# for the server to be able to correctly parse the request:
# for the server to be able to correctly parse the request:
# multipart/form-data; boundary=---abc--
# multipart/form-data; boundary=---abc--
if content_type is not None and content_type.startswith("multipart/form-data"):
if content_type is not None and content_type.startswith("multipart/form-data"):
if "boundary" not in content_type:
if "boundary" not in content_type:
# only remove the header if the boundary hasn't been explicitly set
# only remove the header if the boundary hasn't been explicitly set
# as the caller doesn't want httpx to come up with their own boundary
# as the caller doesn't want httpx to come up with their own boundary
headers.pop("Content-Type")
headers.pop("Content-Type")
# As we are now sending multipart/form-data instead of application/json
# As we are now sending multipart/form-data instead of application/json
# we need to tell httpx to use it, https://www.python-httpx.org/advanced/clients/#multipart-file-encoding
# we need to tell httpx to use it, https://www.python-httpx.org/advanced/clients/#multipart-file-encoding
if json_data:
if json_data:
if not is_dict(json_data):
if not is_dict(json_data):
raise TypeError(
raise TypeError(
f"Expected query input to be a dictionary for multipart requests but got {type(json_data)} instead."
f"Expected query input to be a dictionary for multipart requests but got {type(json_data)} instead."
)
)
kwargs["data"] = self._serialize_multipartform(json_data)
kwargs["data"] = self._serialize_multipartform(json_data)
# httpx determines whether or not to send a "multipart/form-data"
# httpx determines whether or not to send a "multipart/form-data"
# request based on the truthiness of the "files" argument.
# request based on the truthiness of the "files" argument.
# This gets around that issue by generating a dict value that
# This gets around that issue by generating a dict value that
# evaluates to true.
# evaluates to true.
#
#
# https://github.com/encode/httpx/discussions/2399#discussioncomment-3814186
# https://github.com/encode/httpx/discussions/2399#discussioncomment-3814186
if not files:
if not files:
files = cast(HttpxRequestFiles, ForceMultipartDict())
files = cast(HttpxRequestFiles, ForceMultipartDict())
prepared_url = self._prepare_url(options.url)
if "_" in prepared_url.host:
# work around https://github.com/encode/httpx/discussions/2880
kwargs["extensions"] = {"sni_hostname": prepared_url.host.replace("_", "-")}
# TODO: report this error to httpx
# TODO: report this error to httpx
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
headers=headers,
headers=headers,
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
method=options.method,
method=options.method,
url=self._prepare_url(options.url),
url=prepared_url,
# the `Query` type that we use is incompatible with qs'
# the `Query` type that we use is incompatible with qs'
# `Params` type as it needs to be typed as `Mapping[str, object]`
# `Params` type as it needs to be typed as `Mapping[str, object]`
# so that passing a `TypedDict` doesn't cause an error.
# so that passing a `TypedDict` doesn't cause an error.
# https://github.com/microsoft/pyright/issues/3526#event-6715453066
# https://github.com/microsoft/pyright/issues/3526#event-6715453066
params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None,
params=self.qs.stringify(cast(Mapping[str, Any], params)) if params else None,
json=json_data,
json=json_data,
files=files,
files=files,
**kwargs,
**kwargs,
)
)
def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, object]:
items = self.qs.stringify_items(
items = self.qs.stringify_items(
# TODO: type ignore is required as stringify_items is well typed but we can't be
# TODO: type ignore is required as stringify_items is well typed but we can't be
# well typed without heavy validation.
# well typed without heavy validation.
data, # type: ignore
data, # type: ignore
array_format="brackets",
array_format="brackets",
)
)
serialized: dict[str, object] = {}
serialized: dict[str, object] = {}
for key, value in items:
for key, value in items:
existing = serialized.get(key)
existing = serialized.get(key)
if not existing:
if not existing:
serialized[key] = value
serialized[key] = value
continue
continue
# If a value has already been set for this key then that
# If a value has already been set for this key then that
# means we're sending data like `array[]=[1, 2, 3]` and we
# means we're sending data like `array[]=[1, 2, 3]` and we
# need to tell httpx that we want to send multiple values with
# need to tell httpx that we want to send multiple values with
# the same key which is done by using a list or a tuple.
# the same key which is done by using a list or a tuple.
#
#
# Note: 2d arrays should never result in the same key at both
# Note: 2d arrays should never result in the same key at both
# levels so it's safe to assume that if the value is a list,
# levels so it's safe to assume that if the value is a list,
# it was because we changed it to be a list.
# it was because we changed it to be a list.
if is_list(existing):
if is_list(existing):
existing.append(value)
existing.append(value)
else:
else:
serialized[key] = [existing, value]
serialized[key] = [existing, value]
return serialized
return serialized
def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalRequestOptions) -> type[ResponseT]:
def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalRequestOptions) -> type[ResponseT]:
if not is_given(options.headers):
if not is_given(options.headers):
return cast_to
return cast_to
# make a copy of the headers so we don't mutate user-input
# make a copy of the headers so we don't mutate user-input
headers = dict(options.headers)
headers = dict(options.headers)
# we internally support defining a temporary header to override the
# we internally support defining a temporary header to override the
# default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response`
# default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response`
# see _response.py for implementation details
# see _response.py for implementation details
override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN)
override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN)
if is_given(override_cast_to):
if is_given(override_cast_to):
options.headers = headers
options.headers = headers
return cast(Type[ResponseT], override_cast_to)
return cast(Type[ResponseT], override_cast_to)
return cast_to
return cast_to
def _should_stream_response_body(self, request: httpx.Request) -> bool:
def _should_stream_response_body(self, request: httpx.Request) -> bool:
return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return]
return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return]
def _process_response_data(
def _process_response_data(
self,
self,
*,
*,
data: object,
data: object,
cast_to: type[ResponseT],
cast_to: type[ResponseT],
response: httpx.Response,
response: httpx.Response,
) -> ResponseT:
) -> ResponseT:
if data is None:
if data is None:
return cast(ResponseT, None)
return cast(ResponseT, None)
if cast_to is object:
if cast_to is object:
return cast(ResponseT, data)
return cast(ResponseT, data)
try:
try:
if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol):
if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol):
return cast(ResponseT, cast_to.build(response=response, data=data))
return cast(ResponseT, cast_to.build(response=response, data=data))
if self._strict_response_validation:
if self._strict_response_validation:
return cast(ResponseT, validate_type(type_=cast_to, value=data))
return cast(ResponseT, validate_type(type_=cast_to, value=data))
return cast(ResponseT, construct_type(type_=cast_to, value=data))
return cast(ResponseT, construct_type(type_=cast_to, value=data))
except pydantic.ValidationError as err:
except pydantic.ValidationError as err:
raise APIResponseValidationError(response=response, body=data) from err
raise APIResponseValidationError(response=response, body=data) from err
@property
@property
def qs(self) -> Querystring:
def qs(self) -> Querystring:
return Querystring()
return Querystring()
@property
@property
def custom_auth(self) -> httpx.Auth | None:
def custom_auth(self) -> httpx.Auth | None:
return None
return None
@property
@property
def auth_headers(self) -> dict[str, str]:
def auth_headers(self) -> dict[str, str]:
return {}
return {}
@property
@property
def default_headers(self) -> dict[str, str | Omit]:
def default_headers(self) -> dict[str, str | Omit]:
return {
return {
"Accept": "application/json",
"Accept": "application/json",
"Content-Type": "application/json",
"Content-Type": "application/json",
"User-Agent": self.user_agent,
"User-Agent": self.user_agent,
**self.platform_headers(),
**self.platform_headers(),
**self.auth_headers,
**self.auth_headers,
**self._custom_headers,
**self._custom_headers,
}
}
@property
@property
def default_query(self) -> dict[str, object]:
def default_query(self) -> dict[str, object]:
return {
return {
**self._custom_query,
**self._custom_query,
}
}
def _validate_headers(
def _validate_headers(
self,
self,
headers: Headers, # noqa: ARG002
headers: Headers, # noqa: ARG002
custom_headers: Headers, # noqa: ARG002
custom_headers: Headers, # noqa: ARG002
) -> None:
) -> None:
"""Validate the given default headers and custom headers.
"""Validate the given default headers and custom headers.
Does nothing by default.
Does nothing by default.
"""
"""
return
return
@property
@property
def user_agent(self) -> str:
def user_agent(self) -> str:
return f"{self.__class__.__name__}/Python {self._version}"
return f"{self.__class__.__name__}/Python {self._version}"
@property
@property
def base_url(self) -> URL:
def base_url(self) -> URL:
return self._base_url
return self._base_url
@base_url.setter
@base_url.setter
def base_url(self, url: URL | str) -> None:
def base_url(self, url: URL | str) -> None:
self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))
self._base_url = self._enforce_trailing_slash(url if isinstance(url, URL) else URL(url))
def platform_headers(self) -> Dict[str, str]:
def platform_headers(self) -> Dict[str, str]:
# the actual implementation is in a separate `lru_cache` decorated
# the actual implementation is in a separate `lru_cache` decorated
# function because adding `lru_cache` to methods will leak memory
# function because adding `lru_cache` to methods will leak memory
# https://github.com/python/cpython/issues/88476
# https://github.com/python/cpython/issues/88476
return platform_headers(self._version, platform=self._platform)
return platform_headers(self._version, platform=self._platform)
def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None:
def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None:
"""Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified.
"""Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified.
About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
See also https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax
See also https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax
"""
"""
if response_headers is None:
if response_headers is None:
return None
return None
# First, try the non-standard `retry-after-ms` header for milliseconds,
# First, try the non-standard `retry-after-ms` header for milliseconds,
# which is more precise than integer-seconds `retry-after`
# which is more precise than integer-seconds `retry-after`
try:
try:
retry_ms_header = response_headers.get("retry-after-ms", None)
retry_ms_header = response_headers.get("retry-after-ms", None)
return float(retry_ms_header) / 1000
return float(retry_ms_header) / 1000
except (TypeError, ValueError):
except (TypeError, ValueError):
pass # Ignore TypeError and ValueError as not mandatory to get retry-after-ms
pass
except Exception:
raise Exception("Error occurred in parsing retry-after-ms from response_headers")
# Next, try parsing `retry-after` header as seconds (allowing nonstandard floats).
# Next, try parsing `retry-after` header as seconds (allowing nonstandard floats).
retry_header = response_headers.get("retry-after")
retry_header = response_headers.get("retry-after")
try:
try:
# note: the spec indicates that this should only ever be an integer
# note: the spec indicates that this should only ever be an integer
# but if someone sends a float there's no reason for us to not respect it
# but if someone sends a float there's no reason for us to not respect it
return float(retry_header)
return float(retry_header)
except (TypeError, ValueError):
except (TypeError, ValueError):
pass # Ignore TypeError and ValueError as not mandatory to get retry-after
pass
except Exception:
raise Exception("Error occurred in parsing retry-after from response_headers")
# Last, try parsing `retry-after` as a date.
# Last, try parsing `retry-after` as a date.
retry_date_tuple = email.utils.parsedate_tz(retry_header)
retry_date_tuple = email.utils.parsedate_tz(retry_header)
if retry_date_tuple is None:
if retry_date_tuple is None:
return None
return None
retry_date = email.utils.mktime_tz(retry_date_tuple)
retry_date = email.utils.mktime_tz(retry_date_tuple)
return float(retry_date - time.time())
return float(retry_date - time.time())
def _calculate_retry_timeout(
def _calculate_retry_timeout(
self,
self,
remaining_retries: int,
remaining_retries: int,
options: FinalRequestOptions,
options: FinalRequestOptions,
response_headers: Optional[httpx.Headers] = None,
response_headers: Optional[httpx.Headers] = None,
) -> float:
) -> float:
max_retries = options.get_max_retries(self.max_retries)
max_retries = options.get_max_retries(self.max_retries)
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
retry_after = self._parse_retry_after_header(response_headers)
retry_after = self._parse_retry_after_header(response_headers)
if retry_after is not None and 0 < retry_after <= 60:
if retry_after is not None and 0 < retry_after <= 60:
return retry_after
return retry_after
nb_retries = max_retries - remaining_retries
# Also cap retry count to 1000 to avoid any potential overflows with `pow`
nb_retries = min(max_retries - remaining_retries, 1000)
# Apply exponential backoff, but not more than the max.
# Apply exponential backoff, but not more than the max.
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)
# Apply some jitter, plus-or-minus half a second.
# Apply some jitter, plus-or-minus half a second.
jitter = 1 - 0.25 * random()
jitter = 1 - 0.25 * random()
timeout = sleep_seconds * jitter
timeout = sleep_seconds * jitter
return timeout if timeout >= 0 else 0
return timeout if timeout >= 0 else 0
def _should_retry(self, response: httpx.Response) -> bool:
def _should_retry(self, response: httpx.Respo
# Note: this is not a standard header
should_retry_header = response.headers.get("x-should-retry")
# If the server explicitly says whether or not to retry, obey.
if should_retry_header == "true":
log.debug("Retrying as header `x-should-retry` is set to `true`")
return True
if should_retry_header == "false":
log.debug("Not retrying as header `x-should-retry` is set to `false`")