@@ -201,8 +201,7 @@ def test_variable_bern() -> None:
201201 assert 'theta' in bern_mle .metadata .stan_vars
202202 assert bern_mle .metadata .stan_vars ['theta' ].dimensions == ()
203203 theta = bern_mle .stan_variable (var = 'theta' )
204- # TODO(2.0): remove before `or` clause
205- assert isinstance (theta , float ) or theta .shape == ()
204+ assert theta .shape == ()
206205 with pytest .raises (ValueError ):
207206 bern_mle .stan_variable (var = 'eta' )
208207 with pytest .raises (ValueError ):
@@ -234,17 +233,15 @@ def test_variables_3d() -> None:
234233 var_beta = multidim_mle .stan_variable (var = 'beta' )
235234 assert var_beta .shape == (2 ,)
236235 var_frac_60 = multidim_mle .stan_variable (var = 'frac_60' )
237- # TODO(2.0): remove before `or` clause
238- assert isinstance (var_frac_60 , float ) or var_frac_60 .shape == ()
236+ assert var_frac_60 .shape == ()
239237 vars = multidim_mle .stan_variables ()
240238 assert len (vars ) == len (multidim_mle .metadata .stan_vars )
241239 assert 'y_rep' in vars
242240 assert vars ['y_rep' ].shape == (5 , 4 , 3 )
243241 assert 'beta' in vars
244242 assert vars ['beta' ].shape == (2 ,)
245243 assert 'frac_60' in vars
246- # TODO(2.0): remove before `or` clause
247- assert isinstance (vars ['frac_60' ], float ) or vars ['frac_60' ].shape == ()
244+ assert vars ['frac_60' ].shape == ()
248245
249246 multidim_mle_iters = multidim_model .optimize (
250247 data = jdata ,
@@ -565,8 +562,7 @@ def test_single_row_csv() -> None:
565562 model = CmdStanModel (stan_file = stan )
566563 mle = model .optimize ()
567564 theta = mle .stan_variable ('theta' )
568- # TODO(2.0): remove before `or` clause
569- assert isinstance (theta , float ) or theta .shape == ()
565+ assert theta .shape == ()
570566 z_as_ndarray = mle .stan_variable (var = "z" )
571567 assert z_as_ndarray .shape == (4 , 3 )
572568 for i in range (4 ):
@@ -627,8 +623,7 @@ def test_attrs() -> None:
627623
628624 assert fit .a == 4.5
629625 assert fit .b .shape == (3 ,)
630- # TODO(2.0) remove before `or` clause
631- assert isinstance (fit .theta , float ) or fit .theta .shape == ()
626+ assert fit .theta .shape == ()
632627
633628 assert fit .stan_variable ('thin' ) == 3.5
634629
@@ -673,7 +668,7 @@ def test_serialization() -> None:
673668 )
674669
675670
676- def test_optimize_create_inits ():
671+ def test_optimize_create_inits () -> None :
677672 stan = os .path .join (DATAFILES_PATH , 'bernoulli.stan' )
678673 bern_model = CmdStanModel (stan_file = stan )
679674 jdata = os .path .join (DATAFILES_PATH , 'bernoulli.data.json' )
@@ -686,7 +681,7 @@ def test_optimize_create_inits():
686681 assert len (inits ) == 1
687682
688683
689- def test_optimize_init_sampling ():
684+ def test_optimize_init_sampling () -> None :
690685 stan = os .path .join (DATAFILES_PATH , 'logistic.stan' )
691686 logistic_model = CmdStanModel (stan_file = stan )
692687 logistic_data = os .path .join (DATAFILES_PATH , 'logistic.data.R' )
0 commit comments