Boucle de formation
Enfin, tous les efforts que vous avez consacrés à la définition des architectures du modèle et des fonctions de perte portent leurs fruits : c'est l'heure de l'entraînement ! Votre mission consiste à mettre en œuvre et à exécuter la boucle d'entraînement GAN. Remarque : une instruction d'break est placée après le premier lot de données afin d'éviter un temps d'exécution trop long.
Les deux optimiseurs, disc_opt et gen_opt, ont été initialisés en tant qu'optimiseurs d'Adam(). Les fonctions permettant de calculer les pertes que vous avez définies précédemment, gen_loss() et disc_loss(), sont à votre disposition. Une déclaration de confidentialité ( dataloader ) est également mise à votre disposition.
Rappelons que :
disc_loss()Les arguments avancés sont les suivants :gen,disc,real,cur_batch_size,z_dim.gen_loss()Les arguments avancés sont les suivants :gen,disc,cur_batch_size,z_dim.
Cet exercice fait partie du cours
Deep learning pour les images avec PyTorch
Instructions
- Calculez la perte du discriminateur à l'aide de la fonction «
disc_loss()» en lui transmettant, dans cet ordre, le générateur, le discriminateur, l'échantillon d'images réelles, la taille du lot actuel et la taille du bruit de «16», puis attribuez le résultat à «d_loss». - Veuillez calculer les gradients à l'aide de l'
d_loss. - Calculez la perte du générateur à l'aide de
gen_loss()en lui transmettant le générateur, le discriminateur, la taille du lot actuel et la taille du bruit de16, dans cet ordre, puis attribuez le résultat àg_loss. - Veuillez calculer les gradients à l'aide de l'
g_loss.
Exercice interactif pratique
Essayez cet exercice en complétant cet exemple de code.
for epoch in range(1):
for real in dataloader:
cur_batch_size = len(real)
disc_opt.zero_grad()
# Calculate discriminator loss
d_loss = ____
# Compute gradients
____
disc_opt.step()
gen_opt.zero_grad()
# Calculate generator loss
g_loss = ____
# Compute generator gradients
____
gen_opt.step()
print(f"Generator loss: {g_loss}")
print(f"Discriminator loss: {d_loss}")
break