我知道Python列表有一个方法可以返回某个对象的第一个索引:

>>> xs = [1, 2, 3]
>>> xs.index(2)
1

NumPy数组也有类似的东西吗?


当前回答

L.index (x)返回最小的I,使得I是x在列表中第一次出现的索引。

可以放心地假设,Python中的index()函数的实现使它在找到第一个匹配后停止,这将导致最佳的平均性能。

要在NumPy数组中找到第一个匹配后停止的元素,请使用迭代器(ndenumerate)。

In [67]: l=range(100)

In [68]: l.index(2)
Out[68]: 2

NumPy数组:

In [69]: a = np.arange(100)

In [70]: next((idx for idx, val in np.ndenumerate(a) if val==2))
Out[70]: (2L,)

注意,如果没有找到元素,index()和next方法都会返回一个错误。使用next,可以使用第二个参数在未找到元素时返回一个特殊值,例如:

In [77]: next((idx for idx, val in np.ndenumerate(a) if val==400),None)

NumPy中还有其他函数(argmax, where和nonzero)可用于在数组中查找元素,但它们都有一个缺点,即遍历整个数组查找所有出现的元素,因此无法优化以查找第一个元素。还要注意,where和非零返回数组,因此需要选择第一个元素来获取索引。

In [71]: np.argmax(a==2)
Out[71]: 2

In [72]: np.where(a==2)
Out[72]: (array([2], dtype=int64),)

In [73]: np.nonzero(a==2)
Out[73]: (array([2], dtype=int64),)

时间比较

只是检查对于大型数组,当搜索项位于数组的开头时,使用迭代器的解决方案更快(在IPython shell中使用%timeit):

In [285]: a = np.arange(100000)

In [286]: %timeit next((idx for idx, val in np.ndenumerate(a) if val==0))
100000 loops, best of 3: 17.6 µs per loop

In [287]: %timeit np.argmax(a==0)
1000 loops, best of 3: 254 µs per loop

In [288]: %timeit np.where(a==0)[0][0]
1000 loops, best of 3: 314 µs per loop

这是一个开放的NumPy GitHub问题。

参见:Numpy:快速找到第一个值索引

其他回答

对于1D数组,我推荐np。平坦非零(array == value)[0],它等价于np。非零(array == value)[0][0]和np。其中(array == value)[0][0],但避免了对一个单元素元组开箱的丑陋。

要在任何标准上建立索引,你可以这样做:

In [1]: from numpy import *
In [2]: x = arange(125).reshape((5,5,5))
In [3]: y = indices(x.shape)
In [4]: locs = y[:,x >= 120] # put whatever you want in place of x >= 120
In [5]: pts = hsplit(locs, len(locs[0]))
In [6]: for pt in pts:
   .....:         print(', '.join(str(p[0]) for p in pt))
4, 4, 0
4, 4, 1
4, 4, 2
4, 4, 3
4, 4, 4

这里有一个快速函数,它可以做list.index()所做的事情,只是如果没有找到它,它不会引发异常。注意——这在大型数组上可能非常慢。如果你想把它作为一个方法,你也可以把它拼凑到数组上。

def ndindex(ndarray, item):
    if len(ndarray.shape) == 1:
        try:
            return [ndarray.tolist().index(item)]
        except:
            pass
    else:
        for i, subarray in enumerate(ndarray):
            try:
                return [i] + ndindex(subarray, item)
            except:
                pass

In [1]: ndindex(x, 103)
Out[1]: [4, 0, 3]

另一个之前没有提到的选项是bisect模块,它也适用于列表,但需要一个预先排序的列表/数组:

import bisect
import numpy as np
z = np.array([104,113,120,122,126,138])
bisect.bisect_left(z, 122)

收益率

3

Bisect还会在您要查找的数字在数组中不存在时返回一个结果,以便将该数字插入正确的位置。

NumPy中有很多操作可以放在一起来完成这个任务。这将返回等于item的元素的下标:

numpy.nonzero(array - item)

然后你可以取列表的第一个元素来得到一个元素。

8种方法的比较

TL; diana:

(注:适用于100M元素以下的1d数组)

为了获得最佳性能,请使用index_of__v5 (numba + numpy. 5)。枚举+ for循环;参见下面的代码)。 如果numba不可用: 如果期望在前100k个元素中找到目标值,请使用index_of__v7 (for循环+枚举)。 否则使用index_of__v2/v3/v4 (numpy. exe)。Argmax或numpy。基于flatnonzero)。

由perfplot提供

import numpy as np
from numba import njit

# Based on: numpy.argmax()
# Proposed by: John Haberstroh (https://stackoverflow.com/a/67497472/7204581)
def index_of__v1(arr: np.array, v):
    is_v = (arr == v)
    return is_v.argmax() if is_v.any() else -1


# Based on: numpy.argmax()
def index_of__v2(arr: np.array, v):
    return (arr == v).argmax() if v in arr else -1


# Based on: numpy.flatnonzero()
# Proposed by: 1'' (https://stackoverflow.com/a/42049655/7204581)
def index_of__v3(arr: np.array, v):
    idxs = np.flatnonzero(arr == v)
    return idxs[0] if len(idxs) > 0 else -1


# Based on: numpy.argmax()
def index_of__v4(arr: np.array, v):
    return np.r_[False, (arr == v)].argmax() - 1


# Based on: numba, for loop
# Proposed by: MSeifert (https://stackoverflow.com/a/41578614/7204581)
@njit
def index_of__v5(arr: np.array, v):
    for idx, val in np.ndenumerate(arr):
        if val == v:
            return idx[0]
    return -1


# Based on: numpy.ndenumerate(), for loop
def index_of__v6(arr: np.array, v):
    return next((idx[0] for idx, val in np.ndenumerate(arr) if val == v), -1)


# Based on: enumerate(), for loop
# Proposed by: Noyer282 (https://stackoverflow.com/a/40426159/7204581)
def index_of__v7(arr: np.array, v):
    return next((idx for idx, val in enumerate(arr) if val == v), -1)


# Based on: list.index()
# Proposed by: Hima (https://stackoverflow.com/a/23994923/7204581)
def index_of__v8(arr: np.array, v):
    l = list(arr)
    try:
        return l.index(v)
    except ValueError:
        return -1

去Colab