We’ll apply Logistic Regression on the MNIST dataset. This means we’ll have one weight parameter per pixel and one bias parameter per output (class).

  1. Download Data
  2. Setup Datasets and DataLoaders
  3. Define Model
  4. Train
  5. Make Predictions
  6. Save and Load Model

Download Data

from torchvision.datasets import MNIST
MNIST(root="data/", download=True)

Setup Datasets and DataLoaders

from torchvision.transforms import ToTensor
from torch.utils.data import random_split, DataLoader

train_and_val_dataset = MNIST(root="data/", train=True, transform=ToTensor())
train_dataset, val_dataset = random_split(train_and_val_dataset, (50000, 10000))
test_dataset = MNIST(root="data/", train=False, transform=ToTensor())

# batch size
bs = 100
train_dl = DataLoader(train_dataset, bs, shuffle=True)
val_dl = DataLoader(val_dataset, bs)
test_dl = DataLoader(test_dataset, bs)

Define Model

from torch.nn import Module, Linear

class Model(Module):
  def __init__(self):
    super().__init__()
    # Number of features is number of pixels in this logistic regression model
    # Number of classes is 10 (10 possible digits)
    self.linear = Linear(28*28, 10)
  
  def forward(self, xb):
    # We need to reshape xb to be number of batches x number of pixels
    xb = xb.reshape(-1, 28*28)
    logits = self.linear(xb)
    return logits

Train

We tried lr=1e-4 and obtrained 68% validation accuracy in 10 epochs, then switched to lr=1e-3 and obtained 84% validation accuracy in 10 epochs. We also tried lr=1e-5 and obtained 21% accuracy in 10 epochs, so we stuck with lr=1e-3. We achieve 85% test accuracy with lr=1e-3.

import torch
from torch.nn.functional import cross_entropy
from torch.optim import SGD
from torch import tensor

def accuracy(logits, labels):
  _, pred_labels = torch.max(logits, 1)
  # Though it's not necessary to return a tensor, this does so staying consistent with cross_entropy
  return torch.sum(pred_labels == labels).float() / len(labels)

# For evaluating loss and accuracy on validation and test datasets
def evaluate(model, dl):
    batch_sizes, losses, accuracies = [], [], []
    for xb, yb in dl:
      # No gradient computations when evaluating. Disabling gradients when applying model will
      # ensure gradients won't be tracked in subsequent computations (cross_entropy, accuracy..)
      with torch.no_grad():
        logits = model(xb)
      losses.append(cross_entropy(logits, yb))
      batch_sizes.append(len(xb))
      accuracies.append(accuracy(logits, yb))

    losses, accuracies = tensor(losses), tensor(accuracies)
    batch_sizes = tensor(batch_sizes, dtype=torch.float)
    total = torch.sum(batch_sizes)
    loss = torch.sum(losses * (batch_sizes / total)).item()
    acc = torch.sum(accuracies * (batch_sizes / total)).item()
    
    return loss, acc


def train(model, train_dl, val_dl, test_dl, epochs=10, lr=1e-3):
  torch.manual_seed(42)
  
  opt = SGD(model.parameters(), lr = lr)

  for epoch in range(epochs):
    # Training
    for xb, yb in train_dl:
      logits = model(xb)
      loss = cross_entropy(logits, yb)
      # Compute gradients
      loss.backward()
      # Update weights
      opt.step()
      # Zero gradients
      opt.zero_grad()

    val_loss, val_acc = evaluate(model, val_dl)
    print("Epoch {}: val_loss: {:.4f}, val_acc: {:.4f}".format(epoch+1, val_loss, val_acc))
    
  test_loss, test_acc = evaluate(model, test_dl)
  print("\ntest_loss: {:.4f}, test_acc: {:.4f}".format(test_loss, test_acc))


model = Model()
train(model, train_dl, val_dl, test_dl)

Make Predictions

def predict(model, img):
  # Make it a batch of 1
  img_b = img.unsqueeze(0)
  # We're not training, so no need to track gradients
  # Disabling gradients when applying model will ensure, they won't be tracked at later
  # steps (eg, softmax, max..).
  with torch.no_grad():
    logits = model(img_b)
  sm = torch.softmax(logits, 1)
  max_probs, pred_labels = torch.max(sm, 1)
  return pred_labels[0].item(), max_probs[0].item()

img, label = test_dataset[0]
pred_label, prob = predict(model, img)
print("Predicted Label: {} ({:.4f}). Correct Label: {}".format(pred_label, prob, label))

Save and Load Model


torch.save(model.state_dict(), "logistic_regression.pth")
model = Model()
model.load_state_dict(torch.load("logistic_regression.pth"))
evaluate(model, test_dl)

img, label = test_dataset[0]
pred_label, prob = predict(model, img)
print("Predicted Label: {} ({:.4f}). Correct Label: {}".format(pred_label, prob, label))