Skip to content

Commit 25d16be

Browse files
johanstekristapratico
authored andcommitted
feat(client): support callable api_key (#2588)
Co-authored-by: Krista Pratico <krpratic@microsoft.com>
1 parent 8672413 commit 25d16be

File tree

5 files changed

+188
-25
lines changed

5 files changed

+188
-25
lines changed

src/openai/_client.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import os
6-
from typing import TYPE_CHECKING, Any, Union, Mapping
6+
from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable
77
from typing_extensions import Self, override
88

99
import httpx
@@ -25,6 +25,7 @@
2525
get_async_library,
2626
)
2727
from ._compat import cached_property
28+
from ._models import FinalRequestOptions
2829
from ._version import __version__
2930
from ._streaming import Stream as Stream, AsyncStream as AsyncStream
3031
from ._exceptions import OpenAIError, APIStatusError
@@ -96,7 +97,7 @@ class OpenAI(SyncAPIClient):
9697
def __init__(
9798
self,
9899
*,
99-
api_key: str | None = None,
100+
api_key: str | None | Callable[[], str] = None,
100101
organization: str | None = None,
101102
project: str | None = None,
102103
webhook_secret: str | None = None,
@@ -134,7 +135,12 @@ def __init__(
134135
raise OpenAIError(
135136
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
136137
)
137-
self.api_key = api_key
138+
if callable(api_key):
139+
self.api_key = ""
140+
self._api_key_provider: Callable[[], str] | None = api_key
141+
else:
142+
self.api_key = api_key
143+
self._api_key_provider = None
138144

139145
if organization is None:
140146
organization = os.environ.get("OPENAI_ORG_ID")
@@ -295,6 +301,15 @@ def with_streaming_response(self) -> OpenAIWithStreamedResponse:
295301
def qs(self) -> Querystring:
296302
return Querystring(array_format="brackets")
297303

304+
def _refresh_api_key(self) -> None:
305+
if self._api_key_provider:
306+
self.api_key = self._api_key_provider()
307+
308+
@override
309+
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
310+
self._refresh_api_key()
311+
return super()._prepare_options(options)
312+
298313
@property
299314
@override
300315
def auth_headers(self) -> dict[str, str]:
@@ -318,7 +333,7 @@ def default_headers(self) -> dict[str, str | Omit]:
318333
def copy(
319334
self,
320335
*,
321-
api_key: str | None = None,
336+
api_key: str | Callable[[], str] | None = None,
322337
organization: str | None = None,
323338
project: str | None = None,
324339
webhook_secret: str | None = None,
@@ -356,7 +371,7 @@ def copy(
356371

357372
http_client = http_client or self._client
358373
return self.__class__(
359-
api_key=api_key or self.api_key,
374+
api_key=api_key or self._api_key_provider or self.api_key,
360375
organization=organization or self.organization,
361376
project=project or self.project,
362377
webhook_secret=webhook_secret or self.webhook_secret,
@@ -427,7 +442,7 @@ class AsyncOpenAI(AsyncAPIClient):
427442
def __init__(
428443
self,
429444
*,
430-
api_key: str | None = None,
445+
api_key: str | Callable[[], Awaitable[str]] | None = None,
431446
organization: str | None = None,
432447
project: str | None = None,
433448
webhook_secret: str | None = None,
@@ -465,7 +480,12 @@ def __init__(
465480
raise OpenAIError(
466481
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
467482
)
468-
self.api_key = api_key
483+
if callable(api_key):
484+
self.api_key = ""
485+
self._api_key_provider: Callable[[], Awaitable[str]] | None = api_key
486+
else:
487+
self.api_key = api_key
488+
self._api_key_provider = None
469489

470490
if organization is None:
471491
organization = os.environ.get("OPENAI_ORG_ID")
@@ -626,6 +646,15 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse:
626646
def qs(self) -> Querystring:
627647
return Querystring(array_format="brackets")
628648

649+
async def _refresh_api_key(self) -> None:
650+
if self._api_key_provider:
651+
self.api_key = await self._api_key_provider()
652+
653+
@override
654+
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
655+
await self._refresh_api_key()
656+
return await super()._prepare_options(options)
657+
629658
@property
630659
@override
631660
def auth_headers(self) -> dict[str, str]:
@@ -649,7 +678,7 @@ def default_headers(self) -> dict[str, str | Omit]:
649678
def copy(
650679
self,
651680
*,
652-
api_key: str | None = None,
681+
api_key: str | Callable[[], Awaitable[str]] | None = None,
653682
organization: str | None = None,
654683
project: str | None = None,
655684
webhook_secret: str | None = None,
@@ -687,7 +716,7 @@ def copy(
687716

688717
http_client = http_client or self._client
689718
return self.__class__(
690-
api_key=api_key or self.api_key,
719+
api_key=api_key or self._api_key_provider or self.api_key,
691720
organization=organization or self.organization,
692721
project=project or self.project,
693722
webhook_secret=webhook_secret or self.webhook_secret,

src/openai/lib/azure.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
azure_endpoint: str,
9595
azure_deployment: str | None = None,
9696
api_version: str | None = None,
97-
api_key: str | None = None,
97+
api_key: str | Callable[[], str] | None = None,
9898
azure_ad_token: str | None = None,
9999
azure_ad_token_provider: AzureADTokenProvider | None = None,
100100
organization: str | None = None,
@@ -114,7 +114,7 @@ def __init__(
114114
*,
115115
azure_deployment: str | None = None,
116116
api_version: str | None = None,
117-
api_key: str | None = None,
117+
api_key: str | Callable[[], str] | None = None,
118118
azure_ad_token: str | None = None,
119119
azure_ad_token_provider: AzureADTokenProvider | None = None,
120120
organization: str | None = None,
@@ -134,7 +134,7 @@ def __init__(
134134
*,
135135
base_url: str,
136136
api_version: str | None = None,
137-
api_key: str | None = None,
137+
api_key: str | Callable[[], str] | None = None,
138138
azure_ad_token: str | None = None,
139139
azure_ad_token_provider: AzureADTokenProvider | None = None,
140140
organization: str | None = None,
@@ -154,7 +154,7 @@ def __init__(
154154
api_version: str | None = None,
155155
azure_endpoint: str | None = None,
156156
azure_deployment: str | None = None,
157-
api_key: str | None = None,
157+
api_key: str | Callable[[], str] | None = None,
158158
azure_ad_token: str | None = None,
159159
azure_ad_token_provider: AzureADTokenProvider | None = None,
160160
organization: str | None = None,
@@ -258,7 +258,7 @@ def __init__(
258258
def copy(
259259
self,
260260
*,
261-
api_key: str | None = None,
261+
api_key: str | Callable[[], str] | None = None,
262262
organization: str | None = None,
263263
project: str | None = None,
264264
webhook_secret: str | None = None,
@@ -345,7 +345,7 @@ def _configure_realtime(self, model: str, extra_query: Query) -> tuple[httpx.URL
345345
"api-version": self._api_version,
346346
"deployment": self._azure_deployment or model,
347347
}
348-
if self.api_key != "<missing API key>":
348+
if self.api_key and self.api_key != "<missing API key>":
349349
auth_headers = {"api-key": self.api_key}
350350
else:
351351
token = self._get_azure_ad_token()
@@ -372,7 +372,7 @@ def __init__(
372372
azure_endpoint: str,
373373
azure_deployment: str | None = None,
374374
api_version: str | None = None,
375-
api_key: str | None = None,
375+
api_key: str | Callable[[], Awaitable[str]] | None = None,
376376
azure_ad_token: str | None = None,
377377
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
378378
organization: str | None = None,
@@ -393,7 +393,7 @@ def __init__(
393393
*,
394394
azure_deployment: str | None = None,
395395
api_version: str | None = None,
396-
api_key: str | None = None,
396+
api_key: str | Callable[[], Awaitable[str]] | None = None,
397397
azure_ad_token: str | None = None,
398398
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
399399
organization: str | None = None,
@@ -414,7 +414,7 @@ def __init__(
414414
*,
415415
base_url: str,
416416
api_version: str | None = None,
417-
api_key: str | None = None,
417+
api_key: str | Callable[[], Awaitable[str]] | None = None,
418418
azure_ad_token: str | None = None,
419419
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
420420
organization: str | None = None,
@@ -435,7 +435,7 @@ def __init__(
435435
azure_endpoint: str | None = None,
436436
azure_deployment: str | None = None,
437437
api_version: str | None = None,
438-
api_key: str | None = None,
438+
api_key: str | Callable[[], Awaitable[str]] | None = None,
439439
azure_ad_token: str | None = None,
440440
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
441441
organization: str | None = None,
@@ -539,7 +539,7 @@ def __init__(
539539
def copy(
540540
self,
541541
*,
542-
api_key: str | None = None,
542+
api_key: str | Callable[[], Awaitable[str]] | None = None,
543543
organization: str | None = None,
544544
project: str | None = None,
545545
webhook_secret: str | None = None,
@@ -628,7 +628,7 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt
628628
"api-version": self._api_version,
629629
"deployment": self._azure_deployment or model,
630630
}
631-
if self.api_key != "<missing API key>":
631+
if self.api_key and self.api_key != "<missing API key>":
632632
auth_headers = {"api-key": self.api_key}
633633
else:
634634
token = await self._get_azure_ad_token()

src/openai/resources/beta/realtime/realtime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
358358
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
359359

360360
extra_query = self.__extra_query
361+
await self.__client._refresh_api_key()
361362
auth_headers = self.__client.auth_headers
362363
if is_async_azure_client(self.__client):
363364
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
@@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection:
540541
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
541542

542543
extra_query = self.__extra_query
544+
self.__client._refresh_api_key()
543545
auth_headers = self.__client.auth_headers
544546
if is_azure_client(self.__client):
545547
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)

src/openai/resources/realtime/realtime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
326326
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
327327

328328
extra_query = self.__extra_query
329+
await self.__client._refresh_api_key()
329330
auth_headers = self.__client.auth_headers
330331
if is_async_azure_client(self.__client):
331332
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
@@ -507,6 +508,7 @@ def __enter__(self) -> RealtimeConnection:
507508
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc
508509

509510
extra_query = self.__extra_query
511+
self.__client._refresh_api_key()
510512
auth_headers = self.__client.auth_headers
511513
if is_azure_client(self.__client):
512514
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)

0 commit comments

Comments
 (0)