Supervised Learning - Recurrent Neural Networks and LSTMs
Recurrent Neural Networks (RNNs)
Basic Structure: RNNs process sequential data by maintaining a hidden state that is updated at each time step.
Equation: The hidden state ht at time t is calculated using the formula: ht = fw(h{t-1}, xt), where xt is the input at time t and f_w is a function with weights W.
Output Layer: The output yt at time t is computed from the hidden state: yt = Wy ht.
Computational Graph: RNNs can be visualized as a computational graph that unfolds over time, showing dependencies between inputs, hidden states, and outputs at each time step.
Many-to-Many: An RNN computational graph illustrates how inputs x1, x2, x3, …, xT are processed to produce outputs y1, y2, y3, …, yT with corresponding losses L1, L2, L3, …, LT.
Vanishing Gradient Problem
Description: A significant challenge in training RNNs, especially with long sequences.
Recursion Impact: Recursion in RNNs effectively adds hidden layers, making it difficult for backpropagation to effectively update weights due to vanishing gradients.
Mathematical Explanation: During backpropagation, error gradients are summed at each time step. The term that causes issues involves the matrix WR. If the dominant eigenvalue of WR is greater than 1, the gradient explodes; if less than 1, it vanishes.
This can be expressed as \frac{\partial L}{\partial W} = \sum{t=1}^{T} \frac{\partial Lt}{\partial W}.
The gradient depends on the eigenvalues of the weight matrix: \frac{\partial ht}{\partial h{t-1}} = W_R.
Back Propagation Through Time (BPTT)
Process: Involves a forward pass through the entire sequence to compute the loss, followed by a backward pass through the entire sequence to compute gradients.
Truncated Back Propagation Through Time (TBPTT)
Description: An approach to mitigate the vanishing gradient problem by running forward and backward passes through chunks of the sequence.
Advantage: Much faster than simple BPTT.
Mechanism: Hidden states are carried forward in time indefinitely, but backpropagation is limited to a smaller number of steps.
Chunk: The time window used during backpropagation is referred to as a “chunk.”
Disadvantage: Dependencies longer than the chunk length are not learned during training, as the contribution of gradients from distant steps is ignored.
Long Short-Term Memory (LSTM) Networks
Purpose: Designed to address the long-term dependency issues faced by RNNs due to the vanishing gradient problem.
Functionality: LSTMs can process entire sequences of data while retaining useful information from previous data points to aid in processing new ones, making them suitable for text, speech, and time-series data.
Example: Predicting monthly rainfall amounts where an LSTM network can learn the annual trend that repeats every 12 periods.
It retains longer-term context rather than just using the previous prediction.
Structure: LSTMs have a more complex repeating module compared to the simple structure in RNNs.
Inputs:
Current long-term memory (cell state)
Output at the previous time step (previous hidden state)
Input data at the current time step
Gates:
Forget Gate
Input Gate (can be split into input modulation gate and input gate)
Output Gate
These gates act as filters and are implemented as neural networks.
LSTM Mechanism - Step-by-Step
Forget Layer:
Decides which part of the cell state is useful based on the previous hidden state and new input data.
A neural network generates a vector with elements in the interval [0,1] using a sigmoid activation function \sigma.
Values close to 0 indicate irrelevance, while values close to 1 indicate relevance.
The previous cell state C{t-1} is multiplied by the forget gate's output ft:
ft = \sigma(Wf \cdot [h{t-1}, xt] + b_f)
C{t-1}^f = ft * C_{t-1}
Propose a New State:
Determines what new information should be added to the cell state.
Involves two networks:
A tanh-activated neural network generates a new memory update vector \tilde{C}_t with values in [-1, 1].
\tilde{C}t = \tanh(Wm \cdot [h{t-1}, xt] + b_m)
A sigmoid-activated input network \omega_t acts as a filter to identify which components of the new memory vector are worth retaining.
The cell state is updated by adding the filtered new memory vector:
Ct = C{t-1}^f + \omegat * \tilde{C}t
Output Gate:
Defines the new hidden state h_t.
Applies the \tanh function to the current cell state to obtain the squished cell state, which now lies in [-1, 1].
Passes the previous hidden state and current input data through a sigmoid-activated neural network to obtain the filter vector o_t.
Applies this filter vector to the squished cell state by element-wise multiplication.
ot = \sigma(Wo \cdot [h{t-1}, xt] + b_o)
ht = ot * \tanh(C_t)
LSTM and Vanishing Gradient Problem
LSTMs address the vanishing gradient problem with specialized gates:
Forget Gate: C{t-1}^f = ft * C_{t-1}
Input Gate: Ct = C{t-1}^f + \omegat * \tilde{C}t
The weights in these gates depend on the time step t, aiding in the calculation of derivatives.
Example: LSTM Parameter Calculation
Scenario: Univariate time series LSTM problem with cell state and hidden state having 32 units each, a sliding time window of 5 steps, and a bias term.
Calculation:
Input dimension is 1 (univariate).
4 layers with weights: (1 + 32) \times 32, and biases: 32.
Number of weights: 4 \times ((33 \times 32) + 32)
Final dense layer: 32 \times d + d, where d is the output dimension.
Transformers
Limitations of LSTMs:
Sequential processing limits parallelization.
Difficulty in modeling very long sequences.
Long training times.
Fixed memory bottleneck for context retention.
Need for Paradigm Shift: Modern applications require: