NumPy提出了一种通过np.argmax获取数组最大值索引的方法。

我想要一个类似的东西,但返回N个最大值的索引。

例如,如果我有一个数组[1,3,2,4,5],那么nargmax(array, n=3)将返回对应于元素[5,4,3]的下标[4,3,1]。


当前回答

如果你正在处理nan和/或理解np有问题。试试pandas.DataFrame.sort_values。

import numpy as np
import pandas as pd    

a = np.array([9, 4, 4, 3, 3, 9, 0, 4, 6, 0])

df = pd.DataFrame(a, columns=['array'])
max_values = df['array'].sort_values(ascending=False, na_position='last')
ind = max_values[0:3].index.to_list()

这个例子给出了3个最大的非nan值的索引。可能效率很低,但易于阅读和定制。

其他回答

Use:

from operator import itemgetter
from heapq import nlargest
result = nlargest(N, enumerate(your_list), itemgetter(1))

现在,结果列表将包含N个元组(index, value),其中value是最大的。

我能想到的最简单的是:

>>> import numpy as np
>>> arr = np.array([1, 3, 2, 4, 5])
>>> arr.argsort()[-3:][::-1]
array([4, 3, 1])

这涉及到一个完整的数组。我想知道numpy是否提供了一种内置的方法来进行部分排序;到目前为止我还没有找到。

如果这个解决方案太慢(特别是对于小n),那么可能值得考虑用Cython编写一些东西。

对于多维数组,可以使用axis关键字,以便沿着预期的轴应用分区。

# For a 2D array
indices = np.argpartition(arr, -N, axis=1)[:, -N:]

对于抓取物品:

x = arr.shape[0]
arr[np.repeat(np.arange(x), N), indices.ravel()].reshape(x, N)

但请注意,这不会返回一个排序的结果。在这种情况下,你可以沿着预期的轴使用np.argsort():

indices = np.argsort(arr, axis=1)[:, -N:]

# Result
x = arr.shape[0]
arr[np.repeat(np.arange(x), N), indices.ravel()].reshape(x, N)

这里有一个例子:

In [42]: a = np.random.randint(0, 20, (10, 10))

In [44]: a
Out[44]:
array([[ 7, 11, 12,  0,  2,  3,  4, 10,  6, 10],
       [16, 16,  4,  3, 18,  5, 10,  4, 14,  9],
       [ 2,  9, 15, 12, 18,  3, 13, 11,  5, 10],
       [14,  0,  9, 11,  1,  4,  9, 19, 18, 12],
       [ 0, 10,  5, 15,  9, 18,  5,  2, 16, 19],
       [14, 19,  3, 11, 13, 11, 13, 11,  1, 14],
       [ 7, 15, 18,  6,  5, 13,  1,  7,  9, 19],
       [11, 17, 11, 16, 14,  3, 16,  1, 12, 19],
       [ 2,  4, 14,  8,  6,  9, 14,  9,  1,  5],
       [ 1, 10, 15,  0,  1,  9, 18,  2,  2, 12]])

In [45]: np.argpartition(a, np.argmin(a, axis=0))[:, 1:] # 1 is because the first item is the minimum one.
Out[45]:
array([[4, 5, 6, 8, 0, 7, 9, 1, 2],
       [2, 7, 5, 9, 6, 8, 1, 0, 4],
       [5, 8, 1, 9, 7, 3, 6, 2, 4],
       [4, 5, 2, 6, 3, 9, 0, 8, 7],
       [7, 2, 6, 4, 1, 3, 8, 5, 9],
       [2, 3, 5, 7, 6, 4, 0, 9, 1],
       [4, 3, 0, 7, 8, 5, 1, 2, 9],
       [5, 2, 0, 8, 4, 6, 3, 1, 9],
       [0, 1, 9, 4, 3, 7, 5, 2, 6],
       [0, 4, 7, 8, 5, 1, 9, 2, 6]])

In [46]: np.argpartition(a, np.argmin(a, axis=0))[:, -3:]
Out[46]:
array([[9, 1, 2],
       [1, 0, 4],
       [6, 2, 4],
       [0, 8, 7],
       [8, 5, 9],
       [0, 9, 1],
       [1, 2, 9],
       [3, 1, 9],
       [5, 2, 6],
       [9, 2, 6]])

In [89]: a[np.repeat(np.arange(x), 3), ind.ravel()].reshape(x, 3)
Out[89]:
array([[10, 11, 12],
       [16, 16, 18],
       [13, 15, 18],
       [14, 18, 19],
       [16, 18, 19],
       [14, 14, 19],
       [15, 18, 19],
       [16, 17, 19],
       [ 9, 14, 14],
       [12, 15, 18]])

比较了编码的便捷性和速度

速度对我的需求很重要,所以我测试了这个问题的三个答案。

根据我的具体情况,对这三个答案中的代码进行了修改。

然后我比较了每种方法的速度。

编码智慧:

NPE的回答是最优雅的,也足够快地满足我的需求。 Fred foo的回答需要最多的重构来满足我的需求,但却是最快的。我选择了这个答案,因为尽管它需要更多的工作,但它并不太糟糕,并且具有显著的速度优势。 Off99555的回答是最优雅的,但也是最慢的。

测试和比较的完整代码

import numpy as np
import time
import random
import sys
from operator import itemgetter
from heapq import nlargest

''' Fake Data Setup '''
a1 = list(range(1000000))
random.shuffle(a1)
a1 = np.array(a1)

''' ################################################ '''
''' NPE's Answer Modified A Bit For My Case '''
t0 = time.time()
indices = np.flip(np.argsort(a1))[:5]
results = []
for index in indices:
    results.append((index, a1[index]))
t1 = time.time()
print("NPE's Answer:")
print(results)
print(t1 - t0)
print()

''' Fred Foos Answer Modified A Bit For My Case'''
t0 = time.time()
indices = np.argpartition(a1, -6)[-5:]
results = []
for index in indices:
    results.append((a1[index], index))
results.sort(reverse=True)
results = [(b, a) for a, b in results]
t1 = time.time()
print("Fred Foo's Answer:")
print(results)
print(t1 - t0)
print()

''' off99555's Answer - No Modification Needed For My Needs '''
t0 = time.time()
result = nlargest(5, enumerate(a1), itemgetter(1))
t1 = time.time()
print("off99555's Answer:")
print(result)
print(t1 - t0)

输出速度报告

肺水肿的回答是:

[(631934, 999999), (788104, 999998), (413003, 999997), (536514, 999996), (81029, 999995)]
0.1349949836730957

Fred Foo的回答:

[(631934, 999999), (788104, 999998), (413003, 999997), (536514, 999996), (81029, 999995)]
0.011161565780639648

off99555的回答是:

[(631934, 999999), (788104, 999998), (413003, 999997), (536514, 999996), (81029, 999995)]
0.439760684967041

如果你碰巧在使用一个多维数组,那么你需要平展和解开索引:

def largest_indices(ary, n):
    """Returns the n largest indices from a numpy array."""
    flat = ary.flatten()
    indices = np.argpartition(flat, -n)[-n:]
    indices = indices[np.argsort(-flat[indices])]
    return np.unravel_index(indices, ary.shape)

例如:

>>> xs = np.sin(np.arange(9)).reshape((3, 3))
>>> xs
array([[ 0.        ,  0.84147098,  0.90929743],
       [ 0.14112001, -0.7568025 , -0.95892427],
       [-0.2794155 ,  0.6569866 ,  0.98935825]])
>>> largest_indices(xs, 3)
(array([2, 0, 0]), array([2, 2, 1]))
>>> xs[largest_indices(xs, 3)]
array([ 0.98935825,  0.90929743,  0.84147098])