Source code for lit_ecology_classifier.helpers.helpers

import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn

import torch
import torch.nn.functional as F
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, ModelSummary, StochasticWeightAveraging
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
from torch import nn
from torch.autograd import Variable
import tarfile
import os

[docs] class FocalLoss(nn.Module): def __init__(self, gamma=0, alpha=None, size_average=True): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha]) if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) self.size_average = size_average
[docs] def forward(self, input, target): if input.dim() > 2: input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W input = input.transpose(1, 2) # N,C,H*W => N,H*W,C input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C target = target.view(-1, 1) logpt = F.log_softmax(input, dim=1) logpt = logpt.gather(1, target) logpt = logpt.view(-1) pt = Variable(logpt.data.exp()) if self.alpha is not None: if self.alpha.type() != input.data.type(): self.alpha = self.alpha.type_as(input.data) at = self.alpha.gather(0, target.data.view(-1)) logpt = logpt * Variable(at) loss = -1 * (1 - pt) ** self.gamma * logpt if self.size_average: return loss.mean() else: return loss.sum()
[docs] def output_results(outpath, im_names, labels, scores,priority_classes=False,rest_classes=False,tar_file=False): """ Output the prediction results to a file. Args: outpath (str): Output directory path. im_names (list): List of image filenames. labels (list): List of predicted labels. """ labels = labels.tolist() base_filename = f"{outpath}/predictions_lit_ecology_classifier"+("_priority" if priority_classes else "")+("_rest" if rest_classes else "") file_path = f"{base_filename}.txt" if tar_file: im_names = [img.name for img in im_names] lines = [f"{img}------------------ {label}/{score}\n" for img, label,score in zip(im_names, labels,scores)] with open(file_path, "w+") as f: f.writelines(lines)
[docs] def gmean(input_x, dim): """ Compute the geometric mean of the input tensor along the specified dimension. Args: input_x (torch.Tensor): Input tensor. dim (int): Dimension along which to compute the geometric mean. Returns: torch.Tensor: Geometric mean of the input tensor. """ log_x = torch.log(input_x) return torch.exp(torch.mean(log_x, dim=dim))
[docs] def plot_confusion_matrix(all_labels, all_preds, class_names): """ Plot and return confusion matrices (absolute and normalized). Args: all_labels (torch.Tensor): True labels. all_preds (torch.Tensor): Predicted labels. class_names (list): List of class names. Returns: tuple: (figure for absolute confusion matrix, figure for normalized confusion matrix) """ class_indices = np.arange(len(class_names)) confusion_matrix = sklearn.metrics.confusion_matrix(all_labels.cpu(), all_preds.cpu(), labels=class_indices) confusion_matrix_norm = sklearn.metrics.confusion_matrix(all_labels.cpu(), all_preds.cpu(), normalize="pred", labels=class_indices) num_classes = confusion_matrix.shape[0] fig, ax = plt.subplots(figsize=(20, 20)) fig2, ax2 = plt.subplots(figsize=(20, 20)) if len(class_names) != num_classes: print(f"Warning: Number of class names ({len(class_names)}) does not match the number of classes ({num_classes}) in confusion matrix.") class_names = class_names[:num_classes] cm_display = sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix, display_labels=class_names) cm_display_norm = sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix_norm, display_labels=class_names) cmap = cvd_colormap() cm_display.plot(cmap=cmap, ax=ax, xticks_rotation=90) cm_display_norm.plot(cmap=cmap, ax=ax2, xticks_rotation=90) fig.tight_layout() fig2.tight_layout() return fig, fig2
[docs] def cvd_colormap(): """ A color map accessible for people with color vision deficiency (CVD). """ stops = [0.0000, 0.1250, 0.2500, 0.3750, 0.5000, 0.6250, 0.7500, 0.8750, 1.0000] red = [0.2082, 0.0592, 0.0780, 0.0232, 0.1802, 0.5301, 0.8186, 0.9956, 0.9764] green = [0.1664, 0.3599, 0.5041, 0.6419, 0.7178, 0.7492, 0.7328, 0.7862, 0.9832] blue = [0.5293, 0.8684, 0.8385, 0.7914, 0.6425, 0.4662, 0.3499, 0.1968, 0.0539] # Create a dictionary with color information cdict = { 'red': [(stops[i], red[i], red[i]) for i in range(len(stops))], 'green': [(stops[i], green[i], green[i]) for i in range(len(stops))], 'blue': [(stops[i], blue[i], blue[i]) for i in range(len(stops))] } # Create the colormap return LinearSegmentedColormap('CustomMap', segmentdata=cdict, N=255)
[docs] class CosineWarmupScheduler(torch.optim.lr_scheduler._LRScheduler): """ Learning rate scheduler with cosine annealing and warmup. Args: optimizer (torch.optim.Optimizer): Wrapped optimizer. warmup (int): Number of warmup steps. max_iters (int): Total number of iterations. Methods: get_lr: Compute the learning rate at the current step. get_lr_factor: Compute the learning rate factor at the current step. """ def __init__(self, optimizer, warmup, max_iters): self.warmup = warmup self.max_num_iters = max_iters super().__init__(optimizer)
[docs] def get_lr(self): lr_factor = self.get_lr_factor(epoch=self.last_epoch) return [base_lr * lr_factor for base_lr in self.base_lrs]
[docs] def get_lr_factor(self, epoch): lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) if epoch >= self.max_num_iters: lr_factor *= self.max_num_iters / epoch if epoch <= self.warmup: lr_factor *= epoch * 1.0 / self.warmup return lr_factor
[docs] def define_priority_classes(priority_classes): class_map = {class_name: i + 1 for i, class_name in enumerate(priority_classes)} class_map["rest"] = 0 return class_map
[docs] def define_rest_classes(priority_classes): class_map = {class_name: i for i, class_name in enumerate(priority_classes)} return class_map
[docs] def plot_score_distributions(all_scores, all_preds, class_names, true_label): """ Plot the distribution of prediction scores for each class in separate plots. Args: all_scores (torch.Tensor): Confidence scores of the predictions. all_preds (torch.Tensor): Predicted class indices. class_names (list): List of class names. Returns: list: A list of figures, each representing the score distribution for a class. """ # Convert scores and predictions to CPU if not already all_scores = all_scores.cpu().numpy() all_preds = all_preds.cpu().numpy() true_label = true_label.cpu().numpy() # List to hold the figures fig, ax = plt.subplots(len(class_names) // 4 + 1, 4, figsize=(20, len(class_names) // 4 * 5 + 1)) ax = ax.flatten() # Creating a histogram for each class for i, class_name in enumerate(class_names): # Filter scores for predictions matching the current class sig_scores = all_scores[(true_label == i)][:, i] bkg_scores = all_scores[(true_label != i)][:, i] # Create a figure for the current class ax[i].hist(bkg_scores, bins=np.linspace(0, 1, 30), color="skyblue", edgecolor="black") ax[i].set_ylabel("Rest Counts", color="skyblue") ax[i].set_yscale("log") y_axis = ax[i].twinx() y_axis.hist(sig_scores, bins=np.linspace(0, 1, 30), color="crimson", histtype="step", edgecolor="crimson") ax[i].set_title(f"{class_name}") ax[i].set_xlabel("Predicted Probability") y_axis.set_ylabel("Signal Counts", color="crimson") y_axis.set_yscale("log") fig.tight_layout() return fig
[docs] def TTA_collate_fn(batch: dict, train=False): """ Collate function for test time augmentation (TTA). Args: batch (dict): Dict of tuples containing images and labels. Returns: batch_images: All rotations stacked row-wise batch_labels: Labels of the images """ batch_images = {rot: [] for rot in ["0", "90", "180", "270"]} batch_labels = [] if train: for rotated_images, label in batch: for rot in batch_images: batch_images[rot].append(rotated_images[rot]) batch_labels.append(label) batch_images = {rot: torch.stack(batch_images[rot]) for rot in batch_images} batch_labels = torch.tensor(batch_labels) return batch_images, batch_labels else: for rotated_images in batch: for rot in batch_images: batch_images[rot].append(rotated_images[rot]) batch_images = {rot: torch.stack(batch_images[rot]) for rot in batch_images} return batch_images
[docs] def plot_loss_acc(logger): """ Plots the training and validation loss and accuracy from the logger's metrics file. Args: logger (Logger): The logger object containing the save directory, name, and version. Saves: loss_accuracy.png: A plot of the training and validation loss and accuracy over steps. """ # Read the CSV file metrics_file = f"{logger.save_dir}/{logger.name}/version_{logger.version}/metrics.csv" metrics = pd.read_csv(metrics_file) # Plot the training loss step = metrics["step"] train_loss = metrics["train_loss"] val_loss = metrics["val_loss"] train_acc = metrics["train_acc"] val_acc = metrics["val_acc"] fig, ax = plt.subplots(1, 2, figsize=(10, 5)) ax[0].plot(step[train_loss == train_loss], train_loss[train_loss == train_loss], label="Training Loss", color="skyblue") ax[0].plot(step[val_loss == val_loss], val_loss[val_loss == val_loss], label="Validation Loss", color="crimson") ax[0].set_xlabel("Step") ax[0].set_ylabel("Loss") ax[0].set_title("Loss vs Steps") ax[0].legend() ax[1].plot(step[train_loss == train_loss], train_acc[train_loss == train_loss], label="Training Accuracy", color="skyblue") ax[1].plot(step[val_loss == val_loss], val_acc[val_loss == val_loss], label="Validation Accuracy", color="crimson") ax[1].set_xlabel("Step") ax[1].set_ylabel("Accuracy") ax[1].set_title("Accuracy vs Steps") ax[1].legend() fig.tight_layout() plt.savefig(f"{logger.save_dir}/{logger.name}/version_{logger.version}/loss_accuracy.png")
[docs] def setup_callbacks(priority_classes, ckpt_name): """ Sets up callbacks for the training process. Args: priority_classes (list): List of priority classes to monitor for false positives. ckpt_name (str): The name of the checkpoint file. Returns: list: A list of configured callbacks including EarlyStopping, ModelCheckpoint, and ModelSummary. """ callbacks = [] ckpt_name = ckpt_name + "-{epoch:02d}-{val_acc:.4f}" if len(priority_classes) == 0 else ckpt_name + "-{epoch:02d}-{val_acc:.4f}-{val_false_positives:.4f}" monitor = "val_acc" if len(priority_classes) == 0 else "val_precision" mode = "max" callbacks.append(ModelCheckpoint(filename=ckpt_name, monitor=monitor, mode=mode, save_top_k=5)) callbacks.append(ModelSummary()) return callbacks
[docs] def plot_reduced_classes(model, priority_classes): """ Plots the confusion matrix for reduced classes. Args: model (LightningModule): The trained model. priority_classes (list): List of priority classes. Saves: reduced_confusion_matrix.png: A confusion matrix of the reduced classes. reduced_confusion_matrix_norm.png: A normalized confusion matrix of the reduced classes. """ reduced_class_map = {v: k + 1 for k, v in enumerate(priority_classes)} reduced_class_map["rest"] = 0 inv_reduced_class_map = {v: k for k, v in reduced_class_map.items()} reduced_preds = [] reduced_labels = [] preds = torch.cat(model.test_step_predictions) true_labels = torch.cat(model.test_step_targets) for pred, true in zip(preds, true_labels): name = model.inverted_class_map[pred.item()] name2 = model.inverted_class_map[true.item()] reduced_preds.append(reduced_class_map[name] if name in reduced_class_map else 0) reduced_labels.append(reduced_class_map[name2] if name2 in reduced_class_map else 0) all_preds = torch.tensor(reduced_preds) all_labels = torch.tensor(reduced_labels) fig, fig2 = plot_confusion_matrix(all_labels, all_preds, inv_reduced_class_map) fig.savefig(f"{model.outpath}/reduced_confusion_matrix.png") fig2.savefig(f"{model.outpath}/reduced_confusion_matrix_norm.png")
[docs] def setup_classmap(datapath="", priority_classes=[], rest_classes=[]): if priority_classes != []: logging.info(f"Priority classes not None. Loading priority classes from {priority_classes}") logging.info(f"Priority classes set to: {priority_classes}") class_map = define_priority_classes(priority_classes) elif rest_classes != []: logging.info(f"rest classes not None. Defining clas map from {rest_classes}") class_map = define_rest_classes(rest_classes) # Load class map from JSON or extract it from the tar file if not present else: logging.info(f" Extracting class map from tar file.") class_map = _extract_class_map(datapath) return class_map
def _extract_class_map(tar_or_dir_path): """ Extracts the class map from the contents of the tar file or directory and saves it to a JSON file. Arguments: tar_or_dir_path: str Path to the tar file or directory containing the images. Returns: dict A dictionary mapping class names to indices. """ logging.info("Extracting class map.") class_map = {} if tarfile.is_tarfile(tar_or_dir_path): logging.info("Detected tar file.") with tarfile.open(tar_or_dir_path, "r") as tar: # Temporary set to track folders that contain images folders_with_images = set() # First pass: Identify folders containing images for member in tar.getmembers(): if member.isdir(): continue # Skip directories if member.isfile() and member.name.lower().endswith(("jpg", "jpeg", "png")): class_name = os.path.basename(os.path.dirname(member.name)) folders_with_images.add(class_name) # Second pass: Build the class map only for folders with images for member in tar.getmembers(): if member.isdir(): continue # Skip directories class_name = os.path.basename(os.path.dirname(member.name)) if class_name in folders_with_images: if class_name not in class_map: class_map[class_name] = [] class_map[class_name].append(member.name) elif os.path.isdir(tar_or_dir_path): logging.info("Detected directory.") for root, _, files in os.walk(tar_or_dir_path): for file in files: if file.lower().endswith(("jpg", "jpeg", "png")): class_name = os.path.basename(root) if class_name not in class_map: class_map[class_name] = [] class_map[class_name].append(os.path.join(root, file)) else: raise ValueError("Provided path is neither a valid tar file nor a directory.") # Create a sorted list of class names and map them to indices sorted_class_names = sorted(class_map.keys()) logging.info(f"Found {len(sorted_class_names)} classes.") class_map = {class_name: idx for idx, class_name in enumerate(sorted_class_names)} return class_map