Skip to content

Commit 9a98f4c

Browse files
stefsmeetsjanosh
andauthored
Fix issues with labels (#3169)
* Add label support to (Periodic)Neighbor * Default label to `species_string` * allow empty string as label * test None, '' and tuple Site labels * optimade url use {response_fields=!s} --------- Co-authored-by: Janosh Riebesell <[email protected]>
1 parent 80aa781 commit 9a98f4c

File tree

6 files changed

+70
-50
lines changed

6 files changed

+70
-50
lines changed

pymatgen/core/sites.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
self._species: Composition = species # type: ignore
7272
self.coords: np.ndarray = coords # type: ignore
7373
self.properties: dict = properties or {}
74-
self.label = label
74+
self.label = label if label else self.species_string
7575

7676
def __getattr__(self, attr):
7777
# overriding getattr doesn't play nicely with pickle, so we can't use self._properties
@@ -213,7 +213,7 @@ def __contains__(self, el):
213213
def __repr__(self):
214214
name = self.species_string
215215

216-
if self.label:
216+
if self.label != name:
217217
name = f"{self.label} ({name})"
218218

219219
return f"Site: {name} ({self.coords[0]:.4f}, {self.coords[1]:.4f}, {self.coords[2]:.4f})"
@@ -256,6 +256,9 @@ def as_dict(self) -> dict:
256256
}
257257
if self.properties:
258258
dct["properties"] = self.properties
259+
260+
dct["label"] = self.label
261+
259262
return dct
260263

261264
@classmethod
@@ -274,7 +277,8 @@ def from_dict(cls, dct: dict) -> Site:
274277
if props is not None:
275278
for key in props:
276279
props[key] = json.loads(json.dumps(props[key], cls=MontyEncoder), cls=MontyDecoder)
277-
return cls(atoms_n_occu, dct["xyz"], properties=props)
280+
label = dct.get("label")
281+
return cls(atoms_n_occu, dct["xyz"], properties=props, label=label)
278282

279283

280284
class PeriodicSite(Site, MSONable):
@@ -341,7 +345,7 @@ def __init__(
341345
self._species: Composition = species # type: ignore
342346
self._coords: np.ndarray | None = None
343347
self.properties: dict = properties or {}
344-
self.label = label
348+
self.label = label if label else self.species_string
345349

346350
def __hash__(self) -> int:
347351
"""
@@ -551,7 +555,7 @@ def distance(self, other: PeriodicSite, jimage: ArrayLike | None = None):
551555
def __repr__(self):
552556
name = self.species_string
553557

554-
if self.label:
558+
if self.label != name:
555559
name = f"{self.label} ({name})"
556560

557561
x, y, z = self.coords
@@ -585,7 +589,6 @@ def as_dict(self, verbosity: int = 0) -> dict:
585589

586590
if verbosity > 0:
587591
dct["xyz"] = [float(c) for c in self.coords]
588-
dct["label"] = self.species_string
589592

590593
dct["properties"] = self.properties
591594
dct["label"] = self.label

pymatgen/core/structure.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,19 +69,22 @@ def __init__(
6969
properties: dict | None = None,
7070
nn_distance: float = 0.0,
7171
index: int = 0,
72+
label: str | None = None,
7273
):
7374
"""
7475
:param species: Same as Site
7576
:param coords: Same as Site, but must be fractional.
7677
:param properties: Same as Site
7778
:param nn_distance: Distance to some other Site.
7879
:param index: Index within structure.
80+
:param label: Label for the site. Defaults to None.
7981
"""
8082
self.coords = coords
8183
self._species = species
8284
self.properties = properties or {}
8385
self.nn_distance = nn_distance
8486
self.index = index
87+
self.label = label if label is not None else self.species_string
8588

8689
def __len__(self) -> Literal[3]:
8790
"""Make neighbor Tuple-like to retain backwards compatibility."""
@@ -134,6 +137,7 @@ def __init__(
134137
nn_distance: float = 0.0,
135138
index: int = 0,
136139
image: tuple = (0, 0, 0),
140+
label: str | None = None,
137141
):
138142
"""
139143
Args:
@@ -144,6 +148,7 @@ def __init__(
144148
nn_distance (float, optional): Distance to some other Site.. Defaults to 0.0.
145149
index (int, optional): Index within structure.. Defaults to 0.
146150
image (tuple, optional): PeriodicImage. Defaults to (0, 0, 0).
151+
label (str, optional): Label for the site. Defaults to None.
147152
"""
148153
self._lattice = lattice
149154
self._frac_coords = coords
@@ -152,6 +157,7 @@ def __init__(
152157
self.nn_distance = nn_distance
153158
self.index = index
154159
self.image = image
160+
self.label = label if label is not None else self.species_string
155161

156162
@property # type: ignore
157163
def coords(self) -> np.ndarray: # type: ignore
@@ -310,7 +316,7 @@ def site_properties(self) -> dict[str, Sequence]:
310316
return props
311317

312318
@property
313-
def labels(self) -> list[str | None]:
319+
def labels(self) -> list[str]:
314320
"""Return site labels as a list."""
315321
return [site.label for site in self]
316322

@@ -2874,7 +2880,7 @@ def __init__(
28742880
spin_multiplicity: int | None = None,
28752881
validate_proximity: bool = False,
28762882
site_properties: dict | None = None,
2877-
labels: list[str | None] | None = None,
2883+
labels: Sequence[str | None] | None = None,
28782884
charge_spin_check: bool = True,
28792885
) -> None:
28802886
"""
@@ -4292,7 +4298,7 @@ def __init__(
42924298
spin_multiplicity: int | None = None,
42934299
validate_proximity: bool = False,
42944300
site_properties: dict | None = None,
4295-
labels: list[str | None] | None = None,
4301+
labels: Sequence[str | None] | None = None,
42964302
charge_spin_check: bool = True,
42974303
) -> None:
42984304
"""

pymatgen/core/tests/test_sites.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,33 +36,34 @@ def test_properties(self):
3636
assert self.propertied_site.properties["charge"] == 4.2
3737

3838
def test_to_from_dict(self):
39-
d = self.disordered_site.as_dict()
40-
site = Site.from_dict(d)
39+
dct = self.disordered_site.as_dict()
40+
site = Site.from_dict(dct)
4141
assert site == self.disordered_site
4242
assert site != self.ordered_site
43-
d = self.propertied_site.as_dict()
44-
site = Site.from_dict(d)
43+
dct = self.propertied_site.as_dict()
44+
site = Site.from_dict(dct)
4545
assert site.properties["magmom"] == 5.1
4646
assert site.properties["charge"] == 4.2
47-
d = self.propertied_magmom_vec_site.as_dict()
48-
site = Site.from_dict(d)
47+
dct = self.propertied_magmom_vec_site.as_dict()
48+
site = Site.from_dict(dct)
4949
assert site.properties["magmom"] == Magmom([2.6, 2.6, 3.5])
5050
assert site.properties["charge"] == 4.2
51-
d = self.dummy_site.as_dict()
52-
site = Site.from_dict(d)
51+
dct = self.dummy_site.as_dict()
52+
site = Site.from_dict(dct)
5353
assert site.species == self.dummy_site.species
5454

5555
def test_hash(self):
5656
assert hash(self.ordered_site) == 26
5757
assert hash(self.disordered_site) == 51
5858

59-
def test_cmp(self):
59+
def test_gt_lt(self):
6060
assert self.ordered_site > self.disordered_site
61+
assert self.disordered_site < self.ordered_site
6162

6263
def test_distance(self):
63-
osite = self.ordered_site
64-
assert np.linalg.norm([0.25, 0.35, 0.45]) == osite.distance_from_point([0, 0, 0])
65-
assert osite.distance(self.disordered_site) == 0
64+
ord_site = self.ordered_site
65+
assert np.linalg.norm([0.25, 0.35, 0.45]) == ord_site.distance_from_point([0, 0, 0])
66+
assert ord_site.distance(self.disordered_site) == 0
6667

6768
def test_pickle(self):
6869
o = pickle.dumps(self.propertied_site)
@@ -105,7 +106,7 @@ def test_properties(self):
105106
assert self.site.y == 3.5
106107
assert self.site.z == 4.5
107108
assert self.site.is_ordered
108-
assert self.site.label is None
109+
assert self.site.label == "Fe"
109110
assert not self.site2.is_ordered
110111
assert self.propertied_site.properties["magmom"] == 5.1
111112
assert self.propertied_site.properties["charge"] == 4.2
@@ -170,20 +171,23 @@ def test_equality_with_label(self):
170171
assert self.labeled_site == site
171172

172173
def test_as_from_dict(self):
173-
d = self.site2.as_dict()
174-
site = PeriodicSite.from_dict(d)
174+
dct = self.site2.as_dict()
175+
site = PeriodicSite.from_dict(dct)
175176
assert site == self.site2
176177
assert site != self.site
177-
assert site.label == self.site.label
178-
d = self.propertied_site.as_dict()
178+
assert site.label == self.site2.label
179+
180+
dct = self.propertied_site.as_dict()
179181
site3 = PeriodicSite({"Si": 0.5, "Fe": 0.5}, [0, 0, 0], self.lattice)
180-
d = site3.as_dict()
181-
site = PeriodicSite.from_dict(d)
182+
dct = site3.as_dict()
183+
site = PeriodicSite.from_dict(dct)
182184
assert site.species == site3.species
185+
assert site.label == site3.label
183186

184-
d = self.dummy_site.as_dict()
185-
site = PeriodicSite.from_dict(d)
187+
dct = self.dummy_site.as_dict()
188+
site = PeriodicSite.from_dict(dct)
186189
assert site.species == self.dummy_site.species
190+
assert site.label == self.dummy_site.label
187191

188192
def test_to_unit_cell(self):
189193
site = PeriodicSite("Fe", np.array([1.25, 2.35, 4.46]), self.lattice)

pymatgen/core/tests/test_structure.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,15 @@
1919
from pymatgen.core.lattice import Lattice
2020
from pymatgen.core.operations import SymmOp
2121
from pymatgen.core.periodic_table import Element, Species
22-
from pymatgen.core.structure import IMolecule, IStructure, Molecule, PeriodicNeighbor, Structure, StructureError
22+
from pymatgen.core.structure import (
23+
IMolecule,
24+
IStructure,
25+
Molecule,
26+
Neighbor,
27+
PeriodicNeighbor,
28+
Structure,
29+
StructureError,
30+
)
2331
from pymatgen.electronic_structure.core import Magmom
2432
from pymatgen.io.ase import AseAtomsAdaptor
2533
from pymatgen.util.testing import PymatgenTest
@@ -45,16 +53,21 @@ def test_msonable(self):
4553
nn = json.loads(str_, cls=MontyDecoder)
4654
assert isinstance(nn[0], PeriodicNeighbor)
4755

56+
def test_neighbor_labels(self):
57+
comp = Composition("C")
58+
for label in (None, "", "str label", ("tuple", "label")):
59+
neighbor = Neighbor(comp, (0, 0, 0), label=label)
60+
assert neighbor.label == label if label is not None else str(comp)
61+
62+
p_neighbor = PeriodicNeighbor(comp, (0, 0, 0), (10, 10, 10), label=label)
63+
assert p_neighbor.label == label if label is not None else str(comp)
64+
4865

4966
class IStructureTest(PymatgenTest):
5067
def setUp(self):
5168
coords = [[0, 0, 0], [0.75, 0.5, 0.75]]
5269
self.lattice = Lattice(
53-
[
54-
[3.8401979337, 0.00, 0.00],
55-
[1.9200989668, 3.3257101909, 0.00],
56-
[0.00, -2.2171384943, 3.1355090603],
57-
]
70+
[[3.8401979337, 0, 0], [1.9200989668, 3.3257101909, 0], [0, -2.2171384943, 3.1355090603]]
5871
)
5972
self.struct = IStructure(self.lattice, ["Si"] * 2, coords)
6073
assert len(self.struct) == 2, "Wrong number of sites in structure!"
@@ -69,11 +82,7 @@ def setUp(self):
6982
self.labeled_structure = IStructure(self.lattice, ["Si"] * 2, coords, labels=["Si1", "Si2"])
7083

7184
self.lattice_pbc = Lattice(
72-
[
73-
[3.8401979337, 0.00, 0.00],
74-
[1.9200989668, 3.3257101909, 0.00],
75-
[0.00, -2.2171384943, 3.1355090603],
76-
],
85+
[[3.8401979337, 0, 0], [1.9200989668, 3.3257101909, 0], [0, -2.2171384943, 3.1355090603]],
7786
pbc=(True, True, False),
7887
)
7988

@@ -171,7 +180,7 @@ def test_fractional_occupations(self):
171180

172181
def test_labeled_structure(self):
173182
assert self.labeled_structure.labels == ["Si1", "Si2"]
174-
assert self.struct.labels == [None, None]
183+
assert self.struct.labels == ["Si", "Si"]
175184

176185
def test_get_distance(self):
177186
assert self.struct.get_distance(0, 1) == approx(2.35, abs=1e-2), "Distance calculated wrongly!"
@@ -318,7 +327,7 @@ def test_interpolate(self):
318327
assert interpolated_structs[0].lattice == inter_struct.lattice
319328
assert_array_equal(interpolated_structs[1][1].frac_coords, [0.625, 0.5, 0.625])
320329

321-
bad_lattice = [[1, 0.00, 0.00], [0, 1, 0.00], [0.00, 0, 1]]
330+
bad_lattice = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
322331
struct2 = IStructure(bad_lattice, ["Si"] * 2, coords2)
323332
with pytest.raises(ValueError, match="Structures with different lattices"):
324333
struct.interpolate(struct2)
@@ -764,9 +773,7 @@ def setUp(self):
764773
coords = []
765774
coords.append([0, 0, 0])
766775
coords.append([0.75, 0.5, 0.75])
767-
lattice = Lattice(
768-
[[3.8401979337, 0.00, 0.00], [1.9200989668, 3.3257101909, 0.00], [0.00, -2.2171384943, 3.1355090603]]
769-
)
776+
lattice = Lattice([[3.8401979337, 0, 0], [1.9200989668, 3.3257101909, 0], [0, -2.2171384943, 3.1355090603]])
770777
self.structure = Structure(lattice, ["Si", "Si"], coords)
771778
self.cu_structure = Structure(lattice, ["Cu", "Cu"], coords)
772779
self.disordered = Structure.from_spacegroup("Im-3m", Lattice.cubic(3), [Composition("Fe0.5Mn0.5")], [[0, 0, 0]])
@@ -1157,7 +1164,7 @@ def test_from_magnetic_spacegroup(self):
11571164
"P4_2'/mnm'",
11581165
Lattice.tetragonal(4.87, 3.30),
11591166
["Mn", "F"],
1160-
[[0, 0, 0], [0.30, 0.30, 0.00]],
1167+
[[0, 0, 0], [0.30, 0.30, 0]],
11611168
{"magmom": [4, 0]},
11621169
)
11631170

@@ -1173,7 +1180,7 @@ def test_from_magnetic_spacegroup(self):
11731180
["La", "Mn", "O", "O"],
11741181
[
11751182
[0.05, 0.25, 0.99],
1176-
[0.00, 0.00, 0.50],
1183+
[0, 0, 0.50],
11771184
[0.48, 0.25, 0.08],
11781185
[0.31, 0.04, 0.72],
11791186
],

pymatgen/ext/optimade.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def get_snls_with_filter(
310310
response_fields = self._handle_response_fields(additional_response_fields)
311311

312312
for identifier, resource in self.resources.items():
313-
url = join(resource, f"v1/structures?filter={optimade_filter}&response_fields={response_fields}")
313+
url = join(resource, f"v1/structures?filter={optimade_filter}&{response_fields=!s}")
314314

315315
try:
316316
json = self._get_json(url)

test_files/.pytest-split-durations

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,7 @@
944944
"pymatgen/core/tests/test_sites.py::PeriodicSiteTest::test_repr": 0.0002163760073017329,
945945
"pymatgen/core/tests/test_sites.py::PeriodicSiteTest::test_setters": 0.0005762910150224343,
946946
"pymatgen/core/tests/test_sites.py::PeriodicSiteTest::test_to_unit_cell": 0.0004528739955276251,
947-
"pymatgen/core/tests/test_sites.py::SiteTest::test_cmp": 0.00041866599349305034,
947+
"pymatgen/core/tests/test_sites.py::SiteTest::test_gt_lt": 0.00041866599349305034,
948948
"pymatgen/core/tests/test_sites.py::SiteTest::test_distance": 0.0005102920113131404,
949949
"pymatgen/core/tests/test_sites.py::SiteTest::test_hash": 0.00025291601195931435,
950950
"pymatgen/core/tests/test_sites.py::SiteTest::test_pickle": 0.00037470700044650584,

0 commit comments

Comments
 (0)