Skip to content

✅ Refactor tests, consolidate into a single test file for multiple variants #1409

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Prev Previous commit
Next Next commit
Jules was unable to complete the task in time. Please review the work…
… done so far and provide feedback for Jules to continue.
  • Loading branch information
google-labs-jules[bot] committed Jun 20, 2025
commit 009b161bb71e3d78e31ff15d2d4e21b4e0457d2b
11 changes: 5 additions & 6 deletions tests/test_advanced/test_decimal/test_tutorial001.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import importlib
import types # Add import for types
import types
from decimal import Decimal
from unittest.mock import MagicMock # Keep MagicMock for type hint, though not strictly necessary for runtime

import pytest
from sqlmodel import create_engine

from ...conftest import PrintMock, needs_py310 # Import PrintMock for type hint
from ...conftest import needs_py310, PrintMock # Import PrintMock for type hint

expected_calls = [
[
Expand Down Expand Up @@ -44,10 +45,8 @@ def get_module(request: pytest.FixtureRequest):
return importlib.import_module(f"docs_src.advanced.decimal.{module_name}")


def test_tutorial(
print_mock: PrintMock, module: types.ModuleType
): # Use PrintMock for type hint and types.ModuleType
def test_tutorial(print_mock: PrintMock, module: types.ModuleType):
module.sqlite_url = "sqlite://"
module.engine = create_engine(module.sqlite_url)
module.main()
assert print_mock.calls == expected_calls # Use .calls instead of .mock_calls
assert print_mock.calls == expected_calls # Use .calls instead of .mock_calls
3 changes: 2 additions & 1 deletion tests/test_advanced/test_uuid/test_tutorial001.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import types

import pytest
from dirty_equals import IsUUID
Expand All @@ -19,7 +20,7 @@ def get_module(request: pytest.FixtureRequest):
return importlib.import_module(f"docs_src.advanced.uuid.{module_name}")


def test_tutorial(print_mock: PrintMock, module: type) -> None:
def test_tutorial(print_mock: PrintMock, module: types.ModuleType) -> None:
module.sqlite_url = "sqlite://"
module.engine = create_engine(module.sqlite_url)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_advanced/test_uuid/test_tutorial002.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import types

import pytest
from dirty_equals import IsUUID
Expand All @@ -19,7 +20,7 @@ def get_module(request: pytest.FixtureRequest):
return importlib.import_module(f"docs_src.advanced.uuid.{module_name}")


def test_tutorial(print_mock: PrintMock, module: type) -> None:
def test_tutorial(print_mock: PrintMock, module: types.ModuleType) -> None:
module.sqlite_url = "sqlite://"
module.engine = create_engine(module.sqlite_url)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@
)
def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param
mod = importlib.import_module(f"docs_src.tutorial.connect.delete.{module_name}")
mod = importlib.import_module(
f"docs_src.tutorial.connect.delete.{module_name}"
)
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
return mod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
)
def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param
mod = importlib.import_module(f"docs_src.tutorial.connect.insert.{module_name}")
mod = importlib.import_module(
f"docs_src.tutorial.connect.insert.{module_name}"
)
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
return mod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@
)
def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param
mod = importlib.import_module(f"docs_src.tutorial.connect.select.{module_name}")
mod = importlib.import_module(
f"docs_src.tutorial.connect.select.{module_name}"
)
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
return mod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@
)
def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param
mod = importlib.import_module(f"docs_src.tutorial.connect.select.{module_name}")
mod = importlib.import_module(
f"docs_src.tutorial.connect.select.{module_name}"
)
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
return mod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@
)
def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param
mod = importlib.import_module(f"docs_src.tutorial.connect.select.{module_name}")
mod = importlib.import_module(
f"docs_src.tutorial.connect.select.{module_name}"
)
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
return mod
Expand Down
10 changes: 5 additions & 5 deletions tests/test_tutorial/test_connect/test_update/test_tutorial001.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
from types import ModuleType
from typing import Any # For clear_sqlmodel type hint
from typing import Any # For clear_sqlmodel type hint

import pytest
from sqlmodel import create_engine
Expand Down Expand Up @@ -60,14 +60,14 @@
)
def get_module(request: pytest.FixtureRequest) -> ModuleType:
module_name = request.param
mod = importlib.import_module(f"docs_src.tutorial.connect.update.{module_name}")
mod = importlib.import_module(
f"docs_src.tutorial.connect.update.{module_name}"
)
mod.sqlite_url = "sqlite://"
mod.engine = create_engine(mod.sqlite_url)
return mod


def test_tutorial(
clear_sqlmodel: Any, print_mock: PrintMock, module: ModuleType
) -> None:
def test_tutorial(clear_sqlmodel: Any, print_mock: PrintMock, module: ModuleType) -> None:
module.main()
assert print_mock.calls == expected_calls
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
from types import ModuleType
from typing import Any # For clear_sqlmodel type hint
from typing import Any # For clear_sqlmodel type hint

import pytest
from sqlalchemy import inspect
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import importlib
from types import ModuleType
from typing import Any # For clear_sqlmodel type hint
from typing import Any # For clear_sqlmodel type hint

import pytest
from sqlalchemy import inspect
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import importlib
import sys # Add sys import
import sys
from types import ModuleType
from typing import Any, Generator

import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session, SQLModel, create_engine # Keep this for session_fixture
from sqlmodel.pool import StaticPool # Keep this for session_fixture
from sqlmodel import Session, SQLModel, create_engine # Keep this for session_fixture
from sqlmodel.pool import StaticPool # Keep this for session_fixture

from ....conftest import needs_py39, needs_py310


# This will be our parametrized fixture providing the versioned 'main' module
@pytest.fixture(
name="module",
Expand All @@ -21,9 +20,7 @@
pytest.param("tutorial001_py310", marks=needs_py310),
],
)
def get_module(
request: pytest.FixtureRequest, clear_sqlmodel: Any
) -> ModuleType: # clear_sqlmodel is autouse
def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleType: # clear_sqlmodel is autouse
module_name = f"docs_src.tutorial.fastapi.app_testing.{request.param}.main"

# Forcing reload to try to get a fresh state for models
Expand All @@ -33,7 +30,6 @@ def get_module(
module = importlib.import_module(module_name)
return module


@pytest.fixture(name="session", scope="function")
def session_fixture(module: ModuleType) -> Generator[Session, None, None]:
# Store original engine-related attributes from the module
Expand All @@ -43,13 +39,13 @@ def session_fixture(module: ModuleType) -> Generator[Session, None, None]:

# Force module to use a fresh in-memory SQLite DB for this test run
module.sqlite_url = "sqlite://"
module.connect_args = {"check_same_thread": False} # Crucial for FastAPI + SQLite
module.connect_args = {"check_same_thread": False} # Crucial for FastAPI + SQLite

# Re-create the engine in the module to use these new settings
test_engine = create_engine(
module.sqlite_url,
connect_args=module.connect_args,
poolclass=StaticPool, # Recommended for tests
poolclass=StaticPool # Recommended for tests
)
module.engine = test_engine

Expand All @@ -59,9 +55,7 @@ def session_fixture(module: ModuleType) -> Generator[Session, None, None]:
# Fallback if the function isn't named create_db_and_tables
SQLModel.metadata.create_all(module.engine)

with Session(
module.engine
) as session: # Use the module's (now test-configured) engine
with Session(module.engine) as session: # Use the module's (now test-configured) engine
yield session

# Teardown: drop tables from the module's engine
Expand All @@ -74,16 +68,14 @@ def session_fixture(module: ModuleType) -> Generator[Session, None, None]:
module.connect_args = original_connect_args
if original_engine is not None:
module.engine = original_engine
else: # If engine didn't exist, remove the one we created
else: # If engine didn't exist, remove the one we created
if hasattr(module, "engine"):
del module.engine


@pytest.fixture(name="client", scope="function")
def client_fixture(
session: Session, module: ModuleType
) -> Generator[TestClient, None, None]:
def get_session_override() -> Generator[Session, None, None]: # Must be a generator
def client_fixture(session: Session, module: ModuleType) -> Generator[TestClient, None, None]:
def get_session_override() -> Generator[Session, None, None]: # Must be a generator
yield session

module.app.dependency_overrides[module.get_session] = get_session_override
Expand Down Expand Up @@ -148,7 +140,7 @@ def test_read_heroes(session: Session, client: TestClient, module: ModuleType):


def test_read_hero(session: Session, client: TestClient, module: ModuleType):
hero_1 = module.Hero(name="Deadpond", secret_name="Dive Wilson") # Use module.Hero
hero_1 = module.Hero(name="Deadpond", secret_name="Dive Wilson") # Use module.Hero
session.add(hero_1)
session.commit()

Expand All @@ -163,7 +155,7 @@ def test_read_hero(session: Session, client: TestClient, module: ModuleType):


def test_update_hero(session: Session, client: TestClient, module: ModuleType):
hero_1 = module.Hero(name="Deadpond", secret_name="Dive Wilson") # Use module.Hero
hero_1 = module.Hero(name="Deadpond", secret_name="Dive Wilson") # Use module.Hero
session.add(hero_1)
session.commit()

Expand All @@ -178,13 +170,13 @@ def test_update_hero(session: Session, client: TestClient, module: ModuleType):


def test_delete_hero(session: Session, client: TestClient, module: ModuleType):
hero_1 = module.Hero(name="Deadpond", secret_name="Dive Wilson") # Use module.Hero
hero_1 = module.Hero(name="Deadpond", secret_name="Dive Wilson") # Use module.Hero
session.add(hero_1)
session.commit()

response = client.delete(f"/heroes/{hero_1.id}")

hero_in_db = session.get(module.Hero, hero_1.id) # Use module.Hero
hero_in_db = session.get(module.Hero, hero_1.id) # Use module.Hero

assert response.status_code == 200
assert hero_in_db is None
36 changes: 14 additions & 22 deletions tests/test_tutorial/test_fastapi/test_delete/test_tutorial001.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import importlib
import sys
from types import ModuleType
from typing import Any # For clear_sqlmodel type hint
from typing import Any # For clear_sqlmodel type hint

import pytest
from dirty_equals import IsDict
from fastapi.testclient import TestClient
from sqlmodel import SQLModel, create_engine # Import SQLModel for metadata operations
from sqlmodel import SQLModel, create_engine # Import SQLModel for metadata operations
from sqlmodel.pool import StaticPool

from ....conftest import needs_py39, needs_py310
Expand All @@ -22,7 +22,7 @@
],
)
def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleType:
module_name = f"docs_src.tutorial.fastapi.delete.{request.param}" # No .main here
module_name = f"docs_src.tutorial.fastapi.delete.{request.param}" # No .main here
if module_name in sys.modules:
module = importlib.reload(sys.modules[module_name])
else:
Expand All @@ -34,23 +34,19 @@ def get_module(request: pytest.FixtureRequest, clear_sqlmodel: Any) -> ModuleTyp
module.sqlite_url = "sqlite://"
module.engine = create_engine(
module.sqlite_url,
connect_args={"check_same_thread": False}, # connect_args from original main.py
poolclass=StaticPool,
connect_args={"check_same_thread": False}, # connect_args from original main.py
poolclass=StaticPool
)
# Assuming the module has a create_db_and_tables or similar, or uses SQLModel.metadata directly
if hasattr(module, "create_db_and_tables"):
module.create_db_and_tables()
else:
SQLModel.metadata.create_all(
module.engine
) # Fallback, ensure tables are created
SQLModel.metadata.create_all(module.engine) # Fallback, ensure tables are created

return module


def test_tutorial(
clear_sqlmodel: Any, module: ModuleType
): # clear_sqlmodel is autouse but explicit for safety
def test_tutorial(clear_sqlmodel: Any, module: ModuleType): # clear_sqlmodel is autouse but explicit for safety
# The engine and tables are now set up by the 'module' fixture
# The app's dependency overrides for get_session will use module.engine

Expand All @@ -60,7 +56,7 @@ def test_tutorial(
hero2_data = {
"name": "Spider-Boy",
"secret_name": "Pedro Parqueador",
"id": 9000, # Note: ID is part of creation data here
"id": 9000, # Note: ID is part of creation data here
}
hero3_data = {
"name": "Rusty-Man",
Expand All @@ -69,15 +65,13 @@ def test_tutorial(
}
response = client.post("/heroes/", json=hero1_data)
assert response.status_code == 200, response.text
hero1 = response.json() # Get actual ID of hero1
hero1 = response.json() # Get actual ID of hero1
hero1_id = hero1["id"]

response = client.post("/heroes/", json=hero2_data)
assert response.status_code == 200, response.text
hero2 = response.json()
hero2_id = hero2[
"id"
] # This will be the ID assigned by DB, not 9000 if 9000 is not allowed on POST
hero2_id = hero2["id"] # This will be the ID assigned by DB, not 9000 if 9000 is not allowed on POST

response = client.post("/heroes/", json=hero3_data)
assert response.status_code == 200, response.text
Expand All @@ -92,8 +86,8 @@ def test_tutorial(
# For robustness, let's check for a non-existent ID based on actual data.
# If hero2_id is 1, check for 9000. If it's 9000, check for 1 (assuming hero1_id is 1).
non_existent_id_check = 9000
if hero2_id == non_existent_id_check: # if DB somehow used 9000
non_existent_id_check = hero1_id + hero2_id + 100 # just some other ID
if hero2_id == non_existent_id_check: # if DB somehow used 9000
non_existent_id_check = hero1_id + hero2_id + 100 # just some other ID

response = client.get(f"/heroes/{non_existent_id_check}")
assert response.status_code == 404, response.text
Expand All @@ -108,9 +102,7 @@ def test_tutorial(
)
assert response.status_code == 200, response.text

response = client.patch(
f"/heroes/{non_existent_id_check}", json={"name": "Dragon Cube X"}
)
response = client.patch(f"/heroes/{non_existent_id_check}", json={"name": "Dragon Cube X"})
assert response.status_code == 404, response.text

response = client.delete(f"/heroes/{hero2_id}")
Expand All @@ -119,7 +111,7 @@ def test_tutorial(
response = client.get("/heroes/")
assert response.status_code == 200, response.text
data = response.json()
assert len(data) == 2 # After deleting one hero
assert len(data) == 2 # After deleting one hero

response = client.delete(f"/heroes/{non_existent_id_check}")
assert response.status_code == 404, response.text
Expand Down
Loading
Loading