Deriving the Derivative for a softmax cross entropy

A simple walkthrough to deriving the derivative ( with code ) for the cross entropy loss

Published 23 October, 2023

Introduction

I recently spent some time working on Assignment 1 for CS231n and as part of it, we needed to manually calculate the derivative of the softmax cross entropy loss to update the weights of a two-layer neural network manually. This was a bit complicated and I struggled for a good 2-3 days before I managed to understand how all the parts fit together.

Let's start by working through some key definitions. In order to understand everything, you'll need to have an understanding of the following few things

  1. The Chain Rule
  2. What is a Softmax and how can we calculate its derivative
  3. What is the cross-entropy loss
  4. How can we apply the chain rule to more complex derivatives

Once we've gotten a hang of all these concepts, we'll get to the juicy part - deriving the derivative of the softmax cross-entropy loss.

I assume you've got a rough understanding of derivatives. I try to give a basic understanding of how to derive the derivatives below but it'll make much more sense if you spent a few minutes looking through your old calculus notebook. :)

Chain Rule

The chain rule can be understood with the following equation.

dydx=dydu×dudx\frac{dy}{dx} = \frac{dy}{du} \times \frac{du}{dx}

Let's see a more concret example. Let's say that we have the following two equations.

z=2x+yz = 2x + y
x=a3bx = a - 3b

and we assume that we have the following values where a=10,b=1,y=3a=10,b=1,y=3. This means that we can in turn derive the following values.

x=103×1=7x = 10 - 3 \times 1 = 7
z=2×(7)+3=17z = 2\times (7) + 3 = 17

The chain rule would come in useful if we wanted to know how a change in the value of bb might affect something like zz. We could of course inline a and b directly into the equation above in this simple example but you'll find that this might not always be the case for more complex examples.

Let's first try to compute how zz changes when we change the value of xx. Well, in our previous example, we had x=7x=7. What happens if we use x=6x = 6 and x=8x = 8?

Well, we can calculate the new sums of what zz would be assuming that yy remains constant. In our case, we notice that when x=6x = 6, we get z=15z = 15 whereas when x=7x = 7, we get z=17z = 17. This means that effectively, everytime we modify the value of xx by 1, we observe a corresponding change in the value of zz of 2. We can write this more formally as

dzdx=2\frac{dz}{dx}=2

Let's apply a similar logic to find how the value of xx changes when we modify bb by 1. If we perform the same steps, we'll find that

dxdb=3\frac{dx}{db}=-3

Well, how can we then apply this to find the value of dzdb\frac{dz}{db}? Well, the chain rule tells us that

dzdb=dzdx×dxdb=2×(3)=6\frac{dz}{db}= \frac{dz}{dx} \times \frac{dx}{db} = 2 \times (-3) = -6

So, with that we now know how to calculate derivatives for variables that are linked via a series of equations. Let's now move on to what the softmax is.

Softmax

What is the Softmax Function

Let's imagine that we had three possible events - get a candy, get a toy or get a cake every time we go visit our grand parents. If we visit them 10 times and out of the 10 times, we have

  • 3 candies
  • 3 toys
  • 4 cakes

We could in turn represent the total distribution of events as

counts = [3,3,4]

But we could convert this to a probability distribution by taking

P(Get A Toy)=33+3+4=310P(\text{Get A Toy}) = \frac{3}{3 + 3 + 4} = \frac{3}{10}

The Softmax Function works kind of similarly. It allows us to take any arbitrary numbers and then convert them to a purely positive range. This is because of how the expoential function exe^x behaves. You can see this from the graph below.

Softmax Function

By normalising our inputs to the positive ranges, we've now obtained "counts" of our target variables. We can then normalise these counts, similiar to the example above to a vector of floats that we can treat as a probability distribution. This is done by applying the formula of

S(xi)=exij=0nexjS(x_i) = \frac{e^{x_i}}{\sum_{j=0}^n e^{x_j}}

In fact, this is how we've chosen to interpret the result of most models, GPT models have a softmax layer at the end of their output layers which we then sample from to generate the next token. Using probability distributions also gives us more interesting results. ( Note : in Anthropic's constituitional AI paper, they constrain the probability of generate outputs to ~40 - 60%. So we can see how we can take a raw probability distribution using the softmax and then modify it to generate a smaller distribution )

Numerical Instability

However, we're representing our values in code as floating point numbers. These are finite representations which cannot represent infinite values. Therefore, we can get some funky behaviour once we get large numbers. But, this doesn't even happen when we reach super large numbers.

x = [-10,-1,1000]
np.exp(x) / np.sum(np.exp(x))
 
> <ipython-input-4-c0960e0e33eb>:2: RuntimeWarning: overflow encountered in exp
  np.exp(x) / np.sum(np.exp(x))
> <ipython-input-4-c0960e0e33eb>:2: RuntimeWarning: invalid value encountered in divide
  np.exp(x) / np.sum(np.exp(x))
> array([ 0.,  0., nan])

This poses a problem. If our numbers are too large, how can we ensure we are able to get the same probability distribution without integer overflow. Well, it turns out since Softmax is a monotonically increasing function, we can do the following and still achieve the same distributions.

For a more complex walkthrough, here's an explanation by Jay K Mody. A detailed walkthrough of softmax numerical instability is out of scope for this article.

x = [-10,-1,1000]
x = x - np.max(x)
np.exp(x) / np.sum(np.exp(x))
 
> This yields [0,0,1]

This brings us to a numerically stable implementation of the softmax as seen below.

def softmax(x):
    # assumes x is a vector
    x = x - np.max(x)
    return np.exp(x) / np.sum(np.exp(x))

Calculating the Derivative

Now, let's move on to figuring out what the derivative is. Specifically, given an input xjx_j we want to calculate it's impact on a final softmax output SiS_i where we have an input size of nn. Let's start by writing out our equation.

Si=exik=0nexkS_i = \frac{e^{x_i}}{\sum_{k=0}^n e^{x_k}}

And we want to find

δSiδxj=1δxj[exik=0nexk]\frac{\delta S_i}{\delta x_j} = \frac{1}{\delta x_j} [ \frac{e^{x_i}}{\sum_{k=0}^n e^{x_k}} ]

Let's first do a quick substituition where u=exi,v=exj+k=0,kjexku = e^{x_i}, v = e^{x_j} + \sum_{k=0,k \neq j} e^{x_k}. (Note here that we've simply pulled out our exje^{x_j} term out of our original summation.) This means we can rewrite our original equation now as

S(xi)=uvS(x_i) = \frac{u}{v}

Which allows us to apply the quotient rule that states that

δS(xi)xj=vδuδxjuδvδxjv2\frac{\delta S(x_i)}{x_j} = \frac{v\frac{\delta u}{\delta x_j} - u\frac{\delta v}{ \delta x_j}}{v^{2}}

which we can then rewrite as

δS(xi)xj=(exj+k=0,kjexk)1δxj[exi]exi1δxj[exj+k=0,kjexk](exj+k=0,kjexk)2\frac{\delta S(x_i)}{x_j} = \frac{(e^{x_j} + \sum_{k=0,k \neq j} e^{x_k}) \frac{1}{\delta x_j} [e^{x_i}] - e^{x_i} \frac{1}{\delta x_j} [e^{x_j} + \sum_{k=0,k \neq j} e^{x_k}] }{(e^{x_j} + \sum_{k=0,k \neq j} e^{x_k})^2}

Before we walkthrough the answer, note the following

  • we're taking the derivative with respect to xjx_j
  • if y=exy = e^x, then this means that δyδx=ex\frac{\delta y}{ \delta x} = e^x
  • if y=exiy = e^{x_i} and we take the derivative with respect to xjx_j, then this means that we can treat exie^{x_i} as a constant, and δyδxj=0\frac{\delta y}{\delta x_j} = 0

Let's now apply this new form to solve the two cases.

Case 1 : iji \neq j

With the points above, we can rewrite our derivative and simplify it as

δS(xi)xj=(exj+k=0,kjexk)0exi(exj)(k=0nexk)2=exi(exj)(k=0nexk)2\begin{equation} \begin{split} \frac{\delta S(x_i)}{x_j} &= \frac{(e^{x_j} + \sum_{k=0,k \neq j} e^{x_k}) 0 - e^{x_i} (e^{x_j}) }{(\sum_{k=0}^n e^{x_k})^2} \\ &= \frac{-e^{x_i} (e^{x_j})}{(\sum_{k=0}^n e^{x_k})^2} \end{split} \end{equation}

If you're confused about why the huge term exj+k=0,kjexke^{x_j} + \sum_{k=0,k \neq j} e^{x_k} was simplified, consider the fact that there are no terms expressed in terms of xjx_j in that summation and that y=exjy = e^{x_j} implies that δyδxj=exj\frac{\delta y}{\delta x_j} = e^{x_j}. Now, our derivative looks a little scary but if we rewrite it a little bit, we get the following

δS(xi)δxj=exik=0nexk×exjk=0nexk=S(xi)S(xj)\begin{equation} \begin{split} \frac{\delta S(x_i)}{\delta x_j} &= -\frac{e^{x_i}}{\sum_{k=0}^n e^{x_k}} \times \frac{e^{x_j}}{\sum_{k=0}^n e^{x_k}} \\ &= -S(x_i) S(x_j) \end{split} \end{equation}

Case 2 : i=ji = j

In this case, we can rewrite our equation as

δS(xi)xi=(exi+k=0,kinexk)1δxi[exi]exi1δxi[exi+k=0,kjnexk](exi+k=0,kinexk)2\frac{\delta S(x_i)}{x_i} = \frac{(e^{x_i} + \sum_{k=0,k \neq i}^n e^{x_k}) \frac{1}{\delta x_i} [e^{x_i}] - e^{x_i} \frac{1}{\delta x_i} [e^{x_i} + \sum_{k=0,k \neq j}^n e^{x_k}] }{(e^{x_i} + \sum_{k=0,k \neq i}^n e^{x_k})^2}

Which is going to be simplified to

δS(xi)xi=(exi+k=0,kiexk)exiexi(exi)(k=0nexk)2=exi(k=0,kiexk)(k=0nexk)2=exi((k=0nexk)exi)(k=0nexk)2=S(xi)(1S(xi))\begin{equation} \begin{split} \frac{\delta S(x_i)}{x_i} &= \frac{(e^{x_i} + \sum_{k=0,k \neq i} e^{x_k}) e^{x_i} - e^{x_i} ( e^{x_i} ) }{(\sum_{k=0}^n e^{x_k})^2} \\ &= \frac{e^{x_i} (\sum_{k=0,k \neq i} e^{x_k}) }{(\sum_{k=0}^n e^{x_k})^2} \\ &= \frac{e^{x_i} ((\sum_{k=0}^n e^{x_k}) - e^{x_i}) }{(\sum_{k=0}^n e^{x_k})^2} \\ &= S(x_i) ( 1 - S(x_i)) \end{split} \end{equation}

This therefore means that we can express our derivative in these two cases

δS(xi)xj={S(xi)(1S(xi))i=jS(xi)S(xj)ij\frac{\delta S(x_i)}{x_j} = \begin{cases} S(x_i) ( 1 - S(x_i)) & i = j \\ -S(x_i) S(x_j) & i \neq j \end{cases}

Cross-Entropy Loss

The Cross Entropy loss is commonly used when we have a model that we want to train to be able to classify an input into a specific class or multiple classes. In this article and example, we'll be looking at the case whereby we're trying to classify an input into a specific class.

Loss Function

The use of a positive loss function holds for a majority of loss functions that you'll see. There are some cases whereby loss functions seek to maximise a negative value and drive it closer to 0 as seen here but they are quite rare from my personal experience.

When we have a loss function, we can take a prediction xx and compute how good it is w.r.t some ground truth. This means that the better the predictions, the lower the loss. Our goal in doing so is to create a single objective number that we can look to slowly minimise.

Conceptually, this is because minimising a loss function means driving it as close to 0. If we tried to create a loss that we could maximise instead, then we simply wouldn't be able to maximise it since it would span from 0,0,\infty.

So this brings us to two simple characteristics of a loss function

  • It has to have a range between 0,0,\infty. If not we would be minimising towards -\infty.
  • The better our predictions, the lower our loss should be

Softmax and the Loss

We can see that our softmax outputs a vector which contains probabilities. These will range from 0 to 1. In our case, since we are looking to make a correct prediction, ideally our probability for the correct class should be 1. When it's a bad prediction, it will be close to 0.

If we took the raw probability as the loss, then this would mean that the better our predictions, the higher our loss. That's where the Cross-Entropy Loss comes in. It has the formula of

L=i=0ntilog(S(Xi))L = -\sum_{i=0}^n t_i log(S(X_i))

where tit_i represents the predicted label for the i-th class and yiy_i represents the predicted probability that our softmax produces for the specific class ii. Since a item either belongs or does not belong to a class, this means that tit_i can only take on the values of 0,10,1. Since we're predicting for a single class, this also means that essentially have the following sum

L=(0×log(S(X0))+0×log(S(X1))++1×log(S(Xi)))L = -(0 \times log(S(X_0)) + 0 \times log(S(X_1)) + \dots + 1 \times log(S(X_i)) )

Which brings our loss to

L=log(S(xi))L = -log(S(x_i))

where ii represents the index of the correct class. This derivative is simply to calculate and is just

δLδS(xi)=1S(xi) \frac{\delta L}{\delta S(x_i)} = - \frac{1}{S(x_i)}

Why use the Negative Log Probability

Now that we have a rough idea of what the loss function looks like, let's take a step back to think about why we might want to use the negative log probability. A good place to start is to look at how log behaves for the range 0 to 1.

As we can see from the graph, as our value of x goes from, the value of y=log(x)y = log(x) goes from -\infty to 0.This fits one of the key characteristics of our loss function - which is that we want a smaller value to be produced as our predictions improve.

However, we're still faced with a small problem, we need our loss function to output positive results. To fix this, we just multiply the loss by 1-1 to get y=log(x)y = -log(x). This gives the graph below.

With this, we have the two properties of a loss function that we originally wanted. Note that in practice, we normally clip the values of the prediction before passing it into the binary cross entropy since x0,log(x)x \to 0, log(x) \to -\infty. Therefore we use numpy's clip functionality to clip the small probabilities to some small value as seen below.

np.clip([0.2,0.4,0.000001,0.2], 1e-7, 1-1e-7)

Bringing It All Together

We've now found two equations. The first is the cross entropy loss with respect to a specific softmax value.

δLδS(xi)=1S(xi) \frac{\delta L}{\delta S(x_i)} = - \frac{1}{S(x_i)}

The second is the softmax value w.r.t its original inputs.

δS(xi)xj={S(xi)(1S(xi))i=jS(xi)S(xj)ij\frac{\delta S(x_i)}{x_j} = \begin{cases} S(x_i) ( 1 - S(x_i)) & i = j \\ -S(x_i) S(x_j) & i \neq j \end{cases}

We can simply apply the chain rule to get the derivative of the loss w.r.t the original inputs to the softmax.

Case 1 : iji \neq j

δLδxj=δLδS(xi)×δS(xi)xj=1S(xi)×S(xi)S(xj)=S(xj)\begin{equation}\begin{split} \frac{\delta L}{\delta x_j} &= \frac{\delta L}{\delta S(x_i)} \times \frac{\delta S(x_i)}{x_j} \\ &= -\frac{1}{S(x_i)} \times -S(x_i) S(x_j) \\ &= S(x_j) \end{split}\end{equation}

Case 2 : i=ji = j

δLδxj=δLδS(xi)×δS(xi)xj=1S(xi)×S(xi)(1S(xi))=S(xi)1\begin{equation}\begin{split} \frac{\delta L}{\delta x_j} &= \frac{\delta L}{\delta S(x_i)} \times \frac{\delta S(x_i)}{x_j} \\ &= -\frac{1}{S(x_i)} \times S(x_i) ( 1 - S(x_i)) \\ &= S(x_i) - 1 \end{split}\end{equation}

We can write this in a case form as

δLδxj={S(xi)1i=jS(xj)ij\frac{\delta L}{\delta x_j} = \begin{cases} S(x_i)-1 & i = j \\ S(x_j) & i \neq j \end{cases}

This makes the derivative back propogation facile to implement in python as an if-else

for j in range(len(inputs)):
  if j == y[i]:
      dW[:,j] += (-1 + softmax_output)
  else:
      dW[:,j] += softmax_output