Skip to content

Commit 5bb1fb4

Browse files
Scatter table refactoring (#1319)
### Summary This PR modifies `ScatterTable` which is introduced in #1253. This change resolves some code issues in #1315 and #1243. ### Details and comments In the original design `ScatterTable` is tied to the fit models, and the columns contains `model_name` (str) and `model_id` (int). Also the fit module only allows to have three categorical data; "processed", "formatted", "fitted". However, #1243 breaks this assumption, namely, the `StarkRamseyXYAmpScanAnalysis` fitter defines two fit models which are not directly mapped to the results data. The data fed into the models is synthesized by consuming the input results data. The fitter needs to manage four categorical data; "raw", "ramsey" (raw results), "phase" (synthesized data for fit), and "fitted". This PR relaxes the tight coupling of data to the fit model. In above example, "raw" and "ramsey" category data can fill new fields `name` (formally model_name) and `class_id` (model_id) without indicating a particular fit model. Usually, raw category data is just classified according to the `data_subfit_map` definition, and the map doesn't need to match with the fit models. The connection to fit models is only introduced in a particular category defined by new option value `fit_category`. This option defaults to "formatted", but `StarkRamseyXYAmpScanAnalysis` fitter would set "phase" instead. Thus fit model assignment is effectively delayed until the formatter function. Also the original scatter table is designed to store all circuit metadata which causes some problem in data formatting, especially when it tries to average the data over the same x value in the group. Non-numeric data is averaged by builtin set operation, but this assumes the metadata value is hashable object, which is not generally true. This PR also drops all metadata from the scatter table. Note that important metadata fields for the curve analysis are one used for model classification (classifier fields), and other fields just decorate the table with unnecessary memory footprint requirements. The classifier fields and `name` (`class_id`) are sort of duplicated information. This implies the `name` and `class_id` fields are enough for end-users to reuse the table data for further analysis once after it's saved as an artifact. --------- Co-authored-by: Will Shanks <[email protected]>
1 parent 726f422 commit 5bb1fb4

File tree

9 files changed

+230
-198
lines changed

9 files changed

+230
-198
lines changed

docs/tutorials/curve_analysis.rst

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,12 @@ different sets of experiment results. A single experiment can define sub-experim
2323
consisting of multiple circuits which are tagged with common metadata,
2424
and curve analysis sorts the experiment results based on the circuit metadata.
2525

26-
This is an example of showing the abstract data structure of a typical curve analysis experiment:
26+
This is an example showing the abstract data flow of a typical curve analysis experiment:
2727

28-
.. jupyter-input::
29-
:emphasize-lines: 1,10,19
30-
31-
"experiment"
32-
- circuits[0] (x=x1_A, "series_A")
33-
- circuits[1] (x=x1_B, "series_B")
34-
- circuits[2] (x=x2_A, "series_A")
35-
- circuits[3] (x=x2_B, "series_B")
36-
- circuits[4] (x=x3_A, "series_A")
37-
- circuits[5] (x=x3_B, "series_B")
38-
- ...
39-
40-
"experiment data"
41-
- data[0] (y1_A, "series_A")
42-
- data[1] (y1_B, "series_B")
43-
- data[2] (y2_A, "series_A")
44-
- data[3] (y2_B, "series_B")
45-
- data[4] (y3_A, "series_A")
46-
- data[5] (y3_B, "series_B")
47-
- ...
48-
49-
"analysis"
50-
- "series_A": y_A = f_A(x_A; p0, p1, p2)
51-
- "series_B": y_B = f_B(x_B; p0, p1, p2)
52-
- fixed parameters {p1: v}
28+
.. figure:: images/curve_analysis_structure.png
29+
:width: 600
30+
:align: center
31+
:class: no-scaled-link
5332

5433
Here the experiment runs two subsets of experiments, namely, series A and series B.
5534
The analysis defines corresponding fit models :math:`f_A(x_A)` and :math:`f_B(x_B)`.
@@ -289,21 +268,78 @@ A developer can override this method to perform initialization of analysis-speci
289268

290269
Curve analysis calls the :meth:`_run_data_processing` method, where
291270
the data processor in the analysis option is internally called.
292-
This consumes input experiment results and creates the :class:`.CurveData` dataclass.
293-
Then the :meth:`_format_data` method is called with the processed dataset to format it.
271+
This consumes input experiment results and creates the :class:`.ScatterTable` dataframe.
272+
This table may look like:
273+
274+
.. code-block::
275+
276+
xval yval yerr name class_id category shots
277+
0 0.1 0.153659 0.011258 A 0 raw 1024
278+
1 0.1 0.590732 0.015351 B 1 raw 1024
279+
2 0.1 0.315610 0.014510 A 0 raw 1024
280+
3 0.1 0.376098 0.015123 B 1 raw 1024
281+
4 0.2 0.937073 0.007581 A 0 raw 1024
282+
5 0.2 0.323415 0.014604 B 1 raw 1024
283+
6 0.2 0.538049 0.015565 A 0 raw 1024
284+
7 0.2 0.530244 0.015581 B 1 raw 1024
285+
8 0.3 0.143902 0.010958 A 0 raw 1024
286+
9 0.3 0.261951 0.013727 B 1 raw 1024
287+
10 0.3 0.830732 0.011707 A 0 raw 1024
288+
11 0.3 0.874634 0.010338 B 1 raw 1024
289+
290+
where the experiment consists of two subset series A and B, and the experiment parameter (xval)
291+
is scanned from 0.1 to 0.3 in each subset. In this example, the experiment is run twice
292+
for each condition. The role of each column is as follows:
293+
294+
- ``xval``: Parameter scanned in the experiment. This value must be defined in the circuit metadata.
295+
- ``yval``: Nominal part of the outcome. The outcome is something like expectation value, which is computed from the experiment result with the data processor.
296+
- ``yerr``: Standard error of the outcome, which is mainly due to sampling error.
297+
- ``name``: Unique identifier of the result class. This is defined by the ``data_subfit_map`` option.
298+
- ``class_id``: Numerical index corresponding to the result class. This number is automatically assigned.
299+
- ``category``: The attribute of data set. The "raw" category indicates an output from the data processing.
300+
- ``shots``: Number of measurement shots used to acquire this result.
301+
302+
3. Formatting
303+
^^^^^^^^^^^^^
304+
305+
Next, the processed dataset is converted into another format suited for the fitting and
306+
every valid result is assigned a class corresponding to a fit model.
294307
By default, the formatter takes average of the outcomes in the processed dataset
295308
over the same x values, followed by the sorting in the ascending order of x values.
296309
This allows the analysis to easily estimate the slope of the curves to
297310
create algorithmic initial guess of fit parameters.
298311
A developer can inject extra data processing, for example, filtering, smoothing,
299312
or elimination of outliers for better fitting.
313+
The new class_id is given here so that its value corresponds to the fit model object index
314+
in this analysis class. This index mapping is done based upon the correspondence of
315+
the data name and the fit model name.
316+
317+
This is done by calling :meth:`_format_data` method.
318+
This may return new scatter table object with the addition of rows like the following below.
319+
320+
.. code-block::
321+
322+
12 0.1 0.234634 0.009183 A 0 formatted 2048
323+
13 0.2 0.737561 0.008656 A 0 formatted 2048
324+
14 0.3 0.487317 0.008018 A 0 formatted 2048
325+
15 0.1 0.483415 0.010774 B 1 formatted 2048
326+
16 0.2 0.426829 0.010678 B 1 formatted 2048
327+
17 0.3 0.568293 0.008592 B 1 formatted 2048
328+
329+
The default :meth:`_format_data` method adds its output data with the category "formatted".
330+
This category name must be also specified in the analysis option ``fit_category``.
331+
If overriding this method to do additional processing after the default formatting,
332+
the ``fit_category`` analysis option can be set to choose a different category name to use to
333+
select the data to pass to the fitting routine.
334+
The (x, y) value in each row is passed to the corresponding fit model object
335+
to compute residual values for the least square optimization.
300336

301337
3. Fitting
302338
^^^^^^^^^^
303339

304-
Curve analysis calls the :meth:`_run_curve_fit` method, which is the core functionality of the fitting.
305-
Another method :meth:`_generate_fit_guesses` is internally called to
306-
prepare the initial guess and parameter boundary with respect to the formatted data.
340+
Curve analysis calls the :meth:`_run_curve_fit` method with the formatted subset of the scatter table.
341+
This internally calls :meth:`_generate_fit_guesses` to prepare
342+
the initial guess and parameter boundary with respect to the formatted dataset.
307343
Developers usually override this method to provide better initial guesses
308344
tailored to the defined fit model or type of the associated experiment.
309345
See :ref:`curve_analysis_init_guess` for more details.
@@ -314,13 +350,18 @@ custom fitting algorithms. This method must return a :class:`.CurveFitResult` da
314350
^^^^^^^^^^^^^^^^^^
315351

316352
Curve analysis runs several postprocessing against the fit outcome.
317-
It calls :meth:`._create_analysis_results` to create the :class:`.AnalysisResultData` class
353+
When the fit is successful, it calls :meth:`._create_analysis_results` to create the :class:`.AnalysisResultData` objects
318354
for the fitting parameters of interest. A developer can inject custom code to
319355
compute custom quantities based on the raw fit parameters.
320356
See :ref:`curve_analysis_results` for details.
321-
Afterwards, figure plotting is handed over to the :doc:`Visualization </tutorials/visualization>` module via
322-
the :attr:`~.CurveAnalysis.plotter` attribute, and a list of created analysis results and the figure are returned.
323-
357+
Afterwards, fit curves are computed with the fit models and optimal parameters, and the scatter table is
358+
updated with the computed (x, y) values. This dataset is stored under the "fitted" category.
359+
360+
Finally, the :meth:`._create_figures` method is called with the entire scatter table data
361+
to initialize the curve plotter instance accessible via the :attr:`~.CurveAnalysis.plotter` attribute.
362+
The visualization is handed over to the :doc:`Visualization </tutorials/visualization>` module,
363+
which provides a standardized image format for curve fit results.
364+
A developer can overwrite this method to draw custom images.
324365

325366
.. _curve_analysis_init_guess:
326367

10.7 MB
Loading

qiskit_experiments/curve_analysis/base_curve_analysis.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def _default_options(cls) -> Options:
188188
lmfit_options (Dict[str, Any]): Options that are passed to the
189189
LMFIT minimizer. Acceptable options depend on fit_method.
190190
x_key (str): Circuit metadata key representing a scanned value.
191+
fit_category (str): Name of dataset in the scatter table to fit.
191192
result_parameters (List[Union[str, ParameterRepr]): Parameters reported in the
192193
database as a dedicated entry. This is a list of parameter representation
193194
which is either string or ParameterRepr object. If you provide more
@@ -219,6 +220,7 @@ def _default_options(cls) -> Options:
219220
options.normalization = False
220221
options.average_method = "shots_weighted"
221222
options.x_key = "xval"
223+
options.fit_category = "formatted"
222224
options.result_parameters = []
223225
options.extra = {}
224226
options.fit_method = "least_squares"
@@ -282,11 +284,13 @@ def set_options(self, **fields):
282284
def _run_data_processing(
283285
self,
284286
raw_data: List[Dict],
287+
category: str = "raw",
285288
) -> ScatterTable:
286289
"""Perform data processing from the experiment result payload.
287290
288291
Args:
289292
raw_data: Payload in the experiment data.
293+
category: Category string of the output dataset.
290294
291295
Returns:
292296
Processed data that will be sent to the formatter method.
@@ -296,14 +300,16 @@ def _run_data_processing(
296300
def _format_data(
297301
self,
298302
curve_data: ScatterTable,
303+
category: str = "formatted",
299304
) -> ScatterTable:
300-
"""Postprocessing for the processed dataset.
305+
"""Postprocessing for preparing the fitting data.
301306
302307
Args:
303308
curve_data: Processed dataset created from experiment results.
309+
category: Category string of the output dataset.
304310
305311
Returns:
306-
Formatted data.
312+
New scatter table instance including fit data.
307313
"""
308314

309315
@abstractmethod

qiskit_experiments/curve_analysis/composite_curve_analysis.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -230,32 +230,32 @@ def _create_figures(
230230
A list of figures.
231231
"""
232232
for analysis in self.analyses():
233-
sub_data = curve_data[curve_data.model_name.str.endswith(f"_{analysis.name}")]
234-
for model_id, data in list(sub_data.groupby("model_id")):
235-
model_name = analysis._models[model_id]._name
233+
sub_data = curve_data[curve_data.group == analysis.name]
234+
for name, data in list(sub_data.groupby("name")):
235+
full_name = f"{name}_{analysis.name}"
236236
# Plot raw data scatters
237237
if analysis.options.plot_raw_data:
238-
raw_data = data.filter(like="processed", axis="index")
238+
raw_data = data[data.category == "raw"]
239239
self.plotter.set_series_data(
240-
series_name=model_name,
240+
series_name=full_name,
241241
x=raw_data.xval.to_numpy(),
242242
y=raw_data.yval.to_numpy(),
243243
)
244244
# Plot formatted data scatters
245-
formatted_data = data.filter(like="formatted", axis="index")
245+
formatted_data = data[data.category == analysis.options.fit_category]
246246
self.plotter.set_series_data(
247-
series_name=model_name,
247+
series_name=full_name,
248248
x_formatted=formatted_data.xval.to_numpy(),
249249
y_formatted=formatted_data.yval.to_numpy(),
250250
y_formatted_err=formatted_data.yerr.to_numpy(),
251251
)
252252
# Plot fit lines
253-
line_data = data.filter(like="fitted", axis="index")
253+
line_data = data[data.category == "fitted"]
254254
if len(line_data) == 0:
255255
continue
256256
fit_stdev = line_data.yerr.to_numpy()
257257
self.plotter.set_series_data(
258-
series_name=model_name,
258+
series_name=full_name,
259259
x_interp=line_data.xval.to_numpy(),
260260
y_interp=line_data.yval.to_numpy(),
261261
y_interp_err=fit_stdev if np.isfinite(fit_stdev).all() else None,
@@ -353,21 +353,16 @@ def _run_analysis(
353353
metadata = analysis.options.extra.copy()
354354
metadata["group"] = analysis.name
355355

356-
curve_data = analysis._format_data(
357-
analysis._run_data_processing(experiment_data.data())
358-
)
359-
fit_data = analysis._run_curve_fit(curve_data.filter(like="formatted", axis="index"))
356+
table = analysis._format_data(analysis._run_data_processing(experiment_data.data()))
357+
formatted_subset = table[table.category == analysis.options.fit_category]
358+
fit_data = analysis._run_curve_fit(formatted_subset)
360359
fit_dataset[analysis.name] = fit_data
361360

362361
if fit_data.success:
363362
quality = analysis._evaluate_quality(fit_data)
364363
else:
365364
quality = "bad"
366365

367-
# After the quality is determined, plot can become a boolean flag for whether
368-
# to generate the figure
369-
plot_bool = plot == "always" or (plot == "selective" and quality == "bad")
370-
371366
if self.options.return_fit_parameters:
372367
# Store fit status overview entry regardless of success.
373368
# This is sometime useful when debugging the fitting code.
@@ -382,10 +377,9 @@ def _run_analysis(
382377
if fit_data.success:
383378
# Add fit data to curve data table
384379
fit_curves = []
385-
formatted = curve_data.filter(like="formatted", axis="index")
386-
columns = list(curve_data.columns)
387-
for i, sub_data in list(formatted.groupby("model_id")):
388-
name = analysis._models[i]._name
380+
columns = list(table.columns)
381+
model_names = analysis.model_names()
382+
for i, sub_data in list(formatted_subset.groupby("class_id")):
389383
xval = sub_data.xval.to_numpy()
390384
if len(xval) == 0:
391385
# If data is empty, skip drawing this model.
@@ -404,12 +398,10 @@ def _run_analysis(
404398
model_fit[:, columns.index("yval")] = unp.nominal_values(yval_fit)
405399
if fit_data.covar is not None:
406400
model_fit[:, columns.index("yerr")] = unp.std_devs(yval_fit)
407-
model_fit[:, columns.index("model_name")] = name
408-
model_fit[:, columns.index("model_id")] = i
409-
curve_data = curve_data.append_list_values(
410-
other=np.vstack(fit_curves),
411-
prefix="fitted",
412-
)
401+
model_fit[:, columns.index("name")] = model_names[i]
402+
model_fit[:, columns.index("class_id")] = i
403+
model_fit[:, columns.index("category")] = "fitted"
404+
table = table.append_list_values(other=np.vstack(fit_curves))
413405
analysis_results.extend(
414406
analysis._create_analysis_results(
415407
fit_data=fit_data,
@@ -421,18 +413,20 @@ def _run_analysis(
421413
if self.options.return_data_points:
422414
# Add raw data points
423415
analysis_results.extend(
424-
analysis._create_curve_data(
425-
curve_data=curve_data.filter(like="formatted", axis="index"),
426-
**metadata,
427-
)
416+
analysis._create_curve_data(curve_data=formatted_subset, **metadata)
428417
)
429418

430-
curve_data.model_name += f"_{analysis.name}"
431-
curve_data_set.append(curve_data)
419+
# Add extra column to identify the fit model
420+
table["group"] = analysis.name
421+
curve_data_set.append(table)
432422

433423
combined_curve_data = pd.concat(curve_data_set)
434424
total_quality = self._evaluate_quality(fit_dataset)
435425

426+
# After the quality is determined, plot can become a boolean flag for whether
427+
# to generate the figure
428+
plot_bool = plot == "always" or (plot == "selective" and total_quality == "bad")
429+
436430
# Create analysis results by combining all fit data
437431
if all(fit_data.success for fit_data in fit_dataset.values()):
438432
composite_results = self._create_analysis_results(

0 commit comments

Comments
 (0)