Skip to content

Commit 161273f

Browse files
committed
fix(autograd): include traced keys in HDF5 hash input
1 parent 7b3c66d commit 161273f

File tree

5 files changed

+121
-46
lines changed

5 files changed

+121
-46
lines changed

tests/test_components/autograd/test_autograd.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,21 @@ def objective(*params):
11741174
ag.grad(objective)(params0)
11751175

11761176

1177+
def test_sim_hash_changes_with_traced_keys():
1178+
"""Ensure the model hash accounts for autograd traced paths."""
1179+
1180+
sim_traced = SIM_FULL.copy()
1181+
original_field_map = sim_traced._strip_traced_fields()
1182+
1183+
structures = list(sim_traced.structures)
1184+
structures[0] = structures[0].to_static()
1185+
sim_modified = sim_traced.updated_copy(structures=tuple(structures))
1186+
1187+
modified_field_map = sim_modified._strip_traced_fields()
1188+
assert original_field_map != modified_field_map
1189+
assert sim_traced._hash_self() != sim_modified._hash_self()
1190+
1191+
11771192
def test_sim_traced_override_structures():
11781193
"""Make sure that sims with traced override structures are handled properly."""
11791194

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Typed containers for autograd traced field metadata."""
2+
3+
from __future__ import annotations
4+
5+
import json
6+
from typing import Any, Callable
7+
8+
import pydantic.v1 as pydantic
9+
10+
from tidy3d.components.autograd.types import AutogradFieldMap, dict_ag
11+
from tidy3d.components.base import Tidy3dBaseModel
12+
from tidy3d.components.types import ArrayLike, tidycomplex
13+
14+
15+
class Tracer(Tidy3dBaseModel):
16+
"""Representation of a single traced element within a model."""
17+
18+
path: tuple[Any, ...] = pydantic.Field(
19+
...,
20+
title="Path to the traced object in the model dictionary.",
21+
)
22+
data: float | tidycomplex | ArrayLike = pydantic.Field(..., title="Tracing data")
23+
24+
25+
class FieldMap(Tidy3dBaseModel):
26+
"""Collection of traced elements."""
27+
28+
tracers: tuple[Tracer, ...] = pydantic.Field(
29+
...,
30+
title="Collection of Tracers.",
31+
)
32+
33+
@property
34+
def to_autograd_field_map(self) -> AutogradFieldMap:
35+
"""Convert to ``AutogradFieldMap`` autograd dictionary."""
36+
return dict_ag({tracer.path: tracer.data for tracer in self.tracers})
37+
38+
@classmethod
39+
def from_autograd_field_map(cls, autograd_field_map: AutogradFieldMap) -> FieldMap:
40+
"""Initialize from an ``AutogradFieldMap`` autograd dictionary."""
41+
tracers = []
42+
for path, data in autograd_field_map.items():
43+
tracers.append(Tracer(path=path, data=data))
44+
return cls(tracers=tuple(tracers))
45+
46+
47+
def _encoded_path(path: tuple[Any, ...]) -> str:
48+
"""Return a stable JSON representation for a traced path."""
49+
return json.dumps(list(path), separators=(",", ":"), ensure_ascii=True)
50+
51+
52+
class TracerKeys(Tidy3dBaseModel):
53+
"""Collection of traced field paths."""
54+
55+
keys: tuple[tuple[Any, ...], ...] = pydantic.Field(
56+
...,
57+
title="Collection of tracer keys.",
58+
)
59+
60+
def encoded_keys(self) -> list[str]:
61+
"""Return the JSON-encoded representation of keys."""
62+
return [_encoded_path(path) for path in self.keys]
63+
64+
@classmethod
65+
def from_field_mapping(
66+
cls,
67+
field_mapping: AutogradFieldMap,
68+
*,
69+
sort_key: Callable[[tuple[Any, ...]], str] | None = None,
70+
) -> TracerKeys:
71+
"""Construct keys from an autograd field mapping."""
72+
if sort_key is None:
73+
sort_key = _encoded_path
74+
75+
sorted_paths = tuple(sorted(field_mapping.keys(), key=sort_key))
76+
return cls(keys=sorted_paths)

tidy3d/components/base.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
# If json string is larger than ``MAX_STRING_LENGTH``, split the string when storing in hdf5
3939
MAX_STRING_LENGTH = 1_000_000_000
4040
FORBID_SPECIAL_CHARACTERS = ["/"]
41+
TRACED_FIELD_KEYS_ATTR = "__tidy3d_traced_field_keys__"
4142

4243

4344
def cache(prop):
@@ -767,6 +768,9 @@ def add_data_to_file(data_dict: dict, group_path: str = "") -> None:
767768
add_data_to_file(data_dict=value, group_path=subpath)
768769

769770
add_data_to_file(data_dict=self.dict())
771+
traced_keys_payload = self._serialized_traced_field_keys()
772+
if traced_keys_payload:
773+
f_handle.attrs[TRACED_FIELD_KEYS_ATTR] = traced_keys_payload
770774

771775
@classmethod
772776
def dict_from_hdf5_gz(
@@ -1054,6 +1058,19 @@ def insert_value(x, path: tuple[str, ...], sub_dict: dict):
10541058

10551059
return self.parse_obj(self_dict)
10561060

1061+
def _serialized_traced_field_keys(self) -> Optional[str]:
1062+
"""Return a serialized, order-independent representation of traced field paths."""
1063+
1064+
field_mapping = self._strip_traced_fields()
1065+
if not field_mapping:
1066+
return None
1067+
1068+
# TODO: remove this deferred import once TracerKeys is decoupled from Tidy3dBaseModel.
1069+
from tidy3d.components.autograd.field_map import TracerKeys
1070+
1071+
tracer_keys = TracerKeys.from_field_mapping(field_mapping)
1072+
return tracer_keys.json(separators=(",", ":"), ensure_ascii=True)
1073+
10571074
def to_static(self) -> Tidy3dBaseModel:
10581075
"""Version of object with all autograd-traced fields removed."""
10591076

tidy3d/web/api/autograd/io_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import tempfile
55

66
import tidy3d as td
7+
from tidy3d.components.autograd.field_map import FieldMap, TracerKeys
78
from tidy3d.web.core.s3utils import download_file, upload_file # type: ignore
89

910
from .constants import SIM_FIELDS_KEYS_FILE, SIM_VJP_FILE
10-
from .utils import FieldMap, TracerKeys
1111

1212

1313
def upload_sim_fields_keys(sim_fields_keys: list[tuple], task_id: str, verbose: bool = False):

tidy3d/web/api/autograd/utils.py

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44
import typing
55

66
import numpy as np
7-
import pydantic as pd
87

98
import tidy3d as td
10-
from tidy3d.components.autograd.types import AutogradFieldMap, dict_ag
11-
from tidy3d.components.base import Tidy3dBaseModel
12-
from tidy3d.components.types import ArrayLike, tidycomplex
9+
from tidy3d.components.autograd.field_map import FieldMap, Tracer, TracerKeys
1310

1411
""" E and D field gradient map calculation helpers. """
1512

@@ -81,44 +78,14 @@ def get_field_key(dim: str, fld_data: typing.Union[td.FieldData, td.Permittivity
8178
return fld_1.updated_copy(**field_components)
8279

8380

84-
class Tracer(Tidy3dBaseModel):
85-
"""Class to store a single traced field."""
86-
87-
path: tuple[typing.Any, ...] = pd.Field(
88-
...,
89-
title="Path to the traced object in the model dictionary.",
90-
)
91-
92-
data: typing.Union[float, tidycomplex, ArrayLike] = pd.Field(..., title="Tracing data")
93-
94-
95-
class FieldMap(Tidy3dBaseModel):
96-
"""Class to store a collection of traced fields."""
97-
98-
tracers: tuple[Tracer, ...] = pd.Field(
99-
...,
100-
title="Collection of Tracers.",
101-
)
102-
103-
@property
104-
def to_autograd_field_map(self) -> AutogradFieldMap:
105-
"""Convert to ``AutogradFieldMap`` autograd dictionary."""
106-
return dict_ag({tracer.path: tracer.data for tracer in self.tracers})
107-
108-
@classmethod
109-
def from_autograd_field_map(cls, autograd_field_map) -> FieldMap:
110-
"""Initialize from an ``AutogradFieldMap`` autograd dictionary."""
111-
tracers = []
112-
for path, data in autograd_field_map.items():
113-
tracers.append(Tracer(path=path, data=data))
114-
115-
return cls(tracers=tuple(tracers))
116-
117-
118-
class TracerKeys(Tidy3dBaseModel):
119-
"""Class to store a collection of tracer keys."""
120-
121-
keys: tuple[tuple[typing.Any, ...], ...] = pd.Field(
122-
...,
123-
title="Collection of tracer keys.",
124-
)
81+
__all__ = [
82+
"E_to_D",
83+
"FieldMap",
84+
"Tracer",
85+
"TracerKeys",
86+
"derivative_map_D",
87+
"derivative_map_E",
88+
"derivative_map_H",
89+
"get_derivative_maps",
90+
"multiply_field_data",
91+
]

0 commit comments

Comments
 (0)