Skip to content

Commit a472cc2

Browse files
committed
server/discount: fix attributeerror when trying to update field not applicable to current discount
1 parent 70b30bd commit a472cc2

File tree

2 files changed

+64
-14
lines changed

2 files changed

+64
-14
lines changed

server/polar/discount/service.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,14 @@ async def update(
247247
discount.discount_products.append(DiscountProduct(product=product))
248248

249249
updated_fields = set()
250+
exclude = {"products"}
251+
if isinstance(discount, DiscountFixed):
252+
exclude.add("basis_points")
253+
else:
254+
exclude.add("amount")
255+
exclude.add("currency")
250256
for attr, value in discount_update.model_dump(
251-
exclude_unset=True, exclude={"products"}, by_alias=True
257+
exclude_unset=True, exclude=exclude, by_alias=True
252258
).items():
253259
if value != getattr(discount, attr):
254260
setattr(discount, attr, value)

server/tests/discount/test_service.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
Product,
2929
UserOrganization,
3030
)
31-
from polar.models.discount import DiscountDuration, DiscountType
31+
from polar.models.discount import (
32+
DiscountDuration,
33+
DiscountFixed,
34+
DiscountPercentage,
35+
DiscountType,
36+
)
3237
from polar.postgres import AsyncSession
3338
from tests.fixtures.database import SaveFixture
3439
from tests.fixtures.random_objects import create_checkout, create_discount
@@ -185,8 +190,34 @@ async def test_update_forbidden_field_with_redemptions(
185190
),
186191
)
187192

193+
@pytest.mark.parametrize(
194+
"type,payload",
195+
[
196+
(
197+
DiscountType.percentage,
198+
DiscountUpdate(
199+
basis_points=2000,
200+
# Make sure passing "currency" doesn't cause AttributeError
201+
# on percentage discounts
202+
currency="usd",
203+
),
204+
),
205+
(
206+
DiscountType.fixed,
207+
DiscountUpdate(
208+
amount=2000,
209+
currency="usd",
210+
# Make sure passing "basis_points" doesn't cause AttributeError
211+
# on percentage discounts
212+
basis_points=2000,
213+
),
214+
),
215+
],
216+
)
188217
async def test_update_sensitive_fields(
189218
self,
219+
type: DiscountType,
220+
payload: DiscountUpdate,
190221
stripe_service_mock: MagicMock,
191222
save_fixture: SaveFixture,
192223
session: AsyncSession,
@@ -196,24 +227,37 @@ async def test_update_sensitive_fields(
196227
id="NEW_STRIPE_COUPON_ID"
197228
)
198229

199-
discount = await create_discount(
200-
save_fixture,
201-
type=DiscountType.percentage,
202-
basis_points=1000,
203-
duration=DiscountDuration.once,
204-
organization=organization,
205-
starts_at=utc_now() - timedelta(days=1),
206-
ends_at=utc_now() + timedelta(days=1),
207-
)
230+
if type == DiscountType.percentage:
231+
discount = await create_discount(
232+
save_fixture,
233+
type=DiscountType.percentage,
234+
basis_points=1000,
235+
duration=DiscountDuration.once,
236+
organization=organization,
237+
)
238+
else:
239+
discount = await create_discount(
240+
save_fixture,
241+
type=DiscountType.fixed,
242+
amount=1000,
243+
currency="usd",
244+
duration=DiscountDuration.once,
245+
organization=organization,
246+
)
208247
old_stripe_coupon_id = discount.stripe_coupon_id
209248

210249
updated_ends_at = utc_now() + timedelta(days=2)
250+
payload.ends_at = updated_ends_at
211251
updated_discount = await discount_service.update(
212-
session,
213-
discount,
214-
discount_update=DiscountUpdate(ends_at=updated_ends_at),
252+
session, discount, discount_update=payload
215253
)
216254

255+
if isinstance(updated_discount, DiscountPercentage):
256+
assert updated_discount.basis_points == 2000
257+
elif isinstance(updated_discount, DiscountFixed):
258+
assert updated_discount.amount == 2000
259+
assert updated_discount.currency == "usd"
260+
217261
assert updated_discount.ends_at == updated_ends_at
218262
assert updated_discount.stripe_coupon_id == "NEW_STRIPE_COUPON_ID"
219263

0 commit comments

Comments
 (0)