Source code for nlgm.train

import torch
from torch import nn
from tqdm.auto import tqdm


[docs] def train_and_evaluate( model, train_loader, test_loader, epochs=10, device=torch.device("cpu"), progress_bar=True, ): """ Train a model and evaluate on a test loader each epoch. Parameters ---------- model : nn.Module Model to train and evaluate. train_loader : torch.utils.data.DataLoader Training data loader. test_loader : torch.utils.data.DataLoader Test data loader. epochs : int, default=10 Number of training epochs. device : torch.device, default=torch.device("cpu") Device used for training and evaluation. progress_bar : bool, default=True Whether to show epoch-level progress. Returns ------- list[float] Per-epoch training losses. float Final epoch test loss. """ criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) train_losses = [] model.to(device) it = tqdm(range(epochs), desc="Epochs") if progress_bar else range(epochs) for epoch in it: model.train() train_loss = 0 for images, _ in tqdm(train_loader, desc="Training Loop"): images = images.to(device) optimizer.zero_grad() reconstructed_images = model(images) loss = criterion(reconstructed_images, images) loss.backward() optimizer.step() train_loss += loss.item() train_loss /= len(train_loader) train_losses.append(train_loss) model.eval() test_loss = 0 with torch.no_grad(): for images, _ in test_loader: images = images.to(device) reconstructed_images = model(images) loss = criterion(reconstructed_images, images) test_loss += loss.item() test_loss /= len(test_loader) print( f"Epoch [{epoch + 1}/{epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}" ) return train_losses, test_loss