Updated 8 DEC 2017: Improved the model and rewrote some parts
Since ancient times, it has been known that machines excel at math while humans are pretty good at detecting cats in pictures. But with the advent of deep learning, the boundaries have started to blur…
Today we want to teach the machines to do math - again! But instead of feeding them with optimized, known representations of numbers and calling hard-wired operations on them, we will feed in strings representing math formulas along with strings representing their results, character by character, and have the machine figure out how to interpret them and arrive at the results on its own.
The resulting model is much more human: Not all results are exact, but they are close. It’s more of an approximation than an exact calculation. As a 100% human myself, I surely can relate to almost-correct math, if anything.
Today’s post will be much more practical than usual, so if you want to build this at home, get your long-neglected Pythons out of their cages! You can find the complete code with lots of improvements here.
Recurrent Neural Networks and Long-Short-Term-Memory
As in my previous post, we’re going to use Recurrent Neural Networks, or RNN for short. More specifically, a kind of RNN known under the fancy name of Long-Short-Term-Memory, or LSTM. If you need a primer on these, I can recommend Christopher Olah’s “Understanding LSTMs” and again Andrej Karpathy’s “The Unreasonable Effectiveness of Recurrent Neural Networks”. But to summarize: RNNs are essentially building blocks of Neural Networks that not only look at the current input, but also remember the input before. LSTMs especially have a complex memory mechanism that can learn which parts of the data are important to remember, which can be ignored, and which can be forgotten.
Sequence to Sequence Learning
Sequence to sequence learning deals with problems in which a source sequence of inputs has to be mapped to a target sequence of outputs where each output is not necessarily directly dependent on a single input. The classical example is translation. How do you learn that a Chinese input phrase “他现在已经在路上了。” equals “She is on her way.” in English?
Even with RNNs it’s not directly obvious how to do this. Well, it wasn’t until Google published their paper “Sequence to Sequence Learning with Neural Networks” (Seq2Seq) in 2014. The idea is to train a joint encoder-decoder model. First, an encoder based on RNNs learns an abstract representation. Then a decoder also based on RNNs learns to decode it to another language, generating a new sequence from the encoding as output.
It turns out RNNs work much better if we let them look into the future as well. For this purpose, so-called bi-directional RNNs were invented. Each bi-directional RNN consists of two RNNs: One that looks at the sequence from start to end, and one that looks at it in reverse. At each part of the sequence we thus have information about what came before and what will come after. That way, a RNNs can better learn about the context of each segment.
For most problems, data is hard to come by. Labeled data even more so. But math
equations are cheap: They can be easily generated, and Python gives us their
result simply by calling
Starting with simple addition of small, natural numbers, we can easily generate
a lot of equations along with their results and train a Seq2Seq model on it.
For example, a typical datapoint for addition could look like this:
input: '31 + 87' output: '118'
Since we’re learning on just a fraction of all possible formulas, the model can’t just learn the results by heart. Instead, in order to generalize to all other equations, it really needs to “understand” what addition “means”. This could include, among others:
- An equation consists of numbers and operations
- Two numbers are added up digit-by-digit
- If two digits add up to a value higher than 9, they carry over to the next
a + b = b + a
- Adding zero does not actually do anything:
a + 0 = a
And this just for additions! If we add more operations, or mix them, the network needs to grog even more rules like this.
Setting up your Environment
If you’re doing this from scratch, you will want to work in a new virtualenv.
I also recommend using Python 3, because it’s really time we got over Python
2… If you don’t know how virtualenv works, I really recommend looking into
them. You can find a short tutorial
here. To get a
virtualenv with Python 3, try
virtualenv -p python3 venv.
Once your virtualenv is activated, install these requirements:
pip install numpy tensorflow keras
If you run into trouble with Tensorflow, follow their guide.
Great! You can paste the code below into a Python shell, or store it in a file, it’s up to you.
Generating Math Formulas as Strings
For now, let’s build some equations! We’re going to go real simple here and just work with easy-peasy addition of two natural numbers. Once we get that to work, we can build more funky stuff later on. See the full code to find out how to build more complex equations.
To make sure we build each possible equation only once, we’re going to use a
generator. With the handy
itertools Python standard package, we can generate
all possible permutations of two numbers and then create a formula for each
pair. Here’s our basic code. Note that I’m referencing some global variables
(IN ALL CAPS), we will define them later.
Assuming global config values
MIN_NUMBER = 0 and
MAX_NUMBER = 999,
this code will generate us equations like these:
'89 + 7' '316 + 189' '240 + 935'
For any of these equation strings, we can easily get the result in Python using
That was easy, right?
Encoding Strings for Neural Networks
Remember we want to look at the strings as sequences. The RNN will not see the input string as a whole, but one character after the other. It will then have to learn to remember the important parts in its internal state. That means we need to convert each input and output string into vectors first. Let’s also add an “end” character to each sequence so the neural network can learn when a sequence is finished. I’m using the (invisible) ‘\0’ character here.
We do this by encoding each character as a one-hot vector. Each string is then simply a matrix of these character-vectors. The way we encode is quite arbitrary, as long as we decode it the same way later on. Here’s some code I used for this:
With these helper functions, it is quite easy to generate the dataset:
And later to print out some examples along with their targets:
It is important to note here that the training data uses a max length for the input and output sequences only for numerical reasons. The model is not limited by that length and could look at or output longer sequences after this training.
Building the Model in Keras
Now, let’s build the model. Thanks to Keras, this is quite straightforward.
First, we need to decide what the shape of our inputs is supposed to be. Since
it’s a matrix with a one-hot-vector for each position in the equation string,
this is simply
(MAX_EQUATION_LENGTH, N_FEATURES). We’ll pass that input to a
first layer consisting of 20 (bidirectional) LSTM cells. Each will look at the input,
character by character, and output a single value.
There are several ways to build a Seq2Seq model. The simplest way is to simply
collect all information about the input sequence in a fixed vector, then have
the decoder generate a sequence from that. To make this possible, we use
RepeatVector to feed the decoder network with the representation in each time
step. For a more detailed discussion about Seq2Seq models in Keras, see
So after we repeat the encoded vector
n times with
n being the (maximum)
length of our output sequences, we run this repeat-sequence through the
decoder: A (bidirectional) LSTM layer that will output sequences of vectors.
Finally, we want to combine each LSTM cell’s output at each point in the
sequence to a single output vector. This is done using Keras’
wrapper around a simple
Since we expect something like a one-hot vector for each output character,
we still need to apply
softmax as usual in classification problems. This
essentially yields us a probability distribution over the character classes.
Note that from what I gathered, this way of building Seq2Seq models in Keras is not optimal and not exactly equivalent to what is proposed in the paper. It works nonetheless. For a more correct implementation, try out Fariz Rahman’s seq2seq package.
Either way, here is our code:
Training is easy. We just call the fit function of the model, passing handy
callbacks like the
ModelCheckpoint, which stores the best model after each
epoch. Here’s the main function of our code:
Finally, we need to fill in some of the global variables we used throughout the code. With equations using two numbers from 0 to 999, there’s 998k possible data points. Let’s use 30k of those for our dataset, and validate on 10% of that, meaning we’ll train on about 2.7% of possible equations and expect it to generalize to the remaining 97.3%. That may sound impressive, but Deep Learning is actually used to facing much more dire odds.
Here’s my complete config:
Finally, we can start training. Either call
main() from the shell, or store
everything in a file
training.py and run it via
Running the code as it is described here, with some patience I get to a test
1.0 on the test set after about 120 epochs. Yes that’s right,
it seems we’re making zero mistakes on our test set of 3000 equations.
Here’s a graph showing how accuracy developed during training.
As you can see, overfitting is not a problem for us. The capacity of the
network is just way too small to be learning examples by heart.
That’s great, but running it on some more, unseen example equations, the model still makes the occasional mistake! Why is that? Did our model not generalize to all areas of the problem space?
Analyzing the Mistakes
Let’s look at some of the examples where the model failed to give the right answer. I tried 20k new equations, and out of those, only 10 were incorrect, for example:
47 + 58 = 115 (should have been 105) 94 + 909 = 1903 (should have been 1003) 2 + 7 = 19 (should have been 9) 989 + 811 = 1890 (should have been 1800) 22 + 78 = 00 (should have been 100)
- Most, but not all errors are in the space of smaller numbers, i.e. where at least one number is less than 100.
- All errors seem to be in one single digit only, while the rest of the digits are correct.
Error #1 is easy to understand: In the lower number range, the rules of addition change a little bit (the first digit no longer counts the 100s). At the same time, this problematic space has seen less training examples than the easier parts. In the whole equation space, there are ~800k points where both numbers have three digits, and only about 80 where both numbers are below 10.
Number two is a bit harder to follow. It seems to be an utterly un-human thing to do. A human would make mistakes in the space of numbers, not in the space of strings. It’s perfectly alright to get the last digit wrong, but to report “1903” instead of “1003” is unacceptable! What happened?
In the case of “1903” vs “1003”, it looks like part of the decoder thought the result would be above 1003, so correctly spit out a “1”, while the next output thought “Nope, definitely below 1000” and output a “9” as in nine-hundred-something. Does each position in the output sequence follow its own stubborn logic, not really caring if the number as a whole makes any sense? This could probably be improved by using a better vector representation or coming up with a better loss function. Or it might be a general weakness in this kind of Seq2Seq model.
Anyway, since our input space has only two factors of variation (the two numbers that go into building the equation), we can plot the equation space in a 2D pane. I went ahead and created a scatter plot with green dots marking correct predictions and red dots marking incorrect ones. Find the code for that here.
The red dots are hard to see, but they are there!
There’s lots to do! You can get creative about getting over this lower-number
problem. We can also increase the complexity of the equations. I was able to get
quite far on ones as complex as
131 + 83 - 744 * 33 (= -24338), but haven’t
really gotten it to work with division.
Feel free to pass on hints, ideas for improvement, or your own results in the comments or as issues on my repository.