PyTorch Lightning in sintesi
PyTorch Lightning è un wrapper leggero che separa la logica del modello dalla logica di training, permettendo di scrivere codice più pulito e modulare. Con pochi cambiamenti è possibile passare da un semplice script a una pipeline distribuita.
La struttura tipica prevede le classi LightningModule e Trainer. Il primo incapsula modello, ottimizzatore e funzioni di loss; il secondo gestisce loop di training, validazione, logging e checkpointing.
Distribuzione multi-GPU e multi-node
L’API Trainer supporta diversi backend: ddp, dp e ddp_spawn. Basta impostare l’attributo accelerator='gpu' e il parametro devices=2 per eseguire il training su due GPU. Per cluster più grandi, PyTorch Lightning si integra con Ray o SLURM senza modifiche al codice modello.
Checkpointing intelligente e profiling
L’opzione enable_checkpointing=True salva automaticamente lo stato del modello nei punti di miglioramento della validazione. Con auto_lr_find=True è possibile eseguire un’analisi automatica dei learning rate, riducendo i tempi di ricerca iperparametri.
Profiling per ottimizzare le prestazioni
L’interfaccia Profiler('simple') fornisce metriche dettagliate su tempo di forward/backward e utilizzo della GPU. Queste informazioni aiutano a identificare colli di bottiglia, come operazioni CUDA non sincronizzate o carichi di memoria elevati.