Chapter 3: Under the Hood: Training a Digit Classifier

Use Pytorch/SGD to fit a (quadratic) curve

https://en.wikipedia.org/wiki/Stochastic_gradient_descent

  • First, generate a dataset:

    time = torch.arange(0, 20).float()
    speed = (time - 10) ** 2 + (torch.randn(20) * 3)
    plt.scatter(time, speed)
    
  • We want to fit a quadratic equation of the form \(ax^2 + bx + c\) to this data set (a, b, and c are the weights we want to tune). Define a generic quadratic function and a loss function:

    def f(x, params):
      a,b,c = params
      return a*(x**2) + b*(x) + c
    
    def loss_fn(actual, prediction):
      return ((actual - prediction) ** 2).mean()
    
  • Apply SGD:

    lr = 1e-5
    
    # 1. Initialize weights. 
    #    `requires_grad_` indicates that we want to calculate the 
    #     gradient/derivative for these weights at some point.
    params = (torch.randn(3) * 20).requires_grad_()
    
    for i in range(2000):
        # 2. Use weights to obtain a result curve
        predictions = f(time, params)
          
        # 3. Use the loss function to compare the result curve to the actual curve
        loss = loss_fn(speed, predictions)
          
        # 4. Use `backward` to calculate derivatives for the weights against the loss function
        #    (slope of the tangent; where x = one of the weights, and f(x) is the loss given that weight)
        #    All intermediate computation between 1. and 4. are preserved behind the scenes
        #    so pytorch can calculate this automatically, which is _really_ cool!
        loss.backward()
          
        # 5. Update the weights based on this derivative. Subtract here because:
        #    - A negative gradient (downward slope) implies that we want to _increase_ the weight to reduce loss
        #    - A positive gradient (upward slope) implies that we want to _decrease_ the weight to reduce loss
        params.data -= lr * params.grad.data
        params.grad = None
    
    plt.scatter(time, to_np(predictions))
    plt.scatter(time, speed)
    

Train a classifier for the MNIST sample dataset

Train a classifier to pick threes and sevens apart from a dataset of handwritten digits.

  • First, download the dataset and peek at it:

    url = URLs.MNIST_SAMPLE
    path = untar_data(url)
    
    threes = (path/'train/3').ls()
    sevens = (path/'train/7').ls()
    
    valid_threes = (path/'valid/3').ls()
    valid_sevens = (path/'valid/7').ls()
    
    t = tensor(Image.open(threes[1]))
    df = pd.DataFrame(t)
    df.style.background_gradient(cmap='viridis')
    
  • Convert each image into a pytorch-compatible “tensor”, using torch.stack to fold the enclosing list into the tensor too:

    threes_tensor = [tensor(Image.open(i)) for i in threes]
    sevens_tensor = [tensor(Image.open(i)) for i in sevens]
    
    valid_threes_tensor = [tensor(Image.open(i)) for i in valid_threes]
    valid_sevens_tensor = [tensor(Image.open(i)) for i in valid_sevens]
    
    stacked_threes = torch.stack(threes_tensor)
    stacked_sevens = torch.stack(sevens_tensor)
    
    valid_stacked_threes = torch.stack(valid_threes_tensor)
    valid_stacked_sevens = torch.stack(valid_sevens_tensor)
    
    stacked_sevens.shape # torch.Size([6265, 28, 28])
    
  • Flatten each image from a 28x28 2-d array into a flat vector, presumably to make matrix multiplication easier/possible:

    flattened_threes = stacked_threes.view(-1, 28*28)
    flattened_sevens = stacked_sevens.view(-1, 28*28)
    
    valid_flattened_threes = valid_stacked_threes.view(-1, 28*28)
    valid_flattened_sevens = valid_stacked_sevens.view(-1, 28*28)
    
    flattened_threes.shape, flattened_sevens.shape # (torch.Size([6131, 784]), torch.Size([6265, 784]))
    
  • Build pytorch Datasets; one each for the training set and the validation set, and build DataLoaders from it:

    # Not sure why `unsqueeze` is required here, but the model doesn't work without it.
    # Best guess: it's used so the target listing has the same matrix dimensions as the predictions it's being compared to later.
    
    dset = list(zip(
      torch.cat([flattened_threes, flattened_sevens], 0).float(),
      torch.cat([tensor([0] * len(flattened_threes)), tensor([1] * len(flattened_sevens))], 0).unsqueeze(1)
    ))
    
    valid_dset = list(zip(
      torch.cat([valid_flattened_threes, valid_flattened_sevens], 0).float(),
      torch.cat([tensor([0] * len(valid_flattened_threes)), tensor([1] * len(valid_flattened_sevens))], 0).unsqueeze(1)
    ))
    
    dl = DataLoader(dset, batch_size=256, shuffle=True)
    valid_dl = DataLoader(valid_dset, batch_size=256, shuffle=True)
    
  • Next, define a model (this is [apparently] called a linear1 model), a loss function, and an accuracy function.

    • Applying weights to the model results in one prediction per image; a matrix containing all images' pixels is multiplied by a matrix containing all weights.
    • bias provides additional jitter that isn’t influenced by the pixel values.
    • The loss and accuracy functions are distinct; the former is used for SGD, and the latter is used for human consumption.
    • Here, both functions first normalize input predictions to \([0,1]\) via the sigmoid function. The target for a prediction is 0 for digit “3”, and 1 for digit “7”.
    • The loss function takes the mean of every prediction’s distance from it’s target. The accuracy function is more binary; anything over 0.5 is considered a “7”, and anything below is considered a “3”; so a given prediction is either right or wrong.
    • The accuracy function here is a poor fit for a loss function because it’s constant for the most part, except at the threshold (0.5), which makes SGD difficult - gradients are typically 0. So we need a loss function that’s much more responsive to small improvements.
    def apply_weights(initials, weights, bias):
      # initials has dimensions (256 x 28*28) {batch size x image dimensions}
      # weights has dimensions (28*28 x 1) {one weight per pixel}
      # the resultant prediction has dimensions {256 x 1} {one prediction per image}
      return initials@weights + bias
    
    def loss_fn(predictions, targets):
      # each target is 0 or 1; 0 implies "3", 1 implies "7"
      # each prediction is normalized to [0,1]
      #    - 0 implies 100% confidence in a prediction of 3
      #    - 1 implies 100% confidence in a prediction of 7
      # smaller loss implies a better fit
      predictions = predictions.sigmoid()
      return torch.where(targets == 1, 1 - predictions, predictions).mean()
    
    def accuracy_fn(predictions, targets):
      predictions = predictions.sigmoid()
      # A prediction over 0.5 is a "7", a prediction below 0.5 is a "3"
      return ((predictions > 0.5).float() == targets).float().mean()
    
  • Finally, execute the learning process:

    epochs = 20
    lr = 1.
    weights = torch.randn((28*28, 1)).requires_grad_()
    bias = torch.randn(1).requires_grad_()
    
    def train_epoch():
      # A `DataLoader` splits data into batches, so we iterate over batches here.
      for initials, targets in dl:
          predictions = apply_weights(initials, weights, bias)
          loss = loss_fn(predictions, targets)
          loss.backward()
          for param in [weights, bias]:        
              param.data -= (lr * param.grad)
              param.grad.zero_()
              
    def validate_epoch():
        accuracies = [accuracy_fn(apply_weights(initials, weights, bias), targets) for initials, targets in valid_dl]
        return torch.stack(accuracies).mean().item()
    
    for _ in range(epochs):
        train_epoch()
        print("Accuracy: " + str(validate_epoch()))
    
    # Accuracy: 0.8283473253250122
    # Accuracy: 0.9078061580657959
    # Accuracy: 0.8738964200019836
    # Accuracy: 0.9475355744361877
    # Accuracy: 0.948043704032898
    # Accuracy: 0.9533552527427673
    # Accuracy: 0.9523786902427673
    # Accuracy: 0.9636488556861877
    # Accuracy: 0.9628310799598694
    # Accuracy: 0.9648040533065796
    # Accuracy: 0.9336652159690857
    # Accuracy: 0.9371824264526367
    # Accuracy: 0.9515013694763184
    # Accuracy: 0.961814820766449
    # Accuracy: 0.9651533961296082
    # Accuracy: 0.943160891532898
    # Accuracy: 0.9676940441131592
    # Accuracy: 0.9699568152427673
    # Accuracy: 0.9549987316131592
    # Accuracy: 0.9730651378631592
    
  • The pytorch nn.Linear model does the same thing as apply_weights here, and incorporates both weights and biases.

  • fastai provides an SGD function that performs parameter adjustment (after backward above).

  • fastai also provides a Learner that can stand-in for almost the entire orchestration code above:

    dls = DataLoaders(dl, valid_dl)
    learn = Learner(dls, nn.Linear(28*28, 1), opt_func=SGD, loss_func=loss_fn, metrics=accuracy_fn)
    learn.fit(20, lr=1.)
    
    • This is a great way to build up an intuitive sense of the magic that fastai appears to be at first.

Train a classifier for the full MNIST dataset

Train a classifier (as above) on the full MNIST dataset - handwritten digits from 0 - 9

  • An initial attempt was not terrible. The loss function is missing something:

    # 1. Build dataset
    url = URLs.MNIST
    path = untar_data(url)
    
    train = [[int(os.path.basename(dir)), dir.ls()] for dir in (path/'training').ls()]
    valid = [[int(os.path.basename(dir)), dir.ls()] for dir in (path/'testing').ls()]
    
    from functools import reduce
    
    def process_images(memo, val):
        i, t = val
        xb, yb = memo
          
        xb.append(torch.stack([tensor(Image.open(img)).float() for img in t]).view(-1, 28*28))
        yb.append(tensor([float(i)] * len(t)).unsqueeze(1))
                 
        return [xb, yb]
    
    xb, yb = reduce(process_images, sorted(train, key=first), ([], []))
    valid_xb, valid_yb = reduce(process_images, sorted(valid, key=first), ([], []))
    
    train_dset = list(zip(torch.cat(xb, 0), torch.cat(yb, 0)))
    valid_dset = list(zip(torch.cat(valid_xb, 0), torch.cat(valid_yb)))
    
    train_dl = DataLoader(train_dset, batch_size=256, shuffle=True)
    valid_dl = DataLoader(train_dset, batch_size=256, shuffle=True)
    
    # 2. Define loss/accuracy fns
    
    # Normalize every prediction to `[0, 9]`, and find its (absolute) distance to the target.
    def loss(predictions, targets):
      predictions = predictions.sigmoid() * 10
      return (targets - predictions).abs().mean()
        
    # Normalize every prediction to `[0, 9]`; a prediction is correct if it's within 1.0 of the target.
    def accuracy(predictions, targets):
        predictions = predictions.sigmoid() * 10
        return ((targets - predictions).abs().float() < 1.).float().mean()
    
    # 3. Train
    
    dls = DataLoaders(train_dl, valid_dl)
      
    model = nn.Sequential(
        nn.Linear(28*28, 30),
        nn.ReLU(),
        nn.Linear(30, 1)
    )
      
    learner = Learner(dls, model, opt_func=SGD, loss_func=loss, metrics=accuracy)
    learner.fit(40, 1e-3)
    
  • After doing a bit of digging, one promising idea is to have the model return ten predictions per image, representing the probability of that image being a member of each of the digit classes. I didn’t consider a model that doesn’t converge down to one value, but this restriction seems unnecessary on reflection.

  • The next chapter got into this a bit more, too, and introduced cross-entropy loss.

  • I used pytorch’s CrossEntropyLoss function and fastai’s accuracy metric, both of which work with multiple categories; results are much better, as expected:

    import fastai
    
    dls = DataLoaders(train_dl, valid_dl)
      
    model = nn.Sequential(
        nn.Linear(28*28, 100),
        nn.ReLU(),
        nn.Linear(100, 10)
    )
      
    learner = Learner(dls, model, opt_func=SGD, loss_func=nn.CrossEntropyLoss(), metrics=fastai.metrics.accuracy)
    learner.fit(20, 1e-3)
    
  • I attempted to verify this model against a digit I drew (in Pixelmator, and converted to grayscale using convert -colorspace Gray), and it predicted a 3 correctly with ~82% confidence:

    i = tensor(Image.open("/tmp/out.png")).float()
    df = pd.DataFrame(i.reshape([28, 28]))
      
    predictions = F.softmax(model(i.flatten()), dim=0)
      
    pred_idx = predictions.argmax()
    pred_confidence = predictions[pred_idx]
      
    print("Prediction: ", pred, ", confidence: ", pred_confidence)
    df.style.set_properties(**{'font-size':'1px'}).background_gradient(cmap='viridis')
    
Edit