Skip to content

Commit 22960d9

Browse files
committed
Make setters and getters GPU aware
1 parent d92a196 commit 22960d9

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

lasy/utils/grid.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from lasy.backend import xp
1+
import numpy as np
2+
from lasy.backend import xp, use_cupy
23

34
time_axis_indx = -1
45

@@ -77,7 +78,9 @@ def set_temporal_field(self, field):
7778
"""
7879
assert field.shape == self.temporal_field.shape
7980
assert field.dtype == "complex128"
80-
self.temporal_field[:, :, :] = field
81+
if use_cupy and type(field) == np.ndarray:
82+
field = xp.asarray(field) # Copy to GPU
83+
self.temporal_field[:,:,:] = field
8184
self.temporal_field_valid = True
8285
self.spectral_field_valid = False # Invalidates the spectral field
8386

@@ -92,11 +95,13 @@ def set_spectral_field(self, field):
9295
"""
9396
assert field.shape == self.spectral_field.shape
9497
assert field.dtype == "complex128"
95-
self.spectral_field[:, :, :] = field
98+
if use_cupy and type(field) == np.ndarray:
99+
field = xp.asarray(field) # Copy to GPU
100+
self.spectral_field[:,:,:] = field
96101
self.spectral_field_valid = True
97102
self.temporal_field_valid = False # Invalidates the temporal field
98103

99-
def get_temporal_field(self):
104+
def get_temporal_field(self, to_cpu=False):
100105
"""
101106
Return a copy of the temporal field.
102107
@@ -108,37 +113,41 @@ def get_temporal_field(self):
108113
field : ndarray of complexs
109114
The temporal field.
110115
"""
111-
# We return a copy, so that the user cannot modify
112-
# the original field, unless get_temporal_field is called
113-
if self.temporal_field_valid:
114-
return self.temporal_field.copy()
115-
elif self.spectral_field_valid:
116+
if not self.temporal_field_valid:
116117
self.spectral2temporal_fft()
117-
return self.temporal_field.copy()
118+
# Return a copy of the field, either on CPU or GPU, so that the user
119+
# cannot modify the original field, unless set_spectral_field is called
120+
if to_cpu and use_cupy:
121+
return xp.asnumpy(self.temporal_field)
118122
else:
119-
raise ValueError("Both temporal and spectral fields are invalid")
123+
return self.temporal_field.copy()
120124

121-
def get_spectral_field(self):
125+
def get_spectral_field(self, to_cpu=False):
122126
"""
123127
Return a copy of the spectral field.
124128
125129
(Modifying the returned object will not modify the original field stored
126130
in the Grid object ; one must use set_spectral_field to do so.)
127131
132+
Parameters
133+
----------
134+
to_cpu : bool
135+
If True, the returned field is always returned as a numpy array on CPU
136+
(even when the lasy backend is cupy)
137+
128138
Returns
129139
-------
130140
field : ndarray of complexs
131141
The spectral field.
132142
"""
133-
# We return a copy, so that the user cannot modify
134-
# the original field, unless set_spectral_field is called
135-
if self.spectral_field_valid:
136-
return self.spectral_field.copy()
137-
elif self.temporal_field_valid:
143+
if not self.spectral_field_valid:
138144
self.temporal2spectral_fft()
139-
return self.spectral_field.copy()
145+
# Return a copy of the field, either on CPU or GPU, so that the user
146+
# cannot modify the original field, unless set_spectral_field is called
147+
if to_cpu and use_cupy:
148+
return xp.asnumpy(self.spectral_field)
140149
else:
141-
raise ValueError("Both temporal and spectral fields are invalid")
150+
return self.spectral_field.copy()
142151

143152
def temporal2spectral_fft(self):
144153
"""

0 commit comments

Comments
 (0)