我有一个由多重yield返回的生成器对象。准备调用这个生成器是相当耗时的操作。这就是为什么我想重复使用发电机几次。

y = FunctionWithYield()
for x in y: print(x)
#here must be something to reset 'y'
for x in y: print(x)

当然,我正在考虑将内容复制到简单的列表中。有办法重置我的发电机吗?


请参见:如何提前查看Python生成器中的一个元素?


当前回答

现在可以使用more_itertools。Seekable(第三方工具),允许重置迭代器。

通过> pip Install more_itertools安装

import more_itertools as mit


y = mit.seekable(FunctionWithYield())
for x in y:
    print(x)

y.seek(0)                                              # reset iterator
for x in y:
    print(x)

注意:内存消耗会随着迭代器的增加而增加,所以要警惕大型迭代对象。

其他回答

我想为一个老问题提供一个不同的解决方案

class IterableAdapter:
    def __init__(self, iterator_factory):
        self.iterator_factory = iterator_factory

    def __iter__(self):
        return self.iterator_factory()

squares = IterableAdapter(lambda: (x * x for x in range(5)))

for x in squares: print(x)
for x in squares: print(x)

与list(iterator)相比,这样做的好处是它的空间复杂度是O(1),而list(iterator)是O(n)。缺点是,如果你只能访问迭代器,而不能访问产生迭代器的函数,那么你就不能使用这个方法。例如,这样做似乎是合理的,但它不会起作用。

g = (x * x for x in range(5))

squares = IterableAdapter(lambda: g)

for x in squares: print(x)
for x in squares: print(x)

没有重置迭代器的选项。迭代器通常在遍历next()函数时弹出。唯一的方法是在迭代迭代器对象之前进行备份。下面的检查。

创建包含0到9项的迭代器对象

i=iter(range(10))

遍历将弹出的next()函数

print(next(i))

将迭代器对象转换为list

L=list(i)
print(L)
output: [1, 2, 3, 4, 5, 6, 7, 8, 9]

所以第0项已经跳出来了。此外,当我们将迭代器转换为list时,所有的项都会弹出。

next(L) 

Traceback (most recent call last):
  File "<pyshell#129>", line 1, in <module>
    next(L)
StopIteration

因此,在开始迭代之前,需要将迭代器转换为列表以备备份。 List可以用iter(< List -object>)转换为迭代器

如果希望使用预定义的参数集多次重用此生成器,可以使用functools.partial。

from functools import partial
func_with_yield = partial(FunctionWithYield, arg0, arg1)

for i in range(100):
    for x in func_with_yield():
        print(x)

这将把生成器函数包装到另一个函数中,因此每次调用func_with_yield()时,它都会创建相同的生成器函数。

我不知道你说的昂贵的准备是什么意思,但我猜你确实有

data = ... # Expensive computation
y = FunctionWithYield(data)
for x in y: print(x)
#here must be something to reset 'y'
# this is expensive - data = ... # Expensive computation
# y = FunctionWithYield(data)
for x in y: print(x)

如果是这样的话,为什么不重用数据呢?

我的答案解决了稍微不同的问题:如果初始化生成器的开销很大,生成每个生成的对象的开销也很大。但是我们需要在多个函数中多次使用生成器。为了只调用一次生成器和每个生成的对象,我们可以使用线程并在不同的线程中运行每个消费方法。由于GIL,我们可能无法实现真正的并行,但我们将实现我们的目标。

这种方法在以下情况下做得很好:深度学习模型处理了大量图像。结果是图像上的很多物体都有很多遮罩。每个掩码都会消耗内存。我们有大约10种方法来进行不同的统计和度量,但它们都是一次性拍摄所有图像。所有的图像都装不下内存。方法可以很容易地重写为接受迭代器。

class GeneratorSplitter:
'''
Split a generator object into multiple generators which will be sincronised. Each call to each of the sub generators will cause only one call in the input generator. This way multiple methods on threads can iterate the input generator , and the generator will cycled only once.
'''

def __init__(self, gen):
    self.gen = gen
    self.consumers: List[GeneratorSplitter.InnerGen] = []
    self.thread: threading.Thread = None
    self.value = None
    self.finished = False
    self.exception = None

def GetConsumer(self):
    # Returns a generator object. 
    cons = self.InnerGen(self)
    self.consumers.append(cons)
    return cons

def _Work(self):
    try:
        for d in self.gen:
            for cons in self.consumers:
                cons.consumed.wait()
                cons.consumed.clear()

            self.value = d

            for cons in self.consumers:
                cons.readyToRead.set()

        for cons in self.consumers:
            cons.consumed.wait()

        self.finished = True

        for cons in self.consumers:
            cons.readyToRead.set()
    except Exception as ex:
        self.exception = ex
        for cons in self.consumers:
            cons.readyToRead.set()

def Start(self):
    self.thread = threading.Thread(target=self._Work)
    self.thread.start()

class InnerGen:
    def __init__(self, parent: "GeneratorSplitter"):
        self.parent: "GeneratorSplitter" = parent
        self.readyToRead: threading.Event = threading.Event()
        self.consumed: threading.Event = threading.Event()
        self.consumed.set()

    def __iter__(self):
        return self

    def __next__(self):
        self.readyToRead.wait()
        self.readyToRead.clear()
        if self.parent.finished:
            raise StopIteration()
        if self.parent.exception:
            raise self.parent.exception
        val = self.parent.value
        self.consumed.set()
        return val

Ussage:

genSplitter = GeneratorSplitter(expensiveGenerator)

metrics={}
executor = ThreadPoolExecutor(max_workers=3)
f1 = executor.submit(mean,genSplitter.GetConsumer())
f2 = executor.submit(max,genSplitter.GetConsumer())
f3 = executor.submit(someFancyMetric,genSplitter.GetConsumer())
genSplitter.Start()

metrics.update(f1.result())
metrics.update(f2.result())
metrics.update(f3.result())