diff --git a/alodataset/coco_base_dataset.py b/alodataset/coco_base_dataset.py index 890777a4..df9340ce 100644 --- a/alodataset/coco_base_dataset.py +++ b/alodataset/coco_base_dataset.py @@ -16,7 +16,6 @@ from typing import Dict, Union - class CocoBaseDataset(BaseDataset): """ Attributes @@ -73,6 +72,8 @@ def __init__( classes: list = None, fix_classes_len: int = None, return_multiple_labels: list = None, + ignore_classes: list = None, + skip_crowd: bool = True, **kwargs, ): super(CocoBaseDataset, self).__init__(name=name, **kwargs) @@ -80,15 +81,20 @@ def __init__( return else: assert img_folder is not None, "When sample = False, img_folder must be given." - assert ann_file is not None or "test" in img_folder, "When sample = False and the test split is not used, ann_file must be given." - + assert ( + ann_file is not None or "test" in img_folder + ), "When sample = False and the test split is not used, ann_file must be given." # Create properties self.img_folder = os.path.join(self.dataset_dir, img_folder) if "test" in img_folder: - #get a list of indices that don't rely on the annotation file - self.items = [int(Path(os.path.join(self.img_folder, f)).stem) for f in os.listdir(self.img_folder) if os.path.isfile(os.path.join(self.img_folder, f))] + # get a list of indices that don't rely on the annotation file + self.items = [ + int(Path(os.path.join(self.img_folder, f)).stem) + for f in os.listdir(self.img_folder) + if os.path.isfile(os.path.join(self.img_folder, f)) + ] return self.coco = COCO(os.path.join(self.dataset_dir, ann_file)) @@ -103,6 +109,12 @@ def __init__( self._ids_renamed = classes self.label_names = label_names + + if classes is not None and ignore_classes is not None: + raise Exception("Can't considere both classes & ignore_classes at the same time.") + if ignore_classes is not None: + classes = [l for l in self.label_names if l not in set(ignore_classes + ["N/A"])] + if classes is not None: notclass = [label for label in classes if label not in self.label_names] if len(notclass) > 0: # Ignore all labels not in classes @@ -122,6 +134,13 @@ def __init__( ids.append(i) self.items = ids # Remove images without bboxes with classes in classes list + if skip_crowd: + # Check each annotation and remove crowded ones. + for i in range(len(self.items) - 1, -1, -1): + anns = self.coco.loadAnns(self.coco.getAnnIds(self.items[i])) + if any([ann["iscrowd"] for ann in anns]): + del self.items[i] + # Fix lenght of label_names to a desired `fix_classes_len` if fix_classes_len is not None: self._fix_classes(fix_classes_len) @@ -172,7 +191,12 @@ def _fix_classes(self, new_label_size): def _append_labels(self, element: Union[BoundingBoxes2D, Mask], target): def append_new_labels(element, ltensor, lnames, name): - label_2d = Labels(ltensor.to(torch.float32), labels_names=lnames, names=("N"), encoding="id") + label_2d = Labels( + ltensor.to(torch.float32), + labels_names=lnames, + names=("N"), + encoding="id", + ) element.append_labels(label_2d, name=name) labels = target["labels"] @@ -186,7 +210,10 @@ def append_new_labels(element, ltensor, lnames, name): # Append supercategory labels for ktype in self.label_types: append_new_labels( - element, torch.as_tensor(self.label_types[ktype])[labels], self.label_types_names[ktype], ktype + element, + torch.as_tensor(self.label_types[ktype])[labels], + self.label_types_names[ktype], + ktype, ) # Append specific labels @@ -211,7 +238,11 @@ def _target2aloscene(self, target, frame): # Create and append labels to boxes boxes = BoundingBoxes2D( - target["boxes"], boxes_format="xyxy", absolute=True, frame_size=frame.HW, names=("N", None) + target["boxes"], + boxes_format="xyxy", + absolute=True, + frame_size=frame.HW, + names=("N", None), ) self._append_labels(boxes, target) @@ -241,7 +272,7 @@ def getitem(self, idx): image_id = self.items[idx] if "test" in self.img_folder: - #get the filename from image_id without relying on annotation file + # get the filename from image_id without relying on annotation file return Frame(os.path.join(self.img_folder, f"{str(image_id).zfill(12)}.jpg")) frame = Frame(os.path.join(self.img_folder, self.coco.loadImgs(image_id)[0]["file_name"])) @@ -293,7 +324,6 @@ def convert_coco_poly_to_mask(self, segmentations, height, width): return masks def __call__(self, image, target): - w, h = image.shape[-1], image.shape[-2] image_id = target["image_id"] @@ -356,11 +386,11 @@ def __call__(self, image, target): if __name__ == "__main__": coco_dataset = CocoBaseDataset(sample=False, img_folder="test2017") - #checking if regular getitem works + # checking if regular getitem works frame = coco_dataset[0] frame.get_view().render() - #check if dataloader works + # check if dataloader works for f, frames in enumerate(coco_dataset.train_loader(batch_size=2)): frames = Frame.batch_list(frames) frames.get_view().render() diff --git a/alodataset/coco_detection_dataset.py b/alodataset/coco_detection_dataset.py index 714798cd..cecda0c4 100644 --- a/alodataset/coco_detection_dataset.py +++ b/alodataset/coco_detection_dataset.py @@ -42,6 +42,9 @@ class CocoDetectionDataset(CocoBaseDataset, SplitMixin): Return Labels as a dictionary, with all posible categories found in annotations file, by default False ann_file : str Start from a fixe given annotation file where the path is relative to the `dataset_dir` + skip_crowd : bool, optional + Filter out images with `iscrowd` attribute, by default False + Images with crowd are often mislabeled with missing boxes for some persons. **kwargs : dict :mod:`BaseDataset ` optional parameters @@ -80,6 +83,7 @@ def __init__( fix_classes_len: int = None, split=Split.TRAIN, ann_file=None, + skip_crowd: bool = False, **kwargs, ): SplitMixin.__init__(self, split) @@ -144,6 +148,13 @@ def __init__( if fix_classes_len is not None: self._fix_classes(fix_classes_len) + if skip_crowd: + # Check each annotation and remove crowded ones. + for i in range(len(self.items) - 1, -1, -1): + target = self.items[i][2]["segments_info"] + if any([seg["iscrowd"] for seg in target]): + del self.items[i] + # Re-calcule encoding label types (+stuff) if self.label_types is not None: dict_cats = dict() diff --git a/alodataset/coco_panoptic_dataset.py b/alodataset/coco_panoptic_dataset.py index 10782713..6abcedf3 100644 --- a/alodataset/coco_panoptic_dataset.py +++ b/alodataset/coco_panoptic_dataset.py @@ -50,6 +50,9 @@ class CocoPanopticDataset(BaseDataset, SplitMixin): fix_classes_len : int, optional Fix to a specific number the number of classes, filling the rest with "N/A" value. Use when the number of model outputs does not match with the number of classes in the dataset, by default 250 + skip_crowd : bool, optional + Filter out images with `iscrowd` attribute, by default False + Images with crowd are often mislabeled for person boxes and instance segmentation. **kwargs : dict :mod:`BaseDataset ` optional parameters @@ -61,7 +64,10 @@ class CocoPanopticDataset(BaseDataset, SplitMixin): """ SPLIT_FOLDERS = {Split.VAL: "val2017", Split.TRAIN: "train2017"} - SPLIT_ANN_FOLDERS = {Split.VAL: "annotations/panoptic_val2017", Split.TRAIN: "annotations/panoptic_train2017"} + SPLIT_ANN_FOLDERS = { + Split.VAL: "annotations/panoptic_val2017", + Split.TRAIN: "annotations/panoptic_train2017", + } SPLIT_ANN_FILES = { Split.VAL: "annotations/panoptic_val2017.json", Split.TRAIN: "annotations/panoptic_train2017.json", @@ -73,9 +79,10 @@ def __init__( split=Split.TRAIN, return_masks: bool = True, classes: list = None, - ignore_classes: list=None, - fix_classes_len: int = None, # Match with pre-trained weights - **kwargs, + ignore_classes: list = None, + fix_classes_len: int = None, + skip_crowd: bool = False, + **kwargs, # Match with pre-trained weights ): super(CocoPanopticDataset, self).__init__(name=name, split=split, **kwargs) if self.sample: @@ -99,7 +106,7 @@ def __init__( if classes is not None: if self.label_names is None: raise Exception( - "'classes' attribute not support in datasets without 'categories' as attribute in annotation file" + "'classes' attribute not supported in datasets without 'categories' as attribute in annotation file" ) notclass = [label for label in classes if label not in self.label_names] if len(notclass) > 0: # Ignore all labels not in classes @@ -111,7 +118,7 @@ def __init__( self._ids_renamed = np.array(self._ids_renamed) self.label_names = classes - # Check each annotation and keep only that have at least 1 element in classes list + # Check each annotation and keep only that have at least 1 element in classes items = [] for i, (_, _, ann_info) in enumerate(self.items): target = ann_info["segments_info"] @@ -137,6 +144,13 @@ def __init__( self.label_types[ltype] = [x for _, x in sorted(zip(idx_sort, self.label_types[ltype]))] self._ids_renamed = torch.from_numpy(self._ids_renamed) + if skip_crowd: + # Check each annotation and remove crowded ones. + for i in range(len(self.items) - 1, -1, -1): + target = self.items[i][2]["segments_info"] + if any([seg["iscrowd"] for seg in target]): + del self.items[i] + # Fix number of label names if desired if fix_classes_len is not None: if fix_classes_len > len(self.label_names): @@ -269,7 +283,12 @@ def getitem(self, idx): # Make aloscene.frame frame = Frame(img_path) - labels_2d = Labels(labels.to(torch.float32), labels_names=self.label_names, names=("N"), encoding="id") + labels_2d = Labels( + labels.to(torch.float32), + labels_names=self.label_names, + names=("N"), + encoding="id", + ) boxes_2d = BoundingBoxes2D( masks_to_boxes(masks), boxes_format="xyxy", @@ -291,15 +310,15 @@ def getitem(self, idx): if __name__ == "__main__": # coco_seg = CocoPanopticDataset(sample=True) - coco_seg = CocoPanopticDataset() # test - for f, frames in enumerate(coco_seg.train_loader(batch_size=2)): - frames = Frame.batch_list(frames) - labels_set = "category" if isinstance(frames.boxes2d[0].labels, dict) else None - views = [fr.boxes2d.get_view(fr, labels_set=labels_set) for fr in frames] - if frames.segmentation is not None: - views += [fr.segmentation.get_view(fr, labels_set=labels_set) for fr in frames] - frames.get_view(views).render() - # frames.get_view(labels_set=labels_set).render() - - if f > 1: - break + coco_seg = CocoPanopticDataset(classes=["person"], skip_crowd=False) # test + for f, frames in enumerate( + coco_seg.train_loader(batch_size=1, num_workers=0, sampler=torch.utils.data.SequentialSampler(coco_seg)) + ): + if f > 10: + frames = Frame.batch_list(frames) + labels_set = "category" if isinstance(frames.boxes2d[0].labels, dict) else None + views = [fr.boxes2d.get_view(fr, labels_set=labels_set) for fr in frames] + if frames.segmentation is not None: + views += [fr.segmentation.get_view(fr, labels_set=labels_set) for fr in frames] + frames.get_view(views).render() + # frames.get_view(labels_set=labels_set).render() diff --git a/alodataset/transforms.py b/alodataset/transforms.py index 1734f2fc..ccb29368 100644 --- a/alodataset/transforms.py +++ b/alodataset/transforms.py @@ -14,7 +14,12 @@ class AloTransform(object): - def __init__(self, same_on_sequence: bool = True, same_on_frames: bool = False, p: float = 1.0): + def __init__( + self, + same_on_sequence: bool = True, + same_on_frames: bool = False, + p: float = 1.0, + ): """Alo Transform. Each transform in the project should inhert from this class. @@ -55,7 +60,9 @@ def sample_params(self): def set_params(self): raise Exception("Must be implement by a child class") - def __call__(self, frames: Union[Mapping[str, Frame], List[Frame], Frame], **kwargs): + def __call__( + self, frames: Union[Mapping[str, Frame], List[Frame], Frame], **kwargs + ): """Iter on the given frame(s) or return the frame. Based on `same_on_sequence` and `same_on_frames` parameters the method will return and call the `sample_params` method at different time. @@ -76,7 +83,6 @@ def __call__(self, frames: Union[Mapping[str, Frame], List[Frame], Frame], **kwa # Go through each image if isinstance(frames, dict): - n_set = {} if same_on_sequence is None or same_on_frames is None: @@ -86,11 +92,9 @@ def __call__(self, frames: Union[Mapping[str, Frame], List[Frame], Frame], **kwa ) for key in frames: - # Go throguh each element of the sequence # (If needed to apply save the params for each time step if "T" in frames[key].names and same_on_frames and not same_on_sequence: - n_set[key] = [] for t in range(0, frames[key].shape[0]): if t not in seqid2params: @@ -109,8 +113,11 @@ def __call__(self, frames: Union[Mapping[str, Frame], List[Frame], Frame], **kwa n_set[key] = torch.cat(n_set[key], dim=0) # Different for each element of the sequence, but we don't need to save # the params for each image neither - elif "T" in frames[key].names and not same_on_frames and not same_on_sequence: - + elif ( + "T" in frames[key].names + and not same_on_frames + and not same_on_sequence + ): n_set[key] = [] for t in range(0, frames[key].shape[0]): @@ -124,7 +131,9 @@ def __call__(self, frames: Union[Mapping[str, Frame], List[Frame], Frame], **kwa n_set[key] = torch.cat(n_set[key], dim=0) # Same on all frames elif same_on_frames: - frame_params = self.sample_params() if frame_params is None else frame_params + frame_params = ( + self.sample_params() if frame_params is None else frame_params + ) # print('same_on_frames.....', frame_params) self.set_params(*frame_params) n_set[key] = self.apply(frames[key], **kwargs) @@ -140,7 +149,9 @@ def __call__(self, frames: Union[Mapping[str, Frame], List[Frame], Frame], **kwa n_frames = [] last_size = None for t in range(0, frames.shape[0]): - frame_params = self.sample_params() if frame_params is None else frame_params + frame_params = ( + self.sample_params() if frame_params is None else frame_params + ) self.set_params(*self.sample_params()) result = self.apply(frames[t], **kwargs) n_frames.append(result.temporal()) @@ -203,7 +214,14 @@ def __repr__(self): class RandomSelect(AloTransform): - def __init__(self, transforms1: AloTransform, transforms2: AloTransform, p: float = 0.5, *args, **kwargs): + def __init__( + self, + transforms1: AloTransform, + transforms2: AloTransform, + p: float = 0.5, + *args, + **kwargs, + ): """Randomly selects between transforms1 and transforms2, with probability p for transforms1 and (1 - p) for transforms2 @@ -216,7 +234,7 @@ def __init__(self, transforms1: AloTransform, transforms2: AloTransform, p: floa """ self.transforms1 = transforms1 self.transforms2 = transforms2 - self.p = p + self.tp = p super().__init__(*args, **kwargs) def sample_params(self): @@ -225,7 +243,11 @@ def sample_params(self): transformation is apply. """ self._r = random.random() - return (self._r, self.transforms1.sample_params(), self.transforms2.sample_params()) + return ( + self._r, + self.transforms1.sample_params(), + self.transforms2.sample_params(), + ) def set_params(self, _r, param1, param2): """Given predefined params, set the params on the class""" @@ -241,7 +263,7 @@ def apply(self, frame: Frame): frame: Frame Frame to apply the transformation on """ - if self._r < self.p: + if self._r < self.tp: return self.transforms1(frame) return self.transforms2(frame) @@ -284,7 +306,9 @@ def apply(self, frame: Frame): class RandomSizeCrop(AloTransform): - def __init__(self, min_size: Union[int, float], max_size: Union[int, float], *args, **kwargs): + def __init__( + self, min_size: Union[int, float], max_size: Union[int, float], *args, **kwargs + ): """Randomly crop the frame. The region will be sample so that the width & height of the crop will be between `min_size` & `max_size`. @@ -297,7 +321,9 @@ def __init__(self, min_size: Union[int, float], max_size: Union[int, float], *ar Maximun width and height of the crop. I float, will be use as a percentage """ if type(min_size) != type(max_size): - raise Exception("Both `min_size` and `max_size` but be of the same type (float or int)") + raise Exception( + "Both `min_size` and `max_size` but be of the same type (float or int)" + ) self.min_size = min_size self.max_size = max_size super().__init__(*args, **kwargs) @@ -376,11 +402,12 @@ def set_params(self, pad_left, pad_right, pad_top, pad_bottom): self._pad_bottom = pad_bottom def __call__(self, frame): - print((self._pad_top, self._pad_bottom), (self._pad_left, self._pad_right)) return frame.pad( - offset_y=(self._pad_top, self._pad_bottom), offset_x=(self._pad_left, self._pad_right), pad_boxes=True + offset_y=(self._pad_top, self._pad_bottom), + offset_x=(self._pad_left, self._pad_right), + pad_boxes=True, ) @@ -416,7 +443,9 @@ def set_params(self, pad_left, pad_right, pad_top, pad_bottom): def __call__(self, frame): return frame.pad( - offset_y=(self._pad_top, self._pad_bottom), offset_x=(self._pad_left, self._pad_right), pad_boxes=True + offset_y=(self._pad_top, self._pad_bottom), + offset_x=(self._pad_left, self._pad_right), + pad_boxes=True, ) @@ -597,7 +626,14 @@ def apply(self, frame: Frame): class RealisticNoise(AloTransform): - def __init__(self, gaussian_std: float = 0.02, shot_std: float = 0.05, same_on_sequence=False, *args, **kwargs): + def __init__( + self, + gaussian_std: float = 0.02, + shot_std: float = 0.05, + same_on_sequence=False, + *args, + **kwargs, + ): """Add an approximation of a realistic noise to the image. More precisely, we add a gaussian noise and a shot noise to the image. @@ -629,8 +665,12 @@ def set_params(self): def apply(self, frame: Frame): n_frame = frame.norm01() - gaussian_noise = torch.normal(mean=0, std=self.gaussian_std, size=frame.shape, device=frame.device) - shot_noise = torch.normal(mean=0, std=self.shot_std, size=frame.shape, device=frame.device) + gaussian_noise = torch.normal( + mean=0, std=self.gaussian_std, size=frame.shape, device=frame.device + ) + shot_noise = torch.normal( + mean=0, std=self.shot_std, size=frame.shape, device=frame.device + ) noisy_frame = n_frame + n_frame * n_frame * shot_noise + gaussian_noise noisy_frame = torch.clip(noisy_frame, 0, 1) @@ -641,7 +681,14 @@ def apply(self, frame: Frame): class CustomRandomColoring(AloTransform): - def __init__(self, gamma_r=(0.8, 1.2), brightness_r=(0.5, 2.0), colors_r=(0.5, 1.5), *args, **kwargs): + def __init__( + self, + gamma_r=(0.8, 1.2), + brightness_r=(0.5, 2.0), + colors_r=(0.5, 1.5), + *args, + **kwargs, + ): """ Random modification of image colors @@ -669,7 +716,9 @@ def sample_params(self): self.colors = Uniform(colors_min, colors_max).sample(sample_shape=(3,)) def apply(self, frame: Frame): - assert frame.normalization == "01", "frame should be normalized between 0 and 1 before color modification" + assert ( + frame.normalization == "01" + ), "frame should be normalized between 0 and 1 before color modification" frame = frame**self.gamma frame = frame * self.brightness @@ -747,14 +796,21 @@ def __init__( How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. """ torchvision.transforms.ColorJitter.__init__( - self, brightness=brightness, contrast=contrast, saturation=saturation, hue=hue + self, + brightness=brightness, + contrast=contrast, + saturation=saturation, + hue=hue, ) AloTransform.__init__(self, *args, **kwargs) def sample_params(self): """Sample a `size` from the list of possible `sizes`""" return torchvision.transforms.ColorJitter.get_params( - brightness=self.brightness, contrast=self.contrast, saturation=self.saturation, hue=self.hue + brightness=self.brightness, + contrast=self.contrast, + saturation=self.saturation, + hue=self.hue, ) def set_params(self, *params): @@ -862,7 +918,10 @@ class RandomDownScaleCrop(Compose): """ def __init__(self, size, preserve_ratio=False, *args, **kwargs): - transforms = [RandomDownScale(size, preserve_ratio, *args, **kwargs), RandomCrop(size, *args, **kwargs)] + transforms = [ + RandomDownScale(size, preserve_ratio, *args, **kwargs), + RandomCrop(size, *args, **kwargs), + ] super().__init__(transforms, *args, **kwargs) @@ -881,7 +940,11 @@ def sample_params(self): def set_params(self, size): self.crop_size = size - def apply(self, frame: Frame, center: Union[Tuple[int, int], Tuple[float, float]] = (0.5, 0.5)): + def apply( + self, + frame: Frame, + center: Union[Tuple[int, int], Tuple[float, float]] = (0.5, 0.5), + ): """ center: Coordinate of cropped image center. This coordinate is tuple of int or tuple of float. Default: (0.5, 0.5) @@ -910,13 +973,14 @@ def apply(self, frame: Frame, center: Union[Tuple[int, int], Tuple[float, float] class RandomFocusBlur(AloTransform): """Randomly introduces motion blur. - + Parameters ---------- max_filter_size : int Max filter size to use, the higher the more blured the image. - + """ + def __init__(self, max_filter_size=10, *args, **kwargs): assert isinstance(max_filter_size, int) self.max_filter_size = max_filter_size @@ -937,7 +1001,7 @@ def sample_params(self): def set_params(self, h_size, v_size): self.h_filter_size = h_size self.v_filter_size = v_size - + @torch.no_grad() def apply(self, frame): c, h, w = frame.shape @@ -953,7 +1017,9 @@ def apply(self, frame): frame_ = frame.clone().norm255().batch() frame_ = frame_.rename(None) - frame_ = torch.nn.functional.conv2d(frame_.as_tensor(), filter_v, padding="same", groups=3) + frame_ = torch.nn.functional.conv2d( + frame_.as_tensor(), filter_v, padding="same", groups=3 + ) frame_ = torch.nn.functional.conv2d(frame_, filter_h, padding="same", groups=3) frame_ = frame_.reset_names()[0].norm_as(frame) @@ -963,13 +1029,14 @@ def apply(self, frame): class RandomFocusBlurV2(AloTransform): """Randomly introduces motion blur. - + Parameters ---------- max_filter_size : int Max filter size to use, the higher the more blured the image. """ + def __init__(self, max_filter_size=10, *args, **kwargs): assert isinstance(max_filter_size, int) self.max_filter_size = max_filter_size @@ -987,25 +1054,37 @@ def sample_params(self): def set_params(self, h_size, v_size): self.h_filter_size = h_size self.v_filter_size = v_size - + @staticmethod def h_trans(frame, size): v_left_frames = [frame[:, :, i:] for i in range(1, size // 2 + 1)] - v_left_frames = [torch.nn.functional.pad(x, pad=(0, i + 1), value=0) for i, x in enumerate(v_left_frames)] - + v_left_frames = [ + torch.nn.functional.pad(x, pad=(0, i + 1), value=0) + for i, x in enumerate(v_left_frames) + ] + v_right_frames = [frame[:, :, :-i] for i in range(1, size // 2 + 1)] - v_right_frames = [torch.nn.functional.pad(x, pad=(i + 1, 0), value=0) for i, x in enumerate(v_right_frames)] + v_right_frames = [ + torch.nn.functional.pad(x, pad=(i + 1, 0), value=0) + for i, x in enumerate(v_right_frames) + ] v_frames = [*v_left_frames, frame, *v_right_frames] return v_frames - + @staticmethod def v_trans(frame, size): h_top_frames = [frame[:, i:, :] for i in range(1, size // 2 + 1)] - h_top_frames = [torch.nn.functional.pad(x, pad=(0, 0, 0, i + 1), value=0) for i, x in enumerate(h_top_frames)] - + h_top_frames = [ + torch.nn.functional.pad(x, pad=(0, 0, 0, i + 1), value=0) + for i, x in enumerate(h_top_frames) + ] + h_bot_frames = [frame[:, :-i, :] for i in range(1, size // 2 + 1)] - h_bot_frames = [torch.nn.functional.pad(x, pad=(0, 0, i + 1, 0), value=0) for i, x in enumerate(h_bot_frames)] + h_bot_frames = [ + torch.nn.functional.pad(x, pad=(0, 0, i + 1, 0), value=0) + for i, x in enumerate(h_bot_frames) + ] h_frames = [*h_top_frames, frame, *h_bot_frames] return h_frames @@ -1020,10 +1099,10 @@ def apply(self, frame): v_frame = sum(v_frames) / self.v_filter_size h_frame = sum(h_frames) / self.h_filter_size - + blured = (h_frame + v_frame) / 2 blured = Frame(blured) - + blured = blured.norm_as(frame) blured.__dict__ = frame.__dict__.copy() return blured @@ -1034,22 +1113,34 @@ class RandomFocusBlurV3(RandomFocusBlurV2): def h_trans(frame, size): c, h, _ = frame.shape v_left_frames = [frame[:, :, i:] for i in range(1, size // 2 + 1)] - v_left_frames = [torch.cat([f, torch.zeros((c, h, i + 1))], dim=2) for i, f in enumerate(v_left_frames)] - + v_left_frames = [ + torch.cat([f, torch.zeros((c, h, i + 1))], dim=2) + for i, f in enumerate(v_left_frames) + ] + v_right_frames = [frame[:, :, :-i] for i in range(1, size // 2 + 1)] - v_right_frames = [torch.cat([torch.zeros((c, h, i + 1)), f], dim=2) for i, f in enumerate(v_right_frames)] + v_right_frames = [ + torch.cat([torch.zeros((c, h, i + 1)), f], dim=2) + for i, f in enumerate(v_right_frames) + ] v_frames = [*v_left_frames, frame, *v_right_frames] return v_frames - + @staticmethod def v_trans(frame, size): c, _, w = frame.shape h_top_frames = [frame[:, i:, :] for i in range(1, size // 2 + 1)] - h_top_frames = [torch.cat([f, torch.zeros((c, i + 1, w))], dim=1) for i, f in enumerate(h_top_frames)] - + h_top_frames = [ + torch.cat([f, torch.zeros((c, i + 1, w))], dim=1) + for i, f in enumerate(h_top_frames) + ] + h_bot_frames = [frame[:, :-i, :] for i in range(1, size // 2 + 1)] - h_bot_frames = [torch.cat([torch.zeros((c, i + 1, w)), f], dim=1) for i, f in enumerate(h_bot_frames)] + h_bot_frames = [ + torch.cat([torch.zeros((c, i + 1, w)), f], dim=1) + for i, f in enumerate(h_bot_frames) + ] h_frames = [*h_top_frames, frame, *h_bot_frames] return h_frames @@ -1057,7 +1148,7 @@ def v_trans(frame, size): class RandomFlowMotionBlur(AloTransform): """Introduces motion blur from optical flow. - + Idea : Let OpticalFlow : x, y --> x', y' retrive the indexes betwe x, x' and y, y' i.e x -> x1 ... -> x' , y -> y1 ... -> y' @@ -1076,20 +1167,21 @@ class RandomFlowMotionBlur(AloTransform): Motion blur intensity. If this arg is set, the value will not be random anymore. """ + def __init__( - self, - subframes: int = 10, - flow_model=None, - model_kwargs={}, - intensity=None, - **kwargs, - ): + self, + subframes: int = 10, + flow_model=None, + model_kwargs={}, + intensity=None, + **kwargs, + ): if isinstance(intensity, list): assert all([isinstance(x, float) for x in intensity]) assert intensity[0] < intensity[1] assert len(intensity) == 2 - self.intensity = 1. if intensity is None else intensity + self.intensity = 1.0 if intensity is None else intensity self.model_kwargs = model_kwargs self.flow_model = flow_model self.inter_intensity = None @@ -1116,12 +1208,12 @@ def _get_flow_model_kwargs(self, frame1, frame2): frame2 = Frame(frame2).norm_minmax_sym().batch() return {"frame1": frame1, "frame2": frame2, **self.model_kwargs} - + @staticmethod def _adapt_model_output(output): - """Adapts model output to be an optical flow of size [2, H, W] where the first channel + """Adapts model output to be an optical flow of size [2, H, W] where the first channel is the OF over X axis and the second is over Y axis - + Example with alonet/raft/raft ... -> """ @@ -1143,16 +1235,20 @@ def apply(self, frame, flow=None, p_frame=None): flow = flow.as_tensor() else: flow_cls = flow.__class__.__name__ - assert isinstance(flow, torch.Tensor), f"Flow must be an instance of torch.Tensor got {flow_cls} instead" + assert isinstance( + flow, torch.Tensor + ), f"Flow must be an instance of torch.Tensor got {flow_cls} instead" # Resize given the blur intensity HW_ = frame.shape[-2:] HW = flow.shape[-2:] if HW != HW_: - flow = torch.nn.functional.interpolate(flow.unsqueeze(0), size=HW_, mode="nearest") + flow = torch.nn.functional.interpolate( + flow.unsqueeze(0), size=HW_, mode="nearest" + ) flow = flow.squeeze() - + flow = flow * self.inter_intensity # XY Coordinates @@ -1164,18 +1260,20 @@ def apply(self, frame, flow=None, p_frame=None): # Map coridinates of intermediate points X -> X, intemediate X points ..., X + X_displacement (same for Y) subcoords = [ [ - (coords[0] - map_coords[0]) * i / self.subframes + coords[0], # X - (coords[0] - map_coords[0]) * i / self.subframes + coords[1]] # Y - for i in range(self.subframes + 1) - ] + (coords[0] - map_coords[0]) * i / self.subframes + coords[0], # X + (coords[0] - map_coords[0]) * i / self.subframes + coords[1], + ] # Y + for i in range(self.subframes + 1) + ] # Round and clamp indexes (float -> int + Occlusion) subcoords = [ [ torch.round(torch.clamp(s[0], min=0, max=HW_[0] - 1)).long(), - torch.round(torch.clamp(s[1], min=0, max=HW_[1] - 1)).long() + torch.round(torch.clamp(s[1], min=0, max=HW_[1] - 1)).long(), ] - for s in subcoords] + for s in subcoords + ] # Frame to indexed intermediate frames frame_ = [frame_[:, subcoord[0], subcoord[1]] for subcoord in subcoords] @@ -1203,18 +1301,19 @@ class RandomCornersMask(AloTransform): ## p_sides = [top, bottom, right, left] """ + def __init__( - self, - max_mask_size: float = 0.2, - p_sides: List = [0.2, 0.2, 0.2, 0.2], - **kwargs, - ): + self, + max_mask_size: float = 0.2, + p_sides: List = [0.2, 0.2, 0.2, 0.2], + **kwargs, + ): assert len(p_sides) == 4 assert isinstance(p_sides, list) assert isinstance(max_mask_size, float) assert max_mask_size >= 0 and max_mask_size < 1 assert all([isinstance(x, float) for x in p_sides]) - + # Random var param self.max_mask_size = max_mask_size self.p_sides = p_sides diff --git a/aloscene/bounding_boxes_2d.py b/aloscene/bounding_boxes_2d.py index 8ced0d10..df8e0d69 100644 --- a/aloscene/bounding_boxes_2d.py +++ b/aloscene/bounding_boxes_2d.py @@ -93,8 +93,12 @@ def __new__( tensor.add_property("padded_size", None) if absolute and frame_size is None: - raise Exception("If the boxes format are absolute, the `frame_size` must be set") - assert frame_size is None or (isinstance(frame_size, tuple) and len(frame_size) == 2) + raise Exception( + "If the boxes format are absolute, the `frame_size` must be set" + ) + assert frame_size is None or ( + isinstance(frame_size, tuple) and len(frame_size) == 2 + ) tensor.add_property("frame_size", frame_size) return tensor @@ -146,7 +150,11 @@ def xcyc(self): # Convert from xyxy to xcyc labels = tensor.drop_children() xcyc_boxes = torch.cat( - [tensor[:, :2] + ((tensor[:, 2:] - tensor[:, :2]) / 2), (tensor[:, 2:] - tensor[:, :2])], dim=1 + [ + tensor[:, :2] + ((tensor[:, 2:] - tensor[:, :2]) / 2), + (tensor[:, 2:] - tensor[:, :2]), + ], + dim=1, ) xcyc_boxes.boxes_format = "xcyc" xcyc_boxes.set_children(labels) @@ -158,7 +166,8 @@ def xcyc(self): tensor = tensor.rename_(None) xcyc_boxes = torch.cat( [ - tensor[:, :2].flip([1]) + ((tensor[:, 2:].flip([1]) - tensor[:, :2].flip([1])) / 2), + tensor[:, :2].flip([1]) + + ((tensor[:, 2:].flip([1]) - tensor[:, :2].flip([1])) / 2), (tensor[:, 2:].flip([1]) - tensor[:, :2].flip([1])), ], dim=1, @@ -170,7 +179,9 @@ def xcyc(self): tensor.set_children(labels) return xcyc_boxes else: - raise Exception(f"BoundingBoxes2D:Do not know mapping from {tensor.boxes_format} to xcyc") + raise Exception( + f"BoundingBoxes2D:Do not know mapping from {tensor.boxes_format} to xcyc" + ) def xyxy(self): """Get a new BoundingBoxes2D Tensor with boxes following this format: @@ -186,7 +197,10 @@ def xyxy(self): labels = tensor.drop_children() # Convert from xcyc to xyxy n_tensor = torch.cat( - [tensor[:, :2] - (tensor[:, 2:] / 2), tensor[:, :2] + (tensor[:, 2:] / 2)], + [ + tensor[:, :2] - (tensor[:, 2:] / 2), + tensor[:, :2] + (tensor[:, 2:] / 2), + ], dim=1, ) n_tensor.boxes_format = "xyxy" @@ -209,7 +223,9 @@ def xyxy(self): tensor.set_children(labels) return n_tensor else: - raise Exception(f"BoundingBoxes2D:Do not know mapping from {tensor.boxes_format} to xyxy") + raise Exception( + f"BoundingBoxes2D:Do not know mapping from {tensor.boxes_format} to xyxy" + ) def yxyx(self): """Get a new BoundingBoxes2D Tensor with boxes following this format: @@ -255,7 +271,9 @@ def yxyx(self): elif tensor.boxes_format == "yxyx": return tensor else: - raise Exception(f"BoundingBoxes2D:Do not know mapping from {tensor.boxes_format} to yxyx") + raise Exception( + f"BoundingBoxes2D:Do not know mapping from {tensor.boxes_format} to yxyx" + ) def abs_pos(self, frame_size): """Get a new BoundingBoxes2D Tensor with absolute position @@ -274,18 +292,27 @@ def abs_pos(self, frame_size): # Back to relative before to get the absolute pos if tensor.absolute and frame_size != tensor.frame_size: - if tensor.boxes_format == "xcyc" or tensor.boxes_format == "xyxy": tensor = tensor.div( torch.tensor( - [tensor.frame_size[1], tensor.frame_size[0], tensor.frame_size[1], tensor.frame_size[0]], + [ + tensor.frame_size[1], + tensor.frame_size[0], + tensor.frame_size[1], + tensor.frame_size[0], + ], device=tensor.device, ) ) else: tensor = tensor.div( torch.tensor( - [tensor.frame_size[0], tensor.frame_size[1], tensor.frame_size[0], tensor.frame_size[1]], + [ + tensor.frame_size[0], + tensor.frame_size[1], + tensor.frame_size[0], + tensor.frame_size[1], + ], device=tensor.device, ) ) @@ -295,11 +322,17 @@ def abs_pos(self, frame_size): if not tensor.absolute: if tensor.boxes_format == "xcyc" or tensor.boxes_format == "xyxy": tensor = tensor.mul( - torch.tensor([frame_size[1], frame_size[0], frame_size[1], frame_size[0]], device=tensor.device) + torch.tensor( + [frame_size[1], frame_size[0], frame_size[1], frame_size[0]], + device=tensor.device, + ) ) else: tensor = tensor.mul( - torch.tensor([frame_size[0], frame_size[1], frame_size[0], frame_size[1]], device=tensor.device) + torch.tensor( + [frame_size[0], frame_size[1], frame_size[0], frame_size[1]], + device=tensor.device, + ) ) tensor.frame_size = frame_size tensor.absolute = True @@ -324,14 +357,24 @@ def rel_pos(self): if tensor.boxes_format == "xcyc" or tensor.boxes_format == "xyxy": tensor = tensor.div( torch.tensor( - [tensor.frame_size[1], tensor.frame_size[0], tensor.frame_size[1], tensor.frame_size[0]], + [ + tensor.frame_size[1], + tensor.frame_size[0], + tensor.frame_size[1], + tensor.frame_size[0], + ], device=tensor.device, ) ) else: tensor = tensor.div( torch.tensor( - [tensor.frame_size[0], tensor.frame_size[1], tensor.frame_size[0], tensor.frame_size[1]], + [ + tensor.frame_size[0], + tensor.frame_size[1], + tensor.frame_size[0], + tensor.frame_size[1], + ], device=tensor.device, ) ) @@ -376,7 +419,9 @@ def _area(self, boxes): boxes = boxes.as_tensor() return (boxes[:, 2] - boxes[:, 0]).mul(boxes[:, 3] - boxes[:, 1]) else: - raise Exception(f"desired boxes_format {boxes.boxes_format} is not handle to compute the area") + raise Exception( + f"desired boxes_format {boxes.boxes_format} is not handle to compute the area" + ) def abs_area(self, frame_size: Union[tuple, None]) -> torch.Tensor: """Get the absolute area of the current boxes. @@ -396,7 +441,9 @@ def abs_area(self, frame_size: Union[tuple, None]) -> torch.Tensor: return self._area(self.clone()) else: if frame_size is None: - raise Exception("Boxes are encoded as relative, the frame size must be given to compute the area.") + raise Exception( + "Boxes are encoded as relative, the frame size must be given to compute the area." + ) return self._area(self.abs_pos(frame_size)) def rel_area(self) -> torch.Tensor: @@ -423,6 +470,7 @@ def area(self) -> torch.Tensor: else: return self.rel_area() + np.random.seed(165742) _GLOBAL_COLOR_SET = np.random.uniform(0, 1, (300, 3)) def get_view( @@ -447,7 +495,9 @@ def get_view( if frame is not None: if len(frame.shape) > 3: - raise Exception(f"Expect image of shape c,h,w. Found image with shape {frame.shape}") + raise Exception( + f"Expect image of shape c,h,w. Found image with shape {frame.shape}" + ) assert isinstance(frame, Frame) else: size = self.frame_size if self.absolute else (300, 300) @@ -462,7 +512,15 @@ def get_view( # Get an imave with values between 0 and 1 frame_size = frame.HW - frame = frame.norm01().cpu().rename(None).permute([1, 2, 0]).detach().contiguous().numpy() + frame = ( + frame.norm01() + .cpu() + .rename(None) + .permute([1, 2, 0]) + .detach() + .contiguous() + .numpy() + ) # Draw bouding boxes # Try to retrieve the associated label ID (if any) @@ -480,7 +538,11 @@ def get_view( raise Exception( f"Trying to display a boxes labels set ({labels_set}) while boxes do not have multiple set of labels" ) - elif labels_set is not None and isinstance(boxes_abs.labels, dict) and labels_set not in boxes_abs.labels: + elif ( + labels_set is not None + and isinstance(boxes_abs.labels, dict) + and labels_set not in boxes_abs.labels + ): raise Exception( f"Trying to display a boxes labels set ({labels_set}) while boxes do not have this set. Avaiable set (" + f"{[key for key in boxes_abs.labels]}" @@ -497,7 +559,9 @@ def get_view( if label is not None: if isinstance(label, list): label_sum = sum([int(label_value[b]) for label_value in label]) - color = self._GLOBAL_COLOR_SET[int(label_sum) % len(self._GLOBAL_COLOR_SET)] + color = self._GLOBAL_COLOR_SET[ + int(label_sum) % len(self._GLOBAL_COLOR_SET) + ] text = [] for label_elm in label: text.append( @@ -509,9 +573,24 @@ def get_view( ) text = ", ".join(text) else: - color = self._GLOBAL_COLOR_SET[int(label) % len(self._GLOBAL_COLOR_SET)] - text = label.labels_names[int(label)] if label.labels_names else int(label) - cv2.putText(frame, str(text), (int(x2), int(y1)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA) + color = self._GLOBAL_COLOR_SET[ + int(label) % len(self._GLOBAL_COLOR_SET) + ] + text = ( + label.labels_names[int(label)] + if label.labels_names + else int(label) + ) + cv2.putText( + frame, + str(text), + (int(x2), int(y1)), + cv2.FONT_HERSHEY_SIMPLEX, + 0.5, + color, + 1, + cv2.LINE_AA, + ) cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 3) else: color = (0, 1, 0) @@ -604,13 +683,21 @@ def giou_with(self, boxes2) -> torch.Tensor: # degenerate boxes gives inf / nan results # so do an early check try: - assert (boxes1.as_tensor()[:, 2:] >= boxes1.as_tensor()[:, :2]).all(), f"{boxes1.as_tensor()}" - assert (boxes2.as_tensor()[:, 2:] >= boxes2.as_tensor()[:, :2]).all(), f"{boxes2.as_tensor()}" + assert ( + boxes1.as_tensor()[:, 2:] >= boxes1.as_tensor()[:, :2] + ).all(), f"{boxes1.as_tensor()}" + assert ( + boxes2.as_tensor()[:, 2:] >= boxes2.as_tensor()[:, :2] + ).all(), f"{boxes2.as_tensor()}" except: print("boxes1", boxes1) print("boxes2", boxes2) - assert (boxes1.as_tensor()[:, 2:] >= boxes1.as_tensor()[:, :2]).all(), f"{boxes1.as_tensor()}" - assert (boxes2.as_tensor()[:, 2:] >= boxes2.as_tensor()[:, :2]).all(), f"{boxes2.as_tensor()}" + assert ( + boxes1.as_tensor()[:, 2:] >= boxes1.as_tensor()[:, :2] + ).all(), f"{boxes1.as_tensor()}" + assert ( + boxes2.as_tensor()[:, 2:] >= boxes2.as_tensor()[:, :2] + ).all(), f"{boxes2.as_tensor()}" iou, union = boxes1.iou_with(boxes2, ret_union=True) @@ -700,7 +787,9 @@ def _crop(self, H_crop: tuple, W_crop: tuple, **kwargs): cropped_boxes2d BoundingBoxes2D """ if self.padded_size is not None: - raise Exception("Can't crop when padded size is not Note. Call fit_to_padded_size() first") + raise Exception( + "Can't crop when padded size is not Note. Call fit_to_padded_size() first" + ) absolute = self.absolute frame_size = self.frame_size @@ -730,7 +819,10 @@ def _crop(self, H_crop: tuple, W_crop: tuple, **kwargs): # Put back the instance into the same state as before if absolute: - n_frame_size = ((H_crop[1] - H_crop[0]) * frame_size[0], (W_crop[1] - W_crop[0]) * frame_size[1]) + n_frame_size = ( + (H_crop[1] - H_crop[0]) * frame_size[0], + (W_crop[1] - W_crop[0]) * frame_size[1], + ) cropped_boxes = cropped_boxes.abs_pos(n_frame_size) else: cropped_boxes.frame_size = None @@ -749,7 +841,9 @@ def fit_to_padded_size(self): >>> padded_boxes = boxes.fit_to_padded_size() """ if self.padded_size is None: - raise Exception("Trying to fit to padded size without any previous stored padded_size.") + raise Exception( + "Trying to fit to padded size without any previous stored padded_size." + ) offset_y = (self.padded_size[0][0], self.padded_size[0][1]) offset_x = (self.padded_size[1][0], self.padded_size[1][1]) @@ -758,15 +852,22 @@ def fit_to_padded_size(self): boxes = self.abs_pos((100, 100)).xcyc() h_shift = boxes.frame_size[0] * offset_y[0] w_shift = boxes.frame_size[1] * offset_x[0] - boxes = boxes + torch.as_tensor([[w_shift, h_shift, 0, 0]], device=boxes.device) - boxes.frame_size = (100 * (1.0 + offset_y[0] + offset_y[1]), 100 * (1.0 + offset_x[0] + offset_x[1])) + boxes = boxes + torch.as_tensor( + [[w_shift, h_shift, 0, 0]], device=boxes.device + ) + boxes.frame_size = ( + 100 * (1.0 + offset_y[0] + offset_y[1]), + 100 * (1.0 + offset_x[0] + offset_x[1]), + ) boxes = boxes.get_with_format(self.boxes_format) boxes = boxes.rel_pos() else: boxes = self.xcyc() h_shift = boxes.frame_size[0] * offset_y[0] w_shift = boxes.frame_size[1] * offset_x[0] - boxes = boxes + torch.as_tensor([[w_shift, h_shift, 0, 0]], device=boxes.device) + boxes = boxes + torch.as_tensor( + [[w_shift, h_shift, 0, 0]], device=boxes.device + ) boxes.frame_size = ( boxes.frame_size[0] * (1.0 + offset_y[0] + offset_y[1]), boxes.frame_size[1] * (1.0 + offset_x[0] + offset_x[1]), @@ -800,7 +901,6 @@ def _pad(self, offset_y: tuple, offset_x: tuple, pad_boxes: bool = True, **kwarg n_boxes = self.clone() if n_boxes.padded_size is not None: - if n_boxes.absolute: pr_frame_size = self.frame_size else: @@ -809,22 +909,48 @@ def _pad(self, offset_y: tuple, offset_x: tuple, pad_boxes: bool = True, **kwarg padded_size = n_boxes.padded_size prev_padded_size = ( - ((padded_size[0][0] * pr_frame_size[0]), (padded_size[0][1] * pr_frame_size[0])), - ((padded_size[1][0] * pr_frame_size[1]), (padded_size[1][1] * pr_frame_size[1])), + ( + (padded_size[0][0] * pr_frame_size[0]), + (padded_size[0][1] * pr_frame_size[0]), + ), + ( + (padded_size[1][0] * pr_frame_size[1]), + (padded_size[1][1] * pr_frame_size[1]), + ), ) n_padded_size = ( ( prev_padded_size[0][0] - + offset_y[0] * (prev_padded_size[0][0] + prev_padded_size[0][1] + pr_frame_size[0]), + + offset_y[0] + * ( + prev_padded_size[0][0] + + prev_padded_size[0][1] + + pr_frame_size[0] + ), prev_padded_size[0][1] - + offset_y[1] * (prev_padded_size[0][0] + prev_padded_size[0][1] + pr_frame_size[0]), + + offset_y[1] + * ( + prev_padded_size[0][0] + + prev_padded_size[0][1] + + pr_frame_size[0] + ), ), ( prev_padded_size[1][0] - + offset_x[0] * (prev_padded_size[1][0] + prev_padded_size[1][1] + pr_frame_size[1]), + + offset_x[0] + * ( + prev_padded_size[1][0] + + prev_padded_size[1][1] + + pr_frame_size[1] + ), prev_padded_size[1][1] - + offset_x[1] * (prev_padded_size[1][0] + prev_padded_size[1][1] + pr_frame_size[1]), + + offset_x[1] + * ( + prev_padded_size[1][0] + + prev_padded_size[1][1] + + pr_frame_size[1] + ), ), ) @@ -858,15 +984,22 @@ def _pad(self, offset_y: tuple, offset_x: tuple, pad_boxes: bool = True, **kwarg boxes = self.abs_pos((100, 100)).xcyc() h_shift = boxes.frame_size[0] * offset_y[0] w_shift = boxes.frame_size[1] * offset_x[0] - boxes = boxes + torch.as_tensor([[w_shift, h_shift, 0, 0]], device=boxes.device) - boxes.frame_size = (100 * (1.0 + offset_y[0] + offset_y[1]), 100 * (1.0 + offset_x[0] + offset_x[1])) + boxes = boxes + torch.as_tensor( + [[w_shift, h_shift, 0, 0]], device=boxes.device + ) + boxes.frame_size = ( + 100 * (1.0 + offset_y[0] + offset_y[1]), + 100 * (1.0 + offset_x[0] + offset_x[1]), + ) boxes = boxes.get_with_format(self.boxes_format) boxes = boxes.rel_pos() else: boxes = self.xcyc() h_shift = boxes.frame_size[0] * offset_y[0] w_shift = boxes.frame_size[1] * offset_x[0] - boxes = boxes + torch.as_tensor([[w_shift, h_shift, 0, 0]], device=boxes.device) + boxes = boxes + torch.as_tensor( + [[w_shift, h_shift, 0, 0]], device=boxes.device + ) boxes.frame_size = ( boxes.frame_size[0] * (1.0 + offset_y[0] + offset_y[1]), boxes.frame_size[1] * (1.0 + offset_x[0] + offset_x[1]), @@ -898,14 +1031,13 @@ def _spatial_shift(self, shift_y: float, shift_x: float, **kwargs): original_absolute = self.absolute frame_size = self.frame_size - n_boxes = self.clone().rel_pos().xcyc() - - n_boxes += torch.as_tensor([[shift_x, shift_y, 0, 0]]) # , device=self.device) + n_boxes = self.clone().rel_pos().xyxy() - max_size = torch.as_tensor([1, 1, 1, 1], dtype=torch.float32) + n_boxes += torch.as_tensor( + [[shift_x, shift_y, shift_x, shift_y]], device=self.device + ) - n_boxes = torch.min(n_boxes.rename(None), max_size) - n_boxes = n_boxes.clamp(min=0) + n_boxes = n_boxes.clamp(0, 1) n_boxes = n_boxes.reset_names() # Filter to keep only boxes with area > 0 area = n_boxes.area() diff --git a/aloscene/frame.py b/aloscene/frame.py index 082815e6..e40a96fe 100644 --- a/aloscene/frame.py +++ b/aloscene/frame.py @@ -510,7 +510,7 @@ def mean_std_norm(self, mean, std, name) -> Frame: tensor = self mean_tensor, std_tensor = self._get_mean_std_tensor( - tensor.shape, tensor.names, (mean,std), device=tensor.device + tensor.shape, tensor.names, (mean, std), device=tensor.device ) if tensor.normalization == "01": tensor = tensor - mean_tensor @@ -568,7 +568,7 @@ def _pad(self, offset_y: tuple, offset_x: tuple, **kwargs): ------- padded tensor """ - pad_values = {"01": 0, "255": 0, "minmax_sym": -1} + pad_values = {"01": 0.5, "255": 127, "minmax_sym": 0} if self.normalization in pad_values: pad_value = pad_values[self.normalization] return super()._pad(offset_y, offset_x, fill=pad_value, **kwargs) @@ -630,32 +630,36 @@ def _spatial_shift(self, shift_y: float, shift_x: float, **kwargs): frame_data = n_frame.as_tensor() - permute_idx = list(range(0, len(self.shape))) - last_current_idx = permute_idx[-1] - permute_idx[-1] = permute_idx[self.names.index("C")] - permute_idx[self.names.index("C")] = last_current_idx - - n_frame_mean = frame_data.permute(permute_idx) - n_frame_mean = n_frame_mean.flatten(end_dim=-2) - n_frame_mean = torch.mean(n_frame_mean, dim=0) - n_shape = [1] * len(self.shape) - n_shape[self.names.index("C")] = 3 - n_frame_mean = n_frame_mean.view(tuple(n_shape)) + pad_values = {"01": 0.5, "255": 127, "minmax_sym": 0} + if self.normalization in pad_values: + pad_value = pad_values[self.normalization] + else: + permute_idx = list(range(0, len(self.shape))) + last_current_idx = permute_idx[-1] + permute_idx[-1] = permute_idx[self.names.index("C")] + permute_idx[self.names.index("C")] = last_current_idx + + pad_value = frame_data.permute(permute_idx) + pad_value = pad_value.flatten(end_dim=-2) + pad_value = torch.mean(pad_value, dim=0) + n_shape = [1] * len(self.shape) + n_shape[self.names.index("C")] = 3 + pad_value = pad_value.view(tuple(n_shape)) frame_data = torch.roll(frame_data, x_shift, dims=self.names.index("W")) # Fillup the shifted area with the mean if x_shift >= 1: - frame_data[self.get_slices({"W": slice(0, x_shift)})] = n_frame_mean + frame_data[self.get_slices({"W": slice(0, x_shift)})] = pad_value elif x_shift <= -1: - frame_data[self.get_slices({"W": slice(x_shift, -1)})] = n_frame_mean + frame_data[self.get_slices({"W": slice(x_shift, None)})] = pad_value # error frame_data = torch.roll(frame_data, y_shift, dims=self.names.index("H")) if y_shift >= 1: - frame_data[self.get_slices({"H": slice(0, y_shift)})] = n_frame_mean + frame_data[self.get_slices({"H": slice(0, y_shift)})] = pad_value elif y_shift <= -1: - frame_data[self.get_slices({"H": slice(y_shift, -1)})] = n_frame_mean + frame_data[self.get_slices({"H": slice(y_shift, None)})] = pad_value n_frame.data = frame_data