PyTorch Lightning in sintesi

PyTorch Lightning è un wrapper leggero che separa la logica del modello dalla logica di training, permettendo di scrivere codice più pulito e modulare. Con pochi cambiamenti è possibile passare da un semplice script a una pipeline distribuita.

La struttura tipica prevede le classi LightningModule e Trainer. Il primo incapsula modello, ottimizzatore e funzioni di loss; il secondo gestisce loop di training, validazione, logging e checkpointing.

Distribuzione multi-GPU e multi-node

L’API Trainer supporta diversi backend: ddp, dp e ddp_spawn. Basta impostare l’attributo accelerator='gpu' e il parametro devices=2 per eseguire il training su due GPU. Per cluster più grandi, PyTorch Lightning si integra con Ray o SLURM senza modifiche al codice modello.

Checkpointing intelligente e profiling

L’opzione enable_checkpointing=True salva automaticamente lo stato del modello nei punti di miglioramento della validazione. Con auto_lr_find=True è possibile eseguire un’analisi automatica dei learning rate, riducendo i tempi di ricerca iperparametri.

Profiling per ottimizzare le prestazioni

L’interfaccia Profiler('simple') fornisce metriche dettagliate su tempo di forward/backward e utilizzo della GPU. Queste informazioni aiutano a identificare colli di bottiglia, come operazioni CUDA non sincronizzate o carichi di memoria elevati.