Overview

ATLAS provides specialized trainer classes for different training paradigms. Each trainer extends HuggingFace’s base trainer with RL-specific capabilities. Training Pipeline

Typical Usage

from trainers import GRPOTrainer
from configs import GRPOConfig

# Initialize configuration
config = GRPOConfig(
    model_name_or_path="Arc-Intelligence/ATLAS-8B-Thinking",
    learning_rate=5e-6,
    num_train_epochs=3,
    beta=0.04  # KL penalty
)

# Create trainer
trainer = GRPOTrainer(
    config=config,
    train_dataset=train_data,
    eval_dataset=eval_data,
    tokenizer=tokenizer
)

# Train model
trainer.train()

# Save final model
trainer.save_model("./output/final_model")

GRPOTrainer

Main trainer for Group Relative Policy Optimization.

Class Overview

GRPOTrainer is the main trainer class for Group Relative Policy Optimization, extending the standard HuggingFace Trainer with reinforcement learning capabilities.
ParameterTypeDescription
configGRPOConfigTraining configuration
modelPreTrainedModelModel to train (policy network)
ref_modelPreTrainedModelReference model for KL penalty
tokenizerPreTrainedTokenizerTokenizer for encoding/decoding
train_datasetDatasetTraining data
eval_datasetDatasetEvaluation data
reward_modelPreTrainedModelOptional external reward model
compute_metricsCallableCustom metrics function
callbacksList[TrainerCallback]Training callbacks
optimizersTupleCustom optimizer and scheduler
GRPO Training Loop: Implements the complete reinforcement learning training process with policy gradient optimization and KL divergence constraints.Generation Support: Supports both local generation and distributed generation via vLLM server integration.Memory Management: Includes optimizations for training large models with gradient checkpointing and model offloading.Reward Integration: Handles multiple reward functions and reward weighting for complex optimization objectives.Implementation: See trainers/grpo.py for complete method signatures and implementation details.
Override these methods for custom behavior:
def on_epoch_begin(self):
    """Called at the beginning of each epoch"""
    pass

def on_step_end(self, args, state, control, **kwargs):
    """Called at the end of each training step"""
    # Log custom metrics
    self.log({
        "rewards/mean": self.current_rewards.mean(),
        "kl_divergence": self.current_kl.mean()
    })

def on_evaluate(self, args, state, control, metrics=None, **kwargs):
    """Called after evaluation"""
    # Custom evaluation logic
    pass
Source: trainers/grpo.py

TeacherGRPOTrainer

Specialized trainer for adaptive teaching with teacher-student paradigm.

Class Overview

TeacherGRPOTrainer extends GRPOTrainer to implement the two-pass teaching protocol. From the actual source code (trainers/teacher_trainers.py), this trainer:
  • Inherits from both GRPOTrainer and TeacherTrainer
  • Accepts student_model parameter in constructor
  • Implements diagnostic probing and adaptive teaching templates
  • Manages both teacher and student models during training

diagnostic_probe()

Assess student capability:
def diagnostic_probe(
    self,
    task: str,
    max_tokens: int = 50
) -> DiagnosticResult:
    """
    Probe student understanding of given task

    Args:
        task: Task or problem statement to probe
        max_tokens: Maximum tokens for probe response

    Returns:
        DiagnosticResult: Assessment containing:
            - capability_level: Student capability (0.0-1.0)
            - confidence_score: Confidence in assessment (0.0-1.0)
            - identified_gaps: List of specific knowledge gaps
            - response_quality: Quality metrics for student response
            - probe_tokens: Number of tokens used in probe

    Raises:
        ValueError: If task is empty or max_tokens < 1
        RuntimeError: If student model inference fails
        TypeError: If task is not a string

    Example:
        result = trainer.diagnostic_probe("Solve: 2x + 5 = 11")
        if result.capability_level < 0.3:
            print("Student needs significant help")
    """

generate_guidance()

Create adaptive teaching:
def generate_guidance(
    self,
    task: str,
    diagnostic: DiagnosticResult,
    max_tokens: int = 200
) -> str:
    """
    Generate teaching guidance based on diagnosis

    Args:
        task: Original task or problem statement
        diagnostic: Results from diagnostic_probe()
        max_tokens: Maximum tokens for guidance response

    Returns:
        str: Tailored teaching guidance text
            Format depends on capability_level:
            - Low (0.0-0.3): Step-by-step walkthrough
            - Medium (0.3-0.7): Hints and scaffolding
            - High (0.7-1.0): Minimal guidance or verification

    Raises:
        ValueError: If max_tokens < 10 or diagnostic is None
        RuntimeError: If teacher model fails to generate guidance
        TypeError: If task is not a string

    Example:
        diagnostic = trainer.diagnostic_probe("Solve quadratic equation")
        guidance = trainer.generate_guidance(
            "Solve: x² - 5x + 6 = 0",
            diagnostic,
            max_tokens=150
        )
        print(f"Teaching guidance: {guidance}")
    """

compute_teaching_reward()

Calculate teaching effectiveness:
def compute_teaching_reward(
    self,
    baseline_score: float,
    enhanced_score: float,
    teaching_length: int
) -> float:
    """
    Compute reward for teaching quality

    Args:
        baseline_score: Student performance without guidance (0.0-1.0)
        enhanced_score: Student performance with guidance (0.0-1.0)
        teaching_length: Number of tokens in teaching guidance

    Returns:
        float: Teaching reward score
            - Positive: Effective teaching (improvement achieved)
            - Zero: No improvement or neutral
            - Negative: Degraded performance (safety penalty)

    Raises:
        ValueError: If scores not in [0.0, 1.0] or teaching_length < 0
        TypeError: If inputs are not numeric

    Reward Components:
        - Improvement bonus: (enhanced_score - baseline_score)
        - Efficiency bonus: max(0, 1 - teaching_length / 200)
        - Safety penalty: -2.0 if enhanced_score < baseline_score

    Example:
        reward = trainer.compute_teaching_reward(
            baseline_score=0.6,
            enhanced_score=0.8,
            teaching_length=120
        )
        # reward ≈ 0.2 + 0.4 = 0.6 (improvement + efficiency)
    """
The two-pass protocol implementation:
def teaching_step(self, batch):
    """Execute one teaching interaction"""

    # Phase 1: Diagnostic
    diagnostics = []
    for prompt in batch["prompts"]:
        diag = self.diagnostic_probe(prompt)
        diagnostics.append(diag)

    # Phase 2: Guidance generation
    guidances = []
    for prompt, diag in zip(batch["prompts"], diagnostics):
        guidance = self.generate_guidance(prompt, diag)
        guidances.append(guidance)

    # Phase 3: Student enhancement
    baseline_responses = self.student_model.generate(batch["prompts"])
    enhanced_responses = self.student_model.generate(
        batch["prompts"],
        guidance=guidances
    )

    # Phase 4: Reward computation
    rewards = []
    for base, enh, guid in zip(baseline_responses, enhanced_responses, guidances):
        reward = self.compute_teaching_reward(
            self.score(base),
            self.score(enh),
            len(guid)
        )
        rewards.append(reward)

    return rewards
Source: trainers/teacher_grpo.py

SFTTrainer

Supervised fine-tuning trainer for warmup before RL.

Constructor

class SFTTrainer(Trainer):
    def __init__(
        self,
        model: PreTrainedModel,
        args: TrainingArguments,
        train_dataset: Dataset,
        eval_dataset: Optional[Dataset] = None,
        tokenizer: PreTrainedTokenizer,
        data_collator: Optional[DataCollator] = None,
        max_seq_length: int = 2048,
        packing: bool = False,
        formatting_func: Optional[Callable] = None,
    ):
  • Sequence packing: Efficient batching of variable-length sequences
  • Custom formatting: Apply templates to raw data
  • Gradient accumulation: Handle large effective batch sizes
  • Mixed precision: FP16/BF16 training support

format_dataset()

Prepare data for training:
def format_dataset(self, dataset):
    """Format dataset for SFT training"""

    def formatting_func(example):
        # Apply chat template
        messages = example["messages"]
        text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False
        )
        return {"text": text}

    return dataset.map(formatting_func)

pack_sequences()

Efficient sequence packing:
def pack_sequences(self, tokenized_dataset):
    """Pack multiple sequences into single training example"""
    # Implementation for efficient GPU utilization
Source: trainers/sft.py

Custom Trainer Implementation

Create your own trainer by extending base classes:
from trainers import GRPOTrainer
import torch

class CustomRewardTrainer(GRPOTrainer):
    """Custom trainer with modified reward computation"""

    def compute_rewards(self, completions, prompts):
        """Override reward computation"""
        rewards = []
        for completion, prompt in zip(completions, prompts):
            # Custom reward logic
            reward = self.custom_reward_function(completion, prompt)
            rewards.append(reward)
        return torch.tensor(rewards)

    def custom_reward_function(self, completion, prompt):
        """Implement domain-specific rewards"""
        # Example: Length penalty
        length_penalty = min(1.0, len(completion) / 500)

        # Example: Quality score
        quality = self.quality_model(completion)

        return quality * length_penalty

Callbacks and Monitoring

Available Callbacks

from trainers.callbacks import (
    WandbCallback,
    TensorBoardCallback,
    EarlyStoppingCallback,
    ModelCheckpointCallback
)

# Configure callbacks
callbacks = [
    WandbCallback(
        project="atlas-training",
        name="experiment-1"
    ),
    EarlyStoppingCallback(
        early_stopping_patience=3,
        early_stopping_threshold=0.001
    ),
    ModelCheckpointCallback(
        save_steps=500,
        save_total_limit=3
    )
]

trainer = GRPOTrainer(
    config=config,
    callbacks=callbacks
)

Custom Metrics

def compute_metrics(eval_predictions):
    """Custom metrics computation"""
    predictions, labels = eval_predictions

    return {
        "accuracy": accuracy_score(labels, predictions),
        "perplexity": perplexity(predictions),
        "diversity": diversity_score(predictions),
        "safety_rate": safety_check(predictions)
    }

trainer = GRPOTrainer(
    config=config,
    compute_metrics=compute_metrics
)

Distributed Training

Multi-GPU Setup

from trainers import GRPOTrainer
from accelerate import Accelerator

accelerator = Accelerator()

trainer = GRPOTrainer(
    config=config,
    model=model,
    accelerator=accelerator
)

# Trainer automatically handles distributed setup
trainer.train()

DeepSpeed Integration

# deepspeed_config.json
{
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu"
        }
    },
    "fp16": {
        "enabled": true
    }
}

trainer = GRPOTrainer(
    config=config,
    deepspeed="deepspeed_config.json"
)

Implementation Notes

ATLAS trainers extend standard HuggingFace Trainer classes with RL-specific functionality. The implementation details can be found in:
  • trainers/grpo.py - Main GRPO trainer implementation
  • trainers/teacher_trainers.py - Teacher-student training logic
  • trainers/grpo_config.py - Configuration parameters

Troubleshooting

Problem: CUDA OOM during trainingSolutions:
# Reduce batch size
config.per_device_train_batch_size = 1
config.gradient_accumulation_steps = 32

# Enable gradient checkpointing
config.gradient_checkpointing = True

# Use mixed precision
config.fp16 = True
Problem: Training is slower than expectedSolutions:
# Enable compilation (PyTorch 2.0+)
model = torch.compile(model)

# Use Flash Attention
config.attn_implementation = "flash_attention_2"

# Optimize data loading
config.dataloader_num_workers = 4
config.dataloader_pin_memory = True
Problem: Loss spikes or NaN valuesSolutions:
# Reduce learning rate
config.learning_rate = 1e-6

# Increase KL penalty
config.beta = 0.1

# Clip gradients
config.max_grad_norm = 0.5

Next Steps