Harnessing Zephyr's Breeze: DPO Training on Mistral-7B-GPTQ for Language Model Alignment

We've taken on the exciting challenge of implementing the cutting-edge strategies presented in "ZEPHYR: Direct Distillation of LM Alignment". This paper's approach is not just theoretical—it's a blueprint for a significant leap in language model training. By adopting ZEPHYR's distilled direct preference optimization (dDPO), we've embarked on a code journey that brings these innovations from concept to reality.

Understanding the Gust of ZEPHYR

The implementation of ZEPHYR revolves around the concept of Direct Preference Optimization (DPO), a technique designed to fine-tune language models not just for accuracy, but for alignment with human values and intentions. This process involves training models to prefer certain types of responses over others, effectively teaching them what we, as humans, consider a 'better' reply.

The Code That Steers the Wind

To bring this concept to life, developers rely on a series of Python scripts, each fulfilling a pivotal role in the training and deployment of these aligned models. These scripts are the sails of our vessel, harnessing the theoretical ZEPHYR into a practical tool.

  1. Configuring the Sails: config.py

The journey begins with config.py, a script that sets the environment for our model. It defines the model's identity, the dataset it will train on, and the hyperparameters that guide its learning. The intricacies of GPTQ, LoRA, and training configurations are established here, forming the blueprint of our model's architecture.

```python from pydantic_settings import BaseSettings

class Config(BaseSettings): MODEL_ID: str = "TheBloke/OpenHermes-2-Mistral-7B-GPTQ" DATASET_ID: str = "HuggingFaceH4/ultrafeedback_binarized"

# GPTQ config
BITS:int = 4
DISABLE_EXLLAMA:bool = True

# AutoModelForCausalLM config
DEVICE_MAP:str = "auto"

# Lora config
LORA_R: int = 4
LORA_ALPHA: int = 8
LORA_DROPOUT: float = 0.1
LORA_TARGET_MODULES: list = ["q_proj", "v_proj"]
LORA_TASK_TYPE:str ="CAUSAL_LM"
LORA_BIAS:str = "none"
INFERENCE_MODE:bool = False

# DPOTrainer config
BATCH_SIZE: int = 1
MAX_STEPS: int = 50
REMOVE_UNUSED_COLUMNS: bool = False
GRAD_ACCUMULATION_STEPS: int = 1
LEARNING_RATE: float = 3e-4
EVALUATION_STRATEGY: str = "steps"
LOGGING_FIRST_STEP: bool = True
LOGGING_STEPS: int = 10
OUTPUT_DIR:str = "openhermes-mistral-gptq-dpo"
OPTIM:str = "paged_adamw_32bit"
WARMUP_STEPS:int = 2
FP16:bool = True
PUSH_TO_HUB:bool = True

class Config:
    env_prefix = ''  # defaults to no prefix, i.e. ""

`` 2. **Charting the Course:data_utils.py`**

Next, data_utils.py charts the course by preparing the dataset. It processes the raw data into a structured format that the model can understand, focusing on prompts, and the preferred and rejected responses—much like a navigator charting a path through the stars.

```python from datasets import Dataset, load_dataset from mistral.dpo.config import Config import warnings warnings.filterwarnings("ignore")

def dpo_data(dataset_id, split:str='train_prefs') -> Dataset:

dataset = load_dataset(
    dataset_id,
    split = split,
    use_auth_token=True
)

original_columns = dataset.column_names

def return_prompt_and_responses(samples):
    return {
        "prompt": samples["prompt"],
        "chosen": samples["chosen"],
        "rejected": samples["rejected"]
    }

return dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns,
)

Create triple (prompt, chosen, rejected) dataset

def create_dataset(dataset_id, split='train_prefs'): dataset =dpo_data(dataset_id, split=split) df = dataset.to_pandas() df["chosen"] = df["chosen"].apply(lambda x: x[1]["content"]) df["rejected"] = df["rejected"].apply(lambda x: x[1]["content"]) df = df.dropna() dataset = Dataset.from_pandas(df) return dataset ```

  1. Setting Sail: dpo_trainer.py

With the path charted, dpo_trainer.py sets the sails. This script is where the model begins its training, learning from the data prepared earlier. It meticulously adjusts the weights within the model, guided by the preferences we've outlined, ensuring that every response generated is a step closer to our ideal.

```python import torch from datasets import Dataset from peft import AutoPeftModelForCausalLM, LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import AutoTokenizer, TrainingArguments, AutoModelForCausalLM, GPTQConfig from trl import DPOTrainer from mistral.dpo.config import Config from mistral.dpo.data_utils import create_dataset import warnings warnings.filterwarnings("ignore")

class MistralDPOTrainer: def init(self, config: Config): self.config = config self.tokenizer = AutoTokenizer.from_pretrained(self.config.MODEL_ID) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token

# DPOTrainer requires a triple dataset (prompt, chosen, rejected)
def create_triple_dataset(self):
    dataset = create_dataset(self.config.DATASET_ID, split='train_prefs')
    df = dataset.to_pandas()
    train_size = int(len(df) * 0.8)
    train_df = df[:train_size].sample(1000)
    train_dataset = Dataset.from_pandas(train_df)
    val_df = df[train_size:].sample(200)
    val_dataset = Dataset.from_pandas(val_df)
    test_dataset = create_dataset(self.config.DATASET_ID, split='test_prefs')
    return train_dataset, val_dataset, test_dataset

def prepare_model(self):
    gptq_config = GPTQConfig(bits=self.config.BITS, disable_exllama=self.config.DISABLE_EXLLAMA)
    model = AutoModelForCausalLM.from_pretrained(config.MODEL_ID, torch_dtype=torch.float16, 
                                                 low_cpu_mem_usage=True, 
                                                 quantization_config=gptq_config,
                                                  device_map=self.config.DEVICE_MAP)
    model_ref = AutoModelForCausalLM.from_pretrained(config.MODEL_ID, torch_dtype=torch.float16, 
                                                     low_cpu_mem_usage=True, 
                                                     quantization_config=gptq_config,
                                                     device_map=self.config.DEVICE_MAP)
    print("Load model from pretrained checkpoint")
    print(model)

    peft_config = LoraConfig(
        r=self.config.LORA_R,
        lora_alpha=self.config.LORA_ALPHA,
        lora_dropout=self.config.LORA_DROPOUT,
        target_modules=self.config.LORA_TARGET_MODULES,
        task_type=self.config.LORA_TASK_TYPE,
        bias=self.config.LORA_BIAS,
        inference_mode=self.config.INFERENCE_MODE)

    model = prepare_model_for_kbit_training(model)
    model.config.use_cache=False
    model.gradient_checkpointing_enable()
    model.config.pretraining_tp=1
    model = get_peft_model(model, peft_config)

    print("Load model with LoRA Adapter")
    print(model)

    # DPOTrainer requires a reference model
    model_ref = prepare_model_for_kbit_training(model_ref)
    model_ref.config.use_cache=False
    model_ref.gradient_checkpointing_enable()
    model_ref.config.pretraining_tp=1
    model_ref = get_peft_model(model_ref, peft_config)

    print("Load reference model with LoRA Adapter")
    print(model_ref)

    return model, model_ref, peft_config

def set_training_arguments(self):

    '''
    Sets the arguments for the training loop in TrainingArguments class
    '''

    training_arguments = TrainingArguments(
    per_device_train_batch_size=self.config.BATCH_SIZE,
    max_steps=self.config.MAX_STEPS,
    remove_unused_columns=self.config.REMOVE_UNUSED_COLUMNS,
    gradient_accumulation_steps=self.config.GRAD_ACCUMULATION_STEPS,
    learning_rate=self.config.LEARNING_RATE,
    evaluation_strategy=self.config.EVALUATION_STRATEGY,
    logging_first_step=self.config.LOGGING_FIRST_STEP,
    logging_steps=self.config.LOGGING_STEPS,
    output_dir=self.config.OUTPUT_DIR,
    optim=self.config.OPTIM,
    warmup_steps=self.config.WARMUP_STEPS,
    fp16=self.config.FP16,
    push_to_hub=self.config.PUSH_TO_HUB
    )
    return training_arguments

def train(self):
    train_dataset, val_dataset, test_dataset = self.create_triple_dataset()
    print('triple dataset for DPO', '*'*20)
    print('train_dataset', train_dataset)
    print('val_dataset', val_dataset)
    print('test_dataset', test_dataset)
    print('train_dataset', '*'*20)
    model, model_ref, peft_config = self.prepare_model()

    training_args = self.set_training_arguments()

    dpo_trainer = DPOTrainer(
        model,
        model_ref,
        args=training_args,
        beta=0.1,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=self.tokenizer,
        max_length=256,
        max_target_length=128,
        max_prompt_length=128
    )
    dpo_trainer.train()
    dpo_trainer.push_to_hub("jamesliu23/" + config.OUTPUT_DIR)

if name == 'main': config = Config() dpo_trainer = MistralDPOTrainer(config) dpo_trainer.train() ```

  1. Navigating the Currents: dpo_inference.py

Finally, dpo_inference.py navigates the currents of real-world application. It takes the helm, using the trained model to generate responses to new prompts. It's the moment of truth, where we see the ZEPHYR model come to life, aligning its generated text with the preferences it has learned.

```python from peft import AutoPeftModelForCausalLM from transformers import GenerationConfig from transformers import AutoTokenizer import torch from mistral.dpo.config import Config

if name == 'main': config = Config() tokenizer = AutoTokenizer.from_pretrained("Vasanth/openhermes-mistral-dpo-gptq")

inputs = tokenizer("""I have dropped my phone in water. Now it is not working what should I do now?""", return_tensors="pt").to("cuda")

model = AutoPeftModelForCausalLM.from_pretrained(
    config.OUTPUT_DIR,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map="cuda")

generation_config = GenerationConfig(
    do_sample=True,
    top_k=1,
    temperature=0.1,
    max_new_tokens=256,
    pad_token_id=tokenizer.eos_token_id
)

```

Reflecting on the Voyage

The code behind ZEPHYR is more than a set of Python instructions; it's a testament to human ingenuity and our desire to make technology reflect our better selves. The scripts are the embodiment of the paper's vision, each line a step closer to creating language models that understand not just our words, but our meanings and intentions.

The journey of ZEPHYR is ongoing. Each implementation, each model trained, is another breeze harnessed, another step toward a future where AI and humans speak not just the same language, but share the same understanding.


Exploring ZEPHYR's code is akin to a nautical voyage, where each script is a crucial part of the vessel, navigating the vast seas of AI alignment. As we refine these scripts, we refine our journey, ever striving for that perfect alignment, like a sailor seeking the ideal wind to fill their sails.

References

Related

Created 2023-11-09T16:30:34-08:00, updated 2023-12-08T05:23:28-08:00 · History · Edit