我知道Python列表有一个方法可以返回某个对象的第一个索引:
>>> xs = [1, 2, 3]
>>> xs.index(2)
1
NumPy数组也有类似的东西吗?
我知道Python列表有一个方法可以返回某个对象的第一个索引:
>>> xs = [1, 2, 3]
>>> xs.index(2)
1
NumPy数组也有类似的东西吗?
当前回答
从np.where()中选择第一个元素的替代方法是使用生成器表达式和enumerate,例如:
>>> import numpy as np
>>> x = np.arange(100) # x = array([0, 1, 2, 3, ... 99])
>>> next(i for i, x_i in enumerate(x) if x_i == 2)
2
对于二维数组,可以这样做:
>>> x = np.arange(100).reshape(10,10) # x = array([[0, 1, 2,... 9], [10,..19],])
>>> next((i,j) for i, x_i in enumerate(x)
... for j, x_ij in enumerate(x_i) if x_ij == 2)
(0, 2)
这种方法的优点是,它在找到第一个匹配后停止检查数组的元素,而np。Where检查所有元素是否匹配。如果在数组的前面有匹配,生成器表达式会更快。
其他回答
如果你只需要第一次出现一个值的索引,你可以使用nonzero(或where,在这种情况下相当于相同的东西):
>>> t = array([1, 1, 1, 2, 2, 3, 8, 3, 8, 8])
>>> nonzero(t == 8)
(array([6, 8, 9]),)
>>> nonzero(t == 8)[0][0]
6
如果需要多个值中的每个值的第一个索引,显然可以重复执行上述操作,但有一个技巧可能更快。下面的代码查找每个子序列的第一个元素的下标:
>>> nonzero(r_[1, diff(t)[:-1]])
(array([0, 3, 5, 6, 7, 8]),)
注意,它找到了3s的子序列和8s的子序列的开头:
[1, 1, 1, 2, 2, 3, 8, 3, 8, 8]
这和求每个值的第一次出现有点不同。在你的程序中,你可以使用t的排序版本来得到你想要的:
>>> st = sorted(t)
>>> nonzero(r_[1, diff(st)[:-1]])
(array([0, 3, 5, 7]),)
只是添加一个非常高性能和方便的numba替代np。Ndenumerate来查找第一个索引:
from numba import njit
import numpy as np
@njit
def index(array, item):
for idx, val in np.ndenumerate(array):
if val == item:
return idx
# If no item was found return None, other return types might be a problem due to
# numbas type inference.
这非常快,并且自然地处理多维数组:
>>> arr1 = np.ones((100, 100, 100))
>>> arr1[2, 2, 2] = 2
>>> index(arr1, 2)
(2, 2, 2)
>>> arr2 = np.ones(20)
>>> arr2[5] = 2
>>> index(arr2, 2)
(5,)
这比任何使用np的方法都要快得多(因为它使操作短路)。Where或np. non0。
然而np。Argwhere也可以优雅地处理多维数组(你需要手动将它转换为元组,而且不会短路),但如果没有找到匹配,它就会失败:
>>> tuple(np.argwhere(arr1 == 2)[0])
(2, 2, 2)
>>> tuple(np.argwhere(arr2 == 2)[0])
(5,)
另一个之前没有提到的选项是bisect模块,它也适用于列表,但需要一个预先排序的列表/数组:
import bisect
import numpy as np
z = np.array([104,113,120,122,126,138])
bisect.bisect_left(z, 122)
收益率
3
Bisect还会在您要查找的数字在数组中不存在时返回一个结果,以便将该数字插入正确的位置。
8种方法的比较
TL; diana:
(注:适用于100M元素以下的1d数组)
为了获得最佳性能,请使用index_of__v5 (numba + numpy. 5)。枚举+ for循环;参见下面的代码)。 如果numba不可用: 如果期望在前100k个元素中找到目标值,请使用index_of__v7 (for循环+枚举)。 否则使用index_of__v2/v3/v4 (numpy. exe)。Argmax或numpy。基于flatnonzero)。
由perfplot提供
import numpy as np
from numba import njit
# Based on: numpy.argmax()
# Proposed by: John Haberstroh (https://stackoverflow.com/a/67497472/7204581)
def index_of__v1(arr: np.array, v):
is_v = (arr == v)
return is_v.argmax() if is_v.any() else -1
# Based on: numpy.argmax()
def index_of__v2(arr: np.array, v):
return (arr == v).argmax() if v in arr else -1
# Based on: numpy.flatnonzero()
# Proposed by: 1'' (https://stackoverflow.com/a/42049655/7204581)
def index_of__v3(arr: np.array, v):
idxs = np.flatnonzero(arr == v)
return idxs[0] if len(idxs) > 0 else -1
# Based on: numpy.argmax()
def index_of__v4(arr: np.array, v):
return np.r_[False, (arr == v)].argmax() - 1
# Based on: numba, for loop
# Proposed by: MSeifert (https://stackoverflow.com/a/41578614/7204581)
@njit
def index_of__v5(arr: np.array, v):
for idx, val in np.ndenumerate(arr):
if val == v:
return idx[0]
return -1
# Based on: numpy.ndenumerate(), for loop
def index_of__v6(arr: np.array, v):
return next((idx[0] for idx, val in np.ndenumerate(arr) if val == v), -1)
# Based on: enumerate(), for loop
# Proposed by: Noyer282 (https://stackoverflow.com/a/40426159/7204581)
def index_of__v7(arr: np.array, v):
return next((idx for idx, val in enumerate(arr) if val == v), -1)
# Based on: list.index()
# Proposed by: Hima (https://stackoverflow.com/a/23994923/7204581)
def index_of__v8(arr: np.array, v):
l = list(arr)
try:
return l.index(v)
except ValueError:
return -1
去Colab
对于我的用例,我不能提前对数组排序,因为元素的顺序很重要。这是我的全部numpy实现:
import numpy as np
# The array in question
arr = np.array([1,2,1,2,1,5,5,3,5,9])
# Find all of the present values
vals=np.unique(arr)
# Make all indices up-to and including the desired index positive
cum_sum=np.cumsum(arr==vals.reshape(-1,1),axis=1)
# Add zeros to account for the n-1 shape of diff and the all-positive array of the first index
bl_mask=np.concatenate([np.zeros((cum_sum.shape[0],1)),cum_sum],axis=1)>=1
# The desired indices
idx=np.where(np.diff(bl_mask))[1]
# Show results
print(list(zip(vals,idx)))
>>> [(1, 0), (2, 1), (3, 7), (5, 5), (9, 9)]
我认为它解释了重复值的无序数组。