在Python多处理库中,是否有支持多个参数的pool.map变体?

import multiprocessing

text = "test"

def harvester(text, case):
    X = case[0]
    text + str(X)

if __name__ == '__main__':
    pool = multiprocessing.Pool(processes=6)
    case = RAW_DATASET
    pool.map(harvester(text, case), case, 1)
    pool.close()
    pool.join()

当前回答

有一个叫做pathos的多处理分支(注意:使用GitHub上的版本),它不需要starmap——map函数镜像Python map的API,因此map可以接受多个参数。

使用pathos,您通常也可以在解释器中执行多处理,而不是陷入__main__块。Pathos将在经过一些轻微的更新后发布——主要是转换为Python3.x。

  Python 2.7.5 (default, Sep 30 2013, 20:15:49)
  [GCC 4.2.1 (Apple Inc. build 5566)] on darwin
  Type "help", "copyright", "credits" or "license" for more information.
  >>> def func(a,b):
  ...     print a,b
  ...
  >>>
  >>> from pathos.multiprocessing import ProcessingPool
  >>> pool = ProcessingPool(nodes=4)
  >>> pool.map(func, [1,2,3], [1,1,1])
  1 1
  2 1
  3 1
  [None, None, None]
  >>>
  >>> # also can pickle stuff like lambdas
  >>> result = pool.map(lambda x: x**2, range(10))
  >>> result
  [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
  >>>
  >>> # also does asynchronous map
  >>> result = pool.amap(pow, [1,2,3], [4,5,6])
  >>> result.get()
  [1, 32, 729]
  >>>
  >>> # or can return a map iterator
  >>> result = pool.imap(pow, [1,2,3], [4,5,6])
  >>> result
  <processing.pool.IMapIterator object at 0x110c2ffd0>
  >>> list(result)
  [1, 32, 729]

pathos有几种方法可以让你得到星图的精确行为。

>>> def add(*x):
...   return sum(x)
...
>>> x = [[1,2,3],[4,5,6]]
>>> import pathos
>>> import numpy as np
>>> # use ProcessPool's map and transposing the inputs
>>> pp = pathos.pools.ProcessPool()
>>> pp.map(add, *np.array(x).T)
[6, 15]
>>> # use ProcessPool's map and a lambda to apply the star
>>> pp.map(lambda x: add(*x), x)
[6, 15]
>>> # use a _ProcessPool, which has starmap
>>> _pp = pathos.pools._ProcessPool()
>>> _pp.starmap(add, x)
[6, 15]
>>>

其他回答

pool.map是否有支持多个参数的变体?

Python 3.3包含pool.starmap()方法:

#!/usr/bin/env python3
from functools import partial
from itertools import repeat
from multiprocessing import Pool, freeze_support

def func(a, b):
    return a + b

def main():
    a_args = [1,2,3]
    second_arg = 1
    with Pool() as pool:
        L = pool.starmap(func, [(1, 1), (2, 1), (3, 1)])
        M = pool.starmap(func, zip(a_args, repeat(second_arg)))
        N = pool.map(partial(func, b=second_arg), a_args)
        assert L == M == N

if __name__=="__main__":
    freeze_support()
    main()

对于旧版本:

#!/usr/bin/env python2
import itertools
from multiprocessing import Pool, freeze_support

def func(a, b):
    print a, b

def func_star(a_b):
    """Convert `f([1,2])` to `f(1,2)` call."""
    return func(*a_b)

def main():
    pool = Pool()
    a_args = [1,2,3]
    second_arg = 1
    pool.map(func_star, itertools.izip(a_args, itertools.repeat(second_arg)))

if __name__=="__main__":
    freeze_support()
    main()

输出

1 1
2 1
3 1

注意这里是如何使用itertools.izip()和itertools.crepeat()的。

由于@unsubu提到的错误,您不能在Python 2.6上使用functools.partial()或类似功能,因此应该显式定义简单包装函数func_tar()。另请参阅uptimebox建议的解决方法。

这可能是另一种选择。技巧在于包装器函数,它返回传递给pool.map的另一个函数。下面的代码读取一个输入数组,对于其中的每个(唯一)元素,返回该元素在数组中出现的次数(即计数)。例如,如果输入是

np.eye(3) = [ [1. 0. 0.]
              [0. 1. 0.]
              [0. 0. 1.]]

然后零出现6次,一出现3次

import numpy as np
from multiprocessing.dummy import Pool as ThreadPool
from multiprocessing import cpu_count


def extract_counts(label_array):
    labels = np.unique(label_array)
    out = extract_counts_helper([label_array], labels)
    return out

def extract_counts_helper(args, labels):
    n = max(1, cpu_count() - 1)
    pool = ThreadPool(n)
    results = {}
    pool.map(wrapper(args, results), labels)
    pool.close()
    pool.join()
    return results

def wrapper(argsin, results):
    def inner_fun(label):
        label_array = argsin[0]
        counts = get_label_counts(label_array, label)
        results[label] = counts
    return inner_fun

def get_label_counts(label_array, label):
    return sum(label_array.flatten() == label)

if __name__ == "__main__":
    img = np.ones([2,2])
    out = extract_counts(img)
    print('input array: \n', img)
    print('label counts: ', out)
    print("========")
           
    img = np.eye(3)
    out = extract_counts(img)
    print('input array: \n', img)
    print('label counts: ', out)
    print("========")
    
    img = np.random.randint(5, size=(3, 3))
    out = extract_counts(img)
    print('input array: \n', img)
    print('label counts: ', out)
    print("========")

你应该得到:

input array: 
 [[1. 1.]
 [1. 1.]]
label counts:  {1.0: 4}
========
input array: 
 [[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]
label counts:  {0.0: 6, 1.0: 3}
========
input array: 
 [[4 4 0]
 [2 4 3]
 [2 3 1]]
label counts:  {0: 1, 1: 1, 2: 2, 3: 2, 4: 3}
========

这是我用来将多个参数传递给pool.imap fork中使用的单参数函数的例程的示例:

from multiprocessing import Pool

# Wrapper of the function to map:
class makefun:
    def __init__(self, var2):
        self.var2 = var2
    def fun(self, i):
        var2 = self.var2
        return var1[i] + var2

# Couple of variables for the example:
var1 = [1, 2, 3, 5, 6, 7, 8]
var2 = [9, 10, 11, 12]

# Open the pool:
pool = Pool(processes=2)

# Wrapper loop
for j in range(len(var2)):
    # Obtain the function to map
    pool_fun = makefun(var2[j]).fun

    # Fork loop
    for i, value in enumerate(pool.imap(pool_fun, range(len(var1))), 0):
        print(var1[i], '+' ,var2[j], '=', value)

# Close the pool
pool.close()

另一种方法是将列表列表传递给单参数例程:

import os
from multiprocessing import Pool

def task(args):
    print "PID =", os.getpid(), ", arg1 =", args[0], ", arg2 =", args[1]

pool = Pool()

pool.map(task, [
        [1,2],
        [3,4],
        [5,6],
        [7,8]
    ])

然后可以用自己喜欢的方法构造一个参数列表。

在官方文档中,它只支持一个可迭代的参数。在这种情况下,我喜欢使用apply_async。如果是你,我会:

from multiprocessing import Process, Pool, Manager

text = "test"
def harvester(text, case, q = None):
 X = case[0]
 res = text+ str(X)
 if q:
  q.put(res)
 return res


def block_until(q, results_queue, until_counter=0):
 i = 0
 while i < until_counter:
  results_queue.put(q.get())
  i+=1

if __name__ == '__main__':
 pool = multiprocessing.Pool(processes=6)
 case = RAW_DATASET
 m = Manager()
 q = m.Queue()
 results_queue = m.Queue() # when it completes results will reside in this queue
 blocking_process = Process(block_until, (q, results_queue, len(case)))
 blocking_process.start()
 for c in case:
  try:
   res = pool.apply_async(harvester, (text, case, q = None))
   res.get(timeout=0.1)
  except:
   pass
 blocking_process.join()