Reading time: 15 minutes • Implementation time: 2-3 days • Difficulty: Advanced

Overview

GRPO (Group Relative Policy Optimization) is the core algorithm for training ATLAS teacher models. This guide walks through the complete training pipeline from SFT warmup to full RL optimization. Training Pipeline

Prerequisites

  • 4-8 H100 or A100 GPUs (40GB+ VRAM each)
  • CUDA 11.8+
  • Python 3.8+
  • 100GB+ disk space for checkpoints
  • Weights & Biases account (optional but recommended)

Training Pipeline

1

Environment Setup

Install dependencies and configure the environment:
# Clone repository
git clone https://github.com/Arc-Computer/ATLAS.git
cd ATLAS

# Create environment
conda create -n atlas python=3.10
conda activate atlas

# Install PyTorch with CUDA
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install training dependencies
pip install -r requirements.txt

# Install Flash Attention 2 (recommended)
pip install flash-attn --no-build-isolation

# Login to Hugging Face
huggingface-cli login

# Configure Weights & Biases (optional)
wandb login
Flash Attention 2 significantly improves training speed and memory efficiency. Install it if your GPU supports it (Ampere or newer).
2

SFT Warmup

Start with supervised fine-tuning to establish base capabilities:
# Run SFT training
scripts/launch.sh 8 configs/run/teacher_sft.yaml \
  output_dir=checkpoints/sft \
  num_train_epochs=1 \
  learning_rate=2e-5 \
  per_device_train_batch_size=4 \
  gradient_accumulation_steps=4
Key parameters:
  • num_train_epochs: 1-2 epochs typically sufficient
  • learning_rate: Start with 2e-5, adjust based on loss
  • gradient_accumulation_steps: Increase for larger effective batch size
Monitor training:
# Check training progress
tail -f checkpoints/sft/training.log

# Monitor GPU usage
nvidia-smi -l 1
SFT is critical for GRPO success. Ensure loss converges before proceeding.
3

Launch vLLM Server

Start the inference server for distributed generation:
# Terminal 1: Launch vLLM server
./scripts/launch_vllm_server.sh \
  --model-path checkpoints/sft/final \
  --tensor-parallel-size 4 \
  --gpu-memory-utilization 0.9 \
  --max-model-len 2048 \
  --port 8000

# Verify server is running
curl http://localhost:8000/v1/models
Server configuration options:
# configs/vllm_config.yaml
tensor_parallel_size: 4     # GPUs for inference
pipeline_parallel_size: 1    # Pipeline parallelism
gpu_memory_utilization: 0.9  # Memory allocation
max_model_len: 2048          # Max sequence length
enable_prefix_caching: true  # Cache repeated prompts
4

Run GRPO Training

Execute the main RL training with the SFT checkpoint:
# Terminal 2: Run GRPO training
scripts/launch_with_server.sh 4 4 configs/run/teacher_rcl.yaml \
  model_name_or_path=checkpoints/sft/final \
  output_dir=checkpoints/grpo \
  num_train_epochs=3 \
  learning_rate=5e-7 \
  beta=0.04 \
  temperature=0.7
The script automatically:
  1. Distributes training across 4 GPUs
  2. Uses 4 GPUs for vLLM generation
  3. Manages distributed communication
  4. Handles checkpointing
The first number (4) is training GPUs, second (4) is inference GPUs. Adjust based on your hardware.
5

Monitor Training

Track key metrics during training:
# Monitor with Weights & Biases
wandb sync checkpoints/grpo

# Or use TensorBoard
tensorboard --logdir checkpoints/grpo/tensorboard

# Key metrics to watch:
# - rewards/mean_reward: Should increase
# - kl_divergence: Should stay < 10
# - learning_rate: Verify schedule
# - loss/policy_loss: Should decrease
Real-time monitoring script:
import json
import time

def monitor_training():
    while True:
        with open('checkpoints/grpo/trainer_state.json') as f:
            state = json.load(f)

        print(f"Step: {state['global_step']}")
        print(f"Loss: {state['loss']:.4f}")
        print(f"Learning Rate: {state['learning_rate']:.2e}")
        print(f"Best Metric: {state['best_metric']:.4f}")
        print("-" * 40)

        time.sleep(10)

monitor_training()
6

Evaluate Model

Test the trained model on validation data:
# Run evaluation
python scripts/evaluate_model.py \
  --model-path checkpoints/grpo/best_model \
  --dataset Arc-Intelligence/Arc-ATLAS-Teach-v0 \
  --split validation \
  --metrics accuracy improvement safety

# Test specific capabilities
python scripts/test_teaching.py \
  --teacher checkpoints/grpo/best_model \
  --student Qwen/Qwen3-4B-Instruct \
  --tasks sre_debugging math_reasoning code_generation
Expected metrics:
  • Improvement rate: >15%
  • Non-degradation: >95%
  • Token efficiency: <250 tokens average

Configuration Deep Dive

GRPO Hyperparameters

Critical parameters for successful training:
# configs/trainer/teacher_grpo.yaml
# Algorithm parameters
beta: 0.04                    # KL divergence coefficient (0.01-0.1)
temperature: 0.7              # Sampling temperature (0.5-1.0)
grpo_alpha: 0.5              # PPO-style clipping (0.1-1.0)

# Generation settings
max_new_tokens: 512          # Response length limit
top_p: 0.95                  # Nucleus sampling
do_sample: true              # Enable sampling

# Optimization
learning_rate: 5e-7          # Peak learning rate
warmup_ratio: 0.1           # Warmup proportion
weight_decay: 0.01          # L2 regularization
max_grad_norm: 1.0          # Gradient clipping

# Efficiency
gradient_accumulation_steps: 4  # Effective batch size multiplier
gradient_checkpointing: true    # Memory vs compute tradeoff

Reward Function Configuration

Customize rewards for your use case:
# configs/reward/adaptive_teaching.yaml
# Core reward components
degradation_penalty_multiplier: 2.0  # Penalty for performance drops
efficiency_weight: 1.0              # Reward for concise teaching
baseline_threshold: 0.5             # Minimum performance for rewards

# Advanced settings
diversity_bonus: 0.1                # Encourage varied strategies
consistency_weight: 0.2             # Reward stable performance
max_probe_tokens: 50               # Diagnostic limit
max_guidance_tokens: 200           # Teaching limit

Advanced Training Techniques

Curriculum Learning

Implement progressive difficulty:
class CurriculumScheduler:
    """
    Adjust task difficulty during training
    """

    def __init__(self, num_epochs):
        self.num_epochs = num_epochs
        self.current_epoch = 0

    def get_task_distribution(self):
        """
        Return task mixture for current epoch
        """
        progress = self.current_epoch / self.num_epochs

        if progress < 0.3:
            # Start with simple tasks
            return {
                'simple': 0.6,
                'medium': 0.3,
                'hard': 0.1
            }
        elif progress < 0.7:
            # Balanced distribution
            return {
                'simple': 0.33,
                'medium': 0.34,
                'hard': 0.33
            }
        else:
            # Focus on hard tasks
            return {
                'simple': 0.1,
                'medium': 0.3,
                'hard': 0.6
            }

Mixed Precision Training

Enable for faster training:
# Add to training config
fp16: true                    # Use float16
fp16_opt_level: "O2"         # Optimization level
fp16_backend: "amp"          # Use automatic mixed precision

# Or use bfloat16 (recommended for A100/H100)
bf16: true
bf16_full_eval: true

Gradient Accumulation Strategy

Optimize for your GPU memory:
def calculate_accumulation_steps(
    desired_batch_size=128,
    per_device_batch_size=4,
    num_gpus=8
):
    """
    Calculate gradient accumulation steps
    """
    effective_batch_per_step = per_device_batch_size * num_gpus
    accumulation_steps = desired_batch_size // effective_batch_per_step

    print(f"Per-device batch size: {per_device_batch_size}")
    print(f"Number of GPUs: {num_gpus}")
    print(f"Gradient accumulation steps: {accumulation_steps}")
    print(f"Effective batch size: {desired_batch_size}")

    return accumulation_steps

Troubleshooting

Problem: OOM errors during trainingSolutions:
# Reduce batch size
per_device_train_batch_size=1

# Enable gradient checkpointing
gradient_checkpointing=true

# Use DeepSpeed ZeRO
deepspeed=configs/deepspeed/zero2.json

# Offload to CPU
offload=true
Problem: Rewards go to zero or negativeSolutions:
# Increase KL penalty
beta=0.1

# Reduce learning rate
learning_rate=1e-7

# Add reward clipping
reward_clip_threshold=5.0

# Check data quality
validate_training_data()
Problem: Connection refused or timeoutSolutions:
# Check server is running
ps aux | grep vllm

# Check port availability
lsof -i :8000

# Restart server with more memory
--gpu-memory-utilization 0.95

# Use smaller model
--max-model-len 1024
Problem: Training is slower than expectedSolutions:
# Enable Flash Attention
attn_implementation=flash_attention_2

# Use torch.compile (PyTorch 2.0+)
torch_compile=true

# Optimize data loading
dataloader_num_workers=4
dataloader_pin_memory=true

# Profile to find bottlenecks
python -m torch.profiler.profile

Performance Optimization

Multi-Node Training

Scale across multiple machines:
# Node 1 (master)
torchrun \
  --nproc_per_node=8 \
  --nnodes=2 \
  --node_rank=0 \
  --master_addr=10.0.0.1 \
  --master_port=29500 \
  train.py configs/run/teacher_rcl.yaml

# Node 2
torchrun \
  --nproc_per_node=8 \
  --nnodes=2 \
  --node_rank=1 \
  --master_addr=10.0.0.1 \
  --master_port=29500 \
  train.py configs/run/teacher_rcl.yaml

DeepSpeed Integration

Use ZeRO optimization for large models:
// configs/deepspeed/zero2.json
{
  "zero_optimization": {
    "stage": 2,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9
  },
  "fp16": {
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 1000
  },
  "gradient_clipping": 1.0,
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto"
}

Validation and Testing

Automated Evaluation Suite

def evaluate_teacher_model(model_path, test_suite):
    """
    Comprehensive evaluation of trained model
    """
    results = {}

    # Test improvement rates
    for domain in ['sre', 'math', 'code', 'reasoning']:
        domain_results = evaluate_domain(model_path, domain)
        results[domain] = {
            'improvement': domain_results['mean_improvement'],
            'safety': domain_results['non_degradation_rate'],
            'efficiency': domain_results['avg_tokens']
        }

    # Verify safety constraints
    assert all(r['safety'] > 0.95 for r in results.values()), \
           "Safety threshold not met"

    # Check minimum performance
    assert all(r['improvement'] > 0.10 for r in results.values()), \
           "Improvement threshold not met"

    return results

A/B Testing Framework

class ModelComparison:
    """
    Compare new model against baseline
    """

    def __init__(self, baseline_path, new_model_path):
        self.baseline = load_model(baseline_path)
        self.new_model = load_model(new_model_path)

    def run_comparison(self, test_data, num_samples=100):
        results = {
            'baseline': [],
            'new_model': []
        }

        for sample in test_data[:num_samples]:
            # Test both models
            baseline_result = self.baseline.process(sample)
            new_result = self.new_model.process(sample)

            results['baseline'].append(baseline_result['score'])
            results['new_model'].append(new_result['score'])

        # Statistical significance test
        from scipy.stats import ttest_rel
        t_stat, p_value = ttest_rel(
            results['new_model'],
            results['baseline']
        )

        return {
            'baseline_mean': np.mean(results['baseline']),
            'new_model_mean': np.mean(results['new_model']),
            'improvement': np.mean(results['new_model']) - np.mean(results['baseline']),
            'p_value': p_value,
            'significant': p_value < 0.05
        }

Next Steps