Skip to content

Commit f5c2bab

Browse files
committed
Drop __getattr__ from the ThreadSafeDataFrame. This was implemented to provide convenient access to internal dataframe container so that ThreadSafeDataFrame becomes comparable with pandas DataFrame (indeed pandas implements lots of convenient method to manipulate data). However, this is dangerous in multithread environment, because lock is released after acquired item (this may be some method to manipulate data) is returned. The holder can still mutate the internal container asynchronously in this situation. To avoid this problem, __getattr__ is dropped and all necessary methods are implemented as method with reentrant lock.
1 parent ac29818 commit f5c2bab

File tree

5 files changed

+134
-57
lines changed

5 files changed

+134
-57
lines changed

qiskit_experiments/database_service/utils.py

Lines changed: 44 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -376,17 +376,58 @@ def container(
376376
return container[self._default_columns()]
377377
return container
378378

379+
def drop_entry(
380+
self,
381+
index: str,
382+
):
383+
"""Drop entry from the dataframe.
384+
385+
Args:
386+
index: Name of entry to drop.
387+
388+
Raises:
389+
ValueError: When index is not in this table.
390+
"""
391+
with self._lock:
392+
if index not in self._container.index:
393+
raise ValueError(f"Table index {index} doesn't exist in this table.")
394+
self._container.drop(index, inplace=True)
395+
396+
def get_entry(
397+
self,
398+
index: str,
399+
) -> pd.Series:
400+
"""Get entry from the dataframe.
401+
402+
Args:
403+
index: Name of entry to acquire.
404+
405+
Returns:
406+
Pandas Series of acquired entry. This doesn't mutate the table.
407+
408+
Raises:
409+
ValueError: When index is not in this table.
410+
"""
411+
with self._lock:
412+
if index not in self._container.index:
413+
raise ValueError(f"Table index {index} doesn't exist in this table.")
414+
415+
return self._container.loc[index]
416+
379417
def add_entry(
380418
self,
381419
index: str,
382420
**kwargs,
383-
):
421+
) -> pd.Series:
384422
"""Add new entry to the dataframe.
385423
386424
Args:
387425
index: Name of this entry. Must be unique in this table.
388426
kwargs: Description of new entry to register.
389427
428+
Returns:
429+
Pandas Series of added entry. This doesn't mutate the table.
430+
390431
Raises:
391432
ValueError: When index is not unique in this table.
392433
"""
@@ -406,22 +447,14 @@ def add_entry(
406447
index = str(index)
407448
self._container.loc[index] = list(template.values())
408449

450+
return self._container.iloc[-1]
451+
409452
def _repr_html_(self) -> Union[str, None]:
410453
"""Return HTML representation of this dataframe."""
411454
with self._lock:
412455
# Remove underscored columns.
413456
return self._container._repr_html_()
414457

415-
def __getattr__(self, item):
416-
lock = object.__getattribute__(self, "_lock")
417-
418-
with lock:
419-
# Lock when access to container's member.
420-
container = object.__getattribute__(self, "_container")
421-
if hasattr(container, item):
422-
return getattr(container, item)
423-
raise AttributeError(f"'ThreadSafeDataFrame' object has no attribute '{item}'")
424-
425458
def __json_encode__(self) -> Dict[str, Any]:
426459
with self._lock:
427460
return {

qiskit_experiments/framework/analysis_result_table.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import warnings
1818
from typing import List, Union, Optional
1919

20+
import pandas as pd
21+
2022
from qiskit_experiments.database_service.utils import ThreadSafeDataFrame
2123

2224
LOG = logging.getLogger(__name__)
@@ -52,6 +54,11 @@ def _default_columns(cls) -> List[str]:
5254
"created_time",
5355
]
5456

57+
def result_ids(self) -> List[str]:
58+
"""Return all result IDs in this table."""
59+
with self._lock:
60+
return self._container["result_id"].to_list()
61+
5562
def filter_columns(self, columns: Union[str, List[str]]) -> List[str]:
5663
"""Filter columns names available in this table.
5764
@@ -68,35 +75,37 @@ def filter_columns(self, columns: Union[str, List[str]]) -> List[str]:
6875
Raises:
6976
ValueError: When column is given in string which doesn't match with any builtin group.
7077
"""
71-
if columns == "all":
72-
return self._columns
73-
if columns == "default":
74-
return [
75-
"name",
76-
"experiment",
77-
"components",
78-
"value",
79-
"quality",
80-
"backend",
81-
"run_time",
82-
] + self._extra
83-
if columns == "minimal":
84-
return [
85-
"name",
86-
"components",
87-
"value",
88-
"quality",
89-
] + self._extra
90-
if not isinstance(columns, str):
91-
out = []
92-
for column in columns:
93-
if column in self._columns:
94-
out.append(column)
95-
else:
96-
warnings.warn(
97-
f"Specified column name {column} does not exist in this table.", UserWarning
98-
)
99-
return out
78+
with self._lock:
79+
if columns == "all":
80+
return self._columns
81+
if columns == "default":
82+
return [
83+
"name",
84+
"experiment",
85+
"components",
86+
"value",
87+
"quality",
88+
"backend",
89+
"run_time",
90+
] + self._extra
91+
if columns == "minimal":
92+
return [
93+
"name",
94+
"components",
95+
"value",
96+
"quality",
97+
] + self._extra
98+
if not isinstance(columns, str):
99+
out = []
100+
for column in columns:
101+
if column in self._columns:
102+
out.append(column)
103+
else:
104+
warnings.warn(
105+
f"Specified column name {column} does not exist in this table.",
106+
UserWarning,
107+
)
108+
return out
100109
raise ValueError(
101110
f"Column group {columns} is not valid name. Use either 'all', 'default', 'minimal'."
102111
)
@@ -106,16 +115,19 @@ def add_entry(
106115
self,
107116
result_id: Optional[str] = None,
108117
**kwargs,
109-
):
118+
) -> pd.Series:
110119
"""Add new entry to the table.
111120
112121
Args:
113122
result_id: Result ID. Automatically generated when not provided.
114123
This must be valid hexadecimal UUID string.
115124
kwargs: Description of new entry to register.
125+
126+
Returns:
127+
Pandas Series of added entry. This doesn't mutate the table.
116128
"""
117129
if result_id:
118-
with self.lock:
130+
with self._lock:
119131
if result_id[:8] in self._container.index:
120132
raise ValueError(
121133
f"The short ID of the result_id '{result_id[:8]}' already exists in the "
@@ -129,15 +141,15 @@ def add_entry(
129141
# This mechanism is similar with the github commit hash.
130142
short_index = result_id[:8]
131143

132-
super().add_entry(
144+
return super().add_entry(
133145
index=short_index,
134146
result_id=result_id,
135147
**kwargs,
136148
)
137149

138150
def _unique_table_index(self):
139151
"""Generate unique UUID which is unique in the table with first 8 characters."""
140-
with self.lock:
152+
with self._lock:
141153
n = 0
142154
while n < 1000:
143155
tmp_id = uuid.uuid4().hex

qiskit_experiments/framework/experiment_data.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ def _set_hgp_from_provider(self, provider):
686686
def _clear_results(self):
687687
"""Delete all currently stored analysis results and figures"""
688688
# Schedule existing analysis results for deletion next save call
689-
self._deleted_analysis_results.extend(list(self._analysis_results["result_id"]))
689+
self._deleted_analysis_results.extend(list(self._analysis_results.result_ids()))
690690
self._analysis_results.clear()
691691
# Schedule existing figures for deletion next save call
692692
for key in self._figures.keys():
@@ -1375,7 +1375,7 @@ def add_analysis_results(
13751375
tags = tags or []
13761376
backend = backend or self.backend_name
13771377

1378-
self._analysis_results.add_entry(
1378+
series = self._analysis_results.add_entry(
13791379
result_id=result_id,
13801380
name=name,
13811381
value=value,
@@ -1391,7 +1391,7 @@ def add_analysis_results(
13911391
)
13921392
if self.auto_save:
13931393
service_result = _series_to_service_result(
1394-
series=self._analysis_results.iloc[-1],
1394+
series=series,
13951395
service=self._service,
13961396
auto_save=False,
13971397
)
@@ -1426,7 +1426,7 @@ def delete_analysis_result(
14261426
"Try another key that can uniquely determine entry to delete."
14271427
)
14281428

1429-
self._analysis_results.drop(to_delete.name, inplace=True)
1429+
self._analysis_results.drop_entry(str(to_delete.name))
14301430
if self._service and self.auto_save:
14311431
with service_exception_to_warning():
14321432
self.service.delete_analysis_result(result_id=to_delete.result_id)
@@ -1697,8 +1697,6 @@ def save(
16971697
json_encoder=self._json_encoder,
16981698
max_workers=max_workers,
16991699
)
1700-
for result in self._analysis_results.values():
1701-
result._created_in_db = True
17021700
except Exception as ex: # pylint: disable=broad-except
17031701
# Don't automatically fail the experiment just because its data cannot be saved.
17041702
LOG.error("Unable to save the experiment data: %s", traceback.format_exc())

test/extended_equality.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,10 @@ def _check_dataframes(
281281
**kwargs,
282282
):
283283
"""Check equality of data frame which may involve Qiskit Experiments class value."""
284+
if isinstance(data1, ThreadSafeDataFrame):
285+
data1 = data1.container(collapse_extra=False)
286+
if isinstance(data2, ThreadSafeDataFrame):
287+
data2 = data2.container(collapse_extra=False)
284288
return is_equivalent(
285289
data1.to_dict(orient="index"),
286290
data2.to_dict(orient="index"),

test/framework/test_data_table.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from test.base import QiskitExperimentsTestCase
1616

17+
import uuid
1718
import numpy as np
1819
import pandas as pd
1920

@@ -58,28 +59,33 @@ def test_raises_initializing_with_wrong_table(self):
5859
# columns doesn't match with default_columns
5960
TestBaseTable.TestTable(wrong_table)
6061

62+
def test_get_entry(self):
63+
"""Test getting an entry from the table."""
64+
table = TestBaseTable.TestTable({"x": [1.0, 2.0, 3.0]})
65+
self.assertListEqual(table.get_entry("x").to_list(), [1.0, 2.0, 3.0])
66+
6167
def test_add_entry(self):
6268
"""Test adding data with default keys to table."""
6369
table = TestBaseTable.TestTable()
6470
table.add_entry(index="x", value1=0.0, value2=1.0, value3=2.0)
6571

66-
self.assertListEqual(table.loc["x"].to_list(), [0.0, 1.0, 2.0])
72+
self.assertListEqual(table.get_entry("x").to_list(), [0.0, 1.0, 2.0])
6773

6874
def test_add_entry_with_missing_key(self):
6975
"""Test adding entry with partly specified keys."""
7076
table = TestBaseTable.TestTable()
7177
table.add_entry(index="x", value1=0.0, value3=2.0)
7278

7379
# NaN value cannot be compared with assert
74-
np.testing.assert_equal(table.loc["x"].to_list(), [0.0, float("nan"), 2.0])
80+
np.testing.assert_equal(table.get_entry("x").to_list(), [0.0, float("nan"), 2.0])
7581

7682
def test_add_entry_with_new_key(self):
7783
"""Test adding data with new keys to table."""
7884
table = TestBaseTable.TestTable()
7985
table.add_entry(index="x", value1=0.0, value2=1.0, value3=2.0, extra=3.0)
8086

8187
self.assertListEqual(table.get_columns(), ["value1", "value2", "value3", "extra"])
82-
self.assertListEqual(table.loc["x"].to_list(), [0.0, 1.0, 2.0, 3.0])
88+
self.assertListEqual(table.get_entry("x").to_list(), [0.0, 1.0, 2.0, 3.0])
8389

8490
def test_add_entry_with_new_key_with_existing_entry(self):
8591
"""Test adding new key will expand existing entry."""
@@ -88,10 +94,24 @@ def test_add_entry_with_new_key_with_existing_entry(self):
8894
table.add_entry(index="y", value1=0.0, value2=1.0, value3=2.0, extra=3.0)
8995

9096
self.assertListEqual(table.get_columns(), ["value1", "value2", "value3", "extra"])
91-
self.assertListEqual(table.loc["y"].to_list(), [0.0, 1.0, 2.0, 3.0])
97+
self.assertListEqual(table.get_entry("y").to_list(), [0.0, 1.0, 2.0, 3.0])
9298

9399
# NaN value cannot be compared with assert
94-
np.testing.assert_equal(table.loc["x"].to_list(), [0.0, 1.0, 2.0, float("nan")])
100+
np.testing.assert_equal(table.get_entry("x").to_list(), [0.0, 1.0, 2.0, float("nan")])
101+
102+
def test_drop_entry(self):
103+
"""Test drop entry from the table."""
104+
table = TestBaseTable.TestTable()
105+
table.add_entry(index="x", value1=0.0, value2=1.0, value3=2.0)
106+
table.drop_entry("x")
107+
108+
self.assertEqual(len(table), 0)
109+
110+
def test_drop_non_existing_entry(self):
111+
"""Test dropping non-existing entry raises ValueError."""
112+
table = TestBaseTable.TestTable()
113+
with self.assertRaises(ValueError):
114+
table.drop_entry("x")
95115

96116
def test_return_only_default_columns(self):
97117
"""Test extra entry is correctly recognized."""
@@ -131,7 +151,7 @@ def test_container_is_immutable(self):
131151
self.assertListEqual(dataframe.loc["x"].to_list(), [100, 0.2, 0.3])
132152

133153
# Original object in the experiment payload is preserved
134-
self.assertListEqual(table.loc["x"].to_list(), [0.1, 0.2, 0.3])
154+
self.assertListEqual(table.get_entry("x").to_list(), [0.1, 0.2, 0.3])
135155

136156
def test_round_trip(self):
137157
"""Test JSON roundtrip serialization with the experiment encoder."""
@@ -149,7 +169,7 @@ def test_add_entry_with_result_id(self):
149169
"""Test adding entry with result_id. Index is created by truncating long string."""
150170
table = AnalysisResultTable()
151171
table.add_entry(result_id="9a0bdec8c0104ef7bb7db84939717a6b", value=0.123)
152-
self.assertEqual(table.loc["9a0bdec8"].value, 0.123)
172+
self.assertEqual(table.get_entry("9a0bdec8").value, 0.123)
153173

154174
def test_extra_column_name_is_always_returned(self):
155175
"""Test extra column names are always returned in filtered column names."""
@@ -165,6 +185,16 @@ def test_extra_column_name_is_always_returned(self):
165185
all_columns = table.filter_columns("all")
166186
self.assertTrue("extra" in all_columns)
167187

188+
def test_listing_result_id(self):
189+
"""Test returning result IDs of all stored entries."""
190+
table = AnalysisResultTable()
191+
192+
ref_ids = [uuid.uuid4().hex for _ in range(10)]
193+
for ref_id in ref_ids:
194+
table.add_entry(result_id=ref_id, value=0)
195+
196+
self.assertListEqual(table.result_ids(), ref_ids)
197+
168198
def test_no_overlap_result_id(self):
169199
"""Test automatically prepare unique result IDs for sufficient number of entries."""
170200
table = AnalysisResultTable()

0 commit comments

Comments
 (0)