Skip to content

Commit f3a71b9

Browse files
committed
Get field on CPU explicitly for show/write_to_file
1 parent 22960d9 commit f3a71b9

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

lasy/laser.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import numpy as np
12
from .backend import xp
3+
24
from axiprop.lib import PropagatorFFT2, PropagatorResampling
35
from scipy.constants import c
46

@@ -342,11 +344,12 @@ def show(self, **kw):
342344
----------
343345
**kw: additional arguments to be passed to matplotlib's imshow command
344346
"""
345-
temporal_field = self.grid.get_temporal_field()
347+
# Get field on CPU
348+
temporal_field = self.grid.get_temporal_field(to_cpu=True)
346349
if self.dim == "rt":
347350
# Show field in the plane y=0, above and below axis, with proper sign for each mode
348351
E = [
349-
xp.concatenate(
352+
np.concatenate(
350353
((-1.0) ** m * temporal_field[m, ::-1], temporal_field[m])
351354
)
352355
for m in self.grid.azimuthal_modes

lasy/utils/openpmd_output.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def write_to_openpmd_file(
5353
Whether the envelope is converted to normalized vector potential
5454
before writing to file.
5555
"""
56-
array = grid.get_temporal_field()
56+
# Get field on CPU
57+
array = grid.get_temporal_field(to_cpu=True)
5758

5859
# Create file
5960
series = io.Series("{}_%05T.{}".format(file_prefix, file_format), io.Access.create)
@@ -76,7 +77,7 @@ def write_to_openpmd_file(
7677
m.axis_labels = ["t", "r"]
7778

7879
# Store metadata needed to reconstruct the field
79-
m.set_attribute("angularFrequency", 2 * xp.pi * c / wavelength)
80+
m.set_attribute("angularFrequency", 2 * np.pi * c / wavelength)
8081
m.set_attribute("polarization", pol)
8182
if save_as_vector_potential:
8283
m.set_attribute("envelopeField", "normalized_vector_potential")
@@ -91,20 +92,20 @@ def write_to_openpmd_file(
9192
}
9293

9394
if save_as_vector_potential:
94-
array = field_to_vector_potential(grid, 2 * xp.pi * c / wavelength)
95+
array = field_to_vector_potential(grid, 2 * np.pi * c / wavelength)
9596

9697
# Pick the correct field
9798
if dim == "xyt":
9899
# Switch from x,y,t (internal to lasy) to t,y,x (in openPMD file)
99100
# This is because many PIC codes expect x to be the fastest index
100-
data = xp.transpose(array).copy()
101+
data = np.transpose(array).copy()
101102
elif dim == "rt":
102103
# The representation of modes in openPMD
103104
# (see https://github.com/openPMD/openPMD-standard/blob/latest/STANDARD.md#required-attributes-for-each-mesh-record)
104105
# is different than the representation of modes internal to lasy.
105106
# Thus, there is a non-trivial conversion here
106107
ncomp = 2 * grid.n_azimuthal_modes - 1
107-
data = xp.zeros((ncomp, grid.npoints[0], grid.npoints[1]), dtype=array.dtype)
108+
data = np.zeros((ncomp, grid.npoints[0], grid.npoints[1]), dtype=array.dtype)
108109
data[0, :, :] = array[0, :, :]
109110
for mode in range(1, grid.n_azimuthal_modes):
110111
# cos(m*theta) part of the mode
@@ -113,12 +114,12 @@ def write_to_openpmd_file(
113114
data[2 * mode, :, :] = -1.0j * array[mode, :, :] + 1.0j * array[-mode, :, :]
114115
# Switch from m,r,t (internal to lasy) to m,t,r (in openPMD file)
115116
# This is because many PIC codes expect r to be the fastest index
116-
data = xp.transpose(data, axes=[0, 2, 1]).copy()
117+
data = np.transpose(data, axes=[0, 2, 1]).copy()
117118

118119
# Define the dataset
119120
dataset = io.Dataset(data.dtype, data.shape)
120121
env = m[io.Mesh_Record_Component.SCALAR]
121-
env.position = xp.zeros(len(dim), dtype=xp.float64)
122+
env.position = np.zeros(len(dim), dtype=np.float64)
122123
env.reset_dataset(dataset)
123124
env.store_chunk(data)
124125

0 commit comments

Comments
 (0)