Skip to content

Commit 791c567

Browse files
committed
refactor(tests): share snapshot utils
1 parent 8107db8 commit 791c567

File tree

4 files changed

+118
-105
lines changed

4 files changed

+118
-105
lines changed

tests/lib/chat/test_completions.py

Lines changed: 17 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
from __future__ import annotations
22

3-
import os
4-
import json
53
from enum import Enum
6-
from typing import Any, List, Callable, Optional, Awaitable
4+
from typing import List, Optional
75
from typing_extensions import Literal, TypeVar
86

9-
import httpx
107
import pytest
118
from respx import MockRouter
129
from pydantic import Field, BaseModel
@@ -17,8 +14,9 @@
1714
from openai._utils import assert_signatures_in_sync
1815
from openai._compat import PYDANTIC_V2
1916

20-
from ._utils import print_obj, get_snapshot_value
17+
from ..utils import print_obj
2118
from ...conftest import base_url
19+
from ..snapshots import make_snapshot_request, make_async_snapshot_request
2220
from ..schema_types.query import Query
2321

2422
_T = TypeVar("_T")
@@ -32,7 +30,7 @@
3230

3331
@pytest.mark.respx(base_url=base_url)
3432
def test_parse_nothing(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
35-
completion = _make_snapshot_request(
33+
completion = make_snapshot_request(
3634
lambda c: c.chat.completions.parse(
3735
model="gpt-4o-2024-08-06",
3836
messages=[
@@ -100,7 +98,7 @@ class Location(BaseModel):
10098
temperature: float
10199
units: Literal["c", "f"]
102100

103-
completion = _make_snapshot_request(
101+
completion = make_snapshot_request(
104102
lambda c: c.chat.completions.parse(
105103
model="gpt-4o-2024-08-06",
106104
messages=[
@@ -170,7 +168,7 @@ class Location(BaseModel):
170168
temperature: float
171169
units: Optional[Literal["c", "f"]] = None
172170

173-
completion = _make_snapshot_request(
171+
completion = make_snapshot_request(
174172
lambda c: c.chat.completions.parse(
175173
model="gpt-4o-2024-08-06",
176174
messages=[
@@ -247,7 +245,7 @@ class ColorDetection(BaseModel):
247245
if not PYDANTIC_V2:
248246
ColorDetection.update_forward_refs(**locals()) # type: ignore
249247

250-
completion = _make_snapshot_request(
248+
completion = make_snapshot_request(
251249
lambda c: c.chat.completions.parse(
252250
model="gpt-4o-2024-08-06",
253251
messages=[
@@ -292,7 +290,7 @@ class Location(BaseModel):
292290
temperature: float
293291
units: Literal["c", "f"]
294292

295-
completion = _make_snapshot_request(
293+
completion = make_snapshot_request(
296294
lambda c: c.chat.completions.parse(
297295
model="gpt-4o-2024-08-06",
298296
messages=[
@@ -375,7 +373,7 @@ class CalendarEvent:
375373
date: str
376374
participants: List[str]
377375

378-
completion = _make_snapshot_request(
376+
completion = make_snapshot_request(
379377
lambda c: c.chat.completions.parse(
380378
model="gpt-4o-2024-08-06",
381379
messages=[
@@ -436,7 +434,7 @@ class CalendarEvent:
436434

437435
@pytest.mark.respx(base_url=base_url)
438436
def test_pydantic_tool_model_all_types(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
439-
completion = _make_snapshot_request(
437+
completion = make_snapshot_request(
440438
lambda c: c.chat.completions.parse(
441439
model="gpt-4o-2024-08-06",
442440
messages=[
@@ -521,7 +519,7 @@ class Location(BaseModel):
521519
units: Literal["c", "f"]
522520

523521
with pytest.raises(openai.LengthFinishReasonError):
524-
_make_snapshot_request(
522+
make_snapshot_request(
525523
lambda c: c.chat.completions.parse(
526524
model="gpt-4o-2024-08-06",
527525
messages=[
@@ -548,7 +546,7 @@ class Location(BaseModel):
548546
temperature: float
549547
units: Literal["c", "f"]
550548

551-
completion = _make_snapshot_request(
549+
completion = make_snapshot_request(
552550
lambda c: c.chat.completions.parse(
553551
model="gpt-4o-2024-08-06",
554552
messages=[
@@ -596,7 +594,7 @@ class GetWeatherArgs(BaseModel):
596594
country: str
597595
units: Literal["c", "f"] = "c"
598596

599-
completion = _make_snapshot_request(
597+
completion = make_snapshot_request(
600598
lambda c: c.chat.completions.parse(
601599
model="gpt-4o-2024-08-06",
602600
messages=[
@@ -662,7 +660,7 @@ class GetStockPrice(BaseModel):
662660
ticker: str
663661
exchange: str
664662

665-
completion = _make_snapshot_request(
663+
completion = make_snapshot_request(
666664
lambda c: c.chat.completions.parse(
667665
model="gpt-4o-2024-08-06",
668666
messages=[
@@ -733,7 +731,7 @@ class GetStockPrice(BaseModel):
733731

734732
@pytest.mark.respx(base_url=base_url)
735733
def test_parse_strict_tools(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
736-
completion = _make_snapshot_request(
734+
completion = make_snapshot_request(
737735
lambda c: c.chat.completions.parse(
738736
model="gpt-4o-2024-08-06",
739737
messages=[
@@ -830,7 +828,7 @@ class Location(BaseModel):
830828
temperature: float
831829
units: Literal["c", "f"]
832830

833-
response = _make_snapshot_request(
831+
response = make_snapshot_request(
834832
lambda c: c.chat.completions.with_raw_response.parse(
835833
model="gpt-4o-2024-08-06",
836834
messages=[
@@ -906,7 +904,7 @@ class Location(BaseModel):
906904
temperature: float
907905
units: Literal["c", "f"]
908906

909-
response = await _make_async_snapshot_request(
907+
response = await make_async_snapshot_request(
910908
lambda c: c.chat.completions.with_raw_response.parse(
911909
model="gpt-4o-2024-08-06",
912910
messages=[
@@ -981,87 +979,3 @@ def test_parse_method_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpe
981979
checking_client.chat.completions.parse,
982980
exclude_params={"response_format", "stream"},
983981
)
984-
985-
986-
def _make_snapshot_request(
987-
func: Callable[[OpenAI], _T],
988-
*,
989-
content_snapshot: Any,
990-
respx_mock: MockRouter,
991-
mock_client: OpenAI,
992-
) -> _T:
993-
live = os.environ.get("OPENAI_LIVE") == "1"
994-
if live:
995-
996-
def _on_response(response: httpx.Response) -> None:
997-
# update the content snapshot
998-
assert json.dumps(json.loads(response.read())) == content_snapshot
999-
1000-
respx_mock.stop()
1001-
1002-
client = OpenAI(
1003-
http_client=httpx.Client(
1004-
event_hooks={
1005-
"response": [_on_response],
1006-
}
1007-
)
1008-
)
1009-
else:
1010-
respx_mock.post("/chat/completions").mock(
1011-
return_value=httpx.Response(
1012-
200,
1013-
content=get_snapshot_value(content_snapshot),
1014-
headers={"content-type": "application/json"},
1015-
)
1016-
)
1017-
1018-
client = mock_client
1019-
1020-
result = func(client)
1021-
1022-
if live:
1023-
client.close()
1024-
1025-
return result
1026-
1027-
1028-
async def _make_async_snapshot_request(
1029-
func: Callable[[AsyncOpenAI], Awaitable[_T]],
1030-
*,
1031-
content_snapshot: Any,
1032-
respx_mock: MockRouter,
1033-
mock_client: AsyncOpenAI,
1034-
) -> _T:
1035-
live = os.environ.get("OPENAI_LIVE") == "1"
1036-
if live:
1037-
1038-
async def _on_response(response: httpx.Response) -> None:
1039-
# update the content snapshot
1040-
assert json.dumps(json.loads(await response.aread())) == content_snapshot
1041-
1042-
respx_mock.stop()
1043-
1044-
client = AsyncOpenAI(
1045-
http_client=httpx.AsyncClient(
1046-
event_hooks={
1047-
"response": [_on_response],
1048-
}
1049-
)
1050-
)
1051-
else:
1052-
respx_mock.post("/chat/completions").mock(
1053-
return_value=httpx.Response(
1054-
200,
1055-
content=get_snapshot_value(content_snapshot),
1056-
headers={"content-type": "application/json"},
1057-
)
1058-
)
1059-
1060-
client = mock_client
1061-
1062-
result = await func(client)
1063-
1064-
if live:
1065-
await client.close()
1066-
1067-
return result

tests/lib/chat/test_completions_streaming.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
)
3131
from openai.lib._parsing._completions import ResponseFormatT
3232

33-
from ._utils import print_obj, get_snapshot_value
33+
from ..utils import print_obj, get_snapshot_value
3434
from ...conftest import base_url
3535

3636
_T = TypeVar("_T")

tests/lib/snapshots.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import json
5+
from typing import Any, Callable, Awaitable
6+
from typing_extensions import TypeVar
7+
8+
import httpx
9+
from respx import MockRouter
10+
11+
from openai import OpenAI, AsyncOpenAI
12+
13+
from .utils import get_snapshot_value
14+
15+
_T = TypeVar("_T")
16+
17+
18+
def make_snapshot_request(
19+
func: Callable[[OpenAI], _T],
20+
*,
21+
content_snapshot: Any,
22+
respx_mock: MockRouter,
23+
mock_client: OpenAI,
24+
) -> _T:
25+
live = os.environ.get("OPENAI_LIVE") == "1"
26+
if live:
27+
28+
def _on_response(response: httpx.Response) -> None:
29+
# update the content snapshot
30+
assert json.dumps(json.loads(response.read())) == content_snapshot
31+
32+
respx_mock.stop()
33+
34+
client = OpenAI(
35+
http_client=httpx.Client(
36+
event_hooks={
37+
"response": [_on_response],
38+
}
39+
)
40+
)
41+
else:
42+
respx_mock.post("/chat/completions").mock(
43+
return_value=httpx.Response(
44+
200,
45+
content=get_snapshot_value(content_snapshot),
46+
headers={"content-type": "application/json"},
47+
)
48+
)
49+
50+
client = mock_client
51+
52+
result = func(client)
53+
54+
if live:
55+
client.close()
56+
57+
return result
58+
59+
60+
async def make_async_snapshot_request(
61+
func: Callable[[AsyncOpenAI], Awaitable[_T]],
62+
*,
63+
content_snapshot: Any,
64+
respx_mock: MockRouter,
65+
mock_client: AsyncOpenAI,
66+
) -> _T:
67+
live = os.environ.get("OPENAI_LIVE") == "1"
68+
if live:
69+
70+
async def _on_response(response: httpx.Response) -> None:
71+
# update the content snapshot
72+
assert json.dumps(json.loads(await response.aread())) == content_snapshot
73+
74+
respx_mock.stop()
75+
76+
client = AsyncOpenAI(
77+
http_client=httpx.AsyncClient(
78+
event_hooks={
79+
"response": [_on_response],
80+
}
81+
)
82+
)
83+
else:
84+
respx_mock.post("/chat/completions").mock(
85+
return_value=httpx.Response(
86+
200,
87+
content=get_snapshot_value(content_snapshot),
88+
headers={"content-type": "application/json"},
89+
)
90+
)
91+
92+
client = mock_client
93+
94+
result = await func(client)
95+
96+
if live:
97+
await client.close()
98+
99+
return result

tests/lib/chat/_utils.py renamed to tests/lib/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88
import pydantic
99

10-
from ...utils import rich_print_str
10+
from ..utils import rich_print_str
1111

1212
ReprArgs: TypeAlias = "Iterable[tuple[str | None, Any]]"
1313

0 commit comments

Comments
 (0)