lit_ecology_classifier.models package

Submodules

lit_ecology_classifier.models.model module

class lit_ecology_classifier.models.model.LitClassifier(**hparams)[source]

Bases: LightningModule

TTA(batch)[source]

Perform Test Time Augmentation (TTA) on the input batch. :param batch: Input batch containing images and labels. :type batch: tuple

Returns:

Geometrics Average of probabilities from the TTA predictions. torch.Tensor: True labels if batch is list containg true labels as second entry else None.

Return type:

torch.Tensor

configure_optimizers()[source]

Configure optimizers and learning rate schedulers. :returns: List of optimizers.

list: List of schedulers.

Return type:

list

forward(x)[source]

Forward pass through the model. :param x: Input tensor. :type x: torch.Tensor

Returns:

Model output.

Return type:

torch.Tensor

load_datamodule(datamodule)[source]

Load the data module into the model. :param datamodule: Data module to load. :type datamodule: LightningDataModule

on_fit_end() None[source]

If the model is not using wandb, plot the loss and accuracy curves at the end of training and save them in the output folder.

on_predict_epoch_end() None[source]

Hook to be called at the end of the test epoch. Saves predicted labels in text file in folder Output

on_predict_start() None[source]

Hook for the start of the inference phase.

on_test_epoch_end()[source]

Aggregate outputs and log the confusion matrix at the end of the test epoch. :param outputs: List of dictionaries returned by test_step. :type outputs: list

on_test_epoch_start() None[source]

Hook to be called at the start of the test epoch. Sets up empty lists to store the predicted class probabilities and filenames.

on_validation_epoch_end()[source]

Aggregate outputs and log the confusion matrix at the end of the validation epoch. :param outputs: List of dictionaries returned by validation_step. :type outputs: list

on_validation_epoch_start()[source]

Called in the validation loop at the very beginning of the epoch.

predict_step(batch) None[source]

Perform a prediction step on unlabeled data. :param batch: Input batch containing images :type batch: tuple

test_step(batch, batch_idx)[source]

Perform a test step. :param batch: Input batch containing images and filenames. :type batch: tuple :param batch_idx: Batch index. :type batch_idx: int

training_step(batch, batch_idx)[source]

Perform a training step. :param batch: Input batch containing images and labels. :type batch: tuple :param batch_idx: Batch index. :type batch_idx: int

Returns:

Computed loss for the batch.

Return type:

torch.Tensor

validation_step(batch, batch_idx)[source]

Perform a validation step. :param batch: Input batch containing images and labels. :type batch: tuple :param batch_idx: Batch index. :type batch_idx: int

Returns:

Dictionary containing the loss and predictions.

Return type:

dict

lit_ecology_classifier.models.setup_model module

lit_ecology_classifier.models.setup_model.set_trainable_params(model, train_first=False, finetune=True)[source]

Set the trainable parameters of the model.

Parameters:
  • model (nn.Module) – The model to configure.

  • train_first (bool, optional) – If True, train the first layer of the model. Defaults to False.

  • finetune (bool, optional) – If True, finetune the model. Defaults to True.

lit_ecology_classifier.models.setup_model.setup_model(pretrained=False, num_classes=None, checkpoint_path='checkpoints/backbone.safetensors', **kwargs)[source]

Set up and return the specified model architecture.

Parameters:
  • architecture (str) – The model architecture to use.

  • main_param_path (str) – Path to the directory containing main parameters.

  • ensemble (bool) – Whether to use model ensembling.

  • finetune (bool) – Whether to finetune the model or use it as is.

  • dataset (str) – The name of the dataset.

  • testing (bool, optional) – Set to True if in testing mode. Defaults to False.

  • train_first (bool, optional) – Set to True to train the first layer of the model. Defaults to False.

Returns:

The configured model.

Return type:

model

Module contents