2.1. Perturbations

This functionality provides a way to perturb the data to generate neightborhoods of data.

ReVel.perturbations.Perturbation(...)

Initialize the abstract Perturbation class.

ReVel.perturbations.SquarePerturbation(dim, ...)

Perturbation that perturbs the input image by a square segmentation.

ReVel.perturbations.QuickshiftPerturbation([...])

Perturbation that perturbs the input image by quickshift unsupervised segmentation.

class ReVel.perturbations.Perturbation(segmentation_fn, fn_neutral_image, num_classes: int, get_input_transform, final_size, **kwargs)
View Source Code
class Perturbation: ''' Initialize the abstract Perturbation class. This class is used to perturb the input image This class is used as a base class for the perturbations. It assumes that the needed functions are implemented in the child class. Parameters --------- segmentation_fn: function that returns the segmentation of an image fn_neutral_image: function that returns the neutral image used to replace the original image when the perturbation is applied num_classes: number of classes. It determines the target space final_size: desired size of the final image ''' def __init__(self, segmentation_fn, fn_neutral_image, num_classes:int, get_input_transform, final_size,**kwargs): self.segmentation_fn = segmentation_fn self.fn_neutral_image = fn_neutral_image self.after_transform = get_input_transform(final_size) self.num_classes = num_classes self.final_size = final_size def perturbation(self,img,neutral,segments,indexes): ''' This function is used to perturb the input image. For this purpose, the function recives the original image, the neutral image and the segmentation of the image. We change the segmentation indexes to avoid the original image and replace them with the neutral image. Parameters ---------- img: original image. Dims (H,W,3) neutral: neutral image. Dim (H,W,3) segments: segmentation of the image. Dims (H,W) indexes: indexes of the image that we want to replace with the neutral image. Integer or list of integers Returns ------- Perturbed image. Dim (H,W,C) ''' #check if original image and the neutral image has 3 on the last dimension if np.array(img).shape[-1] != 3 or np.array(neutral).shape[-1] != 3: #if not, raise an error raise ValueError("The original image and the neutral image must have 3 on the last dimension") if isinstance(indexes,np.ndarray) or isinstance(indexes,list) or isinstance(indexes,tuple): conditions = np.array([segments == indx for indx in indexes],dtype=object) condition = np.any(conditions,axis=0) else: condition = (segments == indexes ) if len(np.array(img).shape) == 3: condition = np.expand_dims(condition,-1) condition = np.repeat(condition,3,axis=-1) return np.where(condition,neutral,img ) def transform(self,img): ''' This function is used to transform the input image to the input space. Parameters ---------- img: original image. Dims (H,W,C) or (H,W) Returns ------- Image preprocessed. Dims (H,W,C) ''' if len(np.array(img).shape) == 2: img = np.expand_dims(img,-1) img = np.repeat(img,3,axis=-1) segments = self.segmentation_fn(img) neutral = self.fn_neutral_image(np.array(img)) perturbation = self.perturbation(img,neutral,segments,-1) return self.after_transform(perturbation) def target_transform(self,target:int): ''' This function is used to transform the target to the target space. Usually, it depends on the dataset if we want a one-hot vector or a categorical vector. In this case, we transform the target, coded as an int, to a one-hot vector of size self.num_classes. Parameters ---------- target: target of the image. Integer Returns ------- Target in the target space. Dims (self.num_classes) ''' t = torch.zeros(self.num_classes) t[target] = 1 target = t return target def __call__(self, img,target): ''' This function is used to transform the input image and the target to the input space. Parameters ---------- img: original image to transform. Dims (H,W,C) target: target of the image. Integer Returns ------- Image and target preprocessed. Dims ((H,W,C), (self.num_classes)) ''' t = torch.zeros(self.num_classes) t[target] = 1 return img,t

Initialize the abstract Perturbation class. This class is used to perturb the input image

This class is used as a base class for the perturbations. It assumes that the needed functions are implemented in the child class.

Parameters:
  • segmentation_fn – function that returns the segmentation of an image

  • fn_neutral_image – function that returns the neutral image used to replace the original image when the perturbation is applied

  • num_classes – number of classes. It determines the target space

  • final_size – desired size of the final image

perturbation(img, neutral, segments, indexes)
View Source Code
def perturbation(self,img,neutral,segments,indexes): ''' This function is used to perturb the input image. For this purpose, the function recives the original image, the neutral image and the segmentation of the image. We change the segmentation indexes to avoid the original image and replace them with the neutral image. Parameters ---------- img: original image. Dims (H,W,3) neutral: neutral image. Dim (H,W,3) segments: segmentation of the image. Dims (H,W) indexes: indexes of the image that we want to replace with the neutral image. Integer or list of integers Returns ------- Perturbed image. Dim (H,W,C) ''' #check if original image and the neutral image has 3 on the last dimension if np.array(img).shape[-1] != 3 or np.array(neutral).shape[-1] != 3: #if not, raise an error raise ValueError("The original image and the neutral image must have 3 on the last dimension") if isinstance(indexes,np.ndarray) or isinstance(indexes,list) or isinstance(indexes,tuple): conditions = np.array([segments == indx for indx in indexes],dtype=object) condition = np.any(conditions,axis=0) else: condition = (segments == indexes ) if len(np.array(img).shape) == 3: condition = np.expand_dims(condition,-1) condition = np.repeat(condition,3,axis=-1) return np.where(condition,neutral,img )

This function is used to perturb the input image. For this purpose, the function recives the original image, the neutral image and the segmentation of the image. We change the segmentation indexes to avoid the original image and replace them with the neutral image.

Parameters:
  • img – original image. Dims (H,W,3)

  • neutral – neutral image. Dim (H,W,3)

  • segments – segmentation of the image. Dims (H,W)

  • indexes – indexes of the image that we want to replace with the neutral image. Integer or list of integers

Return type:

Perturbed image. Dim (H,W,C)

target_transform(target: int)
View Source Code
def target_transform(self,target:int): ''' This function is used to transform the target to the target space. Usually, it depends on the dataset if we want a one-hot vector or a categorical vector. In this case, we transform the target, coded as an int, to a one-hot vector of size self.num_classes. Parameters ---------- target: target of the image. Integer Returns ------- Target in the target space. Dims (self.num_classes) ''' t = torch.zeros(self.num_classes) t[target] = 1 target = t return target

This function is used to transform the target to the target space. Usually, it depends on the dataset if we want a one-hot vector or a categorical vector. In this case, we transform the target, coded as an int, to a one-hot vector of size self.num_classes.

Parameters:

target – target of the image. Integer

Return type:

Target in the target space. Dims (self.num_classes)

transform(img)
View Source Code
def transform(self,img): ''' This function is used to transform the input image to the input space. Parameters ---------- img: original image. Dims (H,W,C) or (H,W) Returns ------- Image preprocessed. Dims (H,W,C) ''' if len(np.array(img).shape) == 2: img = np.expand_dims(img,-1) img = np.repeat(img,3,axis=-1) segments = self.segmentation_fn(img) neutral = self.fn_neutral_image(np.array(img)) perturbation = self.perturbation(img,neutral,segments,-1) return self.after_transform(perturbation)

This function is used to transform the input image to the input space.

Parameters:

img – original image. Dims (H,W,C) or (H,W)

Return type:

Image preprocessed. Dims (H,W,C)

class ReVel.perturbations.SquarePerturbation(dim: int, **kwargs)
View Source Code
class SquarePerturbation(Perturbation): ''' Perturbation that perturbs the input image by a square segmentation. This class is used to perturb the input image Parameters ----------- dim: int dimension of the square segmentation in which the image is segmented. ''' def __init__(self,dim:int,**kwargs): self.dim = dim if kwargs.get('get_input_transform',None) is None: kwargs['get_input_transform'] = get_input_transform super().__init__(segmentation_fn=self.square_segmentation, fn_neutral_image=self.neutral, **kwargs) def square_segmentation(self,img): ''' This function returns the segmentation of the image divided on squares of size (H/self.dim,W/self.dim) Parameters ----------- img: original image. Dims (H,W,C) Returns -------- The segmentation of the image divided on squares of size (H/self.dim,W/self.dim). Each square of the matrix is a diferent integer. ''' img = np.array(img) rango = img.shape[0]//self.dim segments = np.zeros(shape=img.shape[0:2]) for i in range(self.dim+1): for j in range(self.dim+1): segments[i*rango:(i+1)*rango,j*rango:(j+1)*rango] = i+j*self.dim return segments def neutral(self,image): ''' This function returns the neutral image. The neutral image is an image with all pixels set to the mean value of the image for each channel. Parameters ----------- image: original image. Dims (H,W,C) Returns -------- The neutral image with same shape as the original image with the mean value of each channel. ''' if isinstance(image,torch.Tensor): image = image.numpy() return np.zeros(image.shape) + np.mean(image)

Perturbation that perturbs the input image by a square segmentation. This class is used to perturb the input image

Parameters:

dim (int) – dimension of the square segmentation in which the image is segmented.

neutral(image)
View Source Code
def neutral(self,image): ''' This function returns the neutral image. The neutral image is an image with all pixels set to the mean value of the image for each channel. Parameters ----------- image: original image. Dims (H,W,C) Returns -------- The neutral image with same shape as the original image with the mean value of each channel. ''' if isinstance(image,torch.Tensor): image = image.numpy() return np.zeros(image.shape) + np.mean(image)

This function returns the neutral image. The neutral image is an image with all pixels set to the mean value of the image for each channel.

Parameters:

image – original image. Dims (H,W,C)

Return type:

The neutral image with same shape as the original image with the mean value of each channel.

square_segmentation(img)
View Source Code
def square_segmentation(self,img): ''' This function returns the segmentation of the image divided on squares of size (H/self.dim,W/self.dim) Parameters ----------- img: original image. Dims (H,W,C) Returns -------- The segmentation of the image divided on squares of size (H/self.dim,W/self.dim). Each square of the matrix is a diferent integer. ''' img = np.array(img) rango = img.shape[0]//self.dim segments = np.zeros(shape=img.shape[0:2]) for i in range(self.dim+1): for j in range(self.dim+1): segments[i*rango:(i+1)*rango,j*rango:(j+1)*rango] = i+j*self.dim return segments

This function returns the segmentation of the image divided on squares of size (H/self.dim,W/self.dim)

Parameters:

img – original image. Dims (H,W,C)

Returns:

  • The segmentation of the image divided on squares of size (H/self.dim,W/self.dim). Each square

  • of the matrix is a diferent integer.

class ReVel.perturbations.QuickshiftPerturbation(kernel_size: float = 4, max_dist: float = 200.0, ratio: float = 0.2, **kwargs)
View Source Code
class QuickshiftPerturbation(Perturbation): ''' Perturbation that perturbs the input image by quickshift unsupervised segmentation. Parameters ----------- kernel_size: float size of the kernel used to compute the segmentation with quickshift max_dist: float maximum distance between points in the segmentation ratio: float ratio of the segmentation ''' def __init__(self,kernel_size:float=4,max_dist:float=200.0,ratio:float=0.2,**kwargs): self.kernel_size = kernel_size self.max_dist = max_dist self.ratio = ratio super().__init__(segmentation_fn = self.quick_shift_segmentation, fn_neutral_image = self.neutral, get_input_transform = get_input_transform, **kwargs) def quick_shift_segmentation(self,img): ''' Quickshift segmentation function. First introduced on :cite:`quickshift`. It is a wrapper of the skimage.segmentation.quickshift function. Parameters ----------- img: original image. Dims (H,W,C) Returns -------- The segmentation of the image. Dims (H,W) ''' img = np.array(img) if img.shape[-1] != 3: raise ValueError("The original image must have 3 on the last dimension") segments = quickshift(img, kernel_size=self.kernel_size, max_dist=self.max_dist, ratio=self.ratio) return segments def neutral(self,image): ''' This function returns the neutral image. The neutral image is an image with all pixels set to the mean value of the image for each channel. Parameters ----------- image: original image. Dims (H,W,3) Returns -------- The neutral image with same shape as the original image with the mean value of each channel. ''' #check if original image has 3 on the last dimension if np.array(image).shape[-1] != 3: raise ValueError("The original image must have 3 on the last dimension") if isinstance(image,torch.Tensor): image = image.numpy() return np.zeros(image.shape) + np.mean(image)

Perturbation that perturbs the input image by quickshift unsupervised segmentation.

Parameters:
  • kernel_size (float) – size of the kernel used to compute the segmentation with quickshift

  • max_dist (float) – maximum distance between points in the segmentation

  • ratio (float) – ratio of the segmentation

neutral(image)
View Source Code
def neutral(self,image): ''' This function returns the neutral image. The neutral image is an image with all pixels set to the mean value of the image for each channel. Parameters ----------- image: original image. Dims (H,W,3) Returns -------- The neutral image with same shape as the original image with the mean value of each channel. ''' #check if original image has 3 on the last dimension if np.array(image).shape[-1] != 3: raise ValueError("The original image must have 3 on the last dimension") if isinstance(image,torch.Tensor): image = image.numpy() return np.zeros(image.shape) + np.mean(image)

This function returns the neutral image. The neutral image is an image with all pixels set to the mean value of the image for each channel. :param image: original image. Dims (H,W,3)

Return type:

The neutral image with same shape as the original image with the mean value of each channel.

quick_shift_segmentation(img)
View Source Code
def quick_shift_segmentation(self,img): ''' Quickshift segmentation function. First introduced on :cite:`quickshift`. It is a wrapper of the skimage.segmentation.quickshift function. Parameters ----------- img: original image. Dims (H,W,C) Returns -------- The segmentation of the image. Dims (H,W) ''' img = np.array(img) if img.shape[-1] != 3: raise ValueError("The original image must have 3 on the last dimension") segments = quickshift(img, kernel_size=self.kernel_size, max_dist=self.max_dist, ratio=self.ratio) return segments

Quickshift segmentation function. First introduced on [VS08]. It is a wrapper of the skimage.segmentation.quickshift function.

Parameters:

img – original image. Dims (H,W,C)

Return type:

The segmentation of the image. Dims (H,W)