1
- from lasy .backend import xp
1
+ import numpy as np
2
+ from lasy .backend import xp , use_cupy
2
3
3
4
time_axis_indx = - 1
4
5
@@ -77,7 +78,9 @@ def set_temporal_field(self, field):
77
78
"""
78
79
assert field .shape == self .temporal_field .shape
79
80
assert field .dtype == "complex128"
80
- self .temporal_field [:, :, :] = field
81
+ if use_cupy and type (field ) == np .ndarray :
82
+ field = xp .asarray (field ) # Copy to GPU
83
+ self .temporal_field [:,:,:] = field
81
84
self .temporal_field_valid = True
82
85
self .spectral_field_valid = False # Invalidates the spectral field
83
86
@@ -92,11 +95,13 @@ def set_spectral_field(self, field):
92
95
"""
93
96
assert field .shape == self .spectral_field .shape
94
97
assert field .dtype == "complex128"
95
- self .spectral_field [:, :, :] = field
98
+ if use_cupy and type (field ) == np .ndarray :
99
+ field = xp .asarray (field ) # Copy to GPU
100
+ self .spectral_field [:,:,:] = field
96
101
self .spectral_field_valid = True
97
102
self .temporal_field_valid = False # Invalidates the temporal field
98
103
99
- def get_temporal_field (self ):
104
+ def get_temporal_field (self , to_cpu = False ):
100
105
"""
101
106
Return a copy of the temporal field.
102
107
@@ -108,37 +113,41 @@ def get_temporal_field(self):
108
113
field : ndarray of complexs
109
114
The temporal field.
110
115
"""
111
- # We return a copy, so that the user cannot modify
112
- # the original field, unless get_temporal_field is called
113
- if self .temporal_field_valid :
114
- return self .temporal_field .copy ()
115
- elif self .spectral_field_valid :
116
+ if not self .temporal_field_valid :
116
117
self .spectral2temporal_fft ()
117
- return self .temporal_field .copy ()
118
+ # Return a copy of the field, either on CPU or GPU, so that the user
119
+ # cannot modify the original field, unless set_spectral_field is called
120
+ if to_cpu and use_cupy :
121
+ return xp .asnumpy (self .temporal_field )
118
122
else :
119
- raise ValueError ( "Both temporal and spectral fields are invalid" )
123
+ return self . temporal_field . copy ( )
120
124
121
- def get_spectral_field (self ):
125
+ def get_spectral_field (self , to_cpu = False ):
122
126
"""
123
127
Return a copy of the spectral field.
124
128
125
129
(Modifying the returned object will not modify the original field stored
126
130
in the Grid object ; one must use set_spectral_field to do so.)
127
131
132
+ Parameters
133
+ ----------
134
+ to_cpu : bool
135
+ If True, the returned field is always returned as a numpy array on CPU
136
+ (even when the lasy backend is cupy)
137
+
128
138
Returns
129
139
-------
130
140
field : ndarray of complexs
131
141
The spectral field.
132
142
"""
133
- # We return a copy, so that the user cannot modify
134
- # the original field, unless set_spectral_field is called
135
- if self .spectral_field_valid :
136
- return self .spectral_field .copy ()
137
- elif self .temporal_field_valid :
143
+ if not self .spectral_field_valid :
138
144
self .temporal2spectral_fft ()
139
- return self .spectral_field .copy ()
145
+ # Return a copy of the field, either on CPU or GPU, so that the user
146
+ # cannot modify the original field, unless set_spectral_field is called
147
+ if to_cpu and use_cupy :
148
+ return xp .asnumpy (self .spectral_field )
140
149
else :
141
- raise ValueError ( "Both temporal and spectral fields are invalid" )
150
+ return self . spectral_field . copy ( )
142
151
143
152
def temporal2spectral_fft (self ):
144
153
"""
0 commit comments