# Chapter 3: Under the Hood: Training a Digit Classifier

## Use Pytorch/SGD to fit a (quadratic) curve

*https://en.wikipedia.org/wiki/Stochastic_gradient_descent*

First, generate a dataset:

`time = torch.arange(0, 20).float() speed = (time - 10) ** 2 + (torch.randn(20) * 3) plt.scatter(time, speed)`

We want to fit a quadratic equation of the form \(ax^2 + bx + c\) to this data set (a, b, and c are the

*weights*we want to tune). Define a generic quadratic function and a loss function:`def f(x, params): a,b,c = params return a*(x**2) + b*(x) + c def loss_fn(actual, prediction): return ((actual - prediction) ** 2).mean()`

Apply SGD:

`lr = 1e-5 # 1. Initialize weights. # `requires_grad_` indicates that we want to calculate the # gradient/derivative for these weights at some point. params = (torch.randn(3) * 20).requires_grad_() for i in range(2000): # 2. Use weights to obtain a result curve predictions = f(time, params) # 3. Use the loss function to compare the result curve to the actual curve loss = loss_fn(speed, predictions) # 4. Use `backward` to calculate derivatives for the weights against the loss function # (slope of the tangent; where x = one of the weights, and f(x) is the loss given that weight) # All intermediate computation between 1. and 4. are preserved behind the scenes # so pytorch can calculate this automatically, which is _really_ cool! loss.backward() # 5. Update the weights based on this derivative. Subtract here because: # - A negative gradient (downward slope) implies that we want to _increase_ the weight to reduce loss # - A positive gradient (upward slope) implies that we want to _decrease_ the weight to reduce loss params.data -= lr * params.grad.data params.grad = None plt.scatter(time, to_np(predictions)) plt.scatter(time, speed)`

## Train a classifier for the MNIST sample dataset

*Train a classifier to pick threes and sevens apart from a dataset of handwritten digits.*

First, download the dataset and peek at it:

`url = URLs.MNIST_SAMPLE path = untar_data(url) threes = (path/'train/3').ls() sevens = (path/'train/7').ls() valid_threes = (path/'valid/3').ls() valid_sevens = (path/'valid/7').ls() t = tensor(Image.open(threes[1])) df = pd.DataFrame(t) df.style.background_gradient(cmap='viridis')`

Convert each image into a pytorch-compatible “tensor”, using

`torch.stack`

to fold the enclosing list into the tensor too:`threes_tensor = [tensor(Image.open(i)) for i in threes] sevens_tensor = [tensor(Image.open(i)) for i in sevens] valid_threes_tensor = [tensor(Image.open(i)) for i in valid_threes] valid_sevens_tensor = [tensor(Image.open(i)) for i in valid_sevens] stacked_threes = torch.stack(threes_tensor) stacked_sevens = torch.stack(sevens_tensor) valid_stacked_threes = torch.stack(valid_threes_tensor) valid_stacked_sevens = torch.stack(valid_sevens_tensor) stacked_sevens.shape # torch.Size([6265, 28, 28])`

Flatten each image from a 28x28 2-d array into a flat vector, presumably to make matrix multiplication easier/possible:

`flattened_threes = stacked_threes.view(-1, 28*28) flattened_sevens = stacked_sevens.view(-1, 28*28) valid_flattened_threes = valid_stacked_threes.view(-1, 28*28) valid_flattened_sevens = valid_stacked_sevens.view(-1, 28*28) flattened_threes.shape, flattened_sevens.shape # (torch.Size([6131, 784]), torch.Size([6265, 784]))`

Build pytorch

`Dataset`

s; one each for the training set and the validation set, and build`DataLoader`

s from it:`# Not sure why `unsqueeze` is required here, but the model doesn't work without it. # Best guess: it's used so the target listing has the same matrix dimensions as the predictions it's being compared to later. dset = list(zip( torch.cat([flattened_threes, flattened_sevens], 0).float(), torch.cat([tensor([0] * len(flattened_threes)), tensor([1] * len(flattened_sevens))], 0).unsqueeze(1) )) valid_dset = list(zip( torch.cat([valid_flattened_threes, valid_flattened_sevens], 0).float(), torch.cat([tensor([0] * len(valid_flattened_threes)), tensor([1] * len(valid_flattened_sevens))], 0).unsqueeze(1) )) dl = DataLoader(dset, batch_size=256, shuffle=True) valid_dl = DataLoader(valid_dset, batch_size=256, shuffle=True)`

Next, define a model (this is [apparently] called a

`linear1`

model), a loss function, and an accuracy function.- Applying weights to the model results in one prediction per image; a matrix containing all images' pixels is multiplied by a matrix containing all weights.
`bias`

provides additional jitter that isn’t influenced by the pixel values.- The loss and accuracy functions are distinct; the former is used for SGD, and the latter is used for human consumption.
- Here, both functions first normalize input predictions to \([0,1]\) via the
`sigmoid`

function. The`target`

for a prediction is`0`

for digit “3”, and`1`

for digit “7”. - The loss function takes the mean of every prediction’s distance from it’s target. The accuracy function is more binary; anything over 0.5 is considered a “7”, and anything below is considered a “3”; so a given prediction is either right or wrong.
- The accuracy function here is a poor fit for a loss function because it’s constant for the most part, except at the threshold (0.5), which makes SGD difficult - gradients are typically 0. So we need a loss function that’s much more responsive to small improvements.

`def apply_weights(initials, weights, bias): # initials has dimensions (256 x 28*28) {batch size x image dimensions} # weights has dimensions (28*28 x 1) {one weight per pixel} # the resultant prediction has dimensions {256 x 1} {one prediction per image} return initials@weights + bias def loss_fn(predictions, targets): # each target is 0 or 1; 0 implies "3", 1 implies "7" # each prediction is normalized to [0,1] # - 0 implies 100% confidence in a prediction of 3 # - 1 implies 100% confidence in a prediction of 7 # smaller loss implies a better fit predictions = predictions.sigmoid() return torch.where(targets == 1, 1 - predictions, predictions).mean() def accuracy_fn(predictions, targets): predictions = predictions.sigmoid() # A prediction over 0.5 is a "7", a prediction below 0.5 is a "3" return ((predictions > 0.5).float() == targets).float().mean()`

Finally, execute the learning process:

`epochs = 20 lr = 1. weights = torch.randn((28*28, 1)).requires_grad_() bias = torch.randn(1).requires_grad_() def train_epoch(): # A `DataLoader` splits data into batches, so we iterate over batches here. for initials, targets in dl: predictions = apply_weights(initials, weights, bias) loss = loss_fn(predictions, targets) loss.backward() for param in [weights, bias]: param.data -= (lr * param.grad) param.grad.zero_() def validate_epoch(): accuracies = [accuracy_fn(apply_weights(initials, weights, bias), targets) for initials, targets in valid_dl] return torch.stack(accuracies).mean().item() for _ in range(epochs): train_epoch() print("Accuracy: " + str(validate_epoch())) # Accuracy: 0.8283473253250122 # Accuracy: 0.9078061580657959 # Accuracy: 0.8738964200019836 # Accuracy: 0.9475355744361877 # Accuracy: 0.948043704032898 # Accuracy: 0.9533552527427673 # Accuracy: 0.9523786902427673 # Accuracy: 0.9636488556861877 # Accuracy: 0.9628310799598694 # Accuracy: 0.9648040533065796 # Accuracy: 0.9336652159690857 # Accuracy: 0.9371824264526367 # Accuracy: 0.9515013694763184 # Accuracy: 0.961814820766449 # Accuracy: 0.9651533961296082 # Accuracy: 0.943160891532898 # Accuracy: 0.9676940441131592 # Accuracy: 0.9699568152427673 # Accuracy: 0.9549987316131592 # Accuracy: 0.9730651378631592`

The pytorch

`nn.Linear`

model does the same thing as`apply_weights`

here, and incorporates both weights and biases.`fastai`

provides an`SGD`

function that performs parameter adjustment (after`backward`

above).`fastai`

*also*provides a`Learner`

that can stand-in for almost the entire orchestration code above:`dls = DataLoaders(dl, valid_dl) learn = Learner(dls, nn.Linear(28*28, 1), opt_func=SGD, loss_func=loss_fn, metrics=accuracy_fn) learn.fit(20, lr=1.)`

- This is a
*great*way to build up an intuitive sense of the*magic*that fastai appears to be at first.

- This is a

## Train a classifier for the full MNIST dataset

*Train a classifier (as above) on the full MNIST dataset - handwritten digits from 0 - 9*

An initial attempt was not

*terrible*. The loss function is missing something:`# 1. Build dataset url = URLs.MNIST path = untar_data(url) train = [[int(os.path.basename(dir)), dir.ls()] for dir in (path/'training').ls()] valid = [[int(os.path.basename(dir)), dir.ls()] for dir in (path/'testing').ls()] from functools import reduce def process_images(memo, val): i, t = val xb, yb = memo xb.append(torch.stack([tensor(Image.open(img)).float() for img in t]).view(-1, 28*28)) yb.append(tensor([float(i)] * len(t)).unsqueeze(1)) return [xb, yb] xb, yb = reduce(process_images, sorted(train, key=first), ([], [])) valid_xb, valid_yb = reduce(process_images, sorted(valid, key=first), ([], [])) train_dset = list(zip(torch.cat(xb, 0), torch.cat(yb, 0))) valid_dset = list(zip(torch.cat(valid_xb, 0), torch.cat(valid_yb))) train_dl = DataLoader(train_dset, batch_size=256, shuffle=True) valid_dl = DataLoader(train_dset, batch_size=256, shuffle=True) # 2. Define loss/accuracy fns # Normalize every prediction to `[0, 9]`, and find its (absolute) distance to the target. def loss(predictions, targets): predictions = predictions.sigmoid() * 10 return (targets - predictions).abs().mean() # Normalize every prediction to `[0, 9]`; a prediction is correct if it's within 1.0 of the target. def accuracy(predictions, targets): predictions = predictions.sigmoid() * 10 return ((targets - predictions).abs().float() < 1.).float().mean() # 3. Train dls = DataLoaders(train_dl, valid_dl) model = nn.Sequential( nn.Linear(28*28, 30), nn.ReLU(), nn.Linear(30, 1) ) learner = Learner(dls, model, opt_func=SGD, loss_func=loss, metrics=accuracy) learner.fit(40, 1e-3)`

After doing a bit of digging, one promising idea is to have the model return

*ten*predictions per image, representing the probability of that image being a member of each of the digit classes. I didn’t consider a model that doesn’t converge down to one value, but this restriction seems unnecessary on reflection.