Diviser les données avec LightningDataModule
Vous devrez suivre la méthode de l'setup dans un LightningDataModule. Un partitionnement adéquat des ensembles de données garantit que le modèle est entraîné sur un sous-ensemble et validé sur un autre, ce qui évite le surapprentissage.
Le module d'extension « dataset » a déjà été pré-importé.
Cet exercice fait partie du cours
Modèles d'IA évolutifs avec PyTorch Lightning
Instructions
- Importez l'
random_splitpour diviser l'ensemble de données en deux parties : formation et validation. - Divisez l'ensemble de données en deux parties : formation (80 %) et validation (20 %) à l'aide de l'
random_split.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
# Import libraries
import lightning.pytorch as pl
from torch.utils.data import ____
class SplitDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
self.train_data = None
self.val_data = None
def setup(self, stage=None):
# Split the dataset into training (80%) and validation (20%)
self.____, self.____ = random_split(dataset, [____, ____])