Brief description of the submodule

Here all of the code used to pass from the datasets downloaded to ready to use data for training is described.


Function to calculate 0.01 and 0.99 percentiles of the bands of planet images. These values will be later used for normalizing the dataset.


  • img_folder: (list) The name of the folder with the images.
  • samples: (integer) The number of images to take to calculate these percentiles, for computing reasons not all images are considered.


  • vals: (numpy.ndarray) The mean 1% and 99% quantiles for the images analysed.

Dependencies used

import os
import random
import numpy as np
import rioxarray

Source code

def calculate_percentiles(img_folder, samples = 400):
imgs = [fn for fn in os.listdir(img_folder) if 'StudyArea' in fn]

img_sample = random.sample(imgs, samples)
quantiles = np.zeros((2,4))

for i in img_sample:
quantiles += rioxarray.open_rasterio(img_folder + "\\" + i).quantile((0.01, 0.99), dim = ('x','y')).values

vals = quantiles/len(img_sample)

return vals


Function to get the training, validation and test torch.DataLoader or torch.Dataset for a specific dataset. This function gets the images from the Img_Dataset class.


  • dir: (str) Directory with the name of the data to be used.
  • batch_size: (int) Size of the batches used for training.
  • transform: (torchvision.transforms.V2.Compose) torch composition of transforms used for image augmentation.
  • normaliztion: (str) Type of normalization used. (Should be 'Linear_1_99')
  • VI: (boolean) Boolean indicating if NDVI and NDWI are also used in training.
  • split_size: (float) Float between 0 and 1 indicating the fraction of dataset to be used (Especifically useful for HP tuning)
  • only_get_DS: (boolean) Boolean for only getting datasets instead of dataloaders.
  • train_split_size: (float) fraction of train split to be loaded. (number between 0 and 1)
  • val_split_size: (float) fraction of validation and test split to be loaded. (number between 0 and 1)


Can be either the data loaders:

  • train_loader: Training torch data loader
  • val_loader: Validation torch data loader
  • test_loader: Test torch data loader

or the datasets:

  • train_DS: Training torch data set.
  • val_DS: Validation torch data loader
  • test_DS: Test torch data loader

Dependencies used

import torch
from import random_split

Source code

def get_DataLoaders(dir, batch_size, transform, normalization, VI, only_get_DS = False, train_split_size = None, val_split_size = None):
train_DS = Img_Dataset(dir, transform, norm = normalization, VI=VI)
val_DS = Img_Dataset(dir, split = 'Validation', norm = normalization, VI=VI)
test_DS = Img_Dataset(dir, split = 'Test', norm = normalization, VI=VI)

if train_split_size != None:
if val_split_size == None:
val_split_size = train_split_size

train_DS, l = random_split(train_DS, [train_split_size, 1-train_split_size], generator=torch.Generator().manual_seed(8))
val_DS, l = random_split(val_DS, [val_split_size, 1-val_split_size], generator=torch.Generator().manual_seed(8))
test_DS, l = random_split(test_DS, [val_split_size, 1-val_split_size], generator=torch.Generator().manual_seed(8))

train_loader =, batch_size=batch_size, shuffle=True)
val_loader =, batch_size=batch_size, shuffle=False)
test_loader =, batch_size=batch_size, shuffle=False)

if only_get_DS:
return train_DS, val_DS, test_DS
return train_loader, val_loader, test_loader


Function to get the loaders for LoveDA dataset, which was retrieved using torchgeo.

Size of the dataset:



  • domain: List with the scene parameter for the LoveDa dataset. It can include 'rural' and/or 'urban'.
  • batch_size: Number of images per batch.
  • transforms: Image augmentations that will be considered.
  • only_get_DS: Boolean for only getting datasets instead of dataloaders.
  • train_split_size: Amount of images from training split that will be considered. (Float between 0 and 1)
  • val_split_size: Amount of images from validation and test split that will be considered. (Float between 0 and 1)


  • train_loader: Training torch LoveDA data loader
  • val_loader: Validation torch LoveDA data loader
  • test_loader: Test torch LoveDA data loader

Dependencies used

import torch
from torchgeo.datasets import LoveDA
from import random_split

Source code

def get_LOVE_DataLoaders(domain = ['urban', 'rural'], batch_size = 4, transforms = None, only_get_DS = False, train_split_size = None, val_split_size = None):
if transforms != None:
train_DS = LoveDA('LoveDA', split = 'train', scene = domain, download = True, transforms = transforms)
train_DS = LoveDA('LoveDA', split = 'train', scene = domain, download = True)

test_DS = LoveDA('LoveDA', split = 'test', scene = domain, download = True, transforms = transforms)
val_DS = LoveDA('LoveDA', split = 'val', scene = domain, download = True, transforms = transforms)

if train_split_size != None:
if val_split_size == None:
val_split_size = train_split_size
train_DS, l = random_split(train_DS, [train_split_size, 1-train_split_size], generator=torch.Generator().manual_seed(8))
val_DS, l = random_split(val_DS, [val_split_size, 1-val_split_size], generator=torch.Generator().manual_seed(8))
test_DS, l = random_split(test_DS, [val_split_size, 1-val_split_size], generator=torch.Generator().manual_seed(8))

train_loader =, batch_size=batch_size, shuffle=True)
val_loader =, batch_size=batch_size, shuffle=False)
test_loader =, batch_size=batch_size, shuffle=False)

if only_get_DS:
return train_DS, val_DS, test_DS
return train_loader, val_loader, test_loader


Class to manage the Cashew dataset. The Cashew dataset consists of 256x256 Planet NICFI images with 4 bands(B, G, R, NIR).

Size of the dataset:



The normalization of the images in the dataset was performed using a linear normalization using the values of the percentiles of 1 and 99 percent. A nice explanation of image normalization can be found on this medium post.

The equation used for the normalization is presented below:

IMGnormalized=IMGPerc1%Perc99%Perc1%IMG_{normalized} = \frac{IMG - Perc_{1\%}}{Perc_{99\%} - Perc_{1\%}}

Vegetation indices

Two additional channels can be added to the tensor which are the Normalized Difference Vegetation Index (NDVI)


and the Normalized Difference Water Index (NDWI).



  • self.img_folder (str) Name of the folder in which the images are stored.
  • self.transform (torchvision.transforms.V2.Compose) torch composition of transforms used for image augmentation.
  • self.split (str) Split of the dataset to be retrieved. Can be Train, Validation or Test.
  • self.norm (str) Type of normalization used. Only 'Linear_1_99' is allowed right now.
  • self.VI (boolean) Boolean indicating if NDVI and NDWI are also used in training.



Method to calculate the number of images in the dataset.


Method to plot a specific image of the dataset. Receives idx as the index of the image and VIs as a boolean to decide to plot or not the vegetation indices of the images.


Method to get the tensors (image and ground truth) for a specific index (idx).

Source code

class Img_Dataset(Dataset):
Class to manage the cashew dataset.
def __init__(self, img_folder, transform = None, split = 'Train', norm = 'Linear_1_99', VI = True, recalculate_perc = False):
self.img_folder = img_folder
self.transform = transform
self.split = split
self.norm = norm
self.VI = VI

# Depending of the domain the images will have different attributes (country and quantiles)
if 'Tanzania' in self.img_folder: = 'Tanzania'

if recalculate_perc:
self.quant_TNZ = calculate_percentiles(img_folder)
self.quant_TNZ = quant_TNZ
else: = 'IvoryCoast'

if recalculate_perc:
self.quant_CIV = calculate_percentiles(img_folder)
self.quant_CIV = quant_CIV

def __len__(self):
Method to calculate the number of images in the dataset.
return sum([self.split in i for i in os.listdir(self.img_folder)])//2

def plot_imgs(self, idx, VIs = False):
Method to plot a specific image of the dataset.

- self: The dataset class and its attributes.
- idx: index of the image that will be plotted.
- VIs: Boolean describing if vegetation indices should be plotted

im, g = self.__getitem__(idx)

if VIs:
fig, ax = plt.subplots(2,2,figsize = (12,12))

ax[0,0].set_title('Planet image')
ax[0,1].set_title('Cashew crops GT')

VIs = im[4:6]

g1=ax[1,0].imshow(VIs[0], cmap ='RdYlGn', 5), vmin = 0, vmax = 1)
g2=ax[1,1].imshow(VIs[1], cmap ='Blues_r', 5), vmin = 0, vmax = 1)

fig, ax = plt.subplots(1,2,figsize = (12,6))

ax[0].set_title('Planet image')
ax[1].set_title('Cashew crops GT')

def __getitem__(self, idx):
Method to get the tensors (image and ground truth) for a specific image.

conversion = T.ToTensor()

img = io.imread(fname = self.img_folder + '/Cropped' + + self.split + 'StudyArea_{:05d}'.format(idx) + '.tif').astype(np.float32)

if self.VI:
if self.norm == 'Linear_1_99':
ndvi = (img[:,:,3] - img[:,:,2])/(img[:,:,3] + img[:,:,2])
ndwi = (img[:,:,1] - img[:,:,3])/(img[:,:,3] + img[:,:,1])

if self.norm == 'Linear_1_99':
for i in range(img.shape[-1]):
if 'Tanz' in self.img_folder:
img[:,:,i] = (img[:,:,i] - self.quant_TNZ[0,i])/(self.quant_TNZ[1,i] - self.quant_TNZ[0,i])
elif 'Ivor' in self.img_folder:
img[:,:,i] = (img[:,:,i] - self.quant_CIV[0,i])/(self.quant_CIV[1,i] - self.quant_CIV[0,i])

if self.VI:
ndvi = np.expand_dims(ndvi, axis = 2)
ndwi = np.expand_dims(ndwi, axis = 2)
img = np.concatenate((img, ndvi, ndwi), axis = 2)

img = conversion(img).float()

img = torchvision.tv_tensors.Image(img)

GT = io.imread(fname = self.img_folder + '/Cropped' + + self.split + 'GT_{:05d}'.format(idx) + '.tif').astype(np.float32)

GT = torch.flip(conversion(GT), dims = (1,))

GT = torchvision.tv_tensors.Image(GT)

if self.transform != None:
GT, img = self.transform(GT, img)

return img, GT