Skip to content

Commit 1329b33

Browse files
authored
Merge pull request #245 from doubledare704/fix-issue-199-relationship-fields
🔧Fix for get_multi returns boolean values for pydantic relationship fields, if a select schema is provided #199
2 parents 622e5ea + 8067710 commit 1329b33

File tree

2 files changed

+214
-1
lines changed

2 files changed

+214
-1
lines changed

fastcrud/crud/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def _extract_matching_columns_from_schema(
7979

8080
if schema:
8181
for field in schema.model_fields.keys():
82-
if hasattr(model_or_alias, field):
82+
if hasattr(model_or_alias, field) and field not in mapper.relationships:
8383
column = getattr(model_or_alias, field)
8484
if prefix is not None or use_temporary_prefix:
8585
column_label = (
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
"""
2+
Test for GitHub issue #199: get_multi returns boolean values for pydantic relationship fields
3+
when a select schema is provided.
4+
5+
This test reproduces the issue where including relationship fields in a Pydantic schema
6+
causes get_multi to execute a cartesian product between tables, resulting in boolean values
7+
instead of the expected data.
8+
"""
9+
10+
import pytest
11+
from typing import Optional
12+
from pydantic import BaseModel, ConfigDict
13+
from fastcrud.crud.fast_crud import FastCRUD
14+
from ...sqlalchemy.conftest import ModelTest, TierModel
15+
16+
17+
class TierReadSchemaIssue199(BaseModel):
18+
"""Schema for tier data in issue #199 tests."""
19+
name: str
20+
21+
22+
class ReadSchemaWithRelationship(BaseModel):
23+
"""Schema that includes a relationship field to test issue #199 fix."""
24+
model_config = ConfigDict(extra="forbid")
25+
name: str
26+
tier_id: int
27+
tier: Optional[TierReadSchemaIssue199] = None
28+
29+
30+
@pytest.fixture(scope="function")
31+
def issue_199_test_data() -> list[dict]:
32+
"""Test data specific to issue #199 reproduction."""
33+
return [
34+
{"id": 1, "name": "Charlie", "tier_id": 1},
35+
{"id": 2, "name": "Alice", "tier_id": 2},
36+
{"id": 3, "name": "Bob", "tier_id": 1},
37+
{"id": 4, "name": "David", "tier_id": 2},
38+
{"id": 5, "name": "Eve", "tier_id": 1},
39+
{"id": 6, "name": "Frank", "tier_id": 2},
40+
{"id": 7, "name": "Grace", "tier_id": 1},
41+
{"id": 8, "name": "Hannah", "tier_id": 2},
42+
{"id": 9, "name": "Ivan", "tier_id": 1},
43+
{"id": 10, "name": "Judy", "tier_id": 2},
44+
{"id": 11, "name": "Alice", "tier_id": 1},
45+
]
46+
47+
48+
@pytest.fixture(scope="function")
49+
def issue_199_tier_data() -> list[dict]:
50+
"""Tier data specific to issue #199 reproduction."""
51+
return [{"id": 1, "name": "Premium"}, {"id": 2, "name": "Basic"}]
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_get_multi_excludes_relationship_fields_from_select(
56+
async_session, issue_199_test_data, issue_199_tier_data
57+
):
58+
"""
59+
Test that reproduces issue #199: get_multi should not include relationship fields
60+
in the SELECT statement when they are present in the schema_to_select.
61+
62+
This test verifies that:
63+
1. No cartesian product is created
64+
2. No boolean values are returned for relationship fields
65+
3. The relationship field is simply excluded from the result
66+
"""
67+
# Setup test data
68+
for tier_item in issue_199_tier_data:
69+
async_session.add(TierModel(**tier_item))
70+
await async_session.commit()
71+
72+
for user_item in issue_199_test_data:
73+
async_session.add(ModelTest(**user_item))
74+
await async_session.commit()
75+
76+
crud = FastCRUD(ModelTest)
77+
78+
# This should NOT create a cartesian product or return boolean values
79+
result = await crud.get_multi(
80+
db=async_session,
81+
schema_to_select=ReadSchemaWithRelationship
82+
)
83+
84+
# Verify the result structure
85+
assert "data" in result
86+
assert "total_count" in result
87+
assert result["total_count"] == len(issue_199_test_data)
88+
89+
# Verify no cartesian product was created (should have 11 records, not 22)
90+
assert len(result["data"]) == len(issue_199_test_data)
91+
92+
# Verify that each record has the expected structure
93+
for item in result["data"]:
94+
assert "name" in item
95+
assert "tier_id" in item
96+
# The relationship field should either be excluded or None, but NOT a boolean
97+
if "tier" in item:
98+
assert item["tier"] is None or isinstance(item["tier"], dict)
99+
assert not isinstance(item["tier"], bool)
100+
101+
# Verify specific data integrity
102+
names_in_result = [item["name"] for item in result["data"]]
103+
expected_names = [item["name"] for item in issue_199_test_data]
104+
assert sorted(names_in_result) == sorted(expected_names)
105+
106+
107+
@pytest.mark.asyncio
108+
async def test_get_multi_return_as_model_with_relationship_fields(
109+
async_session, issue_199_test_data, issue_199_tier_data
110+
):
111+
"""
112+
Test that get_multi with return_as_model=True works correctly when the schema
113+
contains relationship fields.
114+
"""
115+
# Setup test data
116+
for tier_item in issue_199_tier_data:
117+
async_session.add(TierModel(**tier_item))
118+
await async_session.commit()
119+
120+
for user_item in issue_199_test_data:
121+
async_session.add(ModelTest(**user_item))
122+
await async_session.commit()
123+
124+
crud = FastCRUD(ModelTest)
125+
126+
# This should work without validation errors
127+
result = await crud.get_multi(
128+
db=async_session,
129+
schema_to_select=ReadSchemaWithRelationship,
130+
return_as_model=True
131+
)
132+
133+
# Verify the result structure
134+
assert "data" in result
135+
assert "total_count" in result
136+
assert result["total_count"] == len(issue_199_test_data)
137+
assert len(result["data"]) == len(issue_199_test_data)
138+
139+
# Verify that all items are instances of the schema
140+
for item in result["data"]:
141+
assert isinstance(item, ReadSchemaWithRelationship)
142+
assert hasattr(item, "name")
143+
assert hasattr(item, "tier_id")
144+
# The tier field should be None since it's not properly joined
145+
assert item.tier is None
146+
147+
148+
@pytest.mark.asyncio
149+
async def test_get_joined_functionality_unaffected_by_fix(
150+
async_session, issue_199_test_data, issue_199_tier_data
151+
):
152+
"""
153+
Test that get_joined still works correctly and can properly populate relationship fields.
154+
This ensures our fix doesn't break the intended functionality for joins.
155+
"""
156+
# Setup test data
157+
for tier_item in issue_199_tier_data:
158+
async_session.add(TierModel(**tier_item))
159+
await async_session.commit()
160+
161+
for user_item in issue_199_test_data:
162+
async_session.add(ModelTest(**user_item))
163+
await async_session.commit()
164+
165+
crud = FastCRUD(ModelTest)
166+
167+
# This should work correctly with proper joins
168+
result = await crud.get_joined(
169+
db=async_session,
170+
join_model=TierModel,
171+
join_prefix="tier_",
172+
schema_to_select=ReadSchemaWithRelationship,
173+
join_schema_to_select=TierReadSchemaIssue199,
174+
nest_joins=True
175+
)
176+
177+
# Verify the result has the tier information properly nested
178+
assert "name" in result
179+
assert "tier_id" in result
180+
assert "tier" in result
181+
assert isinstance(result["tier"], dict)
182+
assert "name" in result["tier"]
183+
184+
185+
@pytest.mark.asyncio
186+
async def test_relationship_field_exclusion_prevents_cartesian_product(
187+
async_session, issue_199_test_data, issue_199_tier_data
188+
):
189+
"""
190+
Test that specifically verifies the fix prevents cartesian products
191+
by checking the generated SQL doesn't include relationship tables.
192+
"""
193+
# Setup test data
194+
for tier_item in issue_199_tier_data:
195+
async_session.add(TierModel(**tier_item))
196+
await async_session.commit()
197+
198+
for user_item in issue_199_test_data:
199+
async_session.add(ModelTest(**user_item))
200+
await async_session.commit()
201+
202+
crud = FastCRUD(ModelTest)
203+
204+
# Generate the select statement to verify it doesn't include tier table
205+
stmt = await crud.select(schema_to_select=ReadSchemaWithRelationship)
206+
207+
# Convert statement to string to check it doesn't contain tier table
208+
stmt_str = str(stmt.compile(compile_kwargs={"literal_binds": True}))
209+
210+
# The statement should only select from the test table, not include tier
211+
assert "FROM test" in stmt_str
212+
assert "FROM test, tier" not in stmt_str # No cartesian product
213+
assert "tier.id = test.tier_id" not in stmt_str # No relationship field in SELECT

0 commit comments

Comments
 (0)