Tutorial: K-Nearest Neighbor classifier for MNIST

Previously we looked at the Bayes classifier for MNIST data, using a multivariate Gaussian to model each class.

We use the same dimensionality reduced dataset here.

The K-Nearest Neighbor (KNN) classifier is also often used as a “simple baseline” classifier, but there are a couple distinctions from the Bayes classifier that are interesting.

1) KNN does not use probability distributions to model data. In fact, many powerful classifiers do not assume any probability distribution on the data.

2) KNN is a “lazy” classifier. No work is actually done to train the model. It just saves the input points X and all the labels Y.

At classification time, the predicted class/label is chosen by looking at the “k nearest neighbors” of the input test point.

We choose “k” beforehand.

Here is a visual example for k = 3:

main-qimg-9574c0ddd16bd6eb1ff291f0c0f3be5d

So for 3-NN, we go through all the training points, find the 3 closest ones, and choose the label that shows up the most.

i.e. If we get 2 whites and 1 black, we choose white. This is called “voting”.

Here is another example, where we might get a different answer depending on whether k = 3 or k = 5.

0_zbaCKocplWAbM1m5

The idea is, our data should be clustered in “space”. Points from the same class should be near each other. In fact, this is an idea that prevails for pretty much all machine learning algorithms, so you should understand it well.

If data points that are near each other are from the same class, then it follows that a test point can be classified by looking at its neighbors’ classes.

In pseudocode, it might look like this:

function predict(x’):
  Q = [] // we’ll use this to store the k-nearest neighbors
  for x,y in training data:
    if Q.length < k or dist(x,x’) < max distance in Q:
      add dist(x,x’) and y to Q
      if Q.length > k: remove max distance in Q
  return the y that shows up most in Q

A naive implementation might look through every training point, which is O(N), and then to find the max in Q it would be O(k), for a total of O(kN).

You could use a sorted data structure for Q making the search and insert O(logk).

Here is the pseudocode that does this:

http://bit.ly/2KIYu12

k = 1 achieves a test classification rate of 94.8%!

Learn about KNN and more in the course Data Science: Supervised Machine Learning in Python.