diff --git a/qiskit_experiments/framework/base_experiment.py b/qiskit_experiments/framework/base_experiment.py index 6c0eb21748..a8e755c912 100644 --- a/qiskit_experiments/framework/base_experiment.py +++ b/qiskit_experiments/framework/base_experiment.py @@ -16,7 +16,8 @@ from abc import ABC, abstractmethod import copy from collections import OrderedDict -from typing import Sequence, Optional, Tuple, List, Dict, Union +from typing import Sequence, Optional, Tuple, List, Dict, Union, Hashable +from functools import wraps import warnings from qiskit import transpile, QuantumCircuit @@ -31,6 +32,55 @@ from qiskit_experiments.framework.configs import ExperimentConfig +def cached_method(method): + """Decorator to cache the return value of a BaseExperiment method. + + This stores the output of a method in the experiment object instance + in a `_cache` dict attribute. Note that the value is cached only on + the object instance method name, not any values of its arguments. + + The cache can be cleared by calling :meth:`.BaseExperiment.cache_clear`. + """ + + @wraps(method) + def wrapped_method(self, *args, **kwargs): + name = f"{type(self).__name__}.{method.__name__}" + + # making a tuple from the options value. + options_dict = vars(self.experiment_options) + cache_key = tuple(options_dict.values()) + tuple([name]) + for key, val in options_dict.items(): + if isinstance(val, list): # pylint: disable=isinstance-second-argument-not-valid-type + val = tuple(val) + options_dict[key] = val + cache_key = tuple(options_dict.values()) + tuple([name]) + if isinstance( # pylint: disable=isinstance-second-argument-not-valid-type + val, Hashable + ): + continue + # if one of the values in option isn't hashable, we raise a warning and we use the name as + # the key of the cached circuit + warnings.warn( + f"The value of the option {key!r} is not hashable. This can make the cached " + f"transpiled circuit to not match the options." + ) + cache_key = (name,) + break + + # Check for cached value + cached = self._cache.get(cache_key, None) + if cached is not None: + return cached + + # Call method and cache output + cached = method(self, *args, **kwargs) + self._cache[cache_key] = cached + + return cached + + return wrapped_method + + class BaseExperiment(ABC, StoreInitArgs): """Abstract base class for experiments.""" @@ -55,6 +105,9 @@ def __init__( # Experiment identification metadata self._type = experiment_type if experiment_type else type(self).__name__ + # Initialize cache + self._cache = {} + # Circuit parameters self._num_qubits = len(qubits) self._physical_qubits = tuple(qubits) @@ -364,6 +417,7 @@ def circuits(self) -> List[QuantumCircuit]: # values for any explicit experiment options that affect circuit # generation + @cached_method def _transpiled_circuits(self) -> List[QuantumCircuit]: """Return a list of experiment circuits, transpiled. @@ -382,7 +436,6 @@ def _transpiled_circuits(self) -> List[QuantumCircuit]: DeprecationWarning, ) self._postprocess_transpiled_circuits(transpiled) # pylint: disable=no-member - return transpiled @classmethod @@ -409,6 +462,7 @@ def set_experiment_options(self, **fields): Raises: AttributeError: If the field passed in is not a supported options """ + self.cache_clear() for field in fields: if not hasattr(self._experiment_options, field): raise AttributeError( @@ -439,6 +493,7 @@ def set_transpile_options(self, **fields): Raises: QiskitError: if `initial_layout` is one of the fields. """ + self.cache_clear() if "initial_layout" in fields: raise QiskitError( "Initial layout cannot be specified as a transpile option" @@ -502,6 +557,10 @@ def set_analysis_options(self, **fields): ) self.analysis.options.update_options(**fields) + def cache_clear(self): + """Clear all cached method outputs.""" + self._cache = {} + def _metadata(self) -> Dict[str, any]: """Return experiment metadata for ExperimentData. diff --git a/qiskit_experiments/framework/composite/batch_experiment.py b/qiskit_experiments/framework/composite/batch_experiment.py index 534427335c..fc79daff20 100644 --- a/qiskit_experiments/framework/composite/batch_experiment.py +++ b/qiskit_experiments/framework/composite/batch_experiment.py @@ -19,7 +19,8 @@ from qiskit import QuantumCircuit from qiskit.providers.backend import Backend -from .composite_experiment import CompositeExperiment, BaseExperiment +from qiskit_experiments.framework.base_experiment import BaseExperiment, cached_method +from .composite_experiment import CompositeExperiment from .composite_analysis import CompositeAnalysis @@ -81,6 +82,7 @@ def __init__( def circuits(self): return self._batch_circuits(to_transpile=False) + @cached_method def _transpiled_circuits(self): return self._batch_circuits(to_transpile=True) diff --git a/qiskit_experiments/library/characterization/tphi.py b/qiskit_experiments/library/characterization/tphi.py index 828901c18b..46d2a7a416 100644 --- a/qiskit_experiments/library/characterization/tphi.py +++ b/qiskit_experiments/library/characterization/tphi.py @@ -16,8 +16,8 @@ from typing import List, Optional, Union import numpy as np -from qiskit import QiskitError from qiskit.providers import Backend +from qiskit_experiments.framework import Options from qiskit_experiments.framework.composite.batch_experiment import BatchExperiment from qiskit_experiments.library.characterization import ( T1, @@ -51,6 +51,14 @@ class Tphi(BatchExperiment): :doc:`/tutorials/tphi_characterization` """ + @classmethod + def _default_experiment_options(cls): + return Options( + delays_t1=None, + delays_t2=None, + osc_freq=0.0, + ) + def set_experiment_options(self, **fields): """Set the experiment options. Args: @@ -59,16 +67,14 @@ def set_experiment_options(self, **fields): Raises: QiskitError : Error for invalid input option. """ + super().set_experiment_options(**fields) # propagate options to the sub-experiments. - for key in fields: - if key == "delays_t1": - self.component_experiment(0).set_experiment_options(delays=fields["delays_t1"]) - elif key == "delays_t2": - self.component_experiment(1).set_experiment_options(delays=fields["delays_t2"]) - elif key == "osc_freq": - self.component_experiment(1).set_experiment_options(osc_freq=fields["osc_freq"]) - else: - raise QiskitError(f"Tphi experiment does not support option {key}") + if "delays_t1" in fields: + self.component_experiment(0).set_experiment_options(delays=fields["delays_t1"]) + if "delays_t2" in fields: + self.component_experiment(1).set_experiment_options(delays=fields["delays_t2"]) + if "osc_freq" in fields: + self.component_experiment(1).set_experiment_options(osc_freq=fields["osc_freq"]) def __init__( self, @@ -99,4 +105,4 @@ def __init__( # Create batch experiment super().__init__([exp_t1, exp_t2], backend=backend, analysis=analysis) - self.set_experiment_options(delays_t1=delays_t1, delays_t2=delays_t2) + self.set_experiment_options(delays_t1=delays_t1, delays_t2=delays_t2, osc_freq=osc_freq) diff --git a/qiskit_experiments/library/randomized_benchmarking/rb_experiment.py b/qiskit_experiments/library/randomized_benchmarking/rb_experiment.py index 0423666cb4..1f6bba6de7 100644 --- a/qiskit_experiments/library/randomized_benchmarking/rb_experiment.py +++ b/qiskit_experiments/library/randomized_benchmarking/rb_experiment.py @@ -25,7 +25,7 @@ from qiskit.quantum_info import Clifford from qiskit.providers.backend import Backend -from qiskit_experiments.framework import BaseExperiment, Options +from qiskit_experiments.framework.base_experiment import BaseExperiment, Options, cached_method from qiskit_experiments.framework.restless_mixin import RestlessMixin from .rb_analysis import RBAnalysis from .clifford_utils import CliffordUtils @@ -211,6 +211,7 @@ def _generate_circuit( circuits.append(rb_circ) return circuits + @cached_method def _transpiled_circuits(self) -> List[QuantumCircuit]: """Return a list of experiment circuits, transpiled.""" transpiled = super()._transpiled_circuits() diff --git a/releasenotes/notes/cached-method-87b5d878f585ca92.yaml b/releasenotes/notes/cached-method-87b5d878f585ca92.yaml new file mode 100644 index 0000000000..429d8643c1 --- /dev/null +++ b/releasenotes/notes/cached-method-87b5d878f585ca92.yaml @@ -0,0 +1,13 @@ +--- +features: + - | + Adds caching of transpiled circuit generation to :class:`.BaseExperiment` + so that repeated calls of :class:`~.BaseExperiment.run` will avoid + repeated circuit generation and transpilation costs if no experiment options + are changed between run calls. + + Changing experiment or transpilation options with the + :meth:`~.BaseExperiment.set_experiment_options` or + :meth:`~.BaseExperiment.set_transpilation_options` will clear the + cached circuits. The cache can also be manually cleared by calling the + :meth:`~.BaseExperiment.cache_clear` method. diff --git a/test/randomized_benchmarking/test_randomized_benchmarking.py b/test/randomized_benchmarking/test_randomized_benchmarking.py index 64826c46d4..a6265def7a 100644 --- a/test/randomized_benchmarking/test_randomized_benchmarking.py +++ b/test/randomized_benchmarking/test_randomized_benchmarking.py @@ -210,6 +210,30 @@ def test_return_same_circuit(self): self.assertEqual(circs1[1].decompose(), circs2[1].decompose()) self.assertEqual(circs1[2].decompose(), circs2[2].decompose()) + def test_experiment_cache(self): + """Test experiment transpiled circuit cache""" + exp0 = rb.StandardRB( + qubits=(0, 1), + lengths=[10, 20, 30], + seed=123, + backend=self.backend, + ) + exp0.set_transpile_options(**self.transpiler_options) + + # calling a method with '@cached_method' decorator + exp0_transpiled_circ = exp0._transpiled_circuits() + + # calling the method again returns cached circuit + exp0_transpiled_cache = exp0._transpiled_circuits() + + self.assertEqual(exp0_transpiled_circ[0].decompose(), exp0_transpiled_cache[0].decompose()) + self.assertEqual(exp0_transpiled_circ[1].decompose(), exp0_transpiled_cache[1].decompose()) + self.assertEqual(exp0_transpiled_circ[2].decompose(), exp0_transpiled_cache[2].decompose()) + + # Checking that the cache is cleared when setting options + exp0.set_experiment_options(lengths=[10, 20, 30, 40]) + self.assertEqual(exp0._cache, {}) + def test_full_sampling(self): """Test if full sampling generates different circuits.""" exp1 = rb.StandardRB( @@ -357,6 +381,29 @@ def test_two_qubit(self): epc_expected = 3 / 4 * self.p2q self.assertAlmostEqual(epc.value.n, epc_expected, delta=0.1 * epc_expected) + def test_interleaved_cache(self): + """Test two qubit IRB.""" + exp = rb.InterleavedRB( + interleaved_element=CXGate(), + qubits=(0, 1), + lengths=list(range(1, 30, 3)), + seed=123, + backend=self.backend, + ) + exp.set_transpile_options(**self.transpiler_options) + + # calling a method with '@cached_method' decorator + exp_transpiled_circ = exp._transpiled_circuits() + + # calling the method again returns cached circuit + exp_transpiled_cache = exp._transpiled_circuits() + for circ, cached_circ in zip(exp_transpiled_circ, exp_transpiled_cache): + self.assertEqual(circ.decompose(), cached_circ.decompose()) + + # Checking that the cache is cleared when setting options + exp.set_experiment_options(lengths=[10, 20, 30, 40]) + self.assertEqual(exp._cache, {}) + def test_non_clifford_interleaved_element(self): """Verifies trying to run interleaved RB with non Clifford element throws an exception""" qubits = 1 diff --git a/test/test_framework.py b/test/test_framework.py index 70590a2a09..b667104732 100644 --- a/test/test_framework.py +++ b/test/test_framework.py @@ -18,6 +18,7 @@ from qiskit import QuantumCircuit from qiskit_experiments.framework import ExperimentData +from qiskit_experiments.framework.base_experiment import cached_method from qiskit_experiments.test.fake_backend import FakeBackend @@ -117,3 +118,37 @@ def test_analysis_runtime_opts(self): target_opts["figure_names"] = None self.assertEqual(analysis.options.__dict__, target_opts) + + def test_cached_method(self): + """Test cached method decorator""" + + class Experiment(FakeExperiment): + """Test experiment""" + + @cached_method + def custom_method(self): + """Cached method""" + return [1, 2, 3] + + exp = Experiment([0]) + value1 = exp.custom_method() + value2 = exp.custom_method() + self.assertIn("Experiment.custom_method", exp._cache) + self.assertTrue(value1 is value2) + + def test_cached_transpiled_circuits(self): + """Test transpiled circuits are cached""" + exp = FakeExperiment([0]) + value1 = exp._transpiled_circuits() + value2 = exp._transpiled_circuits() + self.assertIn("FakeExperiment._transpiled_circuits", exp._cache) + self.assertTrue(value1 is value2) + + def test_cache_clear(self): + """Test cache_clear method""" + exp = FakeExperiment([0]) + value1 = exp._transpiled_circuits() + exp.cache_clear() + self.assertNotIn("FakeExperiment._transpiled_circuits", exp._cache) + value2 = exp._transpiled_circuits() + self.assertFalse(value1 is value2)