Saya tidak mengerti baris ini:

lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)

Tidak ada komentar, jadi apakah ini idiom Python (atau PyTorch?) yang terkenal? Bisakah seseorang menjelaskan apa artinya, atau menunjukkan cara lain yang membuat maksudnya lebih jelas?

lprobs adalah pytorch Tensor, dan dapat berisi semua tipe float ukuran (saya ragu kode ini dimaksudkan untuk mendukung tipe int atau kompleks). Sejauh yang saya tahu, kelas Tensor tidak menimpa fungsi __ne__.

4
Darren Cook 12 Mei 2021, 11:09

1 menjawab

Jawaban Terbaik

Ini adalah kombinasi dari pengindeksan mewah dengan topeng boolean, dan "trik" (meskipun dirancang) untuk memeriksa NaN: x != x berlaku jika x adalah NaN (untuk float, yaitu).

Mereka bisa saja menulis

lprobs[torch.isnan(lprobs)] = torch.tensor(-math.inf).to(lprobs)

Atau, mungkin lebih idiomatis, menggunakan torch.nan_to_num (tapi hati-hati bahwa yang terakhir juga memiliki perilaku khusus terhadap tak terhingga).

Varian yang tidak memperbarui di atas adalah

torch.where(torch.isnan(lprobs), torch.tensor(-math.inf), lprobs)
7
phipsgabler 12 Mei 2021, 11:50