diff --git a/server/polar/discount/service.py b/server/polar/discount/service.py index a1bc54bef1..f385c939d3 100644 --- a/server/polar/discount/service.py +++ b/server/polar/discount/service.py @@ -247,8 +247,14 @@ async def update( discount.discount_products.append(DiscountProduct(product=product)) updated_fields = set() + exclude = {"products"} + if isinstance(discount, DiscountFixed): + exclude.add("basis_points") + else: + exclude.add("amount") + exclude.add("currency") for attr, value in discount_update.model_dump( - exclude_unset=True, exclude={"products"}, by_alias=True + exclude_unset=True, exclude=exclude, by_alias=True ).items(): if value != getattr(discount, attr): setattr(discount, attr, value) diff --git a/server/tests/discount/test_service.py b/server/tests/discount/test_service.py index d806c6c7dd..78ca1e7929 100644 --- a/server/tests/discount/test_service.py +++ b/server/tests/discount/test_service.py @@ -28,7 +28,12 @@ Product, UserOrganization, ) -from polar.models.discount import DiscountDuration, DiscountType +from polar.models.discount import ( + DiscountDuration, + DiscountFixed, + DiscountPercentage, + DiscountType, +) from polar.postgres import AsyncSession from tests.fixtures.database import SaveFixture from tests.fixtures.random_objects import create_checkout, create_discount @@ -185,8 +190,34 @@ async def test_update_forbidden_field_with_redemptions( ), ) + @pytest.mark.parametrize( + "type,payload", + [ + ( + DiscountType.percentage, + DiscountUpdate( + basis_points=2000, + # Make sure passing "currency" doesn't cause AttributeError + # on percentage discounts + currency="usd", + ), + ), + ( + DiscountType.fixed, + DiscountUpdate( + amount=2000, + currency="usd", + # Make sure passing "basis_points" doesn't cause AttributeError + # on percentage discounts + basis_points=2000, + ), + ), + ], + ) async def test_update_sensitive_fields( self, + type: DiscountType, + payload: DiscountUpdate, stripe_service_mock: MagicMock, save_fixture: SaveFixture, session: AsyncSession, @@ -196,24 +227,38 @@ async def test_update_sensitive_fields( id="NEW_STRIPE_COUPON_ID" ) - discount = await create_discount( - save_fixture, - type=DiscountType.percentage, - basis_points=1000, - duration=DiscountDuration.once, - organization=organization, - starts_at=utc_now() - timedelta(days=1), - ends_at=utc_now() + timedelta(days=1), - ) + discount: Discount + if type == DiscountType.percentage: + discount = await create_discount( + save_fixture, + type=DiscountType.percentage, + basis_points=1000, + duration=DiscountDuration.once, + organization=organization, + ) + else: + discount = await create_discount( + save_fixture, + type=DiscountType.fixed, + amount=1000, + currency="usd", + duration=DiscountDuration.once, + organization=organization, + ) old_stripe_coupon_id = discount.stripe_coupon_id updated_ends_at = utc_now() + timedelta(days=2) + payload.ends_at = updated_ends_at updated_discount = await discount_service.update( - session, - discount, - discount_update=DiscountUpdate(ends_at=updated_ends_at), + session, discount, discount_update=payload ) + if isinstance(updated_discount, DiscountPercentage): + assert updated_discount.basis_points == 2000 + elif isinstance(updated_discount, DiscountFixed): + assert updated_discount.amount == 2000 + assert updated_discount.currency == "usd" + assert updated_discount.ends_at == updated_ends_at assert updated_discount.stripe_coupon_id == "NEW_STRIPE_COUPON_ID"