Skip to content

Commit 147398a

Browse files
committed
add tests for set_xydata
1 parent 4715310 commit 147398a

File tree

2 files changed

+88
-14
lines changed

2 files changed

+88
-14
lines changed

src/matplotgl/mesh.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,9 @@ def _maybe_centers_to_edges(coord: np.ndarray, M: int, N: int, axis: str) -> np.
3838
expected_edges = expected_centers + 1
3939

4040
if len(coord) == expected_edges:
41-
# Already edges
4241
return coord
4342

4443
if len(coord) == expected_centers:
45-
# Assume centers, convert to edges
4644
edges = np.zeros(expected_edges, dtype=coord.dtype)
4745
# Midpoints between centers, with extrapolation at ends
4846
edges[1:-1] = 0.5 * (coord[:-1] + coord[1:])
@@ -58,11 +56,9 @@ def _maybe_centers_to_edges(coord: np.ndarray, M: int, N: int, axis: str) -> np.
5856
elif coord.ndim == 2:
5957
m, n = coord.shape
6058
if (m, n) == (M + 1, N + 1):
61-
# Already edges
6259
return coord
6360

6461
if (m, n) == (M, N):
65-
# Assume centers, convert to edges
6662
# Strategy: First pad the coordinate array by extrapolating one
6763
# layer of "ghost" cells around the perimeter, then compute all
6864
# edges as average of 4 surrounding padded centers
@@ -73,13 +69,13 @@ def _maybe_centers_to_edges(coord: np.ndarray, M: int, N: int, axis: str) -> np.
7369
padded[1:-1, 1:-1] = coord
7470

7571
# Extrapolate edges (sides)
76-
# Left edge: coord[:, 0] - (coord[:, 1] - coord[:, 0])
72+
# Left edge
7773
padded[1:-1, 0] = coord[:, 0] - (coord[:, 1] - coord[:, 0])
78-
# Right edge: coord[:, -1] + (coord[:, -1] - coord[:, -2])
74+
# Right edge
7975
padded[1:-1, -1] = coord[:, -1] + (coord[:, -1] - coord[:, -2])
80-
# Bottom edge: coord[0, :] - (coord[1, :] - coord[0, :])
76+
# Bottom edge
8177
padded[0, 1:-1] = coord[0, :] - (coord[1, :] - coord[0, :])
82-
# Top edge: coord[-1, :] + (coord[-1, :] - coord[-2, :])
78+
# Top edge
8379
padded[-1, 1:-1] = coord[-1, :] + (coord[-1, :] - coord[-2, :])
8480

8581
# Extrapolate corners
@@ -264,19 +260,15 @@ def get_xdata(self) -> np.ndarray:
264260

265261
def set_xdata(self, x: np.ndarray):
266262
M, N = self._c.shape
267-
self._x = np.asarray(x)
268-
# Convert centers to edges if necessary
269-
self._x = _maybe_centers_to_edges(self._x, M, N, axis='x')
263+
self._x = _maybe_centers_to_edges(np.asarray(x), M, N, axis='x')
270264
self._update_positions()
271265

272266
def get_ydata(self) -> np.ndarray:
273267
return self._y
274268

275269
def set_ydata(self, y: np.ndarray):
276270
M, N = self._c.shape
277-
self._y = np.asarray(y)
278-
# Convert centers to edges if necessary
279-
self._y = _maybe_centers_to_edges(self._y, M, N, axis='y')
271+
self._y = _maybe_centers_to_edges(np.asarray(y), M, N, axis='y')
280272
self._update_positions()
281273

282274
def set_array(self, c: np.ndarray):

tests/pcolormesh_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,85 @@ def test_pcolormesh_2d_coords_midpoints():
9999

100100
assert len(ax.collections) == 1
101101
assert m._c.shape == (M, N)
102+
103+
104+
def test_pcolormesh_1d_set_xdata():
105+
_, ax = plt.subplots()
106+
107+
M, N = 80, 100
108+
xx = np.linspace(1.0, 5.0, N + 1)
109+
yy = np.linspace(1.0, 5.0, M + 1)
110+
x, y = np.meshgrid(xx[:-1], yy[:-1])
111+
z = 1.0 - (np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x))
112+
113+
m = ax.pcolormesh(xx, yy, z)
114+
115+
new_xx = np.linspace(2.0, 6.0, N + 1)
116+
m.set_xdata(new_xx)
117+
118+
bbox = m.get_bbox()
119+
assert bbox["left"] == new_xx[0]
120+
assert bbox["right"] == new_xx[-1]
121+
122+
123+
def test_pcolormesh_1d_set_ydata():
124+
_, ax = plt.subplots()
125+
126+
M, N = 80, 100
127+
xx = np.linspace(1.0, 5.0, N + 1)
128+
yy = np.linspace(1.0, 5.0, M + 1)
129+
x, y = np.meshgrid(xx[:-1], yy[:-1])
130+
z = 1.0 - (np.sin(x) ** 10 + np.cos(10 + y * x) * np.cos(x))
131+
132+
m = ax.pcolormesh(xx, yy, z)
133+
134+
new_yy = np.linspace(3.0, 7.0, M + 1)
135+
m.set_ydata(new_yy)
136+
137+
bbox = m.get_bbox()
138+
assert bbox["bottom"] == new_yy[0]
139+
assert bbox["top"] == new_yy[-1]
140+
141+
142+
def test_pcolormesh_2d_set_xdata():
143+
_, ax = plt.subplots()
144+
145+
M, N = 80, 100
146+
xx = np.linspace(1.0, 5.0, N + 1)
147+
yy = np.linspace(1.0, 5.0, M + 1)
148+
x, y = np.meshgrid(xx, yy)
149+
z = 1.0 - (
150+
np.sin(x[:-1, :-1]) ** 10
151+
+ np.cos(10 + y[:-1, :-1] * x[:-1, :-1]) * np.cos(x[:-1, :-1])
152+
)
153+
154+
m = ax.pcolormesh(x, y, z)
155+
156+
new_x = x + 3.0
157+
m.set_xdata(new_x)
158+
159+
bbox = m.get_bbox()
160+
assert bbox["left"] == new_x[0, 0]
161+
assert bbox["right"] == new_x[0, -1]
162+
163+
164+
def test_pcolormesh_2d_set_ydata():
165+
_, ax = plt.subplots()
166+
167+
M, N = 80, 100
168+
xx = np.linspace(1.0, 5.0, N + 1)
169+
yy = np.linspace(1.0, 5.0, M + 1)
170+
x, y = np.meshgrid(xx, yy)
171+
z = 1.0 - (
172+
np.sin(x[:-1, :-1]) ** 10
173+
+ np.cos(10 + y[:-1, :-1] * x[:-1, :-1]) * np.cos(x[:-1, :-1])
174+
)
175+
176+
m = ax.pcolormesh(x, y, z)
177+
178+
new_y = y + 4.0
179+
m.set_ydata(new_y)
180+
181+
bbox = m.get_bbox()
182+
assert bbox["bottom"] == new_y[0, 0]
183+
assert bbox["top"] == new_y[-1, 0]

0 commit comments

Comments
 (0)