Saya sangat pemula dalam hal pembelajaran mesin. Jadi untuk tujuan pembelajaran saya mencoba mengembangkan CNN sederhana untuk mengklasifikasikan bidak catur. Jaring sudah berfungsi dan saya dapat melatihnya tetapi saya memiliki masalah dengan fungsi validasi saya.

Saya tidak dapat membandingkan prediksi saya dengan target_data karena prediksi saya hanya tensor ukuran 13 sedangkan target.data adalah [batch_size]x13. Saya tidak tahu di mana kesalahan saya. Contoh PyTorch hampir semuanya menggunakan fungsi ini untuk membandingkan prediksi dengan data target.

Akan sangat bagus jika ada yang bisa membantu saya di sini.

Anda dapat mencari kode lainnya di sini: https://github.com/ michaelwolz/ChessML/blob/master/train.ipynb

def validate(model, validation_data, criterion):
    model.eval()
    loss = 0
    correct = 0

    for i in range(len(validation_data)):
        data, target = validation_data[i][0], validation_data[i][1]
        target = torch.Tensor(target)

        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()

        out = model(data)

        loss += criterion(out, target).item()

        _, prediction = torch.max(out.data, 1)
        correct += (prediction == target.data).sum().item()

    loss = loss / len(validation_data)
    print("###################################")
    print("Average loss:", loss)
    print("Accuracy:", 100. * correct / len(validation_data))
    print("###################################")

Kesalahan:

<ipython-input-6-6b21e2bfb8a6> in validate(model, validation_data, 

criterion)
     17 
     18         _, prediction = torch.max(out.data, 1)
---> 19         correct += (prediction == target.data).sum().item()
     20 
     21     loss = loss / len(validation_data)

RuntimeError: The size of tensor a (32) must match the size of tensor b (13) at non-singleton dimension 1

Sunting: Label saya terlihat seperti ini:

[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

Setiap indeks mewakili satu kelas.
Output dari fungsi torch.max() tampaknya menjadi indeks kelas. Saya tidak mengerti bagaimana saya bisa membandingkan indeks dengan target_label. Maksud saya, saya hanya bisa menulis fungsi yang memeriksa apakah ada 1 pada indeks yang diprediksi tetapi saya pikir kesalahan saya ada di tempat lain.

0
Michael Wolz 7 Maret 2019, 17:59

1 menjawab

Jawaban Terbaik

Cukup jalankan "argmax" pada target juga:

_, target = torch.max(target.data, 1)

Atau lebih baik lagi, pertahankan target sebagai [example_1_class, example_2_class, ...], alih-alih encoding 1-hot.

1
dedObed 7 Maret 2019, 15:54