1
1
import numpy as np
2
- from lasy .backend import xp , use_cupy
2
+
3
+ from lasy .backend import use_cupy , xp
3
4
4
5
time_axis_indx = - 1
5
6
@@ -79,8 +80,8 @@ def set_temporal_field(self, field):
79
80
assert field .shape == self .temporal_field .shape
80
81
assert field .dtype == "complex128"
81
82
if use_cupy and type (field ) == np .ndarray :
82
- field = xp .asarray (field ) # Copy to GPU
83
- self .temporal_field [:,:, :] = field
83
+ field = xp .asarray (field ) # Copy to GPU
84
+ self .temporal_field [:, :, :] = field
84
85
self .temporal_field_valid = True
85
86
self .spectral_field_valid = False # Invalidates the spectral field
86
87
@@ -96,8 +97,8 @@ def set_spectral_field(self, field):
96
97
assert field .shape == self .spectral_field .shape
97
98
assert field .dtype == "complex128"
98
99
if use_cupy and type (field ) == np .ndarray :
99
- field = xp .asarray (field ) # Copy to GPU
100
- self .spectral_field [:,:, :] = field
100
+ field = xp .asarray (field ) # Copy to GPU
101
+ self .spectral_field [:, :, :] = field
101
102
self .spectral_field_valid = True
102
103
self .temporal_field_valid = False # Invalidates the temporal field
103
104
0 commit comments