Skip to content

PR15: SameDiff Core & Training#10448

Open
agibsonccc wants to merge 8 commits into
masterfrom
pr/15-samediff-core-training
Open

PR15: SameDiff Core & Training#10448
agibsonccc wants to merge 8 commits into
masterfrom
pr/15-samediff-core-training

Conversation

@agibsonccc

@agibsonccc agibsonccc commented Jun 15, 2026

Copy link
Copy Markdown
Contributor

Summary

PR 15 of 22 PRs in the ag_new_release_updates_2 branch split. Merge after Layer 3 (native platform backends + DSP engine).

  • PEFT adapters: 15 adapter types via PeftModel/PeftModelFactory (LoRA, QLoRA, DoRA, AdaLoRA, LoHa, LoKr, IA3, VeRA, DyLoRA, LoftQ, prefix/prompt tuning, adapters, BitFit)
  • LoRA core: LoraLayer injects low-rank A/B matrices; LoftQInitializer minimizes quantization error; LoraAdapterCache provides LRU multi-adapter serving
  • RL alignment trainers: 9 trainers (GRPO, DPO, DAPO, DrGRPO, PPO, KTO, ORPO, SimPO, GSPO) build loss entirely as SameDiff computation graphs; gradients flow through execBackwards()
  • GRPO loss graph: generates groupSize completions, computes z-score advantages, builds clipped surrogate loss + KL penalty via native SameDiff ops on named placeholders
  • VLM GRPO: VlmGRPOTrainer extends GRPO for image-conditioned reward computation
  • Reward functions: CompositeRewardFunction (weighted sum), RuleBasedRewardFunction (format/length checks), SameDiffRewardFunction (reward as SameDiff graph)
  • Training pipelines: SFTTrainingPipeline (SFT with gradient accumulation + checkpointing), RLAlignmentPipeline (full RLHF loop)
  • Mixed-precision stack: GradientAccumulator (N micro-batch accumulation), LossScaler (dynamic FP16/BF16 scaling), FP8ScaleManager (per-tensor E4M3/E5M2 amax history)
  • SameDiff extensions: applyPeft(), getTrainableParameters(), distillFrom(), fitGRPO(), saveAdapters(), loadAdapters() added to SameDiff.java
  • Session routing: AbstractSession now routes to DynamicShapePlanCompiler when DSP is enabled; falls through on control flow ops
  • Multi-GPU memory: 3 new session memory managers allocate activation workspace via DeviceMemoryManager.selectDeviceForAllocation()
  • Legacy removal: GraphExecutioner.java and NativeGraphExecutioner.java deleted; replaced by DynamicShapePlanExecutor

What Changed

PEFT Infrastructure (org/nd4j/autodiff/samediff/peft/)

  • PeftModel.java — HuggingFace-compatible adapter wrapper; injects layers, tracks trainable params, supports mergeAndUnload()
  • PeftModelFactory.java — dispatches adapter injection by PeftType enum
  • LoraLayer.java — wraps existing linear variable with low-rank A/B matrices
  • LoraAdapterCache.java — thread-safe LRU adapter weight cache
  • LoftQInitializer.java — quantize base weight, init LoRA A/B to minimize reconstruction error

PEFT Config Classes (config/ — 19 files)

  • PeftConfig.java / PeftType.java — base config and 13-entry type enum
  • LoraConfig.java — rank, alpha, dropout, bias mode, init strategy
  • QLoraConfig.java / DoraConfig.java / AdaLoraConfig.java / DyLoraConfig.java — variant configs
  • LohaConfig.java / LokrConfig.java / IA3Config.java / VeraConfig.java — additional adapter configs
  • LoftQConfig.java / PrefixTuningConfig.java / PromptTuningConfig.java / AdapterConfig.java — remaining configs

RL Alignment Trainers (org/nd4j/autodiff/samediff/rl/)

  • GRPOTrainer.java — z-score advantages, clipped surrogate loss min(r*A, clip(r,1±eps)*A) + kl*KL(pi||ref)
  • DPOTrainer.java / DAPOTrainer.java / DrGRPOTrainer.java / PPOTrainer.java — DPO-family trainers
  • KTOTrainer.java / ORPOTrainer.java / SimPOTrainer.java / GSPOTrainer.java — additional alignment trainers
  • VlmGRPOTrainer.java — GRPO for vision-language models
  • RewardFunction.java / CompositeRewardFunction.java / RuleBasedRewardFunction.java / SameDiffRewardFunction.java — reward function hierarchy
  • RewardModelTrainer.java / SamplingStrategy.java / TopKSamplingStrategy.java — reward model and sampling

Training Pipelines (org/nd4j/autodiff/samediff/training/)

  • SFTTrainingPipeline.java — supervised fine-tuning with gradient accumulation, eval, checkpointing
  • RLAlignmentPipeline.java — full RLHF loop: freeze reference → rollout → reward → update
  • GradientAccumulator.java — accumulates execBackwards() over N micro-batches before stepping
  • LossScaler.java — doubles scale every scaleWindow steps; halves on NaN/Inf
  • FP8ScaleManager.java — per-tensor rolling amax window; E4M3 max=448.0, E5M2 max=57344.0
  • CheckpointManager.java / CheckpointOffloadManager.java — checkpoint save/load with async offload
  • ContinuedPretrainingWorkflow.java / PreferencePair.java / TrainingResult.java — supporting types

High-Level Trainer Classes

  • DistillationTrainer.java — teacher/student distillation with KL, attention, feature loss
  • RLAlignmentTrainer.java — abstract base for all RL alignment trainers
  • TransferLearning.java / TransferLearningHelper.java — freeze layers, add heads, selective fine-tuning

Training Config Classes (config/ — 18 files)

  • FineTuneConfiguration.java / SFTConfig.java / DistillationConfig.java — fine-tune and distillation configs
  • FP8TrainingConfig.java / GradientCheckpointConfig.java / LossScaleConfig.java — precision and memory configs
  • GRPOConfig.java / DPOConfig.java / PPOConfig.java / DAPOConfig.java / DrGRPOConfig.java — per-trainer configs
  • KTOConfig.java / ORPOConfig.java / SimPOConfig.java / GSPOConfig.java — remaining trainer configs
  • RLAlignmentConfig.java / RLPipelineConfig.java / RewardModelConfig.java — RL pipeline configs
  • VlmTrainingConfig.java / VlmFineTuneConfig.java / VlmGRPOConfig.java / TtsTrainingConfig.java / TtsFineTuneConfig.java — modality-specific configs
  • ContinuedPretrainingConfig.java / TaskType.java / VariableGroup.java / KernelConfiguration.java — supporting config types

SameDiff Core and Session Updates

  • SameDiff.javaapplyPeft(), getTrainableParameters(), printTrainableParameters(), distillFrom(), fitGRPO(), saveAdapters(), loadAdapters()
  • SDVariable.javasetTrainable(boolean), isTrainable(), group-based freeze/unfreeze
  • TrainingConfig.javaGradientCheckpointConfig, LossScaleConfig, FP8TrainingConfig builder methods
  • AbstractSession.java — routes to DynamicShapePlanCompiler; falls through for control flow graphs
  • InferenceSession.java — DSP plan lookup uses DspPlanDiskCache
  • TrainingSession.java — integrated GradientAccumulator, LossScaler, FP8ScaleManager

Session Memory Management

  • internal/memory/MultiGpuWorkspaceSessionMemMgr.java — allocates on device with most free memory
  • internal/memory/MultiBackendWorkspaceSessionMemMgr.java — CPU+GPU hybrid execution memory
  • internal/memory/WorkspaceSessionMemMgr.java — standard workspace memory manager
  • internal/memory/CleanupDiagnostics.java — tracks array cleanup decisions
  • internal/memory/DependencyMap.java — multi-GPU liveness dependency tracking

SameDiff Op Namespace Extensions

  • ops/SDMath.java / SDNN.java / SDLoss.java / SDCNN.java — extended with new ops
  • ops/SDAudio.java — new audio op namespace (MelSpectrogram, MFCC, Whisper)
  • ops/SDSignal.java — new signal processing namespace (DFT, STFT, windowing)
  • ops/SDTraining.java — new training ops namespace (gradient accumulation, loss scaling, EMA)
  • ops/SDBaseOps.java / SDImage.java / SDLinalg.java / SDRNN.java — updated with new ops

Serialization Extensions

  • serde/FlatBuffersMapper.java / SameDiffSerializer.java / SDZSerializer.java — extended for new ops and adapter weights
  • serde/ModelLoadingContext.java — model loading with shard support and lazy weight loading
  • serde/ModelSizeInfo.java — pre-load size estimation for memory planning

Removed Legacy Files

  • ~~GraphExecutioner.java~~ — deleted; replaced by DSP-based InferenceSession
  • ~~NativeGraphExecutioner.java~~ — deleted; replaced by DynamicShapePlanExecutor

Dependencies

  • Depends on: PR12 (GraphScope, DeviceMemoryManager, LifecycleSubsystem), PR13 (LoraMatMul, DoraMatMul, RmsNorm, GRPOConfig ops), PR14 (updated executioners and workspace managers)
  • Required by: PR16 (DSP executor uses InferenceSession, AbstractSession), PR17-PR19 (import pipelines use SameDiff training APIs)

Merge Order

This PR is in Layer 4 (SameDiff core training — parallel with PR12/PR13/PR14, all needed before PR16).

Layer PRs
0 (no deps) PR01, PR02, PR20
1 (build/infra) PR03, PR04
2 (native core) PR05, PR06, PR07
3 (native feat) PR08, PR09, PR10, PR11
4 (java core) PR12, PR13, PR14, PR15
5 (java feat) PR16
6 (import/gen) PR17, PR18, PR19, PR21
7 (validation) PR22

Part of the 22-PR split of ag_new_release_updates_2 branch.
Merge layer: 4 (java core)
Files: 119

See pr-plans/00-master-plan.md for the full split plan and merge order.
…-training

# Conflicts:
#	nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot wasn't able to review this pull request because it exceeds the maximum number of lines (20,000). Try reducing the number of changed lines and requesting a review from Copilot again.

…to slf4j

- Change visualizationEnabled default from true to false in AbstractSession
  (was firing recordStep on every inference step unconditionally)
- Convert all System.out/err.println calls to log.debug in SameDiffExecutionVisualizer
- Add @slf4j annotation
@agibsonccc

Copy link
Copy Markdown
Contributor Author

Architecture Overview

This PR extends SameDiff with training infrastructure: PEFT adapters, RL alignment trainers, mixed-precision gradient accumulation, and DSP session routing. It removes the legacy GraphExecutioner and replaces it with DynamicShapePlanExecutor, making DSP the default execution path.

Highlights

  • 15 PEFT adapter types with RL alignment trainersPeftModel/PeftModelFactory supports LoRA, QLoRA, DoRA, AdaLoRA, LoHa, LoKr, IA3, VeRA, DyLoRA, LoftQ, prefix/prompt tuning, adapters, and BitFit; 9 RL trainers (GRPO, DPO, DAPO, DrGRPO, PPO, KTO, ORPO, SimPO, GSPO) build loss entirely as SameDiff computation graphs with gradients flowing through execBackwards()
  • Mixed-precision training stackGradientAccumulator (N micro-batch accumulation), LossScaler (dynamic FP16/BF16 scaling), FP8ScaleManager (per-tensor E4M3/E5M2 amax history); AbstractSession now routes to DynamicShapePlanCompiler when DSP is enabled, with DynamicShapePlanExecutor managing warmup → freeze → CUDA graph capture → replay lifecycle

- CleanupDiagnostics: replace manual getters with @Getter, add @DaTa
  to inner CleanupResult class, remove backward-compat overload

- GradCheckUtil: remove DSP disabling (setDspAutoCompileEnabled(false))
  and replace with sd.invalidateAllPlanCaches() after each putScalar
  perturbation. DSP stays enabled throughout gradient checking.

- SameDiff: make invalidateAllPlanCaches() public for use by
  GradCheckUtil and other callers that mutate arrays in-place
- Replace all fully-qualified class names with proper imports:
  IdentityHashMap, DataBuffer, ArrayList, Arrays, Pointer, Merge, Switch

- Change log.info to log.debug for internal cache invalidation messages
  that fire frequently during normal execution

- Encapsulate 18 timing fields into TimingState inner class with
  reset() and printSummary() methods. All timing field accesses now
  go through the timing instance rather than scattered across the class.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants