Introduction
I recently spent some time working on Assignment 1 for CS231n and as part of it, we needed to manually calculate the derivative of the softmax cross entropy loss to update the weights of a two-layer neural network manually. This was a bit complicated and I struggled for a good 2-3 days before I managed to understand how all the parts fit together.
Let's start by working through some key definitions. In order to understand everything, you'll need to have an understanding of the following few things
- The Chain Rule
- What is a Softmax and how can we calculate its derivative
- What is the cross-entropy loss
- How can we apply the chain rule to more complex derivatives
Once we've gotten a hang of all these concepts, we'll get to the juicy part - deriving the derivative of the softmax cross-entropy loss.
I assume you've got a rough understanding of derivatives. I try to give a basic understanding of how to derive the derivatives below but it'll make much more sense if you spent a few minutes looking through your old calculus notebook. :)
Chain Rule
The chain rule can be understood with the following equation.
Let's see a more concret example. Let's say that we have the following two equations.
and we assume that we have the following values where . This means that we can in turn derive the following values.
The chain rule would come in useful if we wanted to know how a change in the value of might affect something like . We could of course inline a and b directly into the equation above in this simple example but you'll find that this might not always be the case for more complex examples.
Let's first try to compute how changes when we change the value of . Well, in our previous example, we had . What happens if we use and ?
Well, we can calculate the new sums of what would be assuming that remains constant. In our case, we notice that when , we get whereas when , we get . This means that effectively, everytime we modify the value of by 1, we observe a corresponding change in the value of of 2. We can write this more formally as
Let's apply a similar logic to find how the value of changes when we modify by 1. If we perform the same steps, we'll find that
Well, how can we then apply this to find the value of ? Well, the chain rule tells us that
So, with that we now know how to calculate derivatives for variables that are linked via a series of equations. Let's now move on to what the softmax is.
Softmax
What is the Softmax Function
Let's imagine that we had three possible events - get a candy, get a toy or get a cake every time we go visit our grand parents. If we visit them 10 times and out of the 10 times, we have
- 3 candies
- 3 toys
- 4 cakes
We could in turn represent the total distribution of events as
counts = [3,3,4]
But we could convert this to a probability distribution by taking
The Softmax Function works kind of similarly. It allows us to take any arbitrary numbers and then convert them to a purely positive range. This is because of how the expoential function behaves. You can see this from the graph below.

By normalising our inputs to the positive ranges, we've now obtained "counts" of our target variables. We can then normalise these counts, similiar to the example above to a vector of floats that we can treat as a probability distribution. This is done by applying the formula of
In fact, this is how we've chosen to interpret the result of most models, GPT models have a softmax layer at the end of their output layers which we then sample from to generate the next token. Using probability distributions also gives us more interesting results. ( Note : in Anthropic's constituitional AI paper, they constrain the probability of generate outputs to ~40 - 60%. So we can see how we can take a raw probability distribution using the softmax and then modify it to generate a smaller distribution )
Numerical Instability
However, we're representing our values in code as floating point numbers. These are finite representations which cannot represent infinite values. Therefore, we can get some funky behaviour once we get large numbers. But, this doesn't even happen when we reach super large numbers.
x = [-10,-1,1000]
np.exp(x) / np.sum(np.exp(x))
> <ipython-input-4-c0960e0e33eb>:2: RuntimeWarning: overflow encountered in exp
np.exp(x) / np.sum(np.exp(x))
> <ipython-input-4-c0960e0e33eb>:2: RuntimeWarning: invalid value encountered in divide
np.exp(x) / np.sum(np.exp(x))
> array([ 0., 0., nan])
This poses a problem. If our numbers are too large, how can we ensure we are able to get the same probability distribution without integer overflow. Well, it turns out since Softmax is a monotonically increasing function, we can do the following and still achieve the same distributions.
For a more complex walkthrough, here's an explanation by Jay K Mody. A detailed walkthrough of softmax numerical instability is out of scope for this article.
x = [-10,-1,1000]
x = x - np.max(x)
np.exp(x) / np.sum(np.exp(x))
> This yields [0,0,1]
This brings us to a numerically stable implementation of the softmax as seen below.
def softmax(x):
# assumes x is a vector
x = x - np.max(x)
return np.exp(x) / np.sum(np.exp(x))
Calculating the Derivative
Now, let's move on to figuring out what the derivative is. Specifically, given an input we want to calculate it's impact on a final softmax output where we have an input size of . Let's start by writing out our equation.
And we want to find
Let's first do a quick substituition where . (Note here that we've simply pulled out our term out of our original summation.) This means we can rewrite our original equation now as
Which allows us to apply the quotient rule that states that
which we can then rewrite as
Before we walkthrough the answer, note the following
- we're taking the derivative with respect to
- if , then this means that
- if and we take the derivative with respect to , then this means that we can treat as a constant, and
Let's now apply this new form to solve the two cases.
Case 1 :
With the points above, we can rewrite our derivative and simplify it as
If you're confused about why the huge term was simplified, consider the fact that there are no terms expressed in terms of in that summation and that implies that . Now, our derivative looks a little scary but if we rewrite it a little bit, we get the following
Case 2 :
In this case, we can rewrite our equation as
Which is going to be simplified to
This therefore means that we can express our derivative in these two cases
Cross-Entropy Loss
The Cross Entropy loss is commonly used when we have a model that we want to train to be able to classify an input into a specific class or multiple classes. In this article and example, we'll be looking at the case whereby we're trying to classify an input into a specific class.
Loss Function
The use of a positive loss function holds for a majority of loss functions that you'll see. There are some cases whereby loss functions seek to maximise a negative value and drive it closer to 0 as seen here but they are quite rare from my personal experience.
When we have a loss function, we can take a prediction and compute how good it is w.r.t some ground truth. This means that the better the predictions, the lower the loss. Our goal in doing so is to create a single objective number that we can look to slowly minimise.
Conceptually, this is because minimising a loss function means driving it as close to 0. If we tried to create a loss that we could maximise instead, then we simply wouldn't be able to maximise it since it would span from .
So this brings us to two simple characteristics of a loss function
- It has to have a range between . If not we would be minimising towards .
- The better our predictions, the lower our loss should be
Softmax and the Loss
We can see that our softmax outputs a vector which contains probabilities. These will range from 0 to 1. In our case, since we are looking to make a correct prediction, ideally our probability for the correct class should be 1. When it's a bad prediction, it will be close to 0.
If we took the raw probability as the loss, then this would mean that the better our predictions, the higher our loss. That's where the Cross-Entropy Loss comes in. It has the formula of
where represents the predicted label for the i-th class and represents the predicted probability that our softmax produces for the specific class . Since a item either belongs or does not belong to a class, this means that can only take on the values of . Since we're predicting for a single class, this also means that essentially have the following sum
Which brings our loss to
where represents the index of the correct class. This derivative is simply to calculate and is just
Why use the Negative Log Probability
Now that we have a rough idea of what the loss function looks like, let's take a step back to think about why we might want to use the negative log probability. A good place to start is to look at how log behaves for the range 0 to 1.

As we can see from the graph, as our value of x goes from, the value of goes from to 0.This fits one of the key characteristics of our loss function - which is that we want a smaller value to be produced as our predictions improve.
However, we're still faced with a small problem, we need our loss function to output positive results. To fix this, we just multiply the loss by to get . This gives the graph below.

With this, we have the two properties of a loss function that we originally wanted. Note that in practice, we normally clip the values of the prediction before passing it into the binary cross entropy since . Therefore we use numpy's clip functionality to clip the small probabilities to some small value as seen below.
np.clip([0.2,0.4,0.000001,0.2], 1e-7, 1-1e-7)
Bringing It All Together
We've now found two equations. The first is the cross entropy loss with respect to a specific softmax value.
The second is the softmax value w.r.t its original inputs.
We can simply apply the chain rule to get the derivative of the loss w.r.t the original inputs to the softmax.
Case 1 :
Case 2 :
We can write this in a case form as
This makes the derivative back propogation facile to implement in python as an if-else
for j in range(len(inputs)):
if j == y[i]:
dW[:,j] += (-1 + softmax_output)
else:
dW[:,j] += softmax_output