There’s a really good reason why we divide by \(\sqrt{d_k}\). When you’re computing the dot product between two vectors, you’re performing two operations: multiplying the components within each vector dimension, and adding all of the multiplied terms up. Therefore, the more components your vector has, the more additions you will end up doing, and the larger (in magnitude) your result will be.
To make sense as to why dividing by \(\sqrt{d_k}\) helps, we have to understand three assumptions we make about attention at this stage. These are not guaranteed outcomes of the computation, these are assumed by design.
Assumption 1: we assume that each individual entry in the vector (both query and key vectors) is centered around zero, on average. This means that if our query vector becomes longer (consider increasing the dimensionality of your vectors to make the models more expressive), the resulting dot product between your query and key vectors will also have a mean of zero, in expectation.
Assumption 2: we assume that each individual entry has variance one. Of course, we cannot compute the variance of the individual entry, but we assume it’s drawn from a distribution of variance one (we do not make any assumptions about what distribution this is drawn from, however).
Assumption 3: we assume different entries are independent. We need this to make assumptions about how the variance increases as the dimensionality of the vectors grow.
How do these assumptions explain why we divide by \(\sqrt{d_k}\)? When we compute a dot product between two vectors, we’re summing up the products of the corresponding entries. We want to ensure that this result is stable, so that its properties would not depend on the dimensionality of the vectors.
If we increase the dimensionality, the expected entry will still be zero, since multiplying random variables with a mean of zero will result in zero (Assumption 1). We do not need to do anything here.
If each entry in both the query and key have variance one (Assumption 2), and their entries are independent (Assumption 3), then each individual product term will also have variance one. This comes from a property in probability theory: when you multiply two independent random variables with mean zero and variance one, their product also has variance one.
But we are adding up more than one term. For a vector of length 100, we will add up 100 of such terms. This creates a challenge because the variance adds when you sum independent variables. When we sum 100 of such terms, our variance will be equal to 100, not 1. The reason this is problematic is because high variance scores might mean some scores will be very high or small by chance; which will, in turn, produce spiked probability distributions after applying softmax. This is fundamentally bad for learning and obtaining good gradients.
Dividing by \(\sqrt{d_k}\) solves this. This is because dividing by the standard deviation (which is exactly \(\sqrt{d_k}\) brings back the total variance to one. Therefore, regardless of the size of our query and key vectors, we can ensure the computed scores have both zero mean and variance one. While all three assumptions are necessary for this, and we enforce this by performing appropriate scaling operations, in practice, the query and key entries are not independent after training; and their marginals may not stay unit variance.
