NumPy提出了一种通过np.argmax获取数组最大值索引的方法。

我想要一个类似的东西,但返回N个最大值的索引。

例如,如果我有一个数组[1,3,2,4,5],那么nargmax(array, n=3)将返回对应于元素[5,4,3]的下标[4,3,1]。


当前回答

这将比完整排序更快,这取决于原始数组的大小和选择的大小:

>>> A = np.random.randint(0,10,10)
>>> A
array([5, 1, 5, 5, 2, 3, 2, 4, 1, 0])
>>> B = np.zeros(3, int)
>>> for i in xrange(3):
...     idx = np.argmax(A)
...     B[i]=idx; A[idx]=0 #something smaller than A.min()
...     
>>> B
array([0, 2, 3])

当然,这涉及到对原始数组的篡改。你可以修复(如果需要)通过复制或替换回原始值. ...对你的用例来说,哪个更便宜。

其他回答

Use:

def max_indices(arr, k):
    '''
    Returns the indices of the k first largest elements of arr
    (in descending order in values)
    '''
    assert k <= arr.size, 'k should be smaller or equal to the array size'
    arr_ = arr.astype(float)  # make a copy of arr
    max_idxs = []
    for _ in range(k):
        max_element = np.max(arr_)
        if np.isinf(max_element):
            break
        else:
            idx = np.where(arr_ == max_element)
        max_idxs.append(idx)
        arr_[idx] = -np.inf
    return max_idxs

它也适用于2D数组。例如,

In [0]: A = np.array([[ 0.51845014,  0.72528114],
                     [ 0.88421561,  0.18798661],
                     [ 0.89832036,  0.19448609],
                     [ 0.89832036,  0.19448609]])
In [1]: max_indices(A, 8)
Out[1]:
    [(array([2, 3], dtype=int64), array([0, 0], dtype=int64)),
     (array([1], dtype=int64), array([0], dtype=int64)),
     (array([0], dtype=int64), array([1], dtype=int64)),
     (array([0], dtype=int64), array([0], dtype=int64)),
     (array([2, 3], dtype=int64), array([1, 1], dtype=int64)),
     (array([1], dtype=int64), array([1], dtype=int64))]

In [2]: A[max_indices(A, 8)[0]][0]
Out[2]: array([ 0.89832036])

我认为最省时的方法是手动遍历数组并保持k-size的min-heap,正如其他人所提到的那样。

我还想出了一个蛮力方法:

top_k_index_list = [ ]
for i in range(k):
    top_k_index_list.append(np.argmax(my_array))
    my_array[top_k_index_list[-1]] = -float('inf')

在使用argmax获取其索引后,将最大的元素设置为一个较大的负值。然后argmax的下一次调用将返回第二大的元素。 您可以记录这些元素的原始值,并在需要时恢复它们。

您可以简单地使用字典来查找numpy数组中的前k个值和下标。 例如,如果你想找到前2个最大值和索引

import numpy as np
nums = np.array([0.2, 0.3, 0.25, 0.15, 0.1])


def TopK(x, k):
    a = dict([(i, j) for i, j in enumerate(x)])
    sorted_a = dict(sorted(a.items(), key = lambda kv:kv[1], reverse=True))
    indices = list(sorted_a.keys())[:k]
    values = list(sorted_a.values())[:k]
    return (indices, values)

print(f"Indices: {TopK(nums, k = 2)[0]}")
print(f"Values: {TopK(nums, k = 2)[1]}")


Indices: [1, 2]
Values: [0.3, 0.25]

Use:

from operator import itemgetter
from heapq import nlargest
result = nlargest(N, enumerate(your_list), itemgetter(1))

现在,结果列表将包含N个元组(index, value),其中value是最大的。

比较了编码的便捷性和速度

速度对我的需求很重要,所以我测试了这个问题的三个答案。

根据我的具体情况,对这三个答案中的代码进行了修改。

然后我比较了每种方法的速度。

编码智慧:

NPE的回答是最优雅的,也足够快地满足我的需求。 Fred foo的回答需要最多的重构来满足我的需求,但却是最快的。我选择了这个答案,因为尽管它需要更多的工作,但它并不太糟糕,并且具有显著的速度优势。 Off99555的回答是最优雅的,但也是最慢的。

测试和比较的完整代码

import numpy as np
import time
import random
import sys
from operator import itemgetter
from heapq import nlargest

''' Fake Data Setup '''
a1 = list(range(1000000))
random.shuffle(a1)
a1 = np.array(a1)

''' ################################################ '''
''' NPE's Answer Modified A Bit For My Case '''
t0 = time.time()
indices = np.flip(np.argsort(a1))[:5]
results = []
for index in indices:
    results.append((index, a1[index]))
t1 = time.time()
print("NPE's Answer:")
print(results)
print(t1 - t0)
print()

''' Fred Foos Answer Modified A Bit For My Case'''
t0 = time.time()
indices = np.argpartition(a1, -6)[-5:]
results = []
for index in indices:
    results.append((a1[index], index))
results.sort(reverse=True)
results = [(b, a) for a, b in results]
t1 = time.time()
print("Fred Foo's Answer:")
print(results)
print(t1 - t0)
print()

''' off99555's Answer - No Modification Needed For My Needs '''
t0 = time.time()
result = nlargest(5, enumerate(a1), itemgetter(1))
t1 = time.time()
print("off99555's Answer:")
print(result)
print(t1 - t0)

输出速度报告

肺水肿的回答是:

[(631934, 999999), (788104, 999998), (413003, 999997), (536514, 999996), (81029, 999995)]
0.1349949836730957

Fred Foo的回答:

[(631934, 999999), (788104, 999998), (413003, 999997), (536514, 999996), (81029, 999995)]
0.011161565780639648

off99555的回答是:

[(631934, 999999), (788104, 999998), (413003, 999997), (536514, 999996), (81029, 999995)]
0.439760684967041