BxD Primer Series: Attention Mechanism
Attention Mechanism spotlights on most relevant parts of input data to make a prediction.
Hey there 👋
Welcome to BxD Primer Series where we are covering topics such as Machine learning models, Neural Nets, GPT, Ensemble models, Hyper-automation in ‘one-post-one-topic’ format. Today’s post is on Attention Neural Networks. Let’s get started:
Introduction to Attention Mechanism:
Attention mechanism is a key component in many modern deep learning models, particularly in natural language processing, image and video frame processing. In these tasks, such as machine translation, text summarization, question answering, image captioning etc., the input data often contains long sequences or dependencies that require modeling context and understanding.
Traditional sequence-to-sequence models, like RNN, LSTM or GRU, tend to struggle with long-range dependencies and relevant information become diluted or is lost over time. Attention mechanism addresses this issue by allowing the model to weigh importance of different parts of input dynamically.
Attention mechanism works by creating a context vector for each step or token in input sequence. This context vector is a weighted sum of input elements, with their weights indicating the importance or relevance of each element. High level steps (more detail in later parts):
Encoder: Encoder processes the input and generates a set of encoded representations. Although encoder word is used in many references, we refer it strictly for mentioned purpose.
Query, Key, and Value: Encoder representations are then split into three parts: query, key, and value. These parts will be used to calculate attention weights.
Similarity Scores: Compute a similarity score between query and each key, determining how well each element in input matches the query. Common similarity functions include dot product, scaled dot product, or concatenation with a learnable weight matrix.
Attention Weights: Similarity scores are passed through a softmax function to obtain attention weights, ensuring that the weights sum up to 1 and represent a valid probability distribution.
Weighted Sum: Attention weights are used to calculate a weighted sum of value vectors. This step combines the input representations according to their importance, generating a context vector that captures the most relevant information.
Context Vector and Output: The context vector is then concatenated with the output of the previous step in the model (e.g., the previous hidden state in an RNN) and fed into the next layer for further processing. The attention mechanism can be used in a single layer or across multiple layers in a deep model.
Repetition: Steps 2-6 are repeated for each step or token in the input sequence, allowing the model to attend to different parts of the input at each step.
Attention Neural Networks:
Attention Neural Networks (ANNs) are designed to enhance the ability of a model to selectively focus on specific parts of an input sequence or image. They have been successfully applied in a natural language, image, and speech related tasks.
Basic idea is that the model learns to assign varying degrees of importance or "attention" to different parts of input.
In traditional neural networks, the input is processed in a fixed order, and each element of input is treated equally. But in most tasks, some parts of input is more relevant than others. For example, in language translation, certain words or phrases may be more important for translating a particular sentence accurately.
Attention Neural Networks are designed to address this by allowing the model to selectively attend to most relevant parts of input. It is achieved by using a learnable weighting scheme that assigns varying degrees of attention to different parts of input. Weights represent the degree of relevance or importance of each input feature, and are learned during training process.
There are two main types of Attention:
Soft Attention mechanisms produce a weighted sum of input features, where the weights are learned by model. These weights represent the degree of importance or attention given to each input feature. This is continuous by nature.
Hard Attention mechanisms select a subset of input features to attend to, and discard the rest. This is discrete by nature.
One popular type of ANN is the Transformer model, which uses a self-attention mechanism that allows the model to attend to different parts of input sequence at different times.
In this edition we will cover the basics of Attention Mechanism and move on to transformer models in next edition.
Soft v/s Hard Attention:
Soft Attention calculates a weighted sum of input features, where weights represent the degree of attention given to each input feature. Soft Attention is called "soft" because it produces a continuous distribution of weights over input features. Continuous distribution allows the model to learn to attend to multiple parts of input simultaneously.
On the other hand, Hard Attention selects a subset of input features to attend to, and discards the rest. This selection can be done using either a learned or a fixed rule. Hard Attention is called "hard" because it produces a discrete distribution of weights over input features. Discrete distribution allows the model to attend to only one part of input at a time.
Local v/s Global Attention:
Local attention mechanisms are used when input data is sequential. It focuses on a limited set of neighboring inputs, rather than the entire input sequence. For example, in machine translation, attention mechanism can focus on a few words before and after current word being translated. It is faster than global attention, as it only needs to focus on a small subset of input at any given time.
Global attention mechanisms can attend to any part of input sequence. It is used when input data is not sequential, as in image recognition tasks. Global attention provide more comprehensive information to model because it can access the entire input sequence.
There is also a hybrid approach called ‘local-global attention’, that combines both local and global attention. In this mechanism, the model attend to both a small subset of neighboring inputs and the entire input sequence, allowing it to balance computation efficiency with comprehensive information access.
The How:
Generating attention based context vectors from input typically involves below steps:
✪ Input Encoding: Assume we have an input sequence consisting of tokens or elements denoted as (x1, x2, …., xn).
This input sequence is encoded into a set of representations using an encoding function E(·): h1, h2, …., hn = E(x1, x2, …., xn)
✪ Query, Key, and Value: Encoded representations are split into query (Q), key (K), and value (V) vectors through linear transformations:
Q = W_q × h_i - captures the current context or information that needs attention.
K = W_k × h_j - represents the encoded information or features of the token.
V = W_v × h_j - contains the values or actual information associated with the token.
Where W_q, W_k, and W_v are learnable weight matrices.
✪ Similarity Scores: Compute similarity scores between query and each key using a similarity function:
Where, s_{ij} represents similarity score between query (Q) and key (K) at positions (i) and (j). Similarity function could be dot product, scaled dot product, or concatenation etc.
✪ Attention Weights: Apply a softmax function to similarity scores to obtain attention weights. Softmax function ensures that weights sum up to 1 and represent a valid probability distribution:
✪ Weighted Sum: Attention weights are used to calculate a weighted sum of value vectors. This weighted sum is known as context vector:
✪ Context Vector and Output: Context vector (C) is concatenated with output of previous step and passed through a linear transformation and an activation function to generate output y_i at current step:
y_i = Activation(W_o × [C, h_i])
where W_o is a learnable weight matrix.
✪ Repeat all above steps for each step or token in input sequence, allowing the model to attend to different parts of input at each step.
Common Similarity Functions:
Similarity functions are used to calculate compatibility between query and key vectors. Common similarity functions:
✪ Dot Product measures the similarity in terms of magnitude and direction of the vectors:
✪ Scaled Dot Product is similar to dot product, but also incorporates a scaling factor to control the magnitude of similarity scores. This helps in stabilizing the gradients during training.
where d is the dimensionality of query and key vectors.
✪ Concatenation similarity function concatenates query and key vector, followed by a linear transformation using a weight matrix. This approach allows the model to capture more complex interactions between query and key vectors.
✪ General similarity function is a flexible approach that allows for linear transformation of key vectors before computing similarity scores. This aligns the dimensions of query and key vectors.
Self-Attention:
Self-Attention is a type of Attention mechanism that allows the model to attend to different parts of input sequence to produce a better representation of input. It is called "Self-Attention" because the input sequence is the same sequence that the model is attending to.
Self-Attention operates by computing a weighted sum of input sequence, where the weights are learned by model. These weights represent the importance or attention given to each element of input sequence, relative to other elements. This way, the model learns to weigh each element of input sequence based on its relevance to current context.
Attention weights are learnable parameters, which can be learned through back-propagation to minimize a loss function. They are used to calculate the dot product between input sequence and a set of query, key, and value vectors.
Query vector is used to calculate the similarity between each element of input sequence
Key vector is used to encode the information about each element of input sequence
Value vector is used to represent the output of Self-Attention layer
Self-Attention mechanism is commonly used in natural language processing tasks, where it allows the model to attend to different parts of input sentence to produce accurate representation of context.
Attention is a dynamic operation that can adaptively attend to any part of input, based on its relevance.
Multi-Head Attention:
Multi-head attention is a variant of attention mechanism that enables the model to attend to different parts of input sequence simultaneously. This approach enhances model's ability to capture different types of dependencies and provides more expressive representations. This is slightly different than basic (single head) attention:
In multi-head attention, the key, query, and value vectors are split into multiple heads, typically denoted by parameter "h". Each head has its own set of learned weight matrices for projection.
Each head independently performs a linear projection on key, query, and value vectors by multiplying the vectors with learned weight matrices.
Each head calculates similarity scores between projected query and key vectors.
Similarity scores are then passed through a softmax function to obtain attention weights, ensuring that the weights sum up to 1 and represent a valid probability distribution.
Each head independently computes a weighted sum of value vectors using attention weights.
The weighted sum of each head is concatenated together to form final output of multi-head attention mechanism. This concatenated output contains combined information from all heads and provides a more comprehensive representation of input.
Note: Number of attention heads, "h", is a hyper-parameter that can be tuned based on specific task and dataset. Increasing "h" provides more capacity to capture diverse patterns but also increases computational complexity of model.
The Why:
Reasons for using Attention Neural Networks:
Improves the performance of neural networks for tasks that involve long input sequences.
Provides a way to visualize which parts of input data the model is attending to for making its predictions. This helps in debugging and improving the model.
By focusing on most relevant parts of input data, it reduces the number of computations required and improves efficiency of model.
Useful for handling variable-length inputs in natural language processing tasks.
Memory-efficient than other methods of processing variable-length inputs, such as fixed-size window or sliding window, time unfolding approach.
Can process multi-modal input data, such as images with accompanying text descriptions, by selectively attending to most relevant parts of each modality.
The Why Not:
Reasons for not using Attention Neural Networks:
Attention mechanism can increase the risk of overfitting if the model is not properly regularized or if attention weights are over-emphasizing certain parts of input data.
Increase the training time of neural networks for tasks that involve long input sequences.
Sensitive to noise or irrelevant information if not properly handled or accounted for.
Not useful in situations where there is limited training data
Introduces additional hyper-parameters that need to be tuned, making training process difficult and time-consuming.
Time for you to support:
Reply to this email with your question
Forward/Share to a friend who can benefit from this
Chat on Substack with BxD (here)
Engage with BxD on LinkedIN (here)
In next edition, we will cover Stable Diffusion Models.
Let us know your feedback!
Until then,
Have a great time! 😊