diff --git a/qdev_wrappers/dataset/doNd.py b/qdev_wrappers/dataset/doNd.py index 4a96c21a..aadde4e9 100644 --- a/qdev_wrappers/dataset/doNd.py +++ b/qdev_wrappers/dataset/doNd.py @@ -1,4 +1,5 @@ -from typing import Callable, Sequence, Union, Tuple, List, Optional +from contextlib import contextmanager +from typing import Callable, Sequence, Union, Tuple, List, Optional, Iterator import os import time @@ -6,22 +7,84 @@ import matplotlib import matplotlib.pyplot as plt -from qcodes.dataset.measurements import Measurement +from qcodes.dataset.measurements import Measurement, res_type, DataSaver from qcodes.instrument.base import _BaseParameter from qcodes.dataset.plotting import plot_by_id from qcodes import config +ActionsT = Sequence[Callable[[], None]] + +ParamMeasT = Union[_BaseParameter, Callable[[], None]] + AxesTuple = Tuple[matplotlib.axes.Axes, matplotlib.colorbar.Colorbar] AxesTupleList = Tuple[List[matplotlib.axes.Axes], List[Optional[matplotlib.colorbar.Colorbar]]] AxesTupleListWithRunId = Tuple[int, List[matplotlib.axes.Axes], List[Optional[matplotlib.colorbar.Colorbar]]] -number = Union[float, int] -def do0d(*param_meas: Union[_BaseParameter, Callable[[], None]], - write_period: Optional[float] = None, - do_plot: bool = True) -> AxesTupleListWithRunId: +def _process_params_meas(param_meas: ParamMeasT) -> List[res_type]: + output = [] + for parameter in param_meas: + if isinstance(parameter, _BaseParameter): + output.append((parameter, parameter.get())) + elif callable(parameter): + parameter() + return output + + +def _register_parameters( + meas: Measurement, + param_meas: List[ParamMeasT], + setpoints: Optional[List[_BaseParameter]] = None +) -> None: + for parameter in param_meas: + if isinstance(parameter, _BaseParameter): + meas.register_parameter(parameter, + setpoints=setpoints) + + +def _register_actions( + meas: Measurement, + enter_actions: ActionsT, + exit_actions: ActionsT +) -> None: + for action in enter_actions: + # this omits the possibility of passing + # argument to enter and exit actions. + # Do we want that? + meas.add_before_run(action, ()) + for action in exit_actions: + meas.add_after_run(action, ()) + + + +def _set_write_period( + meas: Measurement, + write_period: Optional[float] = None +) -> None: + if write_period is not None: + meas.write_period = write_period + + +@contextmanager +def _catch_keyboard_interrupts() -> Iterator[Callable[[], bool]]: + interrupted = False + def has_been_interrupted(): + nonlocal interrupted + return interrupted + try: + yield has_been_interrupted + except KeyboardInterrupt: + interrupted = True + + + +def do0d( + *param_meas: ParamMeasT, + write_period: Optional[float] = None, + do_plot: bool = True +) -> AxesTupleListWithRunId: """ Perform a measurement of a single parameter. This is probably most useful for an ArrayParamter that already returns an array of data points @@ -38,41 +101,27 @@ def do0d(*param_meas: Union[_BaseParameter, Callable[[], None]], The run_id of the DataSet created """ meas = Measurement() - if write_period is not None: - meas.write_period = write_period - output = [] - - for parameter in param_meas: - meas.register_parameter(parameter) - output.append([parameter, None]) + _register_parameters(meas, param_meas) + _set_write_period(meas, write_period) with meas.run() as datasaver: + datasaver.add_result(*_process_params_meas(param_meas)) + + return _handle_plotting(datasaver, do_plot) - for i, parameter in enumerate(param_meas): - if isinstance(parameter, _BaseParameter): - output[i][1] = parameter.get() - elif callable(parameter): - parameter() - datasaver.add_result(*output) - dataid = datasaver.run_id - if do_plot is True: - ax, cbs = _save_image(datasaver) - else: - ax = None, - cbs = None - return dataid, ax, cbs -def do1d(param_set: _BaseParameter, start: number, stop: number, - num_points: int, delay: number, - *param_meas: Union[_BaseParameter, Callable[[], None]], - enter_actions: Sequence[Callable[[], None]] = (), - exit_actions: Sequence[Callable[[], None]] = (), - write_period: Optional[float] = None, - do_plot: bool = True) \ - -> AxesTupleListWithRunId: +def do1d( + param_set: _BaseParameter, start: float, stop: float, + num_points: int, delay: float, + *param_meas: ParamMeasT, + enter_actions: ActionsT = (), + exit_actions: ActionsT = (), + write_period: Optional[float] = None, + do_plot: bool = True +) -> AxesTupleListWithRunId: """ Perform a 1D scan of ``param_set`` from ``start`` to ``stop`` in ``num_points`` measuring param_meas at each step. In case param_meas is @@ -99,72 +148,38 @@ def do1d(param_set: _BaseParameter, start: number, stop: number, The run_id of the DataSet created """ meas = Measurement() - if write_period is not None: - meas.write_period = write_period - meas.register_parameter( - param_set) # register the first independent parameter - output = [] + _register_parameters(meas, (param_set,)) + _register_parameters(meas, param_meas, setpoints=(param_set,)) + _set_write_period(meas, write_period) + _register_actions(meas, enter_actions, exit_actions) param_set.post_delay = delay - interrupted = False - - for action in enter_actions: - # this omits the posibility of passing - # argument to enter and exit actions. - # Do we want that? - meas.add_before_run(action, ()) - for action in exit_actions: - meas.add_after_run(action, ()) # do1D enforces a simple relationship between measured parameters # and set parameters. For anything more complicated this should be # reimplemented from scratch - for parameter in param_meas: - if isinstance(parameter, _BaseParameter): - meas.register_parameter(parameter, setpoints=(param_set,)) - output.append([parameter, None]) - - try: - with meas.run() as datasaver: - - for set_point in np.linspace(start, stop, num_points): - param_set.set(set_point) - output = [] - for parameter in param_meas: - if isinstance(parameter, _BaseParameter): - output.append((parameter, parameter.get())) - elif callable(parameter): - parameter() - datasaver.add_result((param_set, set_point), - *output) - except KeyboardInterrupt: - interrupted = True - - dataid = datasaver.run_id # convenient to have for plotting - - if do_plot is True: - ax, cbs = _save_image(datasaver) - else: - ax = None, - cbs = None - - if interrupted: - raise KeyboardInterrupt - return dataid, ax, cbs - - -def do2d(param_set1: _BaseParameter, start1: number, stop1: number, - num_points1: int, delay1: number, - param_set2: _BaseParameter, start2: number, stop2: number, - num_points2: int, delay2: number, - *param_meas: Union[_BaseParameter, Callable[[], None]], - set_before_sweep: Optional[bool] = False, - enter_actions: Sequence[Callable[[], None]] = (), - exit_actions: Sequence[Callable[[], None]] = (), - before_inner_actions: Sequence[Callable[[], None]] = (), - after_inner_actions: Sequence[Callable[[], None]] = (), - write_period: Optional[float] = None, - flush_columns: bool = False, - do_plot: bool=True) -> AxesTupleListWithRunId: + with _catch_keyboard_interrupts() as interrupted, meas.run() as datasaver: + for set_point in np.linspace(start, stop, num_points): + param_set.set(set_point) + datasaver.add_result((param_set, set_point), + *_process_params_meas(param_meas)) + return _handle_plotting(datasaver, do_plot, interrupted()) + + +def do2d( + param_set1: _BaseParameter, start1: float, stop1: float, + num_points1: int, delay1: float, + param_set2: _BaseParameter, start2: float, stop2: float, + num_points2: int, delay2: float, + *param_meas: ParamMeasT, + set_before_sweep: Optional[bool] = False, + enter_actions: ActionsT = (), + exit_actions: ActionsT = (), + before_inner_actions: ActionsT = (), + after_inner_actions: ActionsT = (), + write_period: Optional[float] = None, + flush_columns: bool = False, + do_plot: bool=True +) -> AxesTupleListWithRunId: """ Perform a 1D scan of ``param_set1`` from ``start1`` to ``stop1`` in @@ -202,29 +217,16 @@ def do2d(param_set1: _BaseParameter, start1: number, stop1: number, """ meas = Measurement() - if write_period is not None: - meas.write_period = write_period - meas.register_parameter(param_set1) + _register_parameters(meas, (param_set1, param_set2)) + _register_parameters(meas, param_meas, setpoints=(param_set1, param_set2)) + _set_write_period(meas, write_period) + _register_actions(meas, enter_actions, exit_actions) + param_set1.post_delay = delay1 - meas.register_parameter(param_set2) param_set2.post_delay = delay2 - interrupted = False - for action in enter_actions: - # this omits the possibility of passing - # argument to enter and exit actions. - # Do we want that? - meas.add_before_run(action, ()) - - for action in exit_actions: - meas.add_after_run(action, ()) - for parameter in param_meas: - if isinstance(parameter, _BaseParameter): - meas.register_parameter(parameter, - setpoints=(param_set1, param_set2)) - try: - with meas.run() as datasaver: - for set_point1 in np.linspace(start1, stop1, num_points1): + with _catch_keyboard_interrupts() as interrupted, meas.run() as datasaver: + for set_point1 in np.linspace(start1, stop1, num_points1): if set_before_sweep: param_set2.set(start2) @@ -237,67 +239,64 @@ def do2d(param_set1: _BaseParameter, start1: number, stop1: number, pass else: param_set2.set(set_point2) - output = [] - for parameter in param_meas: - if isinstance(parameter, _BaseParameter): - output.append((parameter, parameter.get())) - elif callable(parameter): - parameter() + datasaver.add_result((param_set1, set_point1), (param_set2, set_point2), - *output) + *_process_params_meas(param_meas)) for action in after_inner_actions: action() if flush_columns: datasaver.flush_data_to_database() - except KeyboardInterrupt: - interrupted = True - dataid = datasaver.run_id + return _handle_plotting(datasaver, do_plot, interrupted()) - if do_plot is True: - ax, cbs = _save_image(datasaver) - else: - ax = None, - cbs = None - if interrupted: - raise KeyboardInterrupt - return dataid, ax, cbs -def _save_image(datasaver) -> AxesTupleList: +def _handle_plotting( + datasaver: DataSaver, + do_plot: bool = True, + interrupted: bool = False +) -> AxesTupleList: """ Save the plots created by datasaver as pdf and png Args: datasaver: a measurement datasaver that contains a dataset to be saved as plot. + :param do_plot: """ - plt.ioff() dataid = datasaver.run_id + if do_plot == True: + res = _create_plots(datasaver) + else: + res = dataid, None, None + + if interrupted: + raise KeyboardInterrupt + + return res + + +def _create_plots(datasaver: DataSaver) -> AxesTupleList: + dataid = datasaver.run_id + plt.ioff() start = time.time() axes, cbs = plot_by_id(dataid) stop = time.time() - print(f"plot by id took {stop-start}") - + print(f"plot by id took {stop - start}") mainfolder = config.user.mainfolder experiment_name = datasaver._dataset.exp_name sample_name = datasaver._dataset.sample_name - storage_dir = os.path.join(mainfolder, experiment_name, sample_name) os.makedirs(storage_dir, exist_ok=True) - png_dir = os.path.join(storage_dir, 'png') pdf_dif = os.path.join(storage_dir, 'pdf') - os.makedirs(png_dir, exist_ok=True) os.makedirs(pdf_dif, exist_ok=True) - save_pdf = True save_png = True - for i, ax in enumerate(axes): if save_pdf: full_path = os.path.join(pdf_dif, f'{dataid}_{i}.pdf') @@ -306,4 +305,5 @@ def _save_image(datasaver) -> AxesTupleList: full_path = os.path.join(png_dir, f'{dataid}_{i}.png') ax.figure.savefig(full_path, dpi=500) plt.ion() - return axes, cbs + res = dataid, axes, cbs + return res