Saya menggunakan pytorch untuk membuat model sederhana seperti VGG16, dan saya telah membebani fungsi forward dalam model saya.

Saya menemukan semua orang cenderung menggunakan model(input) untuk mendapatkan output daripada model.forward(input), dan saya tertarik pada perbedaan di antara mereka. Saya mencoba memasukkan data yang sama, tetapi hasilnya berbeda. Saya bingung.

Saya telah mengeluarkan layer_weight sebelum saya memasukkan data, bobotnya tidak diubah, dan saya tahu ketika kami menggunakan model(input) menggunakan fungsi __call__, dan fungsi ini akan memanggil model.forward.

   vgg = VGG()
   vgg.double()
   for layer in vgg.modules():
      if isinstance(layer,torch.nn.Linear):
         print(layer.weight)
   print("   use model.forward(input)     ")
   result = vgg.forward(array)

   for layer in vgg.modules():
     if isinstance(layer,torch.nn.Linear):
       print(layer.weight) 
   print("   use model(input)     ")
   result_2 = vgg(array)
   print(result)
   print(result_2)

Keluaran:

    Variable containing:1.00000e-02 *
    -0.2931  0.6716 -0.3497 -2.0217 -0.0764  1.2162  1.4983 -1.2881
    [torch.DoubleTensor of size 1x8]

    Variable containing:
    1.00000e-02 *
    0.5302  0.4494 -0.6866 -2.1657 -0.9504  1.0211  0.8308 -1.1665
    [torch.DoubleTensor of size 1x8]
6
KaguyaSan 25 Maret 2019, 16:19

1 menjawab

Jawaban Terbaik

model.forward hanya memanggil operasi penerusan seperti yang Anda sebutkan tetapi __call__ melakukan sedikit tambahan.

Jika Anda menggali kode dari nn.Module Anda akan melihat __call__ pada akhirnya memanggil ke depan tetapi secara internal menangani kait maju atau mundur dan mengelola beberapa status yang diizinkan oleh pytorch. Saat memanggil model sederhana seperti MLP, itu mungkin tidak benar-benar diperlukan tetapi model yang lebih kompleks seperti lapisan normalisasi spektral memiliki kait dan oleh karena itu Anda harus menggunakan tanda tangan model(.) sebanyak mungkin kecuali Anda secara eksplisit hanya ingin memanggil model.forward

Lihat juga Memanggil fungsi penerusan tanpa .forward()

Namun, dalam hal ini, perbedaannya mungkin karena beberapa lapisan putus, Anda harus memanggil vgg.eval() untuk memastikan semua stokastik dalam jaringan dimatikan sebelum membandingkan output.

8
Umang Gupta 25 Maret 2019, 17:54