ATLAS provides specialized trainer classes for different training paradigms. Each trainer extends HuggingFace’s base trainer with RL-specific capabilities.
GRPOTrainer is the main trainer class for Group Relative Policy Optimization, extending the standard HuggingFace Trainer with reinforcement learning capabilities.
Parameters
Parameter
Type
Description
config
GRPOConfig
Training configuration
model
PreTrainedModel
Model to train (policy network)
ref_model
PreTrainedModel
Reference model for KL penalty
tokenizer
PreTrainedTokenizer
Tokenizer for encoding/decoding
train_dataset
Dataset
Training data
eval_dataset
Dataset
Evaluation data
reward_model
PreTrainedModel
Optional external reward model
compute_metrics
Callable
Custom metrics function
callbacks
List[TrainerCallback]
Training callbacks
optimizers
Tuple
Custom optimizer and scheduler
Key Features
GRPO Training Loop: Implements the complete reinforcement learning training process with policy gradient optimization and KL divergence constraints.Generation Support: Supports both local generation and distributed generation via vLLM server integration.Memory Management: Includes optimizations for training large models with gradient checkpointing and model offloading.Reward Composition: Handles multiple reward functions (including optional RIM-based scoring) and reward weighting for complex optimization objectives.Implementation: See src/atlas_core/training/algorithms/grpo.py for complete method signatures and implementation details.
Training Hooks
Override these methods for custom behavior:
def on_epoch_begin(self): """Called at the beginning of each epoch""" passdef on_step_end(self, args, state, control, **kwargs): """Called at the end of each training step""" # Log custom metrics self.log({ "rewards/mean": self.current_rewards.mean(), "kl_divergence": self.current_kl.mean() })def on_evaluate(self, args, state, control, metrics=None, **kwargs): """Called after evaluation""" # Custom evaluation logic pass
TeacherGRPOTrainer extends GRPOTrainer to implement the two-pass teaching protocol. From the actual source code (src/atlas_core/training/algorithms/teacher_trainers.py), this trainer:
Inherits from both GRPOTrainer and TeacherTrainer
Accepts student_model parameter in constructor
Implements diagnostic probing and verifying-teacher guidance templates
Manages both teacher and student models during training
def diagnostic_probe( self, task: str, max_tokens: int = 50) -> DiagnosticResult: """ Probe student understanding of given task Args: task: Task or problem statement to probe max_tokens: Maximum tokens for probe response Returns: DiagnosticResult: Assessment containing: - capability_level: Student capability (0.0-1.0) - confidence_score: Confidence in assessment (0.0-1.0) - identified_gaps: List of specific knowledge gaps - response_quality: Quality metrics for student response - probe_tokens: Number of tokens used in probe Raises: ValueError: If task is empty or max_tokens < 1 RuntimeError: If student model inference fails TypeError: If task is not a string Example: result = trainer.diagnostic_probe("Solve: 2x + 5 = 11") if result.capability_level < 0.3: print("Student needs significant help") """
def generate_guidance( self, task: str, diagnostic: DiagnosticResult, max_tokens: int = 200) -> str: """ Generate teaching guidance based on diagnosis Args: task: Original task or problem statement diagnostic: Results from diagnostic_probe() max_tokens: Maximum tokens for guidance response Returns: str: Tailored teaching guidance text Format depends on capability_level: - Low (0.0-0.3): Step-by-step walkthrough - Medium (0.3-0.7): Hints and scaffolding - High (0.7-1.0): Minimal guidance or verification Raises: ValueError: If max_tokens < 10 or diagnostic is None RuntimeError: If teacher model fails to generate guidance TypeError: If task is not a string Example: diagnostic = trainer.diagnostic_probe("Solve quadratic equation") guidance = trainer.generate_guidance( "Solve: x² - 5x + 6 = 0", diagnostic, max_tokens=150 ) print(f"Teaching guidance: {guidance}") """
TeacherGRPOTrainer expects reward_funcs to supply the evaluation signal. When you pass an instance of RIMReward, each call returns both the aggregated reward and an information dictionary containing per-judge scores, principles, and rationales. The trainer logs these details under rim_rewards and rim_explanations, making it possible to inspect accuracy, helpfulness, process, and diagnostic scores separately. To switch configurations during an experiment, update the Hydra override so that RIMReward loads either reward_system/interpretation.yaml or reward_system/interpretation_offline.yaml. The trainer does not need any code changes when you modify judge prompts, thresholds, or model choices.Source: src/atlas_core/training/algorithms/teacher_trainers.py