Saya tahu cara menyimpan dan memuat nn.Model, tetapi tidak dapat menemukan cara membuat pos pemeriksaan untuk nn.Parameter. Saya mencoba versi ini, tetapi pengoptimal tidak mengubah nilai nn.Parameter setelah memulihkan.

from torch import nn as nn
import torch
from torch.optim import Adam

alpha = torch.ones(10)
lr = 0.001
alpha = nn.Parameter(alpha)
print(alpha)
alpha_optimizer = Adam([alpha], lr=lr)

for i in range(10):
   alpha_loss = - alpha.mean()
   alpha_optimizer.zero_grad()
   alpha_loss.backward()
   alpha_optimizer.step()
   print(alpha)
path = "./test.pt"
state = dict(alpha_optimizer=alpha_optimizer.state_dict(),
               alpha=alpha)
torch.save(state, path)
checkpoint = torch.load(path)
alpha = checkpoint["alpha"]
alpha_optimizer.load_state_dict(checkpoint["alpha_optimizer"])
for i in range(10):
   alpha_loss = - alpha.mean()
   alpha_optimizer.zero_grad()
   alpha_loss.backward()
   alpha_optimizer.step()
   print(alpha)
3
Andrii Zadaianchuk 12 Mei 2021, 17:30

1 menjawab

Jawaban Terbaik

Masalahnya adalah pengoptimal masih dengan referensi ke alpha lama (periksa id(alpha) vs. id(alpha_optimizer.param_groups[0]["params"][0]) sebelum loop for terakhir), sementara yang baru disetel saat Anda memuatnya dari pos pemeriksaan di alpha = checkpoint["alpha"].

Anda perlu memperbarui parameter pengoptimal sebelum memuat statusnya:

# ....
torch.save(state, path)
checkpoint = torch.load(path)

# here's where the reference of alpha changes, and the source of the problem
alpha = checkpoint["alpha"]

# reset optim
alpha_optimizer = Adam([alpha], lr=lr)
alpha_optimizer.load_state_dict(checkpoint["alpha_optimizer"])

for i in range(10):
   alpha_loss = - alpha.mean()
   alpha_optimizer.zero_grad()
   alpha_loss.backward()
   alpha_optimizer.step()
   print(alpha)
2
Berriel 12 Mei 2021, 17:33