9
9
# Any modifications or derivative works of this code must retain this
10
10
# copyright notice, and modified files need to carry a notice indicating
11
11
# that they have been altered from the originals.
12
+
13
+ # pylint: disable=unused-argument
14
+
12
15
"""
13
16
Utility for checking equality of data of Qiskit Experiments class which doesn't
14
17
officially implement the equality dunder method.
36
39
from qiskit_experiments .visualization import BaseDrawer
37
40
38
41
39
- @multimethod
40
42
def is_equivalent (
43
+ data1 : Any ,
44
+ data2 : Any ,
45
+ * ,
46
+ strict_type : bool = True ,
47
+ numerical_precision : float = 1e-8 ,
48
+ ) -> bool :
49
+ """Check if two input data are equivalent.
50
+
51
+ This function is used for custom equivalence evaluation only for unittest purpose.
52
+ Some third party class may not preserve equivalence after JSON round-trip with
53
+ Qiskit Experiments JSON Encoder/Decoder, or some Qiskit Experiments class doesn't
54
+ define the equality dunder method intentionally.
55
+
56
+ Args:
57
+ data1: First data to compare.
58
+ data2: Second data to compare.
59
+ strict_type: Set True to enforce type check before comparison. Note that serialization
60
+ and deserialization round-trip may not preserve data type.
61
+ If the data type doesn't matter and only behavioral equivalence is considered,
62
+ e.g. iterator with the same element; tuple vs list,
63
+ you can turn off this flag to relax the constraint for data type.
64
+ numerical_precision: Tolerance of difference between two real numbers.
65
+
66
+ Returns:
67
+ True when two objects are equivalent.
68
+ """
69
+ if strict_type and type (data1 ) is not type (data2 ):
70
+ return False
71
+ evaluated = _is_equivalent_dispatcher (
72
+ data1 ,
73
+ data2 ,
74
+ strict_type = strict_type ,
75
+ numerical_precision = numerical_precision ,
76
+ )
77
+ if not isinstance (evaluated , (bool , np .bool_ )):
78
+ # When either one of input is numpy array type, it may broadcast equality check
79
+ # and return ndarray of dtype=bool. e.g. np.array([]) == 123
80
+ # The input values should not be equal in this case.
81
+ return False
82
+ return evaluated
83
+
84
+
85
+ @multimethod
86
+ def _is_equivalent_dispatcher (
41
87
data1 : object ,
42
88
data2 : object ,
89
+ ** kwargs ,
43
90
):
44
91
"""Equality check finally falls into this function."""
45
92
if data1 is None and data2 is None :
@@ -50,33 +97,28 @@ def is_equivalent(
50
97
return is_equivalent (
51
98
data1 .__dict__ ,
52
99
data2 .__dict__ ,
100
+ ** kwargs ,
53
101
)
54
- evaluated = data1 == data2
55
- if not isinstance (evaluated , bool ):
56
- # When either one of input is numpy array type, it may broadcast equality check
57
- # and return ndarray of dtype=bool. e.g. np.array([]) == 123
58
- # The input values should not be equal in this case.
59
- return False
102
+ return data1 == data2
60
103
61
- # Return the outcome of native equivalence check.
62
- return evaluated
63
104
64
-
65
- @is_equivalent .register
105
+ @_is_equivalent_dispatcher .register
66
106
def _check_dicts (
67
107
data1 : Union [dict , ThreadSafeOrderedDict ],
68
108
data2 : Union [dict , ThreadSafeOrderedDict ],
109
+ ** kwargs ,
69
110
):
70
111
"""Check equality of dictionary which may involve Qiskit Experiments classes."""
71
112
if set (data1 ) != set (data2 ):
72
113
return False
73
- return all (is_equivalent (data1 [k ], data2 [k ]) for k in data1 .keys ())
114
+ return all (is_equivalent (data1 [k ], data2 [k ], ** kwargs ) for k in data1 .keys ())
74
115
75
116
76
- @is_equivalent .register
117
+ @_is_equivalent_dispatcher .register
77
118
def _check_floats (
78
119
data1 : Union [float , np .floating ],
79
120
data2 : Union [float , np .floating ],
121
+ ** kwargs ,
80
122
):
81
123
"""Check equality of float.
82
124
@@ -86,13 +128,18 @@ def _check_floats(
86
128
if np .isnan (data1 ) and np .isnan (data2 ):
87
129
# Special case
88
130
return True
89
- return float (data1 ) == float (data2 )
131
+
132
+ precision = kwargs .get ("numerical_precision" , 0.0 )
133
+ if precision == 0.0 :
134
+ return float (data1 ) == float (data2 )
135
+ return np .isclose (np .abs (data1 - data2 ), 0.0 , atol = precision )
90
136
91
137
92
- @is_equivalent .register
138
+ @_is_equivalent_dispatcher .register
93
139
def _check_integer (
94
140
data1 : Union [int , np .integer ],
95
141
data2 : Union [int , np .integer ],
142
+ ** kwargs ,
96
143
):
97
144
"""Check equality of integer.
98
145
@@ -101,59 +148,65 @@ def _check_integer(
101
148
return int (data1 ) == int (data2 )
102
149
103
150
104
- @is_equivalent .register
151
+ @_is_equivalent_dispatcher .register
105
152
def _check_sequences (
106
153
data1 : Union [list , tuple , np .ndarray , ThreadSafeList ],
107
154
data2 : Union [list , tuple , np .ndarray , ThreadSafeList ],
155
+ ** kwargs ,
108
156
):
109
157
"""Check equality of sequence."""
110
158
if len (data1 ) != len (data2 ):
111
159
return False
112
- return all (is_equivalent (e1 , e2 ) for e1 , e2 in zip (data1 , data2 ))
160
+ return all (is_equivalent (e1 , e2 , ** kwargs ) for e1 , e2 in zip (data1 , data2 ))
113
161
114
162
115
- @is_equivalent .register
163
+ @_is_equivalent_dispatcher .register
116
164
def _check_unordered_sequences (
117
165
data1 : set ,
118
166
data2 : set ,
167
+ ** kwargs ,
119
168
):
120
169
"""Check equality of sequence after sorting."""
121
170
if len (data1 ) != len (data2 ):
122
171
return False
123
- return all (is_equivalent (e1 , e2 ) for e1 , e2 in zip (sorted (data1 ), sorted (data2 )))
172
+ return all (is_equivalent (e1 , e2 , ** kwargs ) for e1 , e2 in zip (sorted (data1 ), sorted (data2 )))
124
173
125
174
126
- @is_equivalent .register
175
+ @_is_equivalent_dispatcher .register
127
176
def _check_ufloats (
128
177
data1 : uncertainties .UFloat ,
129
178
data2 : uncertainties .UFloat ,
179
+ ** kwargs ,
130
180
):
131
181
"""Check equality of UFloat instance. Correlations are ignored."""
132
- return data1 .n == data2 .n and data1 .s == data2 .s
182
+ return is_equivalent ( data1 .n , data2 .n , ** kwargs ) and is_equivalent ( data1 .s , data2 .s , ** kwargs )
133
183
134
184
135
- @is_equivalent .register
185
+ @_is_equivalent_dispatcher .register
136
186
def _check_lmfit_models (
137
187
data1 : Model ,
138
188
data2 : Model ,
189
+ ** kwargs ,
139
190
):
140
191
"""Check equality of LMFIT model."""
141
- return is_equivalent (data1 .dumps (), data2 .dumps ())
192
+ return is_equivalent (data1 .dumps (), data2 .dumps (), ** kwargs )
142
193
143
194
144
- @is_equivalent .register
195
+ @_is_equivalent_dispatcher .register
145
196
def _check_dataprocessing_instances (
146
197
data1 : Union [DataAction , DataProcessor ],
147
198
data2 : Union [DataAction , DataProcessor ],
199
+ ** kwargs ,
148
200
):
149
201
"""Check equality of classes in the data_processing module."""
150
202
return repr (data1 ) == repr (data2 )
151
203
152
204
153
- @is_equivalent .register
205
+ @_is_equivalent_dispatcher .register
154
206
def _check_curvefit_results (
155
207
data1 : CurveFitResult ,
156
208
data2 : CurveFitResult ,
209
+ ** kwargs ,
157
210
):
158
211
"""Check equality of curve fit result."""
159
212
return _check_all_attributes (
@@ -177,13 +230,15 @@ def _check_curvefit_results(
177
230
],
178
231
data1 = data1 ,
179
232
data2 = data2 ,
233
+ ** kwargs ,
180
234
)
181
235
182
236
183
- @is_equivalent .register
237
+ @_is_equivalent_dispatcher .register
184
238
def _check_service_analysis_results (
185
239
data1 : AnalysisResult ,
186
240
data2 : AnalysisResult ,
241
+ ** kwargs ,
187
242
):
188
243
"""Check equality of AnalysisResult class which is payload for experiment service."""
189
244
return _check_all_attributes (
@@ -203,22 +258,25 @@ def _check_service_analysis_results(
203
258
],
204
259
data1 = data1 ,
205
260
data2 = data2 ,
261
+ ** kwargs ,
206
262
)
207
263
208
264
209
- @is_equivalent .register
265
+ @_is_equivalent_dispatcher .register
210
266
def _check_configurable_classes (
211
267
data1 : Union [BaseExperiment , BaseAnalysis , BaseDrawer ],
212
268
data2 : Union [BaseExperiment , BaseAnalysis , BaseDrawer ],
269
+ ** kwargs ,
213
270
):
214
271
"""Check equality of Qiskit Experiments class with config method."""
215
- return is_equivalent (data1 .config (), data2 .config ())
272
+ return is_equivalent (data1 .config (), data2 .config (), ** kwargs )
216
273
217
274
218
- @is_equivalent .register
275
+ @_is_equivalent_dispatcher .register
219
276
def _check_experiment_data (
220
277
data1 : ExperimentData ,
221
278
data2 : ExperimentData ,
279
+ ** kwargs ,
222
280
):
223
281
"""Check equality of ExperimentData."""
224
282
attributes_equiv = _check_all_attributes (
@@ -234,18 +292,22 @@ def _check_experiment_data(
234
292
],
235
293
data1 = data1 ,
236
294
data2 = data2 ,
295
+ ** kwargs ,
237
296
)
238
297
data_equiv = is_equivalent (
239
298
data1 .data (),
240
299
data2 .data (),
300
+ ** kwargs ,
241
301
)
242
302
analysis_results_equiv = is_equivalent (
243
303
data1 ._analysis_results ,
244
304
data2 ._analysis_results ,
305
+ ** kwargs ,
245
306
)
246
307
child_equiv = is_equivalent (
247
308
data1 .child_data (),
248
309
data2 .child_data (),
310
+ ** kwargs ,
249
311
)
250
312
return all ([attributes_equiv , data_equiv , analysis_results_equiv , child_equiv ])
251
313
@@ -254,6 +316,7 @@ def _check_all_attributes(
254
316
attrs : List [str ],
255
317
data1 : Any ,
256
318
data2 : Any ,
319
+ ** kwargs ,
257
320
):
258
321
"""Helper function to check all attributes."""
259
- return all (is_equivalent (getattr (data1 , att ), getattr (data2 , att )) for att in attrs )
322
+ return all (is_equivalent (getattr (data1 , att ), getattr (data2 , att ), ** kwargs ) for att in attrs )
0 commit comments