Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions tsai/data/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,26 @@ def __init__(self, items, tfms=None, splits=None, split_idx=None, types=None, do
def subset(self, i, **kwargs): return type(self)(self.items, splits=self.splits[i], split_idx=i, do_setup=False, types=self.types,
**kwargs)
def __getitem__(self, it):
if hasattr(self.items, 'oindex'): return self.items.oindex[self._splits[it]]
else: return self.items[self._splits[it]]
# if hasattr(self.items, 'oindex'): return self.items.oindex[self._splits[it]]
# else: return self.items[self._splits[it]]
# changed the old code for our issue.
# 1) compute the raw index or slice from self._splits
idx = self._splits[it]

# 2) if it's a NumPy scalar (e.g. dtype=np.int8), turn it into a plain Python int
if isinstance(idx, np.generic):
idx = idx.item()
# 3) if it's a NumPy array (e.g. dtype=int8 or int32), turn it into a Python list
elif isinstance(idx, np.ndarray):
idx = idx.tolist()

# 4) finally, index self.items (or self.items.oindex) with that pure‐Python index
if hasattr(self.items, 'oindex'):
return self.items.oindex[idx]
else:
return self.items[idx]


def __len__(self): return len(self._splits)
def __repr__(self):
if hasattr(self.items, "shape"):
Expand Down Expand Up @@ -486,12 +504,20 @@ def __init__(self, X=None, y=None, items=None, sel_vars=None, sel_steps=None, tf
self.tls = L(lt(item, t, **kwargs) for lt,item,t in zip(lts, items, self.tfms))
# if len(self.tls) > 0 and len(self.tls[0]) > 0:
# self.typs = [type(tl.items[0]) if isinstance(tl.items[0], torch.Tensor) else self.typs[i] for i,tl in enumerate(self.tls)]
# if self.inplace and (tfms is None or tfms == [None] * len(self.tls)):
# for tl,typ in zip(self.tls, self.typs):
# tl.items = typ(tl.items)
#replacing the above with the following safer version because torch.as_tensor can't automatically infer the dtype from a numpy.int64 array without a valid PyTorch tensor type
import torch

if self.inplace and (tfms is None or tfms == [None] * len(self.tls)):
for tl,typ in zip(self.tls, self.typs):
tl.items = typ(tl.items)
for tl in self.tls:
if isinstance(tl.items, np.ndarray):
tl.items = torch.as_tensor(tl.items, dtype=torch.float32)

self.ptls = self.tls
self.no_tfm = True
else:
else:
self.ptls = L([typ(stack(tl[:]))[...,self.sel_vars, self.sel_steps] if (i==0 and self.multi_index) else typ(stack(tl[:])) \
for i,(tl,typ) in enumerate(zip(self.tls,self.typs))]) if inplace else self.tls
self.no_tfm = False
Expand Down Expand Up @@ -653,7 +679,9 @@ def create_batch(self, b):
if hasattr(self, "split_idxs"):
self.input_idxs = self.split_idxs[b]
else: self.input_idxs = self.idxs
return self.dataset[b]
# return self.dataset[b]
return self.dataset[[int(i) for i in b]]


def create_item(self, s):
if self.indexed: return self.dataset[s or 0]
Expand Down