Supervised Learning - Recurrent Neural Networks and LSTMs
Recurrent Neural Networks (RNNs)
- Basic Structure: RNNs process sequential data by maintaining a hidden state that is updated at each time step.
- Equation: The hidden state h<em>t at time t is calculated using the formula: h</em>t=f<em>w(h</em>t−1,x<em>t), where x</em>t is the input at time t and fw is a function with weights W.
- Output Layer: The output y<em>t at time t is computed from the hidden state: y</em>t=W<em>yh</em>t.
- Computational Graph: RNNs can be visualized as a computational graph that unfolds over time, showing dependencies between inputs, hidden states, and outputs at each time step.
- Many-to-Many: An RNN computational graph illustrates how inputs x<em>1,x</em>2,x<em>3,…,x</em>T are processed to produce outputs y<em>1,y</em>2,y<em>3,…,y</em>T with corresponding losses L<em>1,L</em>2,L<em>3,…,L</em>T.
Vanishing Gradient Problem
- Description: A significant challenge in training RNNs, especially with long sequences.
- Recursion Impact: Recursion in RNNs effectively adds hidden layers, making it difficult for backpropagation to effectively update weights due to vanishing gradients.
- Mathematical Explanation: During backpropagation, error gradients are summed at each time step. The term that causes issues involves the matrix W<em>R. If the dominant eigenvalue of W</em>R is greater than 1, the gradient explodes; if less than 1, it vanishes.
- This can be expressed as ∂W∂L=∑<em>t=1T∂W∂L</em>t.
- The gradient depends on the eigenvalues of the weight matrix: ∂h</em>t−1∂h<em>t=WR.
Back Propagation Through Time (BPTT)
- Process: Involves a forward pass through the entire sequence to compute the loss, followed by a backward pass through the entire sequence to compute gradients.
Truncated Back Propagation Through Time (TBPTT)
- Description: An approach to mitigate the vanishing gradient problem by running forward and backward passes through chunks of the sequence.
- Advantage: Much faster than simple BPTT.
- Mechanism: Hidden states are carried forward in time indefinitely, but backpropagation is limited to a smaller number of steps.
- Chunk: The time window used during backpropagation is referred to as a “chunk.”
- Disadvantage: Dependencies longer than the chunk length are not learned during training, as the contribution of gradients from distant steps is ignored.
Long Short-Term Memory (LSTM) Networks
- Purpose: Designed to address the long-term dependency issues faced by RNNs due to the vanishing gradient problem.
- Functionality: LSTMs can process entire sequences of data while retaining useful information from previous data points to aid in processing new ones, making them suitable for text, speech, and time-series data.
- Example: Predicting monthly rainfall amounts where an LSTM network can learn the annual trend that repeats every 12 periods.
- It retains longer-term context rather than just using the previous prediction.
- Structure: LSTMs have a more complex repeating module compared to the simple structure in RNNs.
- Inputs:
- Current long-term memory (cell state)
- Output at the previous time step (previous hidden state)
- Input data at the current time step
- Gates:
- Forget Gate
- Input Gate (can be split into input modulation gate and input gate)
- Output Gate
- These gates act as filters and are implemented as neural networks.
LSTM Mechanism - Step-by-Step
- Forget Layer:
- Decides which part of the cell state is useful based on the previous hidden state and new input data.
- A neural network generates a vector with elements in the interval [0,1] using a sigmoid activation function σ.
- Values close to 0 indicate irrelevance, while values close to 1 indicate relevance.
- The previous cell state C<em>t−1 is multiplied by the forget gate's output f</em>t:
- f<em>t=σ(W</em>f⋅[h<em>t−1,x</em>t]+bf)
- C<em>t−1f=f</em>t∗Ct−1
- Propose a New State:
- Determines what new information should be added to the cell state.
- Involves two networks:
- A tanh-activated neural network generates a new memory update vector C~t with values in [−1,1].
- C~<em>t=tanh(W</em>m⋅[h<em>t−1,x</em>t]+bm)
- A sigmoid-activated input network ωt acts as a filter to identify which components of the new memory vector are worth retaining.
- ω<em>t=σ(W</em>ω⋅[h<em>t−1,x</em>t]+bω)
- The cell state is updated by adding the filtered new memory vector:
- C<em>t=C</em>t−1f+ω<em>t∗C~</em>t
- Output Gate:
- Defines the new hidden state ht.
- Applies the tanh function to the current cell state to obtain the squished cell state, which now lies in [−1,1].
- Passes the previous hidden state and current input data through a sigmoid-activated neural network to obtain the filter vector ot.
- Applies this filter vector to the squished cell state by element-wise multiplication.
- o<em>t=σ(W</em>o⋅[h<em>t−1,x</em>t]+bo)
- h<em>t=o</em>t∗tanh(Ct)
LSTM and Vanishing Gradient Problem
- LSTMs address the vanishing gradient problem with specialized gates:
- Forget Gate: C<em>t−1f=f</em>t∗Ct−1
- Input Gate: C<em>t=C</em>t−1f+ω<em>t∗C~</em>t
- The weights in these gates depend on the time step t, aiding in the calculation of derivatives.
Example: LSTM Parameter Calculation
- Scenario: Univariate time series LSTM problem with cell state and hidden state having 32 units each, a sliding time window of 5 steps, and a bias term.
- Calculation:
- Input dimension is 1 (univariate).
- 4 layers with weights: (1+32)×32, and biases: 32.
- Number of weights: 4×((33×32)+32)
- Final dense layer: 32×d+d, where d is the output dimension.
- Limitations of LSTMs:
- Sequential processing limits parallelization.
- Difficulty in modeling very long sequences.
- Long training times.
- Fixed memory bottleneck for context retention.
- Need for Paradigm Shift: Modern applications require:
- Efficient parallel processing.
- Access to long-range dependencies.
- Scalable memory of context.
- Transformers: