在三维空间中有两个点
a = (ax, ay, az)
b = (bx, by, bz)
我想计算它们之间的距离:
dist = sqrt((ax-bx)^2 + (ay-by)^2 + (az-bz)^2)
我如何用NumPy做到这一点?我有:
import numpy
a = numpy.array((ax, ay, az))
b = numpy.array((bx, by, bz))
在三维空间中有两个点
a = (ax, ay, az)
b = (bx, by, bz)
我想计算它们之间的距离:
dist = sqrt((ax-bx)^2 + (ay-by)^2 + (az-bz)^2)
我如何用NumPy做到这一点?我有:
import numpy
a = numpy.array((ax, ay, az))
b = numpy.array((bx, by, bz))
当前回答
从Python 3.8开始
从Python 3.8开始,数学模块包含了math.dist()函数。 请看这里https://docs.python.org/3.8/library/math.html#math.dist。
数学。dist (p1, p2) 返回两点p1和p2之间的欧氏距离, 每一个都以坐标序列(或可迭代对象)给出。
import math
print( math.dist( (0,0), (1,1) )) # sqrt(2) -> 1.4142
print( math.dist( (0,0,0), (1,1,1) )) # sqrt(3) -> 1.7321
其他回答
这种解决问题方法的另一个例子:
def dist(x,y):
return numpy.sqrt(numpy.sum((x-y)**2))
a = numpy.array((xa,ya,za))
b = numpy.array((xb,yb,zb))
dist_a_b = dist(a,b)
我想用各种表演笔记来阐述这个简单的答案。Np.linalg.norm可能会做的比你需要的更多:
dist = numpy.linalg.norm(a-b)
首先,这个函数被设计用于处理一个列表并返回所有的值,例如比较pA到点集sP的距离:
sP = set(points)
pA = point
distances = np.linalg.norm(sP - pA, ord=2, axis=1.) # 'distances' is a list
记住几件事:
Python函数调用的开销很大。 [常规]Python不缓存名称查找。
So
def distance(pointA, pointB):
dist = np.linalg.norm(pointA - pointB)
return dist
并不像看上去那么无辜。
>>> dis.dis(distance)
2 0 LOAD_GLOBAL 0 (np)
2 LOAD_ATTR 1 (linalg)
4 LOAD_ATTR 2 (norm)
6 LOAD_FAST 0 (pointA)
8 LOAD_FAST 1 (pointB)
10 BINARY_SUBTRACT
12 CALL_FUNCTION 1
14 STORE_FAST 2 (dist)
3 16 LOAD_FAST 2 (dist)
18 RETURN_VALUE
首先,每次我们调用它时,我们都必须对“np”进行全局查找,对“linalg”进行范围查找,对“norm”进行范围查找,而仅仅调用这个函数的开销就相当于几十条python指令。
最后,我们浪费了两个操作来存储结果并重新加载它以返回…
改进的第一步:使查找更快,跳过存储
def distance(pointA, pointB, _norm=np.linalg.norm):
return _norm(pointA - pointB)
我们得到了更精简的:
>>> dis.dis(distance)
2 0 LOAD_FAST 2 (_norm)
2 LOAD_FAST 0 (pointA)
4 LOAD_FAST 1 (pointB)
6 BINARY_SUBTRACT
8 CALL_FUNCTION 1
10 RETURN_VALUE
不过,函数调用开销仍然需要一些工作。你会想要做基准测试,以确定你自己做数学是否会更好:
def distance(pointA, pointB):
return (
((pointA.x - pointB.x) ** 2) +
((pointA.y - pointB.y) ** 2) +
((pointA.z - pointB.z) ** 2)
) ** 0.5 # fast sqrt
在某些平台上,**0.5比math.sqrt快。你的里程可能会有所不同。
****高级性能说明。
你为什么要计算距离?如果唯一的目的是展示它,
print("The target is %.2fm away" % (distance(a, b)))
沿着。但是如果你在比较距离,进行范围检查等等,我想添加一些有用的性能观察。
让我们以两种情况为例:按距离排序或将列表剔除到满足范围约束的项。
# Ultra naive implementations. Hold onto your hat.
def sort_things_by_distance(origin, things):
return things.sort(key=lambda thing: distance(origin, thing))
def in_range(origin, range, things):
things_in_range = []
for thing in things:
if distance(origin, thing) <= range:
things_in_range.append(thing)
我们需要记住的第一件事是,我们使用毕达哥拉斯来计算距离(dist =根号(x²+ y²+ z²)),所以我们做了很多根号调用。数学101:
dist = root ( x^2 + y^2 + z^2 )
:.
dist^2 = x^2 + y^2 + z^2
and
sq(N) < sq(M) iff M > N
and
sq(N) > sq(M) iff N > M
and
sq(N) = sq(M) iff N == M
简而言之:直到我们真正需要以X为单位的距离,而不是X^2,我们才能消除计算中最难的部分。
# Still naive, but much faster.
def distance_sq(left, right):
""" Returns the square of the distance between left and right. """
return (
((left.x - right.x) ** 2) +
((left.y - right.y) ** 2) +
((left.z - right.z) ** 2)
)
def sort_things_by_distance(origin, things):
return things.sort(key=lambda thing: distance_sq(origin, thing))
def in_range(origin, range, things):
things_in_range = []
# Remember that sqrt(N)**2 == N, so if we square
# range, we don't need to root the distances.
range_sq = range**2
for thing in things:
if distance_sq(origin, thing) <= range_sq:
things_in_range.append(thing)
很好,这两个函数都不再做昂贵的平方根了。这样会快得多,但在进一步讨论之前,请检查自己:为什么sort_things_by_distance两次都需要一个“天真”的免责声明?在最下面回答(*a1)。
我们可以通过将in_range转换为生成器来改进它:
def in_range(origin, range, things):
range_sq = range**2
yield from (thing for thing in things
if distance_sq(origin, thing) <= range_sq)
如果你在做以下事情,这尤其有好处:
if any(in_range(origin, max_dist, things)):
...
但如果你接下来要做的事需要一段距离,
for nearby in in_range(origin, walking_distance, hotdog_stands):
print("%s %.2fm" % (nearby.name, distance(origin, nearby)))
考虑生成元组:
def in_range_with_dist_sq(origin, range, things):
range_sq = range**2
for thing in things:
dist_sq = distance_sq(origin, thing)
if dist_sq <= range_sq: yield (thing, dist_sq)
如果你可能要进行连锁范围检查(“找到X附近和Y的Nm范围内的东西”,因为你不需要再次计算距离),这可能特别有用。
但如果我们在搜索一个很大的列表,我们预计其中有很多不值得考虑呢?
其实有一个很简单的优化:
def in_range_all_the_things(origin, range, things):
range_sq = range**2
for thing in things:
dist_sq = (origin.x - thing.x) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.y - thing.y) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.z - thing.z) ** 2
if dist_sq <= range_sq:
yield thing
这是否有用取决于“事物”的大小。
def in_range_all_the_things(origin, range, things):
range_sq = range**2
if len(things) >= 4096:
for thing in things:
dist_sq = (origin.x - thing.x) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.y - thing.y) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.z - thing.z) ** 2
if dist_sq <= range_sq:
yield thing
elif len(things) > 32:
for things in things:
dist_sq = (origin.x - thing.x) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.y - thing.y) ** 2 + (origin.z - thing.z) ** 2
if dist_sq <= range_sq:
yield thing
else:
... just calculate distance and range-check it ...
同样,考虑生成dist_sq。热狗的例子就变成了:
# Chaining generators
info = in_range_with_dist_sq(origin, walking_distance, hotdog_stands)
info = (stand, dist_sq**0.5 for stand, dist_sq in info)
for stand, dist in info:
print("%s %.2fm" % (stand, dist))
(*a1: sort_things_by_distance的排序键为每一项调用distance_sq,而那个看起来无辜的键是一个lambda,这是必须调用的第二个函数…)
import math
dist = math.hypot(math.hypot(xa-xb, ya-yb), za-zb)
使用numpy.linalg.norm:
dist = numpy.linalg.norm(a-b)
这是因为欧氏距离是l2范数,而numpy.linalg.norm中ord参数的默认值是2。 要了解更多理论,请参阅数据挖掘介绍:
从Python 3.8开始,math模块直接提供dist函数,它返回两点之间的欧几里得距离(以元组或坐标列表的形式给出):
from math import dist
dist((1, 2, 6), (-2, 3, 2)) # 5.0990195135927845
如果你使用列表:
dist([1, 2, 6], [-2, 3, 2]) # 5.0990195135927845