在numpy数组上映射函数的最有效方法是什么?我目前正在做:
import numpy as np
x = np.array([1, 2, 3, 4, 5])
# Obtain array of square of each element in x
squarer = lambda t: t ** 2
squares = np.array([squarer(xi) for xi in x])
然而,这可能非常低效,因为我在将新数组转换回numpy数组之前,使用列表推导式将其构造为Python列表。我们能做得更好吗?
我已经测试了所有建议的方法加上np。数组(list(map(f, x)))和perfplot(我的一个小项目)。
消息#1:如果可以使用numpy的本机函数,就使用它。
如果你试图向量化的函数已经被向量化了(就像原始文章中的x**2例子),使用它比其他任何方法都快得多(注意对数尺度):
如果你真的需要向量化,用哪个变量并不重要。
代码重现图:
import numpy as np
import perfplot
import math
def f(x):
# return math.sqrt(x)
return np.sqrt(x)
vf = np.vectorize(f)
def array_for(x):
return np.array([f(xi) for xi in x])
def array_map(x):
return np.array(list(map(f, x)))
def fromiter(x):
return np.fromiter((f(xi) for xi in x), x.dtype)
def vectorize(x):
return np.vectorize(f)(x)
def vectorize_without_init(x):
return vf(x)
b = perfplot.bench(
setup=np.random.rand,
n_range=[2 ** k for k in range(20)],
kernels=[
f,
array_for,
array_map,
fromiter,
vectorize,
vectorize_without_init,
],
xlabel="len(x)",
)
b.save("out1.svg")
b.show()
以上所有答案都比较好,但如果您需要使用自定义函数进行映射,并且您有numpy。Ndarray,你需要保留数组的形状。
我只比较了两个,但它将保留ndarray的形状。我使用了包含100万个条目的数组进行比较。这里我使用square函数,它也内置在numpy中,具有很大的性能提升,因为需要一些东西,您可以使用您选择的函数。
import numpy, time
def timeit():
y = numpy.arange(1000000)
now = time.time()
numpy.array([x * x for x in y.reshape(-1)]).reshape(y.shape)
print(time.time() - now)
now = time.time()
numpy.fromiter((x * x for x in y.reshape(-1)), y.dtype).reshape(y.shape)
print(time.time() - now)
now = time.time()
numpy.square(y)
print(time.time() - now)
输出
>>> timeit()
1.162431240081787 # list comprehension and then building numpy array
1.0775556564331055 # from numpy.fromiter
0.002948284149169922 # using inbuilt function
在这里,你可以清楚地看到numpy.fromiter工作得很好,考虑到简单的方法,如果内置函数可用,请使用它。
编辑:原来的答案是误导性的,np。SQRT直接应用于数组,开销很小。
在多维情况下,您希望应用一个内建函数,操作1d数组numpy。Apply_along_axis是一个不错的选择,对于numpy和scipy中更复杂的函数组合也是如此。
先前的误导性陈述:
添加方法:
def along_axis(x):
return np.apply_along_axis(f, 0, x)
perfplot代码给出接近np.sqrt的性能结果。
squares = squarer(x)
数组上的算术运算以元素方式自动应用,高效的c级循环避免了适用于python级循环或理解的所有解释器开销。
您希望应用到NumPy数组elementwise的大多数函数都可以正常工作,尽管有些函数可能需要更改。例如,if不能在元素方面工作。你需要将它们转换为使用numpy.where这样的结构:
def using_if(x):
if x < 5:
return x
else:
return x**2
就变成了
def using_where(x):
return numpy.where(x < 5, x, x**2)