Saya memiliki array seperti ini:

>>> a = np.arange(60).reshape([3,4,5])
>>> a
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

Dan saya ingin mengambil nilai k teratas di sepanjang salah satu dimensi. Misalnya saya akan memilih k=2 dan sepanjang dimensi tengah.

Saya telah mencoba menggunakan argpartition dan tampaknya melakukan hal yang benar, tetapi saya mengalami kesulitan menggunakan outputnya untuk mengambil nilai dari array asli. Inilah cara saya menggunakan argpartition:

>>> indices = np.argpartition(a, 2, axis=1)
>>> indices
array([[[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

>>> indices[:,-2:,:]
array([[[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

Tapi saya tidak bisa mendapatkan nilai dengan mengiris menggunakan indeks ini.

>>> a[:,indices[:,-2:,:],:].shape
(3, 3, 2, 5, 5)

Saya mengharapkan untuk melihat array bentuk (3,2,5) (karena saya mencari top-2 di sepanjang sumbu tengah) yang saya bayangkan terlihat seperti ini:

>>> magic_output
array([[[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

Bagaimana cara mengakses nilai menggunakan indeks dari argpartition?

2
Mark McDonald 16 Maret 2017, 14:23

2 jawaban

Jawaban Terbaik

Nah np.argpartition mendapatkan indeks k terkecil. Jadi, untuk mendapatkan indeks k teratas, kita perlu menggunakan array input yang dinegasikan di sepanjang sumbu yang diinginkan. Kemudian, kita perlu menggunakan indeks ini untuk mengindeks ke sumbu itu menggunakan NumPy's advanced-indexing dan mendapatkan output yang diinginkan.

Dengan demikian, implementasinya akan -

k = 2
m,n = a.shape[0], a.shape[2]
idx = np.argpartition(-a,k,axis=1)[:,k-1::-1]
out = a[np.arange(m)[:,None,None], idx, np.arange(n)]

Contoh lari -

1) Masukan:

In [180]: a
Out[180]: 
array([[[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])

2) Kode yang diusulkan:

In [206]: k = 2
     ...: m,n = a.shape[0], a.shape[2]
     ...: idx = np.argpartition(-a,k,axis=1)[:,k-1::-1]
     ...: out = a[np.arange(m)[:,None,None], idx, np.arange(n)]
     ...: 

3) Periksa kembali hasil dan keluaran antara:

In [207]: idx
Out[207]: 
array([[[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]],

       [[2, 2, 2, 2, 2],
        [3, 3, 3, 3, 3]]])

In [208]: out
Out[208]: 
array([[[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]],

       [[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]]])
5
Divakar 16 Maret 2017, 12:09

Anda tidak perlu argpartition cukup urutkan array Anda di sepanjang sumbu kedua menggunakan np.sort() dan pilih 2 item terakhir dari sumbu itu:

np.sort(a, 2)[:, -2:, ]

Berikut adalah contoh pada versi acak dari array Anda:

In [15]: np.random.shuffle(a)

In [16]: a
Out[16]: 
array([[[40, 41, 42, 43, 44],
        [45, 46, 47, 48, 49],
        [50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]],

       [[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[20, 21, 22, 23, 24],
        [25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]]])

In [17]: np.sort(a, 2)[:, -2:, ]
Out[17]: 
array([[[50, 51, 52, 53, 54],
        [55, 56, 57, 58, 59]],

       [[10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19]],

       [[30, 31, 32, 33, 34],
        [35, 36, 37, 38, 39]]])
1
Kasravnd 16 Maret 2017, 11:31