PR15: SameDiff Core & Training#10448
Open
agibsonccc wants to merge 8 commits into
Open
Conversation
Part of the 22-PR split of ag_new_release_updates_2 branch. Merge layer: 4 (java core) Files: 119 See pr-plans/00-master-plan.md for the full split plan and merge order.
…-training # Conflicts: # nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java
This was referenced Jun 15, 2026
…to slf4j - Change visualizationEnabled default from true to false in AbstractSession (was firing recordStep on every inference step unconditionally) - Convert all System.out/err.println calls to log.debug in SameDiffExecutionVisualizer - Add @slf4j annotation
Contributor
Author
Architecture OverviewThis PR extends SameDiff with training infrastructure: PEFT adapters, RL alignment trainers, mixed-precision gradient accumulation, and DSP session routing. It removes the legacy Highlights
|
- CleanupDiagnostics: replace manual getters with @Getter, add @DaTa to inner CleanupResult class, remove backward-compat overload - GradCheckUtil: remove DSP disabling (setDspAutoCompileEnabled(false)) and replace with sd.invalidateAllPlanCaches() after each putScalar perturbation. DSP stays enabled throughout gradient checking. - SameDiff: make invalidateAllPlanCaches() public for use by GradCheckUtil and other callers that mutate arrays in-place
- Replace all fully-qualified class names with proper imports: IdentityHashMap, DataBuffer, ArrayList, Arrays, Pointer, Merge, Switch - Change log.info to log.debug for internal cache invalidation messages that fire frequently during normal execution - Encapsulate 18 timing fields into TimingState inner class with reset() and printSummary() methods. All timing field accesses now go through the timing instance rather than scattered across the class.
Apply Java compilation fixes for SameDiff training files.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
PR 15 of 22 PRs in the
ag_new_release_updates_2branch split. Merge after Layer 3 (native platform backends + DSP engine).PeftModel/PeftModelFactory(LoRA, QLoRA, DoRA, AdaLoRA, LoHa, LoKr, IA3, VeRA, DyLoRA, LoftQ, prefix/prompt tuning, adapters, BitFit)LoraLayerinjects low-rank A/B matrices;LoftQInitializerminimizes quantization error;LoraAdapterCacheprovides LRU multi-adapter servingexecBackwards()groupSizecompletions, computes z-score advantages, builds clipped surrogate loss + KL penalty via native SameDiff ops on named placeholdersVlmGRPOTrainerextends GRPO for image-conditioned reward computationCompositeRewardFunction(weighted sum),RuleBasedRewardFunction(format/length checks),SameDiffRewardFunction(reward as SameDiff graph)SFTTrainingPipeline(SFT with gradient accumulation + checkpointing),RLAlignmentPipeline(full RLHF loop)GradientAccumulator(N micro-batch accumulation),LossScaler(dynamic FP16/BF16 scaling),FP8ScaleManager(per-tensor E4M3/E5M2 amax history)applyPeft(),getTrainableParameters(),distillFrom(),fitGRPO(),saveAdapters(),loadAdapters()added toSameDiff.javaAbstractSessionnow routes toDynamicShapePlanCompilerwhen DSP is enabled; falls through on control flow opsDeviceMemoryManager.selectDeviceForAllocation()GraphExecutioner.javaandNativeGraphExecutioner.javadeleted; replaced byDynamicShapePlanExecutorWhat Changed
PEFT Infrastructure (
org/nd4j/autodiff/samediff/peft/)PeftModel.java— HuggingFace-compatible adapter wrapper; injects layers, tracks trainable params, supportsmergeAndUnload()PeftModelFactory.java— dispatches adapter injection byPeftTypeenumLoraLayer.java— wraps existing linear variable with low-rank A/B matricesLoraAdapterCache.java— thread-safe LRU adapter weight cacheLoftQInitializer.java— quantize base weight, init LoRA A/B to minimize reconstruction errorPEFT Config Classes (
config/— 19 files)PeftConfig.java/PeftType.java— base config and 13-entry type enumLoraConfig.java— rank, alpha, dropout, bias mode, init strategyQLoraConfig.java/DoraConfig.java/AdaLoraConfig.java/DyLoraConfig.java— variant configsLohaConfig.java/LokrConfig.java/IA3Config.java/VeraConfig.java— additional adapter configsLoftQConfig.java/PrefixTuningConfig.java/PromptTuningConfig.java/AdapterConfig.java— remaining configsRL Alignment Trainers (
org/nd4j/autodiff/samediff/rl/)GRPOTrainer.java— z-score advantages, clipped surrogate lossmin(r*A, clip(r,1±eps)*A) + kl*KL(pi||ref)DPOTrainer.java/DAPOTrainer.java/DrGRPOTrainer.java/PPOTrainer.java— DPO-family trainersKTOTrainer.java/ORPOTrainer.java/SimPOTrainer.java/GSPOTrainer.java— additional alignment trainersVlmGRPOTrainer.java— GRPO for vision-language modelsRewardFunction.java/CompositeRewardFunction.java/RuleBasedRewardFunction.java/SameDiffRewardFunction.java— reward function hierarchyRewardModelTrainer.java/SamplingStrategy.java/TopKSamplingStrategy.java— reward model and samplingTraining Pipelines (
org/nd4j/autodiff/samediff/training/)SFTTrainingPipeline.java— supervised fine-tuning with gradient accumulation, eval, checkpointingRLAlignmentPipeline.java— full RLHF loop: freeze reference → rollout → reward → updateGradientAccumulator.java— accumulatesexecBackwards()over N micro-batches before steppingLossScaler.java— doubles scale everyscaleWindowsteps; halves on NaN/InfFP8ScaleManager.java— per-tensor rolling amax window; E4M3 max=448.0, E5M2 max=57344.0CheckpointManager.java/CheckpointOffloadManager.java— checkpoint save/load with async offloadContinuedPretrainingWorkflow.java/PreferencePair.java/TrainingResult.java— supporting typesHigh-Level Trainer Classes
DistillationTrainer.java— teacher/student distillation with KL, attention, feature lossRLAlignmentTrainer.java— abstract base for all RL alignment trainersTransferLearning.java/TransferLearningHelper.java— freeze layers, add heads, selective fine-tuningTraining Config Classes (
config/— 18 files)FineTuneConfiguration.java/SFTConfig.java/DistillationConfig.java— fine-tune and distillation configsFP8TrainingConfig.java/GradientCheckpointConfig.java/LossScaleConfig.java— precision and memory configsGRPOConfig.java/DPOConfig.java/PPOConfig.java/DAPOConfig.java/DrGRPOConfig.java— per-trainer configsKTOConfig.java/ORPOConfig.java/SimPOConfig.java/GSPOConfig.java— remaining trainer configsRLAlignmentConfig.java/RLPipelineConfig.java/RewardModelConfig.java— RL pipeline configsVlmTrainingConfig.java/VlmFineTuneConfig.java/VlmGRPOConfig.java/TtsTrainingConfig.java/TtsFineTuneConfig.java— modality-specific configsContinuedPretrainingConfig.java/TaskType.java/VariableGroup.java/KernelConfiguration.java— supporting config typesSameDiff Core and Session Updates
SameDiff.java—applyPeft(),getTrainableParameters(),printTrainableParameters(),distillFrom(),fitGRPO(),saveAdapters(),loadAdapters()SDVariable.java—setTrainable(boolean),isTrainable(), group-based freeze/unfreezeTrainingConfig.java—GradientCheckpointConfig,LossScaleConfig,FP8TrainingConfigbuilder methodsAbstractSession.java— routes toDynamicShapePlanCompiler; falls through for control flow graphsInferenceSession.java— DSP plan lookup usesDspPlanDiskCacheTrainingSession.java— integratedGradientAccumulator,LossScaler,FP8ScaleManagerSession Memory Management
internal/memory/MultiGpuWorkspaceSessionMemMgr.java— allocates on device with most free memoryinternal/memory/MultiBackendWorkspaceSessionMemMgr.java— CPU+GPU hybrid execution memoryinternal/memory/WorkspaceSessionMemMgr.java— standard workspace memory managerinternal/memory/CleanupDiagnostics.java— tracks array cleanup decisionsinternal/memory/DependencyMap.java— multi-GPU liveness dependency trackingSameDiff Op Namespace Extensions
ops/SDMath.java/SDNN.java/SDLoss.java/SDCNN.java— extended with new opsops/SDAudio.java— new audio op namespace (MelSpectrogram, MFCC, Whisper)ops/SDSignal.java— new signal processing namespace (DFT, STFT, windowing)ops/SDTraining.java— new training ops namespace (gradient accumulation, loss scaling, EMA)ops/SDBaseOps.java/SDImage.java/SDLinalg.java/SDRNN.java— updated with new opsSerialization Extensions
serde/FlatBuffersMapper.java/SameDiffSerializer.java/SDZSerializer.java— extended for new ops and adapter weightsserde/ModelLoadingContext.java— model loading with shard support and lazy weight loadingserde/ModelSizeInfo.java— pre-load size estimation for memory planningRemoved Legacy Files
~~GraphExecutioner.java~~— deleted; replaced by DSP-basedInferenceSession~~NativeGraphExecutioner.java~~— deleted; replaced byDynamicShapePlanExecutorDependencies
InferenceSession,AbstractSession), PR17-PR19 (import pipelines use SameDiff training APIs)Merge Order
This PR is in Layer 4 (SameDiff core training — parallel with PR12/PR13/PR14, all needed before PR16).