itertools --- 建立產生高效率迴圈之疊代器的函式


本模块实现一系列 iterator ,这些迭代器受到APL,Haskell和SML的启发。为了适用于Python,它们都被重新写过。

本模块标准化了一个快速、高效利用内存的核心工具集,这些工具本身或组合都很有用。它们一起形成了“迭代器代数”,这使得在纯Python中有可能创建简洁又高效的专用工具。

例如,SML有一个制表工具: tabulate(f),它可产生一个序列 f(0), f(1), ...。在Python中可以组合 map()count() 实现: map(f, count())

这些工具及其内置对应物也能很好地配合 operator 模块中的快速函数来使用。 例如,乘法运算符可以被映射到两个向量之间执行高效的点积: sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))

无穷迭代器:

迭代器

引數

結果

範例

count()

[start[, step]]

start, start+step, start+2*step, ...

count(10) 10 11 12 13 14 ...

cycle()

p

p0, p1, ... plast, p0, p1, ...

cycle('ABCD') A B C D A B C D ...

repeat()

elem [,n]

elem, elem, elem, ... 重复无限次或n次

repeat(10, 3) 10 10 10

根据最短输入序列长度停止的迭代器:

迭代器

引數

結果

範例

accumulate()

p [,func]

p0, p0+p1, p0+p1+p2, ...

accumulate([1,2,3,4,5]) 1 3 6 10 15

batched()

p, n

(p0, p1, ..., p_n-1), ...

batched('ABCDEFG', n=3) ABC DEF G

chain()

p, q, ...

p0, p1, ... plast, q0, q1, ...

chain('ABC', 'DEF') A B C D E F

chain.from_iterable()

iterable -- 可迭代对象

p0, p1, ... plast, q0, q1, ...

chain.from_iterable(['ABC', 'DEF']) A B C D E F

compress()

data, selectors

(d[0] if s[0]), (d[1] if s[1]), ...

compress('ABCDEF', [1,0,1,0,1,1]) A C E F

dropwhile()

predicate, seq

seq[n], seq[n+1], 从 predicate 未通过时开始

dropwhile(lambda x: x<5, [1,4,6,4,1]) 6 4 1

filterfalse()

predicate, seq

predicate(elem) 未通过的 seq 元素

filterfalse(lambda x: x%2, range(10)) 0 2 4 6 8

groupby()

iterable[, key]

根据key(v)值分组的迭代器

islice()

seq, [start,] stop [, step]

seq[start:stop:step]中的元素

islice('ABCDEFG', 2, None) C D E F G

pairwise()

iterable -- 可迭代对象

(p[0], p[1]), (p[1], p[2])

pairwise('ABCDEFG') AB BC CD DE EF FG

starmap()

func, seq

func(*seq[0]), func(*seq[1]), ...

starmap(pow, [(2,5), (3,2), (10,3)]) 32 9 1000

takewhile()

predicate, seq

seq[0], seq[1], 直到 predicate 未通过

takewhile(lambda x: x<5, [1,4,6,4,1]) 1 4

tee()

it, n

it1, it2, ... itn 将一个迭代器拆分为n个迭代器

zip_longest()

p, q, ...

(p[0], q[0]), (p[1], q[1]), ...

zip_longest('ABCD', 'xy', fillvalue='-') Ax By C- D-

排列组合迭代器:

迭代器

引數

結果

product()

p, q, ... [repeat=1]

笛卡尔积,相当于嵌套的for循环

permutations()

p[, r]

长度r元组,所有可能的排列,无重复元素

combinations()

p, r

长度r元组,有序,无重复元素

combinations_with_replacement()

p, r

长度r元组,有序,元素可重复

例子

結果

product('ABCD', repeat=2)

AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD

permutations('ABCD', 2)

AB AC AD BA BC BD CA CB CD DA DB DC

combinations('ABCD', 2)

AB AC AD BC BD CD

combinations_with_replacement('ABCD', 2)

AA AB AC AD BB BC BD CC CD DD

Itertool 函数

下列模块函数均创建并返回迭代器。有些迭代器不限制输出流长度,所以它们只应在能截断输出流的函数或循环中使用。

itertools.accumulate(iterable[, func, *, initial=None])

创建一个迭代器,返回累积汇总值或其他双目运算函数的累积结果值(通过可选的 func 参数指定)。

如果提供了 func,它应当为带有两个参数的函数。 输入 iterable 的元素可以是能被 func 接受为参数的任意类型。 (例如,对于默认的加法运算,元素可以是任何可相加的类型包括 DecimalFraction。)

通常,输出的元素数量与输入的可迭代对象是一致的。 但是,如果提供了关键字参数 initial,则累加会以 initial 值开始,这样输出就比输入的可迭代对象多一个元素。

大致等價於:

def accumulate(iterable, func=operator.add, *, initial=None):
    'Return running totals'
    # accumulate([1,2,3,4,5]) → 1 3 6 10 15
    # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115
    # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120
    iterator = iter(iterable)
    total = initial
    if initial is None:
        try:
            total = next(iterator)
        except StopIteration:
            return
    yield total
    for element in iterator:
        total = func(total, element)
        yield total

The func argument can be set to min() for a running minimum, max() for a running maximum, or operator.mul() for a running product. Amortization tables can be built by accumulating interest and applying payments:

>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, operator.mul))     # running product
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]
>>> list(accumulate(data, max))              # running maximum
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]

# Amortize a 5% loan of 1000 with 10 annual payments of 90
>>> account_update = lambda bal, pmt: round(bal * 1.05) + pmt
>>> list(accumulate(repeat(-90, 10), account_update, initial=1_000))
[1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]

参考一个类似函数 functools.reduce() ,它只返回一个最终累积值。

Added in version 3.2.

在 3.3 版的變更: 新增選用的 func 參數。

在 3.8 版的變更: 新增選用的 initial 參數。

itertools.batched(iterable, n)

来自 iterable 的长度为 n 元组形式的批次数据。 最后一个批次可能短于 n

循环处理输入可迭代对象并将数据积累为长度至多为 n 的元组。 输入将被惰性地消耗,能填满一个批次即可。 结果将在批次填满或输入可迭代对象被耗尽时产生:

>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
>>> unflattened = list(batched(flattened_data, 2))
>>> unflattened
[('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]

大致等價於:

def batched(iterable, n):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterable = iter(iterable)
    while batch := tuple(islice(iterable, n)):
        yield batch

Added in version 3.12.

itertools.chain(*iterables)

创建一个迭代器,它首先返回第一个可迭代对象中所有元素,接着返回下一个可迭代对象中所有元素,直到耗尽所有可迭代对象中的元素。可将多个序列处理为单个序列。大致相当于:

def chain(*iterables):
    # chain('ABC', 'DEF') → A B C D E F
    for iterable in iterables:
        yield from iterable
classmethod chain.from_iterable(iterable)

构建类似 chain() 迭代器的另一个选择。从一个单独的可迭代参数中得到链式输入,该参数是延迟计算的。大致相当于:

def from_iterable(iterables):
    # chain.from_iterable(['ABC', 'DEF']) → A B C D E F
    for iterable in iterables:
        yield from iterable
itertools.combinations(iterable, r)

返回由输入 iterable 中元素组成长度为 r 的子序列。

The combination tuples are emitted in lexicographic order according to the order of the input iterable. So, if the input iterable is sorted, the output tuples will be produced in sorted order.

Elements are treated as unique based on their position, not on their value. So, if the input elements are unique, there will be no repeated values within each combination.

大致等價於:

def combinations(iterable, r):
    # combinations('ABCD', 2) → AB AC AD BC BD CD
    # combinations(range(4), 3) → 012 013 023 123
    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))
    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)

combinations() 的代码可被改写为 permutations() 过滤后的子序列,(相对于元素在输入中的位置)元素不是有序的。

def combinations(iterable, r):
    pool = tuple(iterable)
    n = len(pool)
    for indices in permutations(range(n), r):
        if sorted(indices) == list(indices):
            yield tuple(pool[i] for i in indices)

0 <= r <= n 时,返回项的个数是 n! / r! / (n-r)!;当 r > n 时,返回项个数为0。

itertools.combinations_with_replacement(iterable, r)

返回由输入 iterable 中元素组成的长度为 r 的子序列,允许每个元素可重复出现。

The combination tuples are emitted in lexicographic order according to the order of the input iterable. So, if the input iterable is sorted, the output tuples will be produced in sorted order.

Elements are treated as unique based on their position, not on their value. So, if the input elements are unique, the generated combinations will also be unique.

大致等價於:

def combinations_with_replacement(iterable, r):
    # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC
    pool = tuple(iterable)
    n = len(pool)
    if not n and r:
        return
    indices = [0] * r
    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != n - 1:
                break
        else:
            return
        indices[i:] = [indices[i] + 1] * (r - i)
        yield tuple(pool[i] for i in indices)

combinations_with_replacement() 的代码可被改写为 production() 过滤后的子序列,(相对于元素在输入中的位置)元素不是有序的。

def combinations_with_replacement(iterable, r):
    pool = tuple(iterable)
    n = len(pool)
    for indices in product(range(n), repeat=r):
        if sorted(indices) == list(indices):
            yield tuple(pool[i] for i in indices)

n > 0 时,返回项个数为 (n+r-1)! / r! / (n-1)!.

Added in version 3.1.

itertools.compress(data, selectors)

Make an iterator that filters elements from data returning only those that have a corresponding element in selectors is true. Stops when either the data or selectors iterables have been exhausted. Roughly equivalent to:

def compress(data, selectors):
    # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F
    return (datum for datum, selector in zip(data, selectors) if selector)

Added in version 3.1.

itertools.count(start=0, step=1)

创建一个迭代器,它从 start 值开始,返回均匀间隔的值。常用于 map() 中的实参来生成连续的数据点。此外,还用于 zip() 来添加序列号。大致相当于:

def count(start=0, step=1):
    # count(10) → 10 11 12 13 14 ...
    # count(2.5, 0.5) → 2.5 3.0 3.5 ...
    n = start
    while True:
        yield n
        n += step

当对浮点数计数时,替换为乘法代码有时精度会更好,例如: (start + step * i for i in count())

在 3.1 版的變更: 新增 step 引數並允許非整數引數。

itertools.cycle(iterable)

创建一个迭代器,返回 iterable 中所有元素并保存一个副本。当取完 iterable 中所有元素,返回副本中的所有元素。无限重复。大致相当于:

def cycle(iterable):
    # cycle('ABCD') → A B C D A B C D A B C D ...
    saved = []
    for element in iterable:
        yield element
        saved.append(element)
    while saved:
        for element in saved:
            yield element

注意,该函数可能需要相当大的辅助空间(取决于 iterable 的长度)。

itertools.dropwhile(predicate, iterable)

创建一个迭代器,如果 predicate 为true,迭代器丢弃这些元素,然后返回其他元素。注意,迭代器在 predicate 首次为false之前不会产生任何输出,所以可能需要一定长度的启动时间。大致相当于:

def dropwhile(predicate, iterable):
    # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8
    iterable = iter(iterable)
    for x in iterable:
        if not predicate(x):
            yield x
            break
    for x in iterable:
        yield x
itertools.filterfalse(predicate, iterable)

创建一个过滤来自 iterable 中元素从而只返回其中 predicate 为假值的元素的迭代器。 如果 predicateNone,则返回为假值的项。 大致等价于:

def filterfalse(predicate, iterable):
    # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8
    if predicate is None:
        predicate = bool
    for x in iterable:
        if not predicate(x):
            yield x
itertools.groupby(iterable, key=None)

创建一个迭代器,返回 iterable 中连续的键和组。key 是一个计算元素键值函数。如果未指定或为 Nonekey 缺省为恒等函数(identity function),返回元素不变。一般来说,iterable 需用同一个键值函数预先排序。

groupby() 操作类似于Unix中的 uniq。当每次 key 函数产生的键值改变时,迭代器会分组或生成一个新组(这就是为什么通常需要使用同一个键值函数先对数据进行排序)。这种行为与SQL的GROUP BY操作不同,SQL的操作会忽略输入的顺序将相同键值的元素分在同组中。

返回的组本身也是一个迭代器,它与 groupby() 共享底层的可迭代对象。因为源是共享的,当 groupby() 对象向后迭代时,前一个组将消失。因此如果稍后还需要返回结果,可保存为列表:

groups = []
uniquekeys = []
data = sorted(data, key=keyfunc)
for k, g in groupby(data, keyfunc):
    groups.append(list(g))      # Store group iterator as a list
    uniquekeys.append(k)

groupby() 大致等價於:

def groupby(iterable, key=None):
    # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B
    # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D

    keyfunc = (lambda x: x) if key is None else key
    iterator = iter(iterable)
    exhausted = False

    def _grouper(target_key):
        nonlocal curr_value, curr_key, exhausted
        yield curr_value
        for curr_value in iterator:
            curr_key = keyfunc(curr_value)
            if curr_key != target_key:
                return
            yield curr_value
        exhausted = True

    try:
        curr_value = next(iterator)
    except StopIteration:
        return
    curr_key = keyfunc(curr_value)

    while not exhausted:
        target_key = curr_key
        curr_group = _grouper(target_key)
        yield curr_key, curr_group
        if curr_key == target_key:
            for _ in curr_group:
                pass
itertools.islice(iterable, stop)
itertools.islice(iterable, start, stop[, step])

制作一个从可迭代对象返回选定元素的迭代器。 如果 start 为非零值,则会跳过可迭代对象中的部分元素直至到达 start 位置。 在此之后,将连续返回元素除非 step 被设为大于 1 的值而会间隔跳过部分结果。 如果 stopNone,则迭代将持续进行直至迭代器中的元素耗尽;在其他情况下,它将在指定的位置上停止。

如果 startNone,迭代从0开始。如果 stepNone ,步长缺省为1。

与常规的切片不同,islice() 不支持 start, stopstep 为负值。 可被用来从内部结构已被展平的数据中提取相关字段(例如,一个多行报告可以每三行列出一个名称字段)。

大致等價於:

def islice(iterable, *args):
    # islice('ABCDEFG', 2) → A B
    # islice('ABCDEFG', 2, 4) → C D
    # islice('ABCDEFG', 2, None) → C D E F G
    # islice('ABCDEFG', 0, None, 2) → A C E G

    s = slice(*args)
    start = 0 if s.start is None else s.start
    stop = s.stop
    step = 1 if s.step is None else s.step
    if start < 0 or (stop is not None and stop < 0) or step <= 0:
        raise ValueError

    indices = count() if stop is None else range(max(start, stop))
    next_i = start
    for i, element in zip(indices, iterable):
        if i == next_i:
            yield element
            next_i += step
itertools.pairwise(iterable)

返回从输入 iterable 中获取的连续重叠对。

输出迭代器中 2 元组的数量将比输入的数量少一个。 如果输入可迭代对象中少于两个值则它将为空。

大致等價於:

def pairwise(iterable):
    # pairwise('ABCDEFG') → AB BC CD DE EF FG
    iterator = iter(iterable)
    a = next(iterator, None)
    for b in iterator:
        yield a, b
        a = b

Added in version 3.10.

itertools.permutations(iterable, r=None)

连续返回由 iterable 元素生成长度为 r 的排列。

如果 r 未指定或为 Noner 默认设置为 iterable 的长度,这种情况下,生成所有全长排列。

组合元组是根据输入的 iterable 顺序以词典排序的形式发出的。 因此,如果输入的 iterable 是已排序的,则输出的元组将按排序后的顺序产生。

Elements are treated as unique based on their position, not on their value. So, if the input elements are unique, there will be no repeated values within a permutation.

大致等價於:

def permutations(iterable, r=None):
    # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) → 012 021 102 120 201 210

    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return

    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])

    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return

The code for permutations() can be also expressed as a subsequence of product() filtered to exclude entries with repeated elements (those from the same position in the input pool):

def permutations(iterable, r=None):
    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    for indices in product(range(n), repeat=r):
        if len(set(indices)) == r:
            yield tuple(pool[i] for i in indices)

0 <= r <= n ,返回项个数为 n! / (n-r)! ;当 r > n ,返回项个数为0。

itertools.product(*iterables, repeat=1)

可迭代对象输入的笛卡儿积。

大致相当于生成器表达式中的嵌套循环。例如, product(A, B)((x,y) for x in A for y in B) 返回结果一样。

嵌套循环像里程表那样循环变动,每次迭代时将最右侧的元素向后迭代。这种模式形成了一种字典序,因此如果输入的可迭代对象是已排序的,笛卡尔积元组依次序发出。

要计算可迭代对象自身的笛卡尔积,将可选参数 repeat 设定为要重复的次数。例如,product(A, repeat=4)product(A, A, A, A) 是一样的。

该函数大致相当于下面的代码,只不过实际实现方案不会在内存中创建中间结果。

def product(*iterables, repeat=1):
    # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111
    pools = [tuple(pool) for pool in iterables] * repeat
    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]
    for prod in result:
        yield tuple(prod)

product() 运行之前,它会完全耗尽输入的可迭代对象,在内存中保留值的临时池以生成结果积。 相应地,它只适用于有限的输入。

itertools.repeat(object[, times])

创建一个持续地返回 object 的迭代器。 将会无限期地运行除非指定了 times 参数。

大致等價於:

def repeat(object, times=None):
    # repeat(10, 3) → 10 10 10
    if times is None:
        while True:
            yield object
    else:
        for i in range(times):
            yield object

repeat 的一个常见用途是向 mapzip 提供一个常量值的流:

>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
itertools.starmap(function, iterable)

创建一个使用从 iterable 获取的参数来运算 function 的迭代器。 当参数形参已经从单个可迭代对象分组为多个元组时(当数据已被 "预 zip" 时)代替 map() 使用。

map()starmap() 之间的区别类似于 function(a,b)function(*c) 之间的差异。 大致相当于:

def starmap(function, iterable):
    # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000
    for args in iterable:
        yield function(*args)
itertools.takewhile(predicate, iterable)

创建一个迭代器,只要 predicate 为真就从可迭代对象中返回元素。大致相当于:

def takewhile(predicate, iterable):
    # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4
    for x in iterable:
        if not predicate(x):
            break
        yield x

Note, the element that first fails the predicate condition is consumed from the input iterator and there is no way to access it. This could be an issue if an application wants to further consume the input iterator after takewhile has been run to exhaustion. To work around this problem, consider using more-iterools before_and_after() instead.

itertools.tee(iterable, n=2)

从一个可迭代对象中返回 n 个独立的迭代器。

大致等價於:

def tee(iterable, n=2):
    iterator = iter(iterable)
    shared_link = [None, None]
    return tuple(_tee(iterator, shared_link) for _ in range(n))

def _tee(iterator, link):
    try:
        while True:
            if link[1] is None:
                link[0] = next(iterator)
                link[1] = [None, None]
            value, link = link
            yield value
    except StopIteration:
        return

一旦 tee() 已被创建,原有的 iterable 就不应在任何其他地方使用;否则,iterable 可能会被向下执行而不通知 tee 对象。

tee 迭代器不是线程安全的。 当同时使用由同一个 tee() 调用所返回的迭代器时可能引发 RuntimeError,即使原本的 iterable 是线程安全的。is threadsafe.

该迭代工具可能需要相当大的辅助存储空间(这取决于要保存多少临时数据)。通常,如果一个迭代器在另一个迭代器开始之前就要使用大部份或全部数据,使用 list() 会比 tee() 更快。

itertools.zip_longest(*iterables, fillvalue=None)

创建一个迭代器,从每个可迭代对象中收集元素。如果可迭代对象的长度未对齐,将根据 fillvalue 填充缺失值。迭代持续到耗光最长的可迭代对象。大致相当于:

def zip_longest(*iterables, fillvalue=None):
    # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D-

    iterators = list(map(iter, iterables))
    num_active = len(iterators)
    if not num_active:
        return

    while True:
        values = []
        for i, iterator in enumerate(iterators):
            try:
                value = next(iterator)
            except StopIteration:
                num_active -= 1
                if not num_active:
                    return
                iterators[i] = repeat(fillvalue)
                value = fillvalue
            values.append(value)
        yield tuple(values)

如果其中一个可迭代对象有无限长度,zip_longest() 函数应封装在限制调用次数的场景中(例如 islice()takewhile())。除非指定, fillvalue 默认为 None

itertools 配方

本节将展示如何使用现有的 itertools 作为基础构件来创建扩展的工具集。

这些 itertools 专题的主要目的是教学。 各个专题显示了对单个工具的各种思维方式 — 例如,chain.from_iterable 被关联到展平的概念。 这些专题还给出了有关这些工具的组合方式的想法 — 例如,starmap()repeat() 应当如何一起工作。 这些专题还显示了 itertools 与 operatorcollections 模块以及内置迭代工具如 map(), filter(), reversed()enumerate() 相互配合的使用模式。

这些例程的次要目的是作为一个孵化器使用。 accumulate(), compress()pairwise() 等迭代工具最初就是作为例程引入的。 目前,sliding_window(), iter_index()sieve() 例程正在被测试以确定它们是否堪当大任。

基本上所有这些配方和许许多多其他配方都可以通过 Python Package Index 上的 more-itertools 项目来安装:

python -m pip install more-itertools

许多例程提供了与底层工具集相当的高性能。 更好的内存效率是通过每次只处理一个元素而不是将整个可迭代对象放入内存来保证的。 代码量的精简是通过以 函数式风格 来链接工具来实现的。 运行的早速度是通过选择使用“矢量化”构件来取代会导致较大解释器开销的 for 循环和 生成器 来达成的。

import collections
import contextlib
import functools
import math
import operator
import random

def take(n, iterable):
    "Return first n items of the iterable as a list."
    return list(islice(iterable, n))

def prepend(value, iterable):
    "Prepend a single value in front of an iterable."
    # prepend(1, [2, 3, 4]) → 1 2 3 4
    return chain([value], iterable)

def tabulate(function, start=0):
    "Return function(0), function(1), ..."
    return map(function, count(start))

def repeatfunc(func, times=None, *args):
    "Repeat calls to func with specified arguments."
    if times is None:
        return starmap(func, repeat(args))
    return starmap(func, repeat(args, times))

def flatten(list_of_lists):
    "Flatten one level of nesting."
    return chain.from_iterable(list_of_lists)

def ncycles(iterable, n):
    "Returns the sequence elements n times."
    return chain.from_iterable(repeat(tuple(iterable), n))

def tail(n, iterable):
    "Return an iterator over the last n items."
    # tail(3, 'ABCDEFG') → E F G
    return iter(collections.deque(iterable, maxlen=n))

def consume(iterator, n=None):
    "Advance the iterator n-steps ahead. If n is None, consume entirely."
    # Use functions that consume iterators at C speed.
    if n is None:
        collections.deque(iterator, maxlen=0)
    else:
        next(islice(iterator, n, n), None)

def nth(iterable, n, default=None):
    "Returns the nth item or a default value."
    return next(islice(iterable, n, None), default)

def quantify(iterable, predicate=bool):
    "Given a predicate that returns True or False, count the True results."
    return sum(map(predicate, iterable))

def first_true(iterable, default=False, predicate=None):
    "Returns the first true value or the *default* if there is no true value."
    # first_true([a,b,c], x) → a or b or c or x
    # first_true([a,b], x, f) → a if f(a) else b if f(b) else x
    return next(filter(predicate, iterable), default)

def all_equal(iterable, key=None):
    "Returns True if all the elements are equal to each other."
    # all_equal('4٤௪౪໔', key=int) → True
    return len(take(2, groupby(iterable, key))) <= 1

def unique_justseen(iterable, key=None):
    "List unique elements, preserving order. Remember only the element just seen."
    # unique_justseen('AAAABBBCCDAABBB') → A B C D A B
    # unique_justseen('ABBcCAD', str.casefold) → A B c A D
    if key is None:
        return map(operator.itemgetter(0), groupby(iterable))
    return map(next, map(operator.itemgetter(1), groupby(iterable, key)))

def unique_everseen(iterable, key=None):
    "List unique elements, preserving order. Remember all elements ever seen."
    # unique_everseen('AAAABBBCCDAABBB') → A B C D
    # unique_everseen('ABBcCAD', str.casefold) → A B c D
    seen = set()
    if key is None:
        for element in filterfalse(seen.__contains__, iterable):
            seen.add(element)
            yield element
    else:
        for element in iterable:
            k = key(element)
            if k not in seen:
                seen.add(k)
                yield element

def sliding_window(iterable, n):
    "Collect data into overlapping fixed-length chunks or blocks."
    # sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
    iterator = iter(iterable)
    window = collections.deque(islice(iterator, n - 1), maxlen=n)
    for x in iterator:
        window.append(x)
        yield tuple(window)

def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks."
    # grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
    # grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
    # grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
    iterators = [iter(iterable)] * n
    match incomplete:
        case 'fill':
            return zip_longest(*iterators, fillvalue=fillvalue)
        case 'strict':
            return zip(*iterators, strict=True)
        case 'ignore':
            return zip(*iterators)
        case _:
            raise ValueError('Expected fill, strict, or ignore')

def roundrobin(*iterables):
    "Visit input iterables in a cycle until each is exhausted."
    # roundrobin('ABC', 'D', 'EF') → A D E B F C
    # Algorithm credited to George Sakkis
    iterators = map(iter, iterables)
    for num_active in range(len(iterables), 0, -1):
        iterators = cycle(islice(iterators, num_active))
        yield from map(next, iterators)

def partition(predicate, iterable):
    """Partition entries into false entries and true entries.

    If *predicate* is slow, consider wrapping it with functools.lru_cache().
    """
    # partition(is_odd, range(10)) → 0 2 4 6 8   and  1 3 5 7 9
    t1, t2 = tee(iterable)
    return filterfalse(predicate, t1), filter(predicate, t2)

def subslices(seq):
    "Return all contiguous non-empty subslices of a sequence."
    # subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
    slices = starmap(slice, combinations(range(len(seq) + 1), 2))
    return map(operator.getitem, repeat(seq), slices)

def iter_index(iterable, value, start=0, stop=None):
    "Return indices where a value occurs in a sequence or iterable."
    # iter_index('AABCADEAF', 'A') → 0 1 4 7
    seq_index = getattr(iterable, 'index', None)
    if seq_index is None:
        iterator = islice(iterable, start, stop)
        for i, element in enumerate(iterator, start):
            if element is value or element == value:
                yield i
    else:
        stop = len(iterable) if stop is None else stop
        i = start
        with contextlib.suppress(ValueError):
            while True:
                yield (i := seq_index(value, i, stop))
                i += 1

def iter_except(func, exception, first=None):
    "Convert a call-until-exception interface to an iterator interface."
    # iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
    with contextlib.suppress(exception):
        if first is not None:
            yield first()
        while True:
            yield func()

下面的例程具有更数学化的风格:

def powerset(iterable):
    "powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def sum_of_squares(iterable):
    "Add up the squares of the input values."
    # sum_of_squares([10, 20, 30]) → 1400
    return math.sumprod(*tee(iterable))

def reshape(matrix, cols):
    "Reshape a 2-D matrix to have a given number of columns."
    # reshape([(0, 1), (2, 3), (4, 5)], 3) →  (0, 1, 2), (3, 4, 5)
    return batched(chain.from_iterable(matrix), cols)

def transpose(matrix):
    "Swap the rows and columns of a 2-D matrix."
    # transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
    return zip(*matrix, strict=True)

def matmul(m1, m2):
    "Multiply two matrices."
    # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
    n = len(m2[0])
    return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)

def convolve(signal, kernel):
    """Discrete linear convolution of two iterables.
    Equivalent to polynomial multiplication.

    Convolutions are mathematically commutative; however, the inputs are
    evaluated differently.  The signal is consumed lazily and can be
    infinite. The kernel is fully consumed before the calculations begin.

    Article:  https://betterexplained.com/articles/intuitive-convolution/
    Video:    https://www.youtube.com/watch?v=KuXjwB4LzSA
    """
    # convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
    # convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
    # convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
    # convolve(data, [1, -2, 1]) → 2nd derivative estimate
    kernel = tuple(kernel)[::-1]
    n = len(kernel)
    padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
    windowed_signal = sliding_window(padded_signal, n)
    return map(math.sumprod, repeat(kernel), windowed_signal)

def polynomial_from_roots(roots):
    """Compute a polynomial's coefficients from its roots.

       (x - 5) (x + 4) (x - 3)  expands to:   x³ -4x² -17x + 60
    """
    # polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
    factors = zip(repeat(1), map(operator.neg, roots))
    return list(functools.reduce(convolve, factors, [1]))

def polynomial_eval(coefficients, x):
    """Evaluate a polynomial at a specific value.

    Computes with better numeric stability than Horner's method.
    """
    # Evaluate x³ -4x² -17x + 60 at x = 5
    # polynomial_eval([1, -4, -17, 60], x=5) → 0
    n = len(coefficients)
    if not n:
        return type(x)(0)
    powers = map(pow, repeat(x), reversed(range(n)))
    return math.sumprod(coefficients, powers)

def polynomial_derivative(coefficients):
    """Compute the first derivative of a polynomial.

       f(x)  =  x³ -4x² -17x + 60
       f'(x) = 3x² -8x  -17
    """
    # polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
    n = len(coefficients)
    powers = reversed(range(1, n))
    return list(map(operator.mul, coefficients, powers))

def sieve(n):
    "Primes less than n."
    # sieve(30) → 2 3 5 7 11 13 17 19 23 29
    if n > 2:
        yield 2
    data = bytearray((0, 1)) * (n // 2)
    for p in iter_index(data, 1, start=3, stop=math.isqrt(n) + 1):
        data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
    yield from iter_index(data, 1, start=3)

def factor(n):
    "Prime factors of n."
    # factor(99) → 3 3 11
    # factor(1_000_000_000_000_007) → 47 59 360620266859
    # factor(1_000_000_000_000_403) → 1000000000000403
    for prime in sieve(math.isqrt(n) + 1):
        while not n % prime:
            yield prime
            n //= prime
            if n == 1:
                return
    if n > 1:
        yield n

def totient(n):
    "Count of natural numbers up to n that are coprime to n."
    # https://mathworld.wolfram.com/TotientFunction.html
    # totient(12) → 4 because len([1, 5, 7, 11]) == 4
    for prime in set(factor(n)):
        n -= n // prime
    return n