@@ -32,25 +32,25 @@ def tensor(data, requires_grad=False, dtype=float32, device="cpu"):
3232 return Tensor (data , requires_grad = requires_grad , dtype = dtype , device = device )
3333
3434
35- def ones (* shape , dtype = None , requires_grad = True , device = "cpu" ):
35+ def ones (* shape , dtype = None , requires_grad = False , device = "cpu" ):
3636 shape = tuple (* shape ) if all (isinstance (arg , (list , tuple )) for arg in shape ) else shape
3737
3838 return Tensor (np .ones (shape , dtype = dtype ), requires_grad = requires_grad , device = device )
3939
4040
41- def zeros (* shape , dtype = None , requires_grad = True , device = "cpu" ):
41+ def zeros (* shape , dtype = None , requires_grad = False , device = "cpu" ):
4242 shape = tuple (* shape ) if all (isinstance (arg , (list , tuple )) for arg in shape ) else shape
4343
4444 return Tensor (np .zeros (shape , dtype = dtype ), requires_grad = requires_grad , device = device )
4545
4646
47- def rand (* shape , dtype = None , requires_grad = True , device = "cpu" ):
47+ def rand (* shape , dtype = None , requires_grad = False , device = "cpu" ):
4848 shape = tuple (* shape ) if all (isinstance (arg , (list , tuple )) for arg in shape ) else shape
4949
5050 return Tensor (np .random .rand (* shape ).astype (dtype ), requires_grad = requires_grad , device = device )
5151
5252
53- def randn (* shape , dtype = None , requires_grad = True , device = "cpu" ):
53+ def randn (* shape , dtype = None , requires_grad = False , device = "cpu" ):
5454 shape = tuple (* shape ) if all (isinstance (arg , (list , tuple )) for arg in shape ) else shape
5555
5656 return Tensor (
@@ -60,7 +60,7 @@ def randn(*shape, dtype=None, requires_grad=True, device="cpu"):
6060 )
6161
6262
63- def arange (start = 0 , end = None , step = 1 , dtype = None , requires_grad = True , device = "cpu" ):
63+ def arange (start = 0 , end = None , step = 1 , dtype = None , requires_grad = False , device = "cpu" ):
6464 if end is None :
6565 start , end = 0 , start
6666 return Tensor (
@@ -70,11 +70,11 @@ def arange(start=0, end=None, step=1, dtype=None, requires_grad=True, device="cp
7070 )
7171
7272
73- def ones_like (tensor , dtype = None , requires_grad = True , device = "cpu" ):
73+ def ones_like (tensor , dtype = None , requires_grad = False , device = "cpu" ):
7474 return Tensor (np .ones_like (tensor .data , dtype ), requires_grad = requires_grad , device = device )
7575
7676
77- def zeros_like (tensor , dtype = None , requires_grad = True , device = "cpu" ):
77+ def zeros_like (tensor , dtype = None , requires_grad = False , device = "cpu" ):
7878 return Tensor (np .zeros_like (tensor .data , dtype ), requires_grad = requires_grad , device = device )
7979
8080
@@ -195,6 +195,7 @@ def flip(x, axis):
195195 return x .flip (axis = axis )
196196
197197def where (condition , x , y ):
198+ x = tensor (x , device = condition .device ) if not isinstance (x , Tensor ) else x
198199 return x .where (condition , y )
199200
200201def equal (x , y ):
0 commit comments