Source code for nlgm.train
import torch
from torch import nn
from tqdm.auto import tqdm
[docs]
def train_and_evaluate(
model,
train_loader,
test_loader,
epochs=10,
device=torch.device("cpu"),
progress_bar=True,
):
"""
Trains and evaluates a model using the given train and test data loaders.
Args:
model (nn.Module): The model to train and evaluate.
train_loader (torch.utils.data.DataLoader): The data loader for training data.
test_loader (torch.utils.data.DataLoader): The data loader for test data.
epochs (int, optional): The number of epochs to train the model. Defaults to 10.
device (torch.device, optional): The device to use for training. Defaults to torch.device("cpu").
progress_bar (bool, optional): Whether to display a progress bar during training. Defaults to True.
Returns:
tuple: A tuple containing the list of train losses and the test loss.
"""
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train_losses = []
model.to(device)
it = tqdm(range(epochs), desc="Epochs") if progress_bar else range(epochs)
for epoch in it:
model.train()
train_loss = 0
for images, _ in tqdm(train_loader, desc="Training Loop"):
images = images.to(device)
optimizer.zero_grad()
reconstructed_images = model(images)
loss = criterion(reconstructed_images, images)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(train_loader)
train_losses.append(train_loss)
model.eval()
test_loss = 0
with torch.no_grad():
for images, _ in test_loader:
images = images.to(device)
reconstructed_images = model(images)
loss = criterion(reconstructed_images, images)
test_loss += loss.item()
test_loss /= len(test_loader)
print(
f"Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}"
)
return train_losses, test_loss