28
28
Product ,
29
29
UserOrganization ,
30
30
)
31
- from polar .models .discount import DiscountDuration , DiscountType
31
+ from polar .models .discount import (
32
+ DiscountDuration ,
33
+ DiscountFixed ,
34
+ DiscountPercentage ,
35
+ DiscountType ,
36
+ )
32
37
from polar .postgres import AsyncSession
33
38
from tests .fixtures .database import SaveFixture
34
39
from tests .fixtures .random_objects import create_checkout , create_discount
@@ -185,8 +190,34 @@ async def test_update_forbidden_field_with_redemptions(
185
190
),
186
191
)
187
192
193
+ @pytest .mark .parametrize (
194
+ "type,payload" ,
195
+ [
196
+ (
197
+ DiscountType .percentage ,
198
+ DiscountUpdate (
199
+ basis_points = 2000 ,
200
+ # Make sure passing "currency" doesn't cause AttributeError
201
+ # on percentage discounts
202
+ currency = "usd" ,
203
+ ),
204
+ ),
205
+ (
206
+ DiscountType .fixed ,
207
+ DiscountUpdate (
208
+ amount = 2000 ,
209
+ currency = "usd" ,
210
+ # Make sure passing "basis_points" doesn't cause AttributeError
211
+ # on percentage discounts
212
+ basis_points = 2000 ,
213
+ ),
214
+ ),
215
+ ],
216
+ )
188
217
async def test_update_sensitive_fields (
189
218
self ,
219
+ type : DiscountType ,
220
+ payload : DiscountUpdate ,
190
221
stripe_service_mock : MagicMock ,
191
222
save_fixture : SaveFixture ,
192
223
session : AsyncSession ,
@@ -196,24 +227,37 @@ async def test_update_sensitive_fields(
196
227
id = "NEW_STRIPE_COUPON_ID"
197
228
)
198
229
199
- discount = await create_discount (
200
- save_fixture ,
201
- type = DiscountType .percentage ,
202
- basis_points = 1000 ,
203
- duration = DiscountDuration .once ,
204
- organization = organization ,
205
- starts_at = utc_now () - timedelta (days = 1 ),
206
- ends_at = utc_now () + timedelta (days = 1 ),
207
- )
230
+ if type == DiscountType .percentage :
231
+ discount = await create_discount (
232
+ save_fixture ,
233
+ type = DiscountType .percentage ,
234
+ basis_points = 1000 ,
235
+ duration = DiscountDuration .once ,
236
+ organization = organization ,
237
+ )
238
+ else :
239
+ discount = await create_discount (
240
+ save_fixture ,
241
+ type = DiscountType .fixed ,
242
+ amount = 1000 ,
243
+ currency = "usd" ,
244
+ duration = DiscountDuration .once ,
245
+ organization = organization ,
246
+ )
208
247
old_stripe_coupon_id = discount .stripe_coupon_id
209
248
210
249
updated_ends_at = utc_now () + timedelta (days = 2 )
250
+ payload .ends_at = updated_ends_at
211
251
updated_discount = await discount_service .update (
212
- session ,
213
- discount ,
214
- discount_update = DiscountUpdate (ends_at = updated_ends_at ),
252
+ session , discount , discount_update = payload
215
253
)
216
254
255
+ if isinstance (updated_discount , DiscountPercentage ):
256
+ assert updated_discount .basis_points == 2000
257
+ elif isinstance (updated_discount , DiscountFixed ):
258
+ assert updated_discount .amount == 2000
259
+ assert updated_discount .currency == "usd"
260
+
217
261
assert updated_discount .ends_at == updated_ends_at
218
262
assert updated_discount .stripe_coupon_id == "NEW_STRIPE_COUPON_ID"
219
263
0 commit comments