Training Guide¶
Train LandmarkDiff from scratch on your own data. This guide covers every step: data preparation, configuration, single-GPU and multi-GPU training, curriculum learning, monitoring, checkpointing, and SLURM submission.
Overview¶
Training has two phases:
- Phase A (synthetic data, diffusion loss only): teaches the model to generate faces conditioned on deformed landmark meshes
- Phase B (clinical data, full loss): fine-tunes on real surgical before/after pairs with identity and perceptual losses
The multi-term loss for Phase B:
Phase B resumes from a Phase A checkpoint.
Prerequisites¶
# Install training dependencies
pip install -e ".[train]"
# Verify PyTorch + CUDA
python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else \"none\"}')"
The [train] extra includes wandb, deepspeed, webdataset, and accelerate.
Data Preparation Pipeline¶
Step 1: Download face images¶
# FFHQ faces (5K for quick experiments, 50K for full training)
python scripts/download_ffhq.py --num 5000 --resolution 512 --output data/ffhq_samples/
# Or use multiple sources for diversity
python scripts/download_faces_multi.py --num 10000 --output data/faces_all/
Step 2: Generate synthetic training pairs¶
Each training pair consists of: - Input image: original face (512x512) - Conditioning image: deformed wireframe mesh overlay - Target image: TPS-warped face (what the model should produce) - Procedure label: which procedure was applied
python scripts/generate_synthetic_data.py \
--input data/ffhq_samples/ \
--output data/synthetic_pairs/ \
--num 50000
# Check what was generated
python scripts/dataset_stats.py data/synthetic_pairs/
For large datasets, use the SLURM parallel generation script:
The generator randomly samples from all registered procedures (rhinoplasty, blepharoplasty, rhytidectomy, orthognathic, brow_lift, mentoplasty) unless you restrict it with --procedure.
Step 3: (Optional) Add clinical augmentations¶
Clinical photos have different characteristics than FFHQ: variable lighting, JPEG compression, color temperature shifts, and noise. Adding these augmentations to synthetic data helps the model generalize:
python scripts/augment_pairs.py \
--input data/synthetic_pairs/ \
--output data/augmented_pairs/ \
--augmentations lighting,color_temp,jpeg,noise
Clinical augmentation can also be applied online during training by setting clinical_augment: true in your config (Phase B only; disabled by default for Phase A).
Step 4: Build the combined training dataset¶
python scripts/build_training_dataset.py \
--input data/synthetic_pairs/ \
--output data/training_combined/
This creates the directory structure expected by the training script: *_input.png, *_conditioning.png, *_target.png triplets plus a metadata.json for curriculum learning.
Step 5: Create train/val splits¶
python scripts/create_test_split.py \
--data_dir data/training_combined/ \
--output_dir data/splits/ \
--val_fraction 0.05
Step 6: Run the preflight check¶
Before submitting a training job, verify everything is in order:
The preflight script checks dataset completeness, metadata presence, val/test splits, config validity, dependency installation, GPU availability, disk space, and existing checkpoints for resume.
Configuration with YAML¶
All training parameters live in YAML config files under configs/. You can either edit a config file or pass CLI arguments to override individual settings.
Available configs¶
| Config | Steps | Purpose |
|---|---|---|
phaseA_quick.yaml |
500 | Smoke test, debug loop |
phaseA_default.yaml |
10,000 | Quick Phase A validation |
phaseA_production.yaml |
50,000 | Full Phase A production run |
phaseA_v3_curriculum.yaml |
100,000 | Phase A with curriculum learning |
phaseB.yaml |
25,000 | Phase B fine-tuning |
phaseB_production.yaml |
50,000 | Full Phase B production run |
phaseB_identity.yaml |
-- | Phase B with identity emphasis |
Config structure¶
# configs/phaseA_v3_curriculum.yaml
experiment_name: phaseA_v3_curriculum
model:
base_model: runwayml/stable-diffusion-v1-5
controlnet_conditioning_channels: 3
controlnet_conditioning_scale: 1.0
use_ema: true
ema_decay: 0.9999
gradient_checkpointing: true
training:
phase: A
learning_rate: 1.0e-5
batch_size: 4
gradient_accumulation_steps: 4 # effective batch = 16
max_train_steps: 100000
warmup_steps: 1000
mixed_precision: bf16 # never fp16
seed: 42
optimizer: adamw
adam_beta1: 0.9
adam_beta2: 0.999
weight_decay: 0.01
max_grad_norm: 1.0
lr_scheduler: cosine
save_every_n_steps: 10000
resume_from_checkpoint: auto
validate_every_n_steps: 5000
num_validation_samples: 8
data:
train_dir: data/training_combined
val_dir: data/validation
image_size: 512
num_workers: 8
random_flip: true
random_rotation: 5.0
color_jitter: 0.1
procedures:
- rhinoplasty
- blepharoplasty
- rhytidectomy
- orthognathic
displacement_model_path: data/displacement_model.npz
wandb:
enabled: true
project: landmarkdiff
tags: [phase-a, curriculum, v3-data]
output_dir: outputs/phaseA_v3_curriculum
CLI overrides¶
Any config field can be overridden on the command line:
python scripts/train_controlnet.py \
--config configs/phaseA_v3_curriculum.yaml \
--learning_rate 5e-6 \
--batch_size 2 \
--num_train_steps 5000
Single GPU Training¶
The simplest way to start:
python scripts/train_controlnet.py \
--data_dir data/training_combined/ \
--output_dir checkpoints/ \
--learning_rate 1e-5 \
--train_batch_size 4 \
--gradient_accumulation_steps 4 \
--num_train_steps 50000 \
--checkpoint_every 5000 \
--sample_every 1000 \
--resume_from_checkpoint latest \
--phase A
Or with a config file:
Dry run¶
To test the training loop without actually running for many steps:
GPU memory recommendations¶
| GPU | VRAM | Batch size | Gradient accumulation | Effective batch |
|---|---|---|---|---|
| P100 | 16 GB | 2 | 8 | 16 |
| V100 | 32 GB | 4 | 4 | 16 |
| A6000 | 48 GB | 4-8 | 2-4 | 16-32 |
| A100 (40 GB) | 40 GB | 4 | 4 | 16 |
| A100 (80 GB) | 80 GB | 8 | 4 | 32 |
| H100 | 80 GB | 8 | 4 | 32 |
| L40S | 48 GB | 4-8 | 2-4 | 16-32 |
If you run out of VRAM, enable gradient checkpointing (gradient_checkpointing: true in the model config) and reduce batch size.
Multi-GPU Distributed Training (DDP)¶
The training script automatically detects PyTorch DDP environment variables (RANK, LOCAL_RANK, WORLD_SIZE) and activates distributed mode. Use torchrun to launch:
2 GPUs on a single node¶
torchrun --nproc_per_node=2 scripts/train_controlnet.py \
--config configs/phaseA_v3_curriculum.yaml \
--output_dir checkpoints/phaseA_ddp/
4 GPUs on a single node¶
torchrun --nproc_per_node=4 scripts/train_controlnet.py \
--config configs/phaseA_v3_curriculum.yaml \
--output_dir checkpoints/phaseA_ddp/ \
--train_batch_size 4 \
--gradient_accumulation_steps 2
With 4 GPUs, batch size 4, and gradient accumulation 2, the effective batch size is 4 * 4 * 2 = 32.
DDP behavior¶
- Only rank 0 saves checkpoints, logs to WandB, and generates sample images
- All ranks participate in gradient computation and synchronization
- The learning rate does not need to be scaled; effective batch size increases naturally through more GPUs
- Use
NCCLbackend for GPU-to-GPU communication (the default on Linux)
Multi-node training (advanced)¶
For training across multiple nodes on a SLURM cluster:
#!/bin/bash
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gres=gpu:4
srun torchrun \
--nnodes=$SLURM_JOB_NUM_NODES \
--nproc_per_node=4 \
--rdzv_id=$SLURM_JOB_ID \
--rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
scripts/train_controlnet.py \
--config configs/phaseA_production.yaml
Curriculum Learning (Phase A to Phase B)¶
Phase A: Synthetic data, diffusion loss only¶
Phase A trains the ControlNet to follow landmark mesh conditioning using synthetic TPS-warped data. The model learns mesh-to-face generation without needing real surgical data.
Key Phase A settings:
- Loss: L_diffusion only
- Data: synthetic TPS pairs
- Learning rate: 1e-5
- LR schedule: cosine decay with warmup
- Steps: 50,000-100,000 depending on dataset size
Phase B: Clinical data, full multi-term loss¶
Phase B fine-tunes the Phase A checkpoint on paired clinical data with the full 4-term loss. This phase requires before/after surgical image pairs.
Key Phase B settings:
- Loss: L_diffusion + 0.1 * L_identity + 0.05 * L_perceptual + 0.1 * L_mask
- Data: clinical before/after pairs with augmentation
- Learning rate: 5e-6 (lower than Phase A)
- Resume from: checkpoints/phaseA/latest
- Steps: 25,000-50,000
The Phase B config automatically loads the Phase A checkpoint:
training:
phase: B
resume_from: checkpoints/phaseA/latest
loss_weights:
diffusion: 1.0
identity: 0.1
perceptual: 0.05
mask: 0.1
Curriculum progression¶
The phaseA_v3_curriculum.yaml config supports progressive difficulty scheduling across waves of data. Each wave introduces more varied and challenging training examples. Metadata in metadata.json tracks which wave each pair belongs to, and the dataloader can sample accordingly.
Monitoring with Weights & Biases¶
Online mode (local machines)¶
Check https://wandb.ai for live loss curves, sample generations, and system metrics.
Offline mode (HPC clusters)¶
Most HPC clusters have restricted internet access. Use offline mode:
After training completes, sync the offline run from a machine with internet access:
Key metrics to watch¶
Phase A (target at 50K steps):
| Metric | Target | Notes |
|---|---|---|
| Training loss | < 0.15 | Should decrease monotonically |
| FID | < 120 | Improves with more data and steps |
| Generated samples | -- | Faces should follow landmark structure |
Phase B (target at 50K steps):
| Metric | Target | Notes |
|---|---|---|
| FID | < 50 | Significant improvement over Phase A |
| NME (landmark error) | < 0.05 | Landmarks match surgical plan |
| Identity similarity | > 0.85 | ArcFace cosine similarity |
| SSIM | > 0.80 | Structural similarity to target |
Real-time monitoring¶
# Follow the SLURM log
tail -f slurm-*.out
# Use the training dashboard
python scripts/training_dashboard.py --output_dir checkpoints/
# Plot loss curves from existing runs
python scripts/plot_training_curves.py --run_dir outputs/phaseA_v3_curriculum/
Checkpointing and Resume¶
Automatic checkpointing¶
Checkpoints are saved every save_every_n_steps (default: 5,000 for Phase A, 1,000 for Phase B). Each checkpoint contains the ControlNet weights, optimizer state, EMA weights, and the training step count.
Resume from interruption¶
Set resume_from_checkpoint: auto (or latest) in your config, or pass --resume_from_checkpoint=latest on the CLI. The training script will find the most recent checkpoint and continue from that step.
# Explicit resume from a specific checkpoint
python scripts/train_controlnet.py \
--config configs/phaseA_default.yaml \
--resume_from_checkpoint checkpoints/checkpoint-15000
EMA weights¶
Exponential Moving Average (EMA) weights at decay rate 0.9999 are maintained throughout training and saved alongside each checkpoint. Use EMA weights for inference; they produce smoother, more stable outputs. The training script saves separate checkpoint-*-ema directories.
Evaluate a checkpoint¶
python scripts/evaluate_checkpoint.py \
--checkpoint checkpoints/checkpoint-50000 \
--test_dir data/splits/test/ \
--output eval_results/
SLURM Submission¶
Single-GPU SLURM job¶
The provided scripts/train_slurm.sh handles everything for a single-GPU run, including preemption handling:
#!/bin/bash
#SBATCH --job-name=surgery_controlnet
#SBATCH --partition=batch_gpu
#SBATCH --account=your_gpu_acc
#SBATCH --gres=gpu:1
#SBATCH --mem=64G
#SBATCH --cpus-per-task=8
#SBATCH --time=48:00:00
#SBATCH --output=slurm-%j.out
#SBATCH --signal=B:USR1@300
#SBATCH --requeue
# === Skip-logic: don't rerun if already completed ===
CKPT_DIR="/path/to/LandmarkDiff/checkpoints"
FINAL_STEP=50000
if [ -d "$CKPT_DIR/checkpoint-${FINAL_STEP}" ]; then
echo "Training already complete at step ${FINAL_STEP}. Exiting."
exit 0
fi
# === Critical HPC safeguards ===
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export WANDB_MODE=offline
# Trap preemption signal -> save checkpoint -> requeue
trap 'echo "Caught USR1 - saving checkpoint..."; kill -INT $TRAIN_PID; wait $TRAIN_PID; scontrol requeue $SLURM_JOB_ID' USR1
WORK_DIR="/path/to/LandmarkDiff"
DATA_DIR="${WORK_DIR}/data/training_combined"
WANDB_DIR="${WORK_DIR}/wandb"
mkdir -p "$CKPT_DIR" "$WANDB_DIR"
cd "$WORK_DIR"
# Activate environment
source $HOME/miniconda3/etc/profile.d/conda.sh
conda activate landmarkdiff
python scripts/train_controlnet.py \
--data_dir=$DATA_DIR \
--output_dir=$CKPT_DIR \
--wandb_dir=$WANDB_DIR \
--learning_rate=1e-5 \
--train_batch_size=4 \
--gradient_accumulation_steps=4 \
--num_train_steps=${FINAL_STEP} \
--checkpoint_every=5000 \
--sample_every=1000 \
--resume_from_checkpoint=latest \
--phase=A &
TRAIN_PID=$!
wait $TRAIN_PID
Key SLURM features:
--signal=B:USR1@300: sends USR1 to the job 300 seconds before wall-time expiration--requeue: requeues the job after preemptiontrap ... USR1: catches the signal, sends SIGINT to the training process (triggering a checkpoint save), then requeues--resume_from_checkpoint=latest: picks up from the last saved checkpoint after requeue- Skip-logic: checks if training is already complete before starting
Submit with:
Monitor the job:
squeue -u $USER # check queue status
tail -f slurm-*.out # follow training log
sacct -j <jobid> --format=JobID,State,Elapsed,MaxRSS # check resource usage
Multi-GPU SLURM job¶
For multi-GPU DDP training on SLURM, replace the python command with torchrun:
#SBATCH --gres=gpu:4
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=32
torchrun --nproc_per_node=4 scripts/train_controlnet.py \
--config configs/phaseA_production.yaml \
--output_dir=$CKPT_DIR \
--resume_from_checkpoint=latest &
TRAIN_PID=$!
wait $TRAIN_PID
Critical Training Safeguards¶
These settings are non-negotiable. Training will produce garbage without them:
| Safeguard | Setting | Why |
|---|---|---|
| Mixed precision | BF16 only | FP16 overflows on SD1.5 activations |
| VAE | Frozen | Gradient leak corrupts the entire latent space |
| EMA | 0.9999 | Without it, checkpoints have high-frequency artifacts |
| Normalization | GroupNorm | BatchNorm is unstable at batch size 4 |
| LR schedule | Cosine | Constant LR causes late-stage oscillation |
| Grad clipping | max_norm 1.0 | Prevents gradient explosions |
| Resume | --resume_from_checkpoint=latest |
Preemption restarts from step 0 without this |
| SLURM signal | --signal=B:USR1@300 |
Saves checkpoint before wall-time preemption |
| Phase A loss | L_diffusion only | Perceptual loss against TPS warps penalizes realism |
| TPS warps | Pre-computed | On-the-fly TPS CPU-bottlenecks the GPU |
| ControlNet scale | max 1.2 | Values above 1.2 cause activation saturation |
Common Training Issues and Fixes¶
Loss spikes or NaN¶
- Check that mixed precision is set to
bf16, notfp16. FP16 overflow is the most common cause. - Verify
max_grad_normis set (1.0 is a good default). - If the gradient watchdog fires (logged as
GradientWatchdog: explosion detected), the emergency save will trigger automatically. Resume from the last clean checkpoint.
Generated images are blank or noisy¶
- Confirm the VAE is frozen (
vae_frozen: true). Unfreezing the VAE is the #1 cause of latent space collapse. - Check that EMA is enabled. Non-EMA checkpoints often have high-frequency noise.
- Inspect training samples at
sample_everyintervals. If early samples (step 1000) show no face structure at all, the conditioning images may be malformed.
Training is slow¶
- Enable gradient checkpointing to trade compute for memory, allowing a larger batch size.
- Pre-compute TPS warps instead of generating them on-the-fly. The CPU-to-GPU data transfer is the bottleneck.
- Use
num_workers: 8(or more, up to your CPU count) for the dataloader. - On HPC with Lustre, set file striping:
lfs setstripe -c -1 $DATA_DIR
SLURM job keeps restarting from step 0¶
- Make sure
--resume_from_checkpoint=latestis set. - Verify checkpoint files exist in the output directory:
ls checkpoints/checkpoint-* - Check that the SLURM trap is correctly wired. The training process must receive SIGINT (not SIGKILL) to save a checkpoint.
Out of disk space mid-training¶
- Each SD1.5 checkpoint is about 2 GB. A 100K-step run with checkpoints every 10K steps produces ~20 GB of checkpoints.
- Point
output_dirto a scratch filesystem with enough space. - Use
python scripts/clean_data.pyto prune old checkpoints, keeping only the latest N.
WandB issues on HPC¶
- Set
WANDB_MODE=offlinebefore training. - Sync after the job completes:
wandb sync outputs/*/wandb/latest-run/ - If WandB is not installed, the training script falls back to console logging automatically.
Phase B identity loss not decreasing¶
- The identity loss uses ArcFace embeddings. Make sure
insightfaceandonnxruntimeare installed:pip install insightface onnxruntime. - If identity similarity starts very low (<0.3), the Phase A checkpoint may not be generating recognizable faces yet. Train Phase A longer.
Next Steps¶
- Evaluation: evaluate your trained checkpoints
- Custom Procedures: add new surgical procedures
- GPU Training Guide: HPC-specific setup details
- Deployment: deploy your trained model