Skip to main content

Overview

Generalized Knowledge Distillation (GKD) enables on-policy distillation of Atlas runtime traces into smaller, faster student models. Instead of re-running a full reinforcement-learning cycle, use GKD to compress a reliable teacher checkpoint while staying within the reward-weighted data captured in production.

Atlas Training Stack Fit

Problem: Atlas SDK deployments often produce a reliable teacher policy plus verified runtime traces, yet the teacher checkpoint is too large or costly to redeploy, and running the full SFT → GRPO stack again would add days of compute. Solution: Use GKD to transfer the teacher’s action distribution to a smaller student directly from the same on-policy traces, reducing latency while respecting the reward-weighted data collected in production. Technical Implementation: The run config (configs/run/teacher_gkd.yaml) wires the GKD trainer (configs/trainer/gkd.yaml) and the Postgres-backed dataset builder in trainers.gkd_dataset. Trace exports match the schema captured in the MCP Tool Learning example, so you can replay identical workloads through distillation or GRPO. For the reinforcement-learning path, see grpo-training.mdx; both flows follow the override patterns documented in the Training Configuration Guide.

Core Capabilities

GKD handles multi-turn traces directly from Postgres, applies the same tokenization and augmentation stack that GRPO uses, and records baseline-comparison telemetry (success delta and token efficiency) in WandB for side-by-side evaluation with other Atlas trainers. The same Hydra overrides that configure the GRPO pipeline apply here, so swapping between distillation and reinforcement learning becomes a config change rather than a separate code path.

When to Use GKD vs GRPO

CriterionGKDGRPO
Data sourceAtlas runtime tracesInteractive environment
Compute costLow (supervised + KL)High (PPO + rollouts)
SpeedFast (single pass)Slow (multi-epoch)
Best forDistill teacher → studentTrain from scratch with RL
Training timeHoursDays
Rule of thumb: Use GKD when you have Atlas traces and want to distill knowledge into a smaller model. Use GRPO when training a new policy from scratch via reinforcement learning.

Quick Start

Prerequisites

For local smoke tests (for example scripts/validate_gkd.py) install torch>=2.0.0 with CUDA or BF16 support alongside transformers==4.51.1, accelerate==1.4.0, trl==0.14.0, datasets==3.2.0, and bitsandbytes>=0.48 when you plan to load either model in 8-bit. Everything else in requirements-py312.txt supports wider Atlas workflows (Hydra configs, serving, logging) and is not required for the math validation run.
# Ensure TRL is installed
pip install "trl>=0.12.0"

# Set database URL
export ATLAS_DB_URL="postgresql://user:pass@host:5432/atlas"

Basic Training

Train a distilled student model from Atlas traces:
python train.py \
  --config-name teacher_gkd \
  teacher_model_name_or_path=Qwen/Qwen2.5-14B-Instruct \
  model.model_name_or_path=Qwen/Qwen2.5-7B-Instruct \
  trainer.min_reward=0.8
These same Qwen checkpoints (Qwen/Qwen2.5-14B-Instruct teacher and Qwen/Qwen2.5-7B-Instruct student) are used by scripts/validate_gkd.py to measure the GSM8K lift before scaling to customer traces. Hydra composes the trainer from configs/run/teacher_gkd.yaml, which overrides _global_.trainer=gkd and the shared data presets documented in Training Configuration. Override any field inline (for example trainer.learning_key) using the same syntax shown in the GRPO guide. Running the command streams Atlas runtime traces directly from Postgres (the same database you populate with arc-atlas --database-url postgresql://... --include-status approved --output traces/runtime.jsonl), fine-tunes the 7B student to mimic the 14B teacher, logs the Baseline Comparison metrics (success delta and token reduction) to WandB, and writes checkpoints into outputs/gkd/. If you prefer to operate on JSONL exports, point the dataset adapter at the CLI output; both paths preserve the SDK schema so Atlas Core sees identical conversation records. For a step-by-step walkthrough of exporting traces, running the validation script, and reading the metrics file, see the Developer Example: Running GKD.

Configuration Files

GKD training uses two main config files. The trainer config (configs/trainer/gkd.yaml) houses the GKD-specific hyperparameters (lmbda, beta, temperature) plus database connectivity and general training settings. The run config (configs/run/teacher_gkd.yaml) specifies student and teacher checkpoints, the Baseline Comparison reference metrics, and the output/logging targets.

Key Parameters

GKD Parameters

lmbda (On-Policy Fraction)

Set lmbda to 1.0 when the student should generate every response and receive teacher guidance token by token; this setting keeps the run fully on-policy and matches the configuration we use for the GSM8K validation sweeps. A mid-range value such as 0.5 mixes teacher- and student-generated continuations when you want to temper exploration, while 0.0 reverts to supervised distillation. Start with 1.0 and only back off when trace quality is noisy or you need consistency with earlier supervised checkpoints (Issue #40 recommendation).

beta (KL Divergence Balance)

beta tunes the interpolation inside the generalized Jensen-Shannon Divergence. Push it toward 0.0 to emphasize the teacher distribution (forward KL), move toward 1.0 to penalize deviations from the student distribution (reverse KL), and stay at the default 0.5 for a balanced view. Start at 0.5, then sweep ±0.2 once you have telemetry on success deltas and token savings.

temperature

Sampling temperature controls how aggressively the student explores during generation. Values around 0.9 yield more diverse reasoning chains and match the public validation scripts; dropping to 0.5–0.7 tightens answers when traces have deterministic formats. Stick with 0.9 for early runs and lower it only after observing long completions or oscillating eval loss.

Database Filtering

min_reward

Minimum session reward threshold for training data:
trainer:
  min_reward: 0.8  # Only use high-quality traces
Higher values (0.8-0.9) ensure training on successful sessions only.

learning_key

Filter traces by task type:
trainer:
  learning_key: "crm_workflows"  # Only CRM-related traces
Set to null to use all traces.

Baseline Comparison Metrics

Track distillation quality against baseline:
trainer:
  baseline_success: 0.75     # Baseline task success rate
  baseline_tokens: 1200      # Baseline tokens per episode
Atlas logs metrics/success_delta, metrics/token_reduction_pct, and metrics/meets_target alongside loss curves so you can judge whether the distilled checkpoint clears your success and token budgets without exporting separate spreadsheets.

Example Configurations

High-Quality Distillation

For maximum quality, use higher reward threshold and more epochs:
# configs/run/gkd_high_quality.yaml
defaults:
  - override /trainer@_global_: gkd

teacher_model_name_or_path: Qwen/Qwen2.5-14B-Instruct
model_name_or_path: Qwen/Qwen2.5-7B-Instruct

trainer:
  min_reward: 0.9           # Only excellent traces
  num_train_epochs: 5       # More training
  learning_rate: 3e-6       # Lower learning rate

Fast Iteration

For rapid experimentation:
# configs/run/gkd_fast.yaml
defaults:
  - override /trainer@_global_: gkd

trainer:
  min_reward: 0.7
  num_train_epochs: 1
  eval_steps: 50
  save_steps: 50

Task-Specific Distillation

For a specific workflow:
trainer:
  learning_key: "crm_contact_management"
  min_reward: 0.85
  baseline_success: 0.78

Monitoring Training

WandB Metrics

The trainer streams train/loss, train/learning_rate, eval/loss, eval/success_rate, eval/avg_tokens, and the Baseline Comparison trio (metrics/success_delta, metrics/token_reduction_pct, metrics/meets_target) to WandB so you can correlate convergence with task-level improvements without building a custom dashboard.

Command Line Output

Starting GKD training with Baseline Comparison reference: success=75.00%, tokens=1200
Loaded datasets: train=850, eval=150 conversations
AtlasGKDTrainer initialized with lmbda=1.0, beta=0.5

Epoch 1/3
  Step 100: loss=0.245, eval_loss=0.312
  ✅ Baseline Comparison targets MET: success delta=12.3 pp, token reduction=35.2%

Epoch 2/3
  Step 200: loss=0.198, eval_loss=0.276
  ✅ Baseline Comparison targets MET: success delta=14.1 pp, token reduction=38.7%

Training Telemetry Example

Recent GSM8K validation runs illustrate what a healthy loop looks like. With the stock configuration (lmbda=1.0, temperature=0.9, min_reward=0.8, learning_rate=2e-5, max_steps=500), training loss fell from 0.0676 to 0.0294 within 500 steps (epoch ≈0.21) and evaluation loss followed closely (0.0437 → 0.0397 → 0.0394). Gradient norms stayed between 0.88 and 1.70 and the entire run finished in 11h 45m on a single DGX Spark. When we extended the same configuration across the full GSM8K export, evaluation loss continued drifting downward toward 0.0341 while gradients remained below 1.3 even as the learning rate decayed to 7e-6. Use those ranges as a reference point—if training loss stalls above ~0.04 with steady gradients, tighten min_reward or lower temperature before changing optimizer settings.

Troubleshooting

Issue: “No conversations found in database”

Cause: Database connectivity or filtering too strict. Solutions: Confirm ATLAS_DB_URL points to the right cluster, run atlas.training_data.get_training_sessions() to verify rows are available, temporarily lower trainer.min_reward, and clear learning_key to widen the task filter while debugging.

Issue: “Out of memory (OOM) during training”

Cause: Student + teacher models exceed GPU memory. Solutions: Reduce the effective batch by setting per_device_train_batch_size=2 with gradient_accumulation_steps=8, keep gradient checkpointing enabled (default), load the teacher in 8-bit via teacher_model_init_kwargs.load_in_8bit=true, or drop to a smaller teacher checkpoint when the hardware budget is tight.

Issue: “Metrics not improving”

Possible causes and solutions: If metrics stall, first extend the schedule (num_train_epochs: 5) so the loss has room to decay, then sweep beta toward 0.3 or 0.7 to change the KL emphasis. Raising min_reward to 0.85 filters out marginal traces, and if the student remains constrained, move to a larger base checkpoint before adjusting optimizer settings.

Issue: “Training too slow”

Cap max_steps (for example 500 instead of multi-epoch sweeps), relax eval_steps to 200 to avoid constant validation passes, or set limit: 5000 on the dataset loader when you only need a smoke test.

Advanced Topics

Continual Learning with GKD

To preserve existing skills while distilling new knowledge:
# Include rehearsal data from previous tasks
trainer:
  learning_key: null  # All tasks
  min_reward: 0.85
Monitor for catastrophic forgetting using regression tests.

Multi-Stage Distillation

Distill progressively smaller models:
14B Teacher → 7B Student₁ → 3B Student₂ → 1.5B Student₃
Each stage uses the previous student as the teacher:
# Stage 1: 14B → 7B
python train.py --config-name teacher_gkd \\
  teacher_model_name_or_path=Qwen/Qwen2.5-14B-Instruct \\
  model.model_name_or_path=Qwen/Qwen2.5-7B-Instruct

# Stage 2: 7B → 3B
python train.py --config-name teacher_gkd \\
  teacher_model_name_or_path=outputs/gkd_7b/final \\
  model.model_name_or_path=Qwen/Qwen2.5-3B-Instruct

Integration with Arc-CRM-Benchmark

For Stage 3 evaluation (Issue #42):
# 1. Train distilled model from baseline reference traces
python train.py --config-name teacher_gkd \\
  trainer.learning_key="crm_workflows" \\
  trainer.min_reward=0.8

# 2. Evaluate distilled model (no guidance)
# ... use arc-crm-benchmark evaluation scripts

API Reference

AtlasGKDTrainer

from trainers import AtlasGKDTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GKDConfig

student = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-7B")
teacher = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-14B")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B")

args = GKDConfig(
    output_dir="outputs/gkd",
    per_device_train_batch_size=4,
    lmbda=1.0,
    beta=0.5,
)

trainer = AtlasGKDTrainer(
    model=student,
    teacher_model=teacher,
    args=args,
    db_url="postgresql://localhost:5432/atlas",
    min_reward=0.8,
    processing_class=tokenizer,
)

trainer.train()
trainer.save_model("outputs/gkd/final")

Dataset Functions

from trainers.gkd_dataset import build_gkd_dataset

train_ds, eval_ds = build_gkd_dataset(
    db_url="postgresql://localhost:5432/atlas",
    min_reward=0.8,
    learning_key="crm_workflows",
    eval_split=0.15,
)

Baseline Comparison Metrics

from trainers.gkd_evaluator import compute_baseline_summary

summary = compute_baseline_summary(
    eval_results,
    baseline_success=0.75,
    baseline_tokens=1200,
)

print(f"Success delta: {summary['success_delta']*100:.1f} pp")
print(f"Token reduction: {summary['token_reduction_pct']:.1f}%")
print(f"Meets targets: {summary['meets_all_targets']}")

Next Steps

Review the Configuration Reference for override syntax, compare with the reinforcement-learning workflow described in grpo-training.mdx, and validate distilled checkpoints with the Evaluation Harnesses.

Status and Contributions

GKD support in Atlas Core remains beta. We’re building expanded dataset filters, staged teacher → student schedules, and deeper telemetry hooks; track progress in Issue #40. Contributions are welcome—open a PR with trace snippets, Hydra overrides, or MCP-focused repro steps (the MCP Tool Learning example is the easiest shared workload) so we can iterate on the trainer together. For a qualitative look at how teams alternate between fast validation runs and longer reliability sweeps, see the ongoing gkd_two_gear_gkd_blog_draft.md research note.

References

Consult the On-Policy Distillation paper for the underlying method, the TRL GKDTrainer docs for library configuration, and Issue #40 for Atlas-specific implementation notes.