Bridging the Gap from Simple Algebra to Machine Learning
You probably know more about machine learning math than you think
The goal of this article is to help you:
Understand that machine learning math finds its roots in math you likely already know.
Understand how to get conceptually from the math you know to ML math.
This is a follow-up to an article I wrote previously that clarifies machine learning at a very high level and explains how it works. If you haven’t read that and don’t know how ML works, I would suggest checking it out first:
Taking ML Back to its Roots
I see machine learning referred to as “magic” quite a bit. I know people don’t actually believe this, but they feel the math behind ML to be so out of touch that it’s easier to think of it this way—but this really isn’t the case. The magic behind machine learning starts with algebra. Yes, the same algebra you learned in grade school is fundamental for machine learning math.
Let’s start with a basic principle you've likely already learned: the equation y=mx+b, known as the slope-intercept form of a line. You can plot this line on a graph by plugging in an x value and computing its corresponding y value. When you do this, the line will intercept the y-axis at point b and the slope of the line will be equal to m. This means that for every single increase in the x direction the line will increase m times in the y direction.
Schools focusing on graphing these equations because it helps us visualize them. But graphs aren't what make equations like this important. What makes them important is that you can give the equation an input and it will give you an output. In this case, you can choose an x and the equation will give you a corresponding y.
In higher-level math and computer science, we call this equation a function. Functions are the key to understanding relationships between two (or more) variables. Eventually in most algebra curriculums, you'll see the y you’re used to seeing in equations change to an f(x) (e.g. f(x)=mx+b). This is just renaming the output variable of the equation to reflect that it’s a function rather than just the y value on a graph. In plain English, f(x) is spoken as “the function of x”. X is our input and f is our function that acts upon x to give us an output. To put it simply: f(x)=y, but instead of emphasizing a y-coordinate, f(x) puts the focus on the equation being a function.
A function can be visualized as shown above. It's important to note that we use functions to define the relationships between multiple entities. In our f(x)=mx+b example, the equation models a relationship between an input (x) and an output (f(x)).
This may seem really obvious, but you likely use functions every day without realizing it:
Managing money requires understanding how many of each bill you have and how they add to create a total.
When you're traveling somewhere, you calculate the time to get there based on the distance to that location and the speed you'll likely travel.
When cooking, you use quantities of ingredients to create a certain flavor profile.
All of these are relationships between an input and an output that you identify without even realizing you're doing it. This shows us the beauty of functions: they can be used to model relationships we understand in real life. This is exactly what machine learning is doing.
Machine learning uses high-level techniques to create functions that model relationships between inputs and outputs. Our slope-intercept form of a function isn't robust enough to represent the complex relationships represented by machine learning. In fact, it isn't even complex enough to model a relationship like the money example above—but it isn't too far. To model the money example, we use a more complex function shown below:
By adding more terms, we were able to capture the impact many different bills had on the output. Now let's take a look at how we can make this even more complex and how we get from where we are now to machine learning models.
If you’ve enjoyed reading up to this point, join Society’s Backend for more interesting AI breakdowns like this!
How do we model more complex relationships?
There are a few basic ways we can model more complex relationships, many of which you're probably familiar with:
Add more variables. We've already done this in our money example. We now have multiple inputs and one output. It allows us to model relationships between more than one input.
Add more complex terms. Think about a quadratic function. Complex terms can model more complex relationships. This one is most easily understood by visualizing it graphically (shown below).
We can change the relationship between terms. Up to this point, we've only been adding terms together. We can model more complicated relationships by changing the way the terms interact with one another.
Let's take a moment to visualize these changes we can make:
All of these changes allow us to create more complex models between inputs and outputs. Generally, as relationships get more complex, we need more complex equations to model those relationships. The magic behind machine learning is using linear algebra to model even more complex relationships.
The equations above only have a few terms (in f(x)=2x+5, both 2x and 5 are considered a ‘term’). In machine learning, we can use billions of terms to model a relationship. We call these parameters. When using this much data, the best way to model relationships between them is through linear algebra, which uses matrices to organize and represent data. A matrix is similar to a spreadsheet, in that it’s a two dimensional arrangement of numbers.
A matrix allows us to perform large-scale operations (addition, multiplication, etc.) quickly by performing those operations on a matrix instead of between individual terms. We take the data we want an ML model to use when modeling a relationship and the parameters within our model and place them into matrices. This makes it easier to represent the complex relationships we need machine learning models to understand when we use billions of parameters to represent those models.
This a very simplistic view of linear algebra, but a good overview for why its crucial for ML. For a full overview of linear algebra, I recommend Khan Academy’s Linear Algebra course.
What's the real benefit of ML?
In the simpler functions we've explored above, we've manually determined the terms of the function and how they interact to best represent the model we want represented. We determine the function and its terms based on what we think works best to model our data. For example, in our money adding equation, we know there are $10, $5, and $1 bills available to us so we know to include those in the function. This is how we arrive at the function f(x,y,z)=10x+5y+z.
Now, can you imagine if there were bills introduced that multiplied with the other bills to determine the total value of the money in your hand? That would be a much tougher function to come up with manually. Luckily, there are mathematical ways to do this. You've probably heard of creating a line of best fit. When given data, you can use determine the line of best fit to fit a function to that data. It often looks like drawing a line through a scatterplot of the data points (see graphic below). That line is a function and that function is what we use to mathematically represent that relationship between that data. While it may not perfectly represent the data, when we plug an input into the function, it gives us a likely output.
As you can imagine, determining this line of best fit gets really complex as the data we're trying to model gets more complicated. This is what machine learning does for us.
The details of how this works depends on the type of machine learning we're doing, so I'll save those details for another time. But machine learning algorithms use math to look at the data and allow the machine learning model to create its own line of best fit. This is done much more accurately and quickly than if we do it ourselves.
When a ML model is given an input, it uses its line of best fit to produce an output. Some real world examples of models doing this include:
Choosing the posts to display on a social media feed
ChatGPT returning an answer to a question
Midjourney returning an image from a user’s prompt
Machine learning is built upon algebra you likely already know. It uses linear algebra and complex functions to represent complicated relationships between data. It also uses math to determine a line of best fit for that data similar to lines of best fit we all calculated in grade school.
The primary differences between simple functions and machine learning models are:
ML models are much better at determining this line of best fit than we are.
Machine learning can model relationships requiring billions of parameters.
Next week, I plan on getting into further detail about how ML math actually works broken down into different types of ML. If this interests you, join Society’s Backend! We’re a network of individuals uncovering the engineering complexities behind the artificial intelligence you use every day. It’s free to join and you’ll get more articles like this along with other member benefits.