August 02, 2019 · 5 mins read

Lesson 1 - Training a fruit classifier

After finishing CS229 Machine Learning by Stanford yesterday I decided to continue this fascinating journey with the fast.ai course. And wow! The first impressions are extremely good.

The style of the homework differs greatly between CS229 and fast.ai. Where the Stanford course provided the data and gave clear instructions on what to do, the fast.ai homework was just: “build an image classifier”. We even had to get the data ourselves! This was a little overwhelming at first, but it was super fun to build a project from the ground up and I’m quite happy with the results. In this blog post, I’m going to show you how to build a fruit classifier using the fastai deep learning framework.

Setting up

For this project I decided to use Google Cloud Platform, or GCP for short. Fast.ai has an excellent tutorial on setting it up here as well as for many other cloud computing services. The are a couple of reasons for this choice. First, google provides a $300 credit to get started. Second, I think Google can be trusted when it comes to cloud computing. If I ever decide to do a bigger project, I want to work on a platform that I can fully trust. So I decided to start learning to use GCP from the start.

Getting the data

The classifier I made simply classifies fruits. I figured I would start with a simple dataset so I can focus on learning the fastai framework.

The dataset I found is the Fruits 360 on Kaggle. The version I used is 2019.07.07.0. It got 114 classes and far over 75000 images.

Getting the data into the Jupiter Notebook was super easy thanks to the Kaggle API. The following code snippet requires little explanation:

from kaggle.api.kaggle_api_extended import KaggleApi

# Set environment variables.
os.environ['KAGGLE_USERNAME'] = "xxxxxxx"
os.environ['KAGGLE_KEY'] = "xxxxxxxxxxxxxxxxxxxxx"

# Authenticate and download dataset
api = KaggleApi()
api.authenticate()
api.dataset_download_files("moltean/fruits", unzip=True)
path = (Path('./fruits-360/Training'))

Then, just for the sake of completeness, I verified the images using verify_images. Although this is not strictly needed on this particular dataset, I’m planning on using my notebooks as a reference for later so I decided to include it nevertheless.

for c in path.ls():
    print(c)
    verify_images((Path('.')/'fruits-360'/'Training')/c, delete=True, max_size=500)

Looking at the data

As the course prescribes, one should always look at the data before using it. This is surprisingly easy using the fastai library.

data.show_batch(rows=3, figsize=(7,8))

Training

Training was by the far the most astonishing part. After just 1 epoch (looking at all the images once), the mean validation error was already at and it decreased by an order of magnitude for the next 3 epochs for a final validation loss of just and an error of in under 20 minutes of work!

I used the resnet34 model for this project, as does the course. Also, I found that using more than 4 epochs didn’t yield significantly better results.

learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)

Results

Using ClassificationInterpretation we can inspect the results of the model.

interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
len(data.valid_ds)==len(losses)==len(idxs)

We can plot the top losses like so

interp.plot_top_losses(9, figsize=(15,11), heatmap=False)

Because our error was , there aren’t any examples the model got wrong in this project so assumably it shows the examples with the lowest accuracy.

The confusion matrix, which is a huge with , shows the number of examples the model got misclassified along with the predicted type. Again, because our model has zero cost, only the diagonal has non-zero values.

Fine tuning learning rates

By unfreezing (training the first layers of the network as well as the last layers) the model we can halve our already low validation loss to .

To decide what range we are using for the learning parameter, we look at the learning rate graph. Plotting this graph is super easy using fastai:

learn.lr_find()
learn.recorder.plot()

As we see in the figure above, we got the lowest loss at around up to so those are the values I decided to use.

learn.unfreeze()
learn.fit_one_cycle(2, max_lr=slice(1e-6,10 ** -5.5))

After just 1 epoch of not even 5 minutes, our error is at which is incredible.

Conclusion

As you’ve seen in this tutorial, it’s quite easy to train a near perfect model in under an hour writing just a few lines of code using fastai. As alwasys, the code is on my GitHub. I was amazed by the results and look forward to the following lesson!

A huge thanks to Sam Miserendino for proofreading this post!