Code for this tutorial is here:
Prerequisites for understanding this material:
- calculus (taking partial derivatives)
Linear regression is one of the simplest machine learning techniques you can use. It is often useful as a baseline relative to more powerful techniques.
To start, we will look at a simple 1-D case.
Like all regressions, we wish to map some input X to some input Y.
ie.
Y = f(X)
With linear regression:
Y = aX + b
Or we can say:
h(X) = aX + b
Where “h” is our “hypothesis”.
You may recall from your high school studies that this is just the equation for a straight line.
When X is 1-D, or when “Y has one explanatory variable”, we call this “simple linear regression”.
When we use linear regression, we are using it to model linear relationships, or what we think may be linear relationships.
As with all supervised machine learning problems, we are given labeled data points:
(X1, Y1), (X2, Y2), (X3, Y3), …, (Xn, Yn)
And we will try to fit the line (aX + b) as best we can to these data points.
This means we have to optimize the parameters “a” and “b”.
How do we do this?
We will define an error function and then find the “a” and “b” that will make the error as small as possible.
You will see that many regression problems work this way.
What is our error function?
We could use the difference between the predicted Y and the actual Y like so:
But if we had equal amounts of errors where Y was bigger than the prediction, and where Y was smaller than the prediction, then the errors would cancel out, even though the absolute difference in errors is large.
Typically in machine learning, the squared error is a good place to start.
Now, whether or not the difference in the actual and predicted output is positive or negative, its contribution to the total error is still positive.
We call this sum the “sum of squared errors”.
Recall that we want to minimize it.
Recall from calculus that to minimize something, you want to take its derivative.
Because there are two parameters, we have to take the derivatives both with respect to a and with respect to b, set them to 0, and solve for a and b.
Luckily, because the error function is a quadratic it increases as (a,b) get further and further away from the minimum.
As an exercise I will let you calculate the derivatives.
You will get 2 equations (the derivatives) and 2 unknowns (a, b). From high school math you should know how to solve this by rearranging the terms.
Note that these equations can be solved analytically. Meaning you can just plug and chug the values of your inputs and get the final value of a and b by blindly using a formula.
Note that this method is also called “ordinary least squares”.
Measuring the error (R-squared)
To determine how well our model fits the data, we need a measure called the “R-square”.
Note that in classification problems, we can simply use the “classification rate”, which is the number of correctly classified inputs divided by the total number of inputs. With the real-valued outputs we have in regression, this is not possible.
Here are the equations we use to predict the R-square.
SS(residual) is the sum of squared error between the actual and predicted output. This is the same as the error we were trying to minimize before!
SS(total) is the sum of squared error between each sample output and the mean of all the sample outputs, i.e. What the residual error would be if we just predicted the average output every time.
So the R-square then, is just how much better our model is compared to predicting the mean each time. If we just predicted the mean each time, the R-square would be 1-1=0. If our model is perfect, then the R-square would be 1-0=1.
Something to think about: If our model performs worse than predicting the mean each time, what would be the R-square value?
Limitations of Linear Regression
- It only models linear equations. You can model higher order polynomials (link to later post) but the model is still linear in its parameters.
- It is sensitive to outliers. Meaning if we have one data point very far away from all the others, it could “pull” the regression line in its direction, away from all the other data points, just to minimize the error.
Learn more about Linear Regression in the course Deep Learning Prerequisites: Linear Regression in Python.