Skip to content

Commit 0d32988

Browse files
committed
server/discount: fix attributeerror when trying to update field not applicable to current discount
1 parent 9f6520b commit 0d32988

File tree

2 files changed

+65
-14
lines changed

2 files changed

+65
-14
lines changed

server/polar/discount/service.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,14 @@ async def update(
247247
discount.discount_products.append(DiscountProduct(product=product))
248248

249249
updated_fields = set()
250+
exclude = {"products"}
251+
if isinstance(discount, DiscountFixed):
252+
exclude.add("basis_points")
253+
else:
254+
exclude.add("amount")
255+
exclude.add("currency")
250256
for attr, value in discount_update.model_dump(
251-
exclude_unset=True, exclude={"products"}, by_alias=True
257+
exclude_unset=True, exclude=exclude, by_alias=True
252258
).items():
253259
if value != getattr(discount, attr):
254260
setattr(discount, attr, value)

server/tests/discount/test_service.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
Product,
2929
UserOrganization,
3030
)
31-
from polar.models.discount import DiscountDuration, DiscountType
31+
from polar.models.discount import (
32+
DiscountDuration,
33+
DiscountFixed,
34+
DiscountPercentage,
35+
DiscountType,
36+
)
3237
from polar.postgres import AsyncSession
3338
from tests.fixtures.database import SaveFixture
3439
from tests.fixtures.random_objects import create_checkout, create_discount
@@ -185,8 +190,34 @@ async def test_update_forbidden_field_with_redemptions(
185190
),
186191
)
187192

193+
@pytest.mark.parametrize(
194+
"type,payload",
195+
[
196+
(
197+
DiscountType.percentage,
198+
DiscountUpdate(
199+
basis_points=2000,
200+
# Make sure passing "currency" doesn't cause AttributeError
201+
# on percentage discounts
202+
currency="usd",
203+
),
204+
),
205+
(
206+
DiscountType.fixed,
207+
DiscountUpdate(
208+
amount=2000,
209+
currency="usd",
210+
# Make sure passing "basis_points" doesn't cause AttributeError
211+
# on percentage discounts
212+
basis_points=2000,
213+
),
214+
),
215+
],
216+
)
188217
async def test_update_sensitive_fields(
189218
self,
219+
type: DiscountType,
220+
payload: DiscountUpdate,
190221
stripe_service_mock: MagicMock,
191222
save_fixture: SaveFixture,
192223
session: AsyncSession,
@@ -196,24 +227,38 @@ async def test_update_sensitive_fields(
196227
id="NEW_STRIPE_COUPON_ID"
197228
)
198229

199-
discount = await create_discount(
200-
save_fixture,
201-
type=DiscountType.percentage,
202-
basis_points=1000,
203-
duration=DiscountDuration.once,
204-
organization=organization,
205-
starts_at=utc_now() - timedelta(days=1),
206-
ends_at=utc_now() + timedelta(days=1),
207-
)
230+
discount: Discount
231+
if type == DiscountType.percentage:
232+
discount = await create_discount(
233+
save_fixture,
234+
type=DiscountType.percentage,
235+
basis_points=1000,
236+
duration=DiscountDuration.once,
237+
organization=organization,
238+
)
239+
else:
240+
discount = await create_discount(
241+
save_fixture,
242+
type=DiscountType.fixed,
243+
amount=1000,
244+
currency="usd",
245+
duration=DiscountDuration.once,
246+
organization=organization,
247+
)
208248
old_stripe_coupon_id = discount.stripe_coupon_id
209249

210250
updated_ends_at = utc_now() + timedelta(days=2)
251+
payload.ends_at = updated_ends_at
211252
updated_discount = await discount_service.update(
212-
session,
213-
discount,
214-
discount_update=DiscountUpdate(ends_at=updated_ends_at),
253+
session, discount, discount_update=payload
215254
)
216255

256+
if isinstance(updated_discount, DiscountPercentage):
257+
assert updated_discount.basis_points == 2000
258+
elif isinstance(updated_discount, DiscountFixed):
259+
assert updated_discount.amount == 2000
260+
assert updated_discount.currency == "usd"
261+
217262
assert updated_discount.ends_at == updated_ends_at
218263
assert updated_discount.stripe_coupon_id == "NEW_STRIPE_COUPON_ID"
219264

0 commit comments

Comments
 (0)