Scaling Laws for Forgetting When Fine-Tuning Large Language Models

When fine-tuning Large Language Models (LLMs) like GPT-3 or BERT for specific tasks, a common challenge encountered is "forgetting" – where the model loses some of its pre-trained capabilities. This phenomenon is particularly noticeable in Parameter-Efficient Fine-Tuning (PEFT) methods such as Low-Rank Adapters (LoRA).

Fine-Tuning and Forgetting: The Balancing Act

The total number of parameters in a pre-trained LLM is fixed, a massive count running into billions. However, not all of these parameters are necessarily updated during the fine-tuning process. In PEFT, the idea is to only adjust a subset of the model's parameters to tailor it to a specific task, which can vary in size.

What does "Increasing with the Number of Parameters Fine-Tuned" Mean?

This phrase refers to how the extent of forgetting can vary depending on how many of the model's parameters are adjusted during fine-tuning. It's a balancing act:

This relationship between the number of parameters fine-tuned and the extent of forgetting is crucial. It provides insights into how to balance improving performance on specific tasks while retaining the general knowledge the model gained during its extensive pre-training.

The Fine-Tuning Conundrum: Balancing Performance and Forgetting

A significant revelation from the study is that the amount of forgetting is not merely a consequence of the number of parameters being fine-tuned but is intricately linked to the fine-tuning performance itself. The research demonstrates that forgetting scales with the number of non-embedding parameters fine-tuned and the number of gradient update steps, following a shifted power law . This implies a complex relationship between the scale of parameter adjustment during fine-tuning and the resulting performance and forgetting rates.

The Dilemma of Parameter Selection

When we talk about "increasing with the number of parameters fine-tuned," it's crucial to clarify that while the total number of parameters in a pre-trained model is fixed, the subset of parameters modified during fine-tuning can vary. The study leverages the LoRA fine-tuning technique, which adds a tuneable "adapter" module to any subset of the pre-trained model's weights, thus varying the number of parameters adjusted during fine-tuning . This approach allows for a controlled exploration of how adjusting different numbers of parameters affects forgetting and performance.

Insights on Optimal Parameter Adjustment

The paper suggests that the relationship between fine-tuning performance and forgetting is complex and cannot be simply mitigated by early stopping or tuning a smaller number of parameters . This observation challenges conventional wisdom and highlights the need for new strategies to balance the dual goals of minimizing forgetting while maximizing fine-tuning performance.

Quantifying the Forgetting Effect

When fine-tuning Large Language Models (LLMs) on specific tasks, it's crucial to measure the extent to which these models "forget" their pre-trained capabilities. This section outlines the methodology used to quantify the forgetting effect, utilizing a metric based on cross-entropy loss between the fine-tuned model's predictions and those of the base pre-trained model.

Step 1: Define the Forgetting Metric (Lf)

The forgetting metric, denoted as Lf, employs cross-entropy loss, where the target predictions are sourced from the base pre-trained model instead of the actual ground truth data. This approach addresses issues that arise from using standard loss measurements, such as the inability to account for the pre-trained model's initial performance on the evaluation dataset and the potential overlap between training and evaluation datasets.

Step 2: Evaluate Base Model Predictions

  1. Base Model Evaluation: Run the base pre-trained model on the evaluation dataset to obtain its predictions. These predictions serve as the "target" for calculating the forgetting metric.

  2. Record Predictions: Store the predicted next tokens or outputs from the base model for each input sequence in the evaluation dataset.

Step 3: Fine-tune the LLM

  1. Select Parameters for Fine-tuning: Choose a subset of the LLM's parameters to update during the fine-tuning process. This can be a critical decision, as the number of parameters fine-tuned can impact the degree of forgetting.

  2. Fine-tune on Specific Task: Train the selected parameters of the LLM on the task-specific dataset, applying a fine-tuning technique such as Low-Rank Adapters (LoRA) for parameter efficiency.

Step 4: Evaluate Fine-tuned Model Predictions

  1. Fine-tuned Model Evaluation: Run the fine-tuned model on the same evaluation dataset used to assess the base model's predictions.

  2. Record Predictions: Capture the predicted next tokens or outputs from the fine-tuned model for each input sequence.

Step 5: Calculate the Forgetting Metric (Lf)

  1. Compute Cross-Entropy Loss: For each input sequence in the evaluation dataset, calculate the cross-entropy loss between the fine-tuned model's predictions and the base model's predictions (serving as the target).

  2. Aggregate Loss: Aggregate these individual loss values across the entire evaluation dataset to compute the overall forgetting metric Lf.

Step 6: Interpret the Results

By systematically following these steps, researchers and practitioners can quantify the forgetting effect in LLMs, providing a basis for developing strategies to mitigate forgetting and preserve pre-trained knowledge during the fine-tuning process.

Key Takeaways

Understanding and managing this trade-off is essential for leveraging the full potential of LLMs in various applications, from natural language understanding and generation to more specialized tasks requiring nuanced domain knowledge.

References

Related

Created 2024-03-16T07:29:37-07:00, updated 2024-03-16T07:35:17-07:00 · History · Edit