是否有一种简单的方法可以用列表理解来扁平化一个可迭代对象列表,或者如果没有,你们都认为什么是扁平化这样一个浅列表的最好方法,平衡性能和可读性?

我尝试用一个嵌套的列表理解来扁平化这样一个列表,就像这样:

[image for image in menuitem for menuitem in list_of_menuitems]

但是我遇到了NameError的麻烦,因为名字‘menuitem’没有定义。在google和Stack Overflow上搜索之后,我用一个reduce语句得到了想要的结果:

reduce(list.__add__, map(lambda x: list(x), list_of_menuitems))

但是这个方法是相当不可读的,因为我需要调用list(x)因为x是Django QuerySet对象。

结论:

感谢每个为这个问题做出贡献的人。以下是我所学到的一份总结。我也把它变成了一个社区维基,以防其他人想要添加或纠正这些观察。

我原来的reduce语句是多余的,最好这样写:

>>> reduce(list.__add__, (list(mi) for mi in list_of_menuitems))

这是嵌套列表理解的正确语法(聪明的总结dF!)

>>> [image for mi in list_of_menuitems for image in mi]

但这两种方法都不如使用itertools.chain有效:

>>> from itertools import chain
>>> list(chain(*list_of_menuitems))

正如@cdleary所指出的那样,使用chain.from_iterable来避免*操作符魔法可能是更好的风格,就像这样:

>>> chain = itertools.chain.from_iterable([[1,2],[3],[5,89],[],[6]])
>>> print(list(chain))
>>> [1, 2, 3, 5, 89, 6]

当前回答

在Python 2或3中实现这一点的最简单方法是使用pip install morph使用morph库。

代码是:

import morph

list = [[1,2],[3],[5,89],[],[6]]
flattened_list = morph.flatten(list)  # returns [1, 2, 3, 5, 89, 6]

其他回答

@。洛特:你启发我写了一个timeit应用程序。

我认为它也会根据分区的数量(容器列表中的迭代器的数量)而变化——你的评论没有提到这30个项目中有多少个分区。这个图在每次运行中平摊1000个项目,使用不同数量的分区。这些物品均匀地分布在各个分区中。

代码(Python 2.6):

#!/usr/bin/env python2.6

"""Usage: %prog item_count"""

from __future__ import print_function

import collections
import itertools
import operator
from timeit import Timer
import sys

import matplotlib.pyplot as pyplot

def itertools_flatten(iter_lst):
    return list(itertools.chain(*iter_lst))

def itertools_iterable_flatten(iter_iter):
    return list(itertools.chain.from_iterable(iter_iter))

def reduce_flatten(iter_lst):
    return reduce(operator.add, map(list, iter_lst))

def reduce_lambda_flatten(iter_lst):
    return reduce(operator.add, map(lambda x: list(x), [i for i in iter_lst]))

def comprehension_flatten(iter_lst):
    return list(item for iter_ in iter_lst for item in iter_)

METHODS = ['itertools', 'itertools_iterable', 'reduce', 'reduce_lambda',
           'comprehension']

def _time_test_assert(iter_lst):
    """Make sure all methods produce an equivalent value.
    :raise AssertionError: On any non-equivalent value."""
    callables = (globals()[method + '_flatten'] for method in METHODS)
    results = [callable(iter_lst) for callable in callables]
    if not all(result == results[0] for result in results[1:]):
        raise AssertionError

def time_test(partition_count, item_count_per_partition, test_count=10000):
    """Run flatten methods on a list of :param:`partition_count` iterables.
    Normalize results over :param:`test_count` runs.
    :return: Mapping from method to (normalized) microseconds per pass.
    """
    iter_lst = [[dict()] * item_count_per_partition] * partition_count
    print('Partition count:    ', partition_count)
    print('Items per partition:', item_count_per_partition)
    _time_test_assert(iter_lst)
    test_str = 'flatten(%r)' % iter_lst
    result_by_method = {}
    for method in METHODS:
        setup_str = 'from test import %s_flatten as flatten' % method
        t = Timer(test_str, setup_str)
        per_pass = test_count * t.timeit(number=test_count) / test_count
        print('%20s: %.2f usec/pass' % (method, per_pass))
        result_by_method[method] = per_pass
    return result_by_method

if __name__ == '__main__':
    if len(sys.argv) != 2:
        raise ValueError('Need a number of items to flatten')
    item_count = int(sys.argv[1])
    partition_counts = []
    pass_times_by_method = collections.defaultdict(list)
    for partition_count in xrange(1, item_count):
        if item_count % partition_count != 0:
            continue
        items_per_partition = item_count / partition_count
        result_by_method = time_test(partition_count, items_per_partition)
        partition_counts.append(partition_count)
        for method, result in result_by_method.iteritems():
            pass_times_by_method[method].append(result)
    for method, pass_times in pass_times_by_method.iteritems():
        pyplot.plot(partition_counts, pass_times, label=method)
    pyplot.legend()
    pyplot.title('Flattening Comparison for %d Items' % item_count)
    pyplot.xlabel('Number of Partitions')
    pyplot.ylabel('Microseconds')
    pyplot.show()

编辑:决定让它成为社区维基。

注意:METHODS可能应该使用装饰器进行积累,但我认为这样更容易让人们阅读。

如果您只是希望迭代一个扁平的数据结构版本,并且不需要可索引序列,请考虑itertools。连锁店和公司。

>>> list_of_menuitems = [['image00', 'image01'], ['image10'], []]
>>> import itertools
>>> chain = itertools.chain(*list_of_menuitems)
>>> print(list(chain))
['image00', 'image01', 'image10']

它可以在任何可迭代的东西上工作,其中应该包括Django的可迭代QuerySets,它似乎是你在问题中使用的。

编辑:无论如何,这可能和reduce一样好,因为reduce将有相同的开销将项复制到正在扩展的列表中。如果你在最后运行list(Chain), Chain只会引起这种(相同的)开销。

Meta-Edit:实际上,它的开销比问题提出的解决方案要少,因为当您使用临时列表扩展原始列表时,会丢弃您创建的临时列表。

编辑:正如J.F.塞巴斯蒂安说itertools.chain.from_iterable避免解包,你应该使用它来避免*魔术,但timeit应用程序显示可以忽略的性能差异。

Sum (list_of_lists,[])将使它变平。

l = [['image00', 'image01'], ['image10'], []]
print sum(l,[]) # prints ['image00', 'image01', 'image10']

下面是一个使用集合处理多层列表的版本。Iterable:

import collections

def flatten(o, flatten_condition=lambda i: isinstance(i,
               collections.Iterable) and not isinstance(i, str)):
    result = []
    for i in o:
        if flatten_condition(i):
            result.extend(flatten(i, flatten_condition))
        else:
            result.append(i)
    return result

如果你要扁平化一个更复杂的列表与不可迭代元素或深度大于2,你可以使用以下函数:

def flat_list(list_to_flat):
    if not isinstance(list_to_flat, list):
        yield list_to_flat
    else:
        for item in list_to_flat:
            yield from flat_list(item)

它将返回一个生成器对象,您可以使用list()函数将其转换为列表。注意,yield from syntax可以从python3.3开始使用,但是可以使用显式迭代。 例子:

>>> a = [1, [2, 3], [1, [2, 3, [1, [2, 3]]]]]
>>> print(list(flat_list(a)))
[1, 2, 3, 1, 2, 3, 1, 2, 3]