Bayes classifier and Naive Bayes tutorial (using the MNIST dataset)

The Naive Bayes classifier is a simple classifier that is often used as a baseline for comparison with more complex classifiers.

It is also conceptually very simple and as you’ll see it is just a fancy application of Bayes rule from your probability class.

We will use the famous MNIST data set for this tutorial.

The MNIST dataset is a set of handwritten digits, and our job is to build a computer program that takes as input an image of a digit, and outputs what digit it is.

 

DhKHDNzS0OJhH2uiMOkl_mnist

 

Recall Bayes rule:

$$ P(c | x) = \frac{P(x | c)P(c)}{P(x)} $$

If you’re like me, you may have found this notation a little confusing at first.

Here \( x \) represents the image, or more precisely, the pixel values of the image formatted as a vector, and \( c \) represents the digit, which can be 0, 1, …, 9.

Sidenote: Images are grayscale and of size 28×28, which, if flattened, yields a vector of length 28×28=784. If you look at the code (linked below) you can see how this is done, and also that we scale the pixel values to be between 0…1. Scaling isn’t strictly necessary, but can be useful for many machine learning algorithms.

We can read the left side \( P(c | x) \) as “the probability that the class is \( c \) given the data \( x \)”. (this is called the “posterior”)

We can read the right side \( P(x | c) \) as “the probability that the data \( x \) belongs to the class \( c \)”. (this is called the “likelihood”)

 

 

One little efficiency trick we can do:

We don’t actually care about the value of  \( P(c | x) \).

We care about the value of \( c \) itself. That tells us “which digit” the image belongs to.

The class chosen is simply the one that yields the highest probability for that data:

$$ c^* = argmax_{c}{ P(c | x) } =argmax_{c}{ \frac{ P(x | c)P(c) }{ P(x) } }$$

 

You will notice that \( P(x) \) is constant for all values of \( c \) in \( P(c | x) \).

So when I take the argmax over \( \frac{ P(x | c)P(c) }{ P(x) } \) I can ignore \( P(x) \).

As a simple example, suppose I have that \( A > B \), or numerically, say, \( 10 > 5 \).

If I multiply or divide these numbers by a positive constant (\( P(x) \) will always be positive) then the relationship still holds.

Ex. \( 2A > 2B \), or \( 20 > 10 \).

 

Using this information, we can simplify our problem so that, in order to choose “which digit” given an image, all we need to do is calculate this argmax (notice \( P(x) \) is removed):

 

$$ c^* = argmax_{c}{ P(x | c)P(c) }$$

 

 

The next step we can take is to think about how to calculate \( P(c) \).

This is just counting.

If I have 100 students in my class, and I want to figure out the probability that a student is born in January, how can I do that?

I simply count up all the students born in January, and divide by the total number of students.

If I want to know the probability of getting heads when I flip a coin, I flip the coin a bunch of times, and divide the number of heads by the total number of coin flips.

Therefore:

$$ P(c) = \frac{ count(number\, of\, times\, images\, of\, the\, digit\, c\, appear) }{ count(total\, number\, of\, images) } $$

 

 

The challenge is choosing a model that accurately fits the data for \( P(x | c) \).

As a thought-exercise, think about how you’d do this naively.

Each image is of size 28×28, which means there are 784 (=28×28) pixels per image. Each pixel can take on integer values in the range 0..255 inclusive.

So, if you modeled this as a discrete probability distribution, you’d have \( 255^{784} \) different possibilities. That’s way more than the number of images you have (~50, 000), and hence, you’d never be able to use the “counting method” (used above for calculating \( P(c) \)) to accurately measure those probabilities.

 

To make the problem tractable and easily computable, we recall that pixels represent light intensity, and light intensity is actually continuous. It’s only discrete inside a computer because computers are discrete.

A reasonable first-guess for modeling continuous data is the multivariate Gaussian or the multivariate Normal.

We can say that:

$$ P(x | c) = \frac{1}{\sqrt{ (2\pi)^D |\Sigma| }} exp\left({ -\frac{1}{2}(x – \mu)^T \Sigma^{-1} (x – \mu) }\right) $$

Note that because the data are continuous, we are not actually calculating probabilities, but probability densities, on the right for \( P(x | c) \). Luckily, Bayes rule still holds for probability densities.

Another thing to note is that because probabilities are very small when dimensionality is high, we’ll work with log-likelihood instead of likelihood. Then instead of getting numbers very close to 0, which is inaccurate when using a computer to represent them, we’ll just get negative numbers.

The log-likelihood can be represented as:

$$ logP(x | c) = -\frac{D}{2}ln(2\pi) – \frac{1}{2}ln|\Sigma| – \frac{1}{2}(x – \mu)^T \Sigma^{-1} (x – \mu) $$

Which you should try to derive yourself. (D is the dimensionality)

 

By the way, to calculate \( \mu \) and \( \Sigma \), you can use the sample mean and covariance: https://en.wikipedia.org/wiki/Sample_mean_and_covariance. Note that it’s \( P(x | c) \), not just \( P(x) \). So, if we want to calculate the mean and covariance for all the images of the digit 9, then we’d first grab only the images of the digit 9 (ignoring the rest of the images), and calculate the sample mean and covariance from this subset.

 

Earlier, we wanted the argmax over \( P( x | c)P(c) \). Since \( log(AB) = log(A) + log(B) \), then using log probabilities, we can choose the digit class using:

 

$$ c^* = argmax_{c} {\left( logP(x | c) + logP(c) \right)} $$

 

This works since the \( log() \) function is monotonically increasing. If \( A > B \) then \( log(A) > log(B) \). Try any 2 numbers on your calculator if you don’t believe me.

 

 

Now this problem is tractable.

Training the classifier would work as follows:

For each class = 0..9
    get all x’s (images) for the class
    save the mean and covariance of those x’s with the class

Prediction using the classifier would work as follows:

Given a test point x:
    Calculate the probability that x is in each class c
    Return the class c that yields the highest posterior probability

 

What makes a Bayes classifier a Naive Bayes classifier?

So far we have only discussed general Bayes classifiers.

A Naive Bayes classifier is one that assumes all input dimensions of \( x \) are independent.

Recall that when 2 random variables \( A \) and \( B \) are independent, their joint probability can be expressed as the product of their individual probabilities:

 

$$ P(A, B) = P(A)P(B) $$

 

For the Gaussian case, this means that instead of having a joint Gaussian with a full covariance matrix \( \Sigma \), we can instead express it as a product of 784 individual univariate Gaussians:

 

$$ P(x | c) = \prod_{i=1}^{784}  \frac{1}{\sqrt{2\pi\sigma_i^2}} exp{\left( -\frac{1}{2} \frac{(x_i – \mu_i)^2}{\sigma_i^2} \right)} $$

 

One advantage of Naive Bayes is that we don’t have to worry about any interactions between any \( x_i \) and \( x_j \) if \( i \neq j \).

In more practical terms, before we had \( \Sigma \), which is of size \( D \times D = 784^2 \).

Now, we only have \( \sigma_i^2, i=1…784 \).

That’s 784 times less numbers you have to store.

You can express the joint distribution of 784 individual univariate Gaussians as one big multivariate Gaussian, it just means that the covariance matrix \( \Sigma \) will have zeros everywhere except along the diagonal, which just stores the 784 univariate variances.

 

Having 784 individual variances means we don’t have to invert \( \Sigma \) to calculate the PDF or log PDF, which leads to even more savings.

The downside of Naive Bayes is that the Naive assumption (that all input dimensions are independent) is most often incorrect.

 

 

So what do we get after training our model?

Visually, we can “see” what the model has learned by plotting the mean \( \mu \) for each class.

Here are the plots you’d get:

image
image
image
image
image
image
image
image
image
image

As you’d expect, the mean of each class very closely captures what that digit typically looks like.

The code for this tutorial can be found here:

Non-Naive Bayes: https://bit.ly/2oWVc1N

Naive Bayes: https://bit.ly/2FvP2fm

 

We get about 94% accuracy on the test set, which is pretty good!

Notice how we achieve only about 80% on the test set with Naive Bayes, due to the fact that the Naive assumption is pretty obviously not correct. E.g. if we’re looking at a black pixel at one of the corners, are the pixels around it also not very likely to be black?

Learn about KNN and more in the course Data Science: Supervised Machine Learning in Python.