这可能是另一种选择。技巧在于包装器函数,它返回传递给pool.map的另一个函数。下面的代码读取一个输入数组,对于其中的每个(唯一)元素,返回该元素在数组中出现的次数(即计数)。例如,如果输入是
np.eye(3) = [ [1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]
然后零出现6次,一出现3次
import numpy as np
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing import cpu_count
def extract_counts(label_array):
labels = np.unique(label_array)
out = extract_counts_helper([label_array], labels)
return out
def extract_counts_helper(args, labels):
n = max(1, cpu_count() - 1)
pool = ThreadPool(n)
results = {}
pool.map(wrapper(args, results), labels)
pool.close()
pool.join()
return results
def wrapper(argsin, results):
def inner_fun(label):
label_array = argsin[0]
counts = get_label_counts(label_array, label)
results[label] = counts
return inner_fun
def get_label_counts(label_array, label):
return sum(label_array.flatten() == label)
if __name__ == "__main__":
img = np.ones([2,2])
out = extract_counts(img)
print('input array: \n', img)
print('label counts: ', out)
print("========")
img = np.eye(3)
out = extract_counts(img)
print('input array: \n', img)
print('label counts: ', out)
print("========")
img = np.random.randint(5, size=(3, 3))
out = extract_counts(img)
print('input array: \n', img)
print('label counts: ', out)
print("========")
你应该得到:
input array:
[[1. 1.]
[1. 1.]]
label counts: {1.0: 4}
========
input array:
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]
label counts: {0.0: 6, 1.0: 3}
========
input array:
[[4 4 0]
[2 4 3]
[2 3 1]]
label counts: {0: 1, 1: 1, 2: 2, 3: 2, 4: 3}
========