如何从numpy数组中删除一些特定的元素?说我有
import numpy as np
a = np.array([1,2,3,4,5,6,7,8,9])
然后我想从a中删除3,4,7。我所知道的是这些值的下标(index=[2,3,6])。
如何从numpy数组中删除一些特定的元素?说我有
import numpy as np
a = np.array([1,2,3,4,5,6,7,8,9])
然后我想从a中删除3,4,7。我所知道的是这些值的下标(index=[2,3,6])。
当前回答
有一个numpy内置函数可以帮助实现这一点。
import numpy as np
>>> a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> b = np.array([3,4,7])
>>> c = np.setdiff1d(a,b)
>>> c
array([1, 2, 5, 6, 8, 9])
其他回答
有一个numpy内置函数可以帮助实现这一点。
import numpy as np
>>> a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> b = np.array([3,4,7])
>>> c = np.setdiff1d(a,b)
>>> c
array([1, 2, 5, 6, 8, 9])
列表理解也是一种有趣的方法。
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
index = np.array([2, 3, 6]) #index is changed to an array.
out = [val for i, val in enumerate(a) if all(i != index)]
>>> [1, 2, 5, 6, 8, 9]
如果我们知道要删除的元素的索引,使用np.delete是最快的方法。但是,为了完整起见,让我添加另一种“删除”数组元素的方法,使用在np.isin的帮助下创建的布尔掩码。该方法允许我们通过直接指定或通过索引来删除元素:
import numpy as np
a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
按指数移除:
indices_to_remove = [2, 3, 6]
a = a[~np.isin(np.arange(a.size), indices_to_remove)]
按元素删除(不要忘记重新创建原来的a,因为它在前一行中被重写了):
elements_to_remove = a[indices_to_remove] # [3, 4, 7]
a = a[~np.isin(a, elements_to_remove)]
你也可以使用集合:
a = numpy.array([10, 20, 30, 40, 50, 60, 70, 80, 90])
the_index_list = [2, 3, 6]
the_big_set = set(numpy.arange(len(a)))
the_small_set = set(the_index_list)
the_delta_row_list = list(the_big_set - the_small_set)
a = a[the_delta_row_list]
我不是一个麻木的人,我试了一下:
>>> import numpy as np
>>> import itertools
>>>
>>> a = np.array([1,2,3,4,5,6,7,8,9])
>>> index=[2,3,6]
>>> a = np.array(list(itertools.compress(a, [i not in index for i in range(len(a))])))
>>> a
array([1, 2, 5, 6, 8, 9])
根据我的测试,这优于numpy.delete()。我不知道为什么会这样,也许是因为初始数组的大小较小?
python -m timeit -s "import numpy as np" -s "import itertools" -s "a = np.array([1,2,3,4,5,6,7,8,9])" -s "index=[2,3,6]" "a = np.array(list(itertools.compress(a, [i not in index for i in range(len(a))])))"
100000 loops, best of 3: 12.9 usec per loop
python -m timeit -s "import numpy as np" -s "a = np.array([1,2,3,4,5,6,7,8,9])" -s "index=[2,3,6]" "np.delete(a, index)"
10000 loops, best of 3: 108 usec per loop
这是一个相当显著的差异(与我预期的方向相反),有人知道为什么会这样吗?
更奇怪的是,传递numpy.delete()一个列表的性能比遍历列表并给它单个索引的性能更差。
python -m timeit -s "import numpy as np" -s "a = np.array([1,2,3,4,5,6,7,8,9])" -s "index=[2,3,6]" "for i in index:" " np.delete(a, i)"
10000 loops, best of 3: 33.8 usec per loop
编辑:这似乎与数组的大小有关。对于大型数组,numpy.delete()要快得多。
python -m timeit -s "import numpy as np" -s "import itertools" -s "a = np.array(list(range(10000)))" -s "index=[i for i in range(10000) if i % 2 == 0]" "a = np.array(list(itertools.compress(a, [i not in index for i in range(len(a))])))"
10 loops, best of 3: 200 msec per loop
python -m timeit -s "import numpy as np" -s "a = np.array(list(range(10000)))" -s "index=[i for i in range(10000) if i % 2 == 0]" "np.delete(a, index)"
1000 loops, best of 3: 1.68 msec per loop
显然,这一切都是相当无关紧要的,因为您应该始终保持清晰,避免重复工作,但我发现它有点有趣,所以我想我就把它留在这里。