Source code for lit_ecology_classifier.data.tardataset

import io
import json
import logging
import os
import pprint
import random
import tarfile
from collections import defaultdict
from typing import Any

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms.v2 import AugMix, Compose, Normalize, RandomHorizontalFlip, RandomRotation, Resize, ToDtype, ToImage

from ..helpers.helpers import define_priority_classes, define_rest_classes
from typing import Any


[docs] class TarImageDataset(Dataset): """ A Dataset subclass for managing and accessing image data stored in tar files. This class supports optional image transformations, and Test Time Augmentation (TTA) for enhancing model evaluation during testing. Attributes: tar_path (str): Path to the tar file containing image data. class_map_path (str): Path to the JSON file mapping class names to labels. priority_classes (str): Path to a JSON file specifying priority classes for targeted training or evaluation. train (bool): Specifies whether the dataset will be used for training. Determines the type of transformations applied. TTA (bool): Indicates if Test Time Augmentation should be applied during testing. """ def __init__(self, tar_path: str,class_map: dict, priority_classes:list, rest_classes:list, TTA: bool = False, train: bool = False): """ Initializes the TarImageDataset with paths and modes. Args: tar_path (str): The file path to the tar archive containing the images. class_map_path (str): The file path to the JSON file with class mappings. priority_classes (str): The file path to the JSON file that contains priority classes. train (bool): A flag to indicate if the dataset is used for training purposes. TTA (bool): A flag to enable Test Time Augmentation. """ self.tar_path = tar_path self.TTA = TTA self.class_map = class_map self.train = train self.priority_classes = priority_classes self.rest_classes = rest_classes # Transformation sequences for training and validation/testing self._define_transforms() # Load image information from the tar file self.image_infos = self._load_image_infos() if rest_classes!=[] and train: self._filter_rest_classes() self.train=train def _filter_rest_classes(self): """ Removes samples that are not in rest_classes from the dataset. """ logging.info(f"Filtering dataset to keep only classes in {self.rest_classes}") filtered_image_infos = [] for image_info in self.image_infos: class_name = os.path.basename(os.path.dirname(image_info.name)) if class_name in self.rest_classes: filtered_image_infos.append(image_info) self.image_infos = filtered_image_infos logging.info(f"Filtered dataset to {len(self.image_infos)} samples.") def _define_transforms(self): mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] # ImageNet mean and std self.train_transforms = Compose([ToImage(), RandomHorizontalFlip(), RandomRotation(180), AugMix(), Resize((224, 224)), ToDtype(torch.float32, scale=True), Normalize(mean, std)]) self.val_transforms = Compose([ToImage(), Resize((224, 224)), ToDtype(torch.float32, scale=True), Normalize(mean, std)]) if self.TTA: self.rotations = { "0": Compose([RandomRotation(0, 0)]), "90": Compose([RandomRotation((90, 90))]), "180": Compose([RandomRotation((180, 180))]), "270": Compose([RandomRotation((270, 270))]), } def __len__(self): """ Returns the total number of images in the dataset. Returns: int: The total number of images. """ return len(self.image_infos) def __getitem__(self, idx): """ Retrieves an image and its corresponding label based on the provided index. Args: idx (int): The index of the image. Returns: tuple: A tuple containing the transformed image and its label. """ with tarfile.open(self.tar_path, "r") as tar: image_info = self.image_infos[idx] image_file = tar.extractfile(image_info) image = Image.open(io.BytesIO(image_file.read())).convert("RGB") # Apply TTA transformations if enabled if self.TTA: image = {rot: self.val_transforms(self.rotations[rot](image)) for rot in self.rotations} elif self.train: image = self.train_transforms(image) else: image = self.val_transforms(image) if self.train: label = self.get_label_from_filename(image_info.name) return image, label else: return image def _load_image_infos(self): """ Load image information from the tar file. """ image_infos = [] with tarfile.open(self.tar_path, "r") as tar: for member in tar.getmembers(): if member.isfile() and member.name.lower().endswith(("jpg", "jpeg", "png")): image_infos.append(member) return image_infos
[docs] def get_label_from_filename(self, filename): """ Extracts the label index from a given filename. Args: filename (str): The filename from which to extract the label. Returns: int: The label index corresponding to the class. """ label = filename.split("/")[1] if self.priority_classes!=[]: label = self.class_map.get(label, 0) else: label = self.class_map[label] return label
[docs] def shuffle(self): """ Shuffles the list of image information to randomize data access, useful during training. """ random.shuffle(self.image_infos)