Skip to main content

Train_DomainOnly

Brief description of the submodule

In this submodule, all of the functions used for training the domain-only models are described in detail.

evaluate()

Function used to evaluate the segmentation performance of a specified network. It gets the predictions calculated using the network for a specified data loader and calculates the mean loss and accuracy comparing the predictions with the ground truth labels of the loader.

Params

  • net: (torch.nn.Module) Network class used to get the segmentation predictions.
  • validate_loader: (torch.nn.DataLoader) Data loader of which the mean loss and accuracy will be calculated.
  • loss_function: (torch.nn.Module) Loss function used during the training of the network.
  • accu_function: (torchmetrics.classification) Function to calculate the mean accuracy of the network on validate_loader. Default is BinaryF1Score()
  • Love: Binary to indicate if working with LoveDA dataset.
  • binary_love: Binary to indicate if working with only one class of LoveDA dataset.

Outputs

  • metric: (list) List with the mean values of accuracy and loss.

Dependencies used

import torch
from torchmetrics.classification import BinaryF1Score

from utils import LOVE_resample_fly

Source code


def evaluate(net, validate_loader, loss_function, accu_function = BinaryF1Score(), Love = False, binary_love = False):
"""
Function to evaluate the performance of a network on a validation data loader.

Inputs:
- net: Pytorch network that will be evaluated.
- validate_loader: Validation (or Test) dataset with which the network will be evaluated.
- loss_function: Loss function used to evaluate the network.
- accu_function: Accuracy function used to evaluate the network.

Output:
- metric: List with loss and accuracy values calculated for the validation/test dataset.
"""

net.eval() # Set the model to evaluation mode
device = next(iter(net.parameters())).device # Get training device ("cuda" or "cpu")

f1_scores = []
losses = []

with torch.no_grad():
# Iterate over validate loader to get mean accuracy and mean loss
for i, Data in enumerate(validate_loader):

# The inputs and GT are obtained differently depending of the Dataset (LoveDA or our own DS)
if Love:
inputs = LOVE_resample_fly(Data['image'])
GTs = LOVE_resample_fly(Data['mask'])
if binary_love:
GTs = (GTs == 6).long()
else:
inputs = Data[0]
GTs = Data[1]


inputs = inputs.to(device)
GTs = GTs.type(torch.long).squeeze().to(device)
pred = net(inputs)

f1 = accu_function.to(device)

if (pred.max(1)[1].shape != GTs.shape):
GTs = GTs[None, :, :]

loss = loss_function(pred, GTs)/GTs.shape[0]

f1_score = f1(pred.max(1)[1], GTs)

f1_scores.append(f1_score.to('cpu').numpy())
losses.append(loss.to('cpu').numpy())

metric = [np.mean(f1_scores), np.mean(losses)]

return metric

training_loop()

Function to train the neural network through backward propagation.

Params

  • train_loader: DataLoader with the training dataset.
  • val_loader: DataLoader with the validation dataset.
  • learning_rate: Initial learning rate for training the network.
  • starter_channels: Starting number of channels in th U-Net
  • momentum: Momentum used during training.
  • number_epochs: Number of training epochs.
  • loss_function: Function to calculate loss.
  • accu_function: Function to calculate accuracy (Default: BinaryF1Score).
  • Love: Boolean to decide between training with LoveDA dataset or our own dataset.
  • decay: Factor in which learning rate decays.
  • bilinear: Boolean to decide the upscaling method (If True Bilinear if False Transpose convolution. Default: True)
  • n_channels: Number of initial channels (Defalut 4 [Planet])
  • n_classes: Number of classes that will be predicted (Default 2 [Binary segmentation])
  • plot: Boolean to decide if training loop should be plotted or not.
  • seed: Seed that will be used for generation of random values.

Outputs

  • best_model: f1-score of the best model trained. (Calculated on validation dataset)
  • model_saved: The best model trained.
  • spearman: Spearman correlation calculated for training progress (High positive value will indicate positive learning)

Dependencies used

import numpy as np
import torch
from collections import deque

from utils import LOVE_resample_fly, get_training_device

Source code

def training_loop(network, train_loader, val_loader, learning_rate, momentum, number_epochs, loss_function, accu_function = BinaryF1Score(), Love = False, binary_love = False, decay = 0.75, bilinear = True, n_channels = 4, n_classes = 2, plot = True, seed = 8):
"""
Function to train the Neural Network.

Input:
- train_loader: DataLoader with the training dataset.
- val_loader: DataLoader with the validation dataset.
- learning_rate: Initial learning rate for training the network.
- starter_channels: Starting number of channels in th U-Net
- momentum: Momentum used during training.
- number_epochs: Number of training epochs.
- loss_function: Function to calculate loss.
- accu_function: Function to calculate accuracy (Default: BinaryF1Score).
- Love: Boolean to decide between training with LoveDA dataset or our own dataset.
- decay: Factor in which learning rate decays.
- bilinear: Boolean to decide the upscaling method (If True Bilinear if False Transpose convolution. Default: True)
- n_channels: Number of initial channels (Defalut 4 [Planet])
- n_classes: Number of classes that will be predicted (Default 2 [Binary segmentation])
- plot: Boolean to decide if training loop should be plotted or not.
- seed: Seed that will be used for generation of random values.

Output:
- best_model: f1-score of the best model trained. (Calculated on validation dataset)
- model_saved: The best model trained.
- spearman: Spearman correlation calculated for training progress (High positive value will indicate positive learning)
"""

device = get_training_device()

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

network = network
network.to(device)
optimizer = torch.optim.SGD(network.parameters(), lr=learning_rate, momentum = momentum, weight_decay=1e-4)

#Training metrics are computed as a running average of the last x samples
loss_train = deque(maxlen=len(train_loader))
accuracy_train = deque(maxlen=len(train_loader))

val_eps = []
val_f1s = []
val_loss = []

train_eps = []
train_f1s = []
train_loss = []

for epoch in tqdm(range(number_epochs), desc = 'Training model'):

#Validation phase 0:
metric_val = evaluate(network, val_loader, loss_function, accu_function, Love, binary_love)

val_eps.append(epoch)
val_f1s.append(metric_val[0])
val_loss.append(metric_val[1])

#Training phase:
network.train() #indicate to the network that we enter training mode

for i, Data in enumerate(train_loader): # Iterate over the training dataset and do the backward propagation.
if Love:
inputs = LOVE_resample_fly(Data['image'])
GTs = LOVE_resample_fly(Data['mask'])
if binary_love:
GTs = (GTs == 6).long()
else:
inputs = Data[0]
GTs = Data[1]

inputs = inputs.to(device)
GTs = GTs.type(torch.long).squeeze().to(device)

#Set the gradients of the model to 0.
optimizer.zero_grad()
# Get predictions
pred = network(inputs)

if (pred.max(1)[1].shape != GTs.shape):
GTs = GTs[None, :, :]

loss = loss_function(pred, GTs)

accu = accu_function.to(device)
accu_ = accu(pred.max(1)[1], GTs)

loss.backward()

optimizer.step()

loss_train.append(loss.item()/GTs.shape[0])
accuracy_train.append(accu_.item())

train_eps.append(epoch+i/len(train_loader))
train_f1s.append(np.mean(accuracy_train))
train_loss.append(np.mean(loss_train))

#Validation phase 1:
metric_val = evaluate(network, val_loader, loss_function, accu_function, Love, binary_love)
print(epoch+1, metric_val)

val_eps.append(epoch + 1)
val_f1s.append(metric_val[0])
val_loss.append(metric_val[1])

if epoch == 0:
best_model = metric_val[0]
torch.save(network, 'BestModel.pt')
model_saved = network
else:
if best_model < metric_val[0]:
best_model = metric_val[0]
torch.save(network, 'BestModel.pt')
model_saved = network

if (epoch//10 == epoch/10):
#After 4 epochs, reduce the learning rate by a factor
optimizer.param_groups[0]['lr'] *= decay

if plot:
fig, ax = plt.subplots(1,1, figsize = (7,5))

ax.plot(train_eps, train_f1s, label = 'Training F1-Score', ls= '--', color = 'r')
ax.plot(train_eps, train_loss, label = 'Training Loss', ls = '-', color = 'r')

ax.plot(val_eps, val_f1s, label = 'Validation F1-Score', ls = '--', color = 'b')
ax.plot(val_eps, val_loss, label = 'Validation Loss', ls = '-', color = 'b')

ax.text(val_eps[np.argmax(val_f1s)], np.max(val_f1s), str(np.max(val_f1s)))

ax.set_xlabel("Epoch")

plt.legend()

fig.savefig('TrainingLoop.png', dpi = 200)

plt.close()

spearman = stats.spearmanr(val_eps, val_f1s)[0]

if val_eps[np.argmax(val_f1s)] == 0:
no_learning = True
else:
no_learning = False

return best_model, model_saved, spearman, no_learning

train_3fold_DomainOnly()

Params

  • domain: String with the prefix of the domain to use for training. (Can be either Tanzania or IvoryCoast)
  • DS_args: List with all the arguments related to the dataset itself (e.g. batch_size, transforms, normalization and use of vegetation indices)
  • network_args: List with arguments used for the network creation (n_classes, bilinear, starter channels, up_layer)
  • training_loop_args: List with all the arguments needed to run the training loop (for more information check training_loop funtion.)
  • eval_args: List with arguments to evaluate the trained network on the test dataset.

Outputs

  • Stats: List with the mean f1 score and its standard deviation

Dependencies used

import time
import torch
import numpy as np

from Dataset.ReadyToTrain_DS import get_DataLoaders

Source code

def train_3fold_DomainOnly(domain, DS_args, network_args, training_loop_args, eval_args):
"""
Function to run all Domain Only training for the three folds.

Input:
- domain: String with the prefix of the domain to use for training. (Can be either Tanzania or IvoryCoast)
- DS_args: List with all the arguments related to the dataset itself (e.g. batch_size, transforms, normalization and use of vegetation indices)
- network_args: List with arguments used for the network creation (n_classes, bilinear, starter channels, up_layer)
- training_loop_args: List with all the arguments needed to run the training loop (for more information check training_loop funtion.)
- eval_args: List with arguments to evaluate the trained network on the test dataset.

Output:
- Stats: Mean and standard deviation of the validation and test accuracy values for the domain only training on the three folds.
"""

folds = 3

fscore = []

# For 3-fold Cross-Validation
for i in range(folds):

# Build Dataloaders
print("Creating dataloaders...")
train_loader, val_loader, test_loader = get_DataLoaders(domain+'Split'+str(i+1), *DS_args)
print("Dataloaders created.\n")

n_channels = next(enumerate(train_loader))[1][0].shape[1] #get band number from actual data
n_classes = 2

# Define the network
network = UNet(n_channels, *network_args)

# Train the model
print("Starting training...")
start = time.time()
f1_val, network_trained, spearman, no_L = training_loop(network, train_loader, val_loader, *training_loop_args)
print("Network trained. Took ", round(time.time() - start, 0), 's\n')

if i == 0:
best_network = network_trained
torch.save(best_network, 'OverallBestModel'+domain+'.pt')
best_f1 = f1_val
else:
if f1_val > best_f1:
best_network = network_trained
torch.save(best_network, 'OverallBestModel'+domain+'.pt')
best_f1 = f1_val

# Evaluate the model
f1_test, loss_test = evaluate(network_trained, test_loader, *eval_args)

print("F1_Validation:", f1_val)
print("F1_Test: ", f1_test)

fscore.append([f1_val, f1_test])

fscore

mean = np.mean(fscore, axis = 0)
std = np.std(fscore, axis = 0)

stats = [mean, std]

return stats

train_LoveDA_DomainOnly()

Function to train the domain only models for the LoveDA dataset.

Params

  • domain: List with the scene parameter for the LoveDa dataset. It can include 'rural' and/or 'urban'.
  • DS_args: List with all the arguments related to the dataset itself (e.g. batch_size, transforms)
  • network_args: List with arguments used for the network creation (n_classes, bilinear, starter channels, up_layer)
  • training_loop_args: List with all the arguments needed to run the training loop (for more information check training_loop funtion.)

Outputs

  • validation_accuracy: Accuracy score for validation dataset.
  • network_trained: Neural network that has been trained

Dependencies used

from Dataset.ReadyToTrain_DS import get_LOVE_DataLoaders

Source code

def train_LoveDA_DomainOnly(domain, DS_args, network_args, training_loop_args):
"""
Function to train the domain only models for the LoveDA dataset.

Inputs:
- domain: List with the scene parameter for the LoveDa dataset. It can include 'rural' and/or 'urban'.
- DS_args: List with all the arguments related to the dataset itself (e.g. batch_size, transforms)
- network_args: List with arguments used for the network creation (n_classes, bilinear, starter channels, up_layer)
- training_loop_args: List with all the arguments needed to run the training loop (for more information check training_loop funtion.)

Outputs:
- validation_accuracy: Accuracy score for validation dataset.
- network_trained: Neural network that has been trained.
"""

# Get DataLoaders
train_loader, val_loader, test_loader = get_LOVE_DataLoaders(domain, *DS_args)

# Get number of channels from actual data
n_channels = next(enumerate(train_loader))[1]['image'].shape[1]

# Define the network
network = UNet(n_channels, *network_args)

# Train the network
accu_val, network_trained, spearman, no_l = training_loop(network, train_loader, val_loader, *training_loop_args)

return accu_val, network_trained

run_DomainOnly()

Aggregating function to perform the whole training routine for one of the domains.

Params

  • domain: String with the name of the domain of interest ('Tanzania' or 'IvoryCoast')

Outputs

Source code

def run_DomainOnly(domain = 'IvoryCoast'):
"""
Function to perform the whole training routine for one of the domains.
"""

## Related to DS
batch_size = 4
transforms = get_transforms()
normalization = 'Linear_1_99'
VI = True
DA = False

## Related to the network
n_classes = 2
bilinear = True
starter_channels = 16
up_layer = 4
attention = True
resunet = False

## Related to training and evaluation
number_epochs = 30
learning_rate = 1
momentum = 0.2
loss_function = FocalLoss(gamma = 2)
accu_function = BinaryF1Score()
device = get_training_device()

DS_args = [batch_size, transforms, normalization, VI, DA, None, None]
network_args = [n_classes, bilinear, starter_channels, up_layer, attention, resunet]
training_args = [learning_rate, momentum, number_epochs, loss_function]
eval_args = [loss_function, accu_function]

Stats = train_3fold_DomainOnly(domain, DS_args, network_args, training_args, eval_args)

print(Stats)

plot_3fold_accuracies(domain, Stats)