itertools --- 建立產生高效率迴圈之疊代器的函式


本模块实现一系列 iterator ,这些迭代器受到APL,Haskell和SML的启发。为了适用于Python,它们都被重新写过。

本模块标准化了一个快速、高效利用内存的核心工具集,这些工具本身或组合都很有用。它们一起形成了“迭代器代数”,这使得在纯Python中有可能创建简洁又高效的专用工具。

例如,SML有一个制表工具: tabulate(f),它可产生一个序列 f(0), f(1), ...。在Python中可以组合 map()count() 实现: map(f, count())

这些工具及其内置对应物也能很好地配合 operator 模块中的快速函数来使用。 例如,乘法运算符可以被映射到两个向量之间执行高效的点积: sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))

无穷迭代器:

迭代器

引數

結果

範例

count()

[start[, step]]

start, start+step, start+2*step, ...

count(10) 10 11 12 13 14 ...

cycle()

p

p0, p1, ... plast, p0, p1, ...

cycle('ABCD') A B C D A B C D ...

repeat()

elem [,n]

elem, elem, elem, ... 重复无限次或n次

repeat(10, 3) 10 10 10

根据最短输入序列长度停止的迭代器:

迭代器

引數

結果

範例

accumulate()

p [,func]

p0, p0+p1, p0+p1+p2, ...

accumulate([1,2,3,4,5]) 1 3 6 10 15

batched()

p, n

(p0, p1, ..., p_n-1), ...

batched('ABCDEFG', n=3) ABC DEF G

chain()

p, q, ...

p0, p1, ... plast, q0, q1, ...

chain('ABC', 'DEF') A B C D E F

chain.from_iterable()

iterable -- 可迭代对象

p0, p1, ... plast, q0, q1, ...

chain.from_iterable(['ABC', 'DEF']) A B C D E F

compress()

data, selectors

(d[0] if s[0]), (d[1] if s[1]), ...

compress('ABCDEF', [1,0,1,0,1,1]) A C E F

dropwhile()

predicate, seq

seq[n], seq[n+1], 从 predicate 未通过时开始

dropwhile(lambda x: x<5, [1,4,6,3,8]) 6 3 8

filterfalse()

predicate, seq

predicate(elem) 未通过的 seq 元素

filterfalse(lambda x: x<5, [1,4,6,3,8]) 6 8

groupby()

iterable[, key]

根据key(v)值分组的迭代器

groupby(['A','B','ABC'], len) (1, A B) (3, ABC)

islice()

seq, [start,] stop [, step]

seq[start:stop:step]中的元素

islice('ABCDEFG', 2, None) C D E F G

pairwise()

iterable -- 可迭代对象

(p[0], p[1]), (p[1], p[2])

pairwise('ABCDEFG') AB BC CD DE EF FG

starmap()

func, seq

func(*seq[0]), func(*seq[1]), ...

starmap(pow, [(2,5), (3,2), (10,3)]) 32 9 1000

takewhile()

predicate, seq

seq[0], seq[1], 直到 predicate 未通过

takewhile(lambda x: x<5, [1,4,6,3,8]) 1 4

tee()

it, n

it1, it2, ... itn 将一个迭代器拆分为n个迭代器

tee('ABC', 2) A B C, A B C

zip_longest()

p, q, ...

(p[0], q[0]), (p[1], q[1]), ...

zip_longest('ABCD', 'xy', fillvalue='-') Ax By C- D-

排列组合迭代器:

迭代器

引數

結果

product()

p, q, ... [repeat=1]

笛卡尔积,相当于嵌套的for循环

permutations()

p[, r]

长度r元组,所有可能的排列,无重复元素

combinations()

p, r

长度r元组,有序,无重复元素

combinations_with_replacement()

p, r

长度r元组,有序,元素可重复

例子

結果

product('ABCD', repeat=2)

AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD

permutations('ABCD', 2)

AB AC AD BA BC BD CA CB CD DA DB DC

combinations('ABCD', 2)

AB AC AD BC BD CD

combinations_with_replacement('ABCD', 2)

AA AB AC AD BB BC BD CC CD DD

Itertool 函数

下列模块函数均创建并返回迭代器。有些迭代器不限制输出流长度,所以它们只应在能截断输出流的函数或循环中使用。

itertools.accumulate(iterable[, function, *, initial=None])

Make an iterator that returns accumulated sums or accumulated results from other binary functions.

The function defaults to addition. The function should accept two arguments, an accumulated total and a value from the iterable.

If an initial value is provided, the accumulation will start with that value and the output will have one more element than the input iterable.

大致等價於:

def accumulate(iterable, function=operator.add, *, initial=None):
    'Return running totals'
    # accumulate([1,2,3,4,5]) → 1 3 6 10 15
    # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115
    # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120

    iterator = iter(iterable)
    total = initial
    if initial is None:
        try:
            total = next(iterator)
        except StopIteration:
            return

    yield total
    for element in iterator:
        total = function(total, element)
        yield total

The function argument can be set to min() for a running minimum, max() for a running maximum, or operator.mul() for a running product. Amortization tables can be built by accumulating interest and applying payments:

>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, max))              # running maximum
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
>>> list(accumulate(data, operator.mul))     # running product
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]

# Amortize a 5% loan of 1000 with 10 annual payments of 90
>>> update = lambda balance, payment: round(balance * 1.05) - payment
>>> list(accumulate(repeat(90, 10), update, initial=1_000))
[1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]

参考一个类似函数 functools.reduce() ,它只返回一个最终累积值。

Added in version 3.2.

在 3.3 版的變更: 新增選用的 function 參數。

在 3.8 版的變更: 新增選用的 initial 參數。

itertools.batched(iterable, n, *, strict=False)

来自 iterable 的长度为 n 元组形式的批次数据。 最后一个批次可能短于 n

如果 strict 为真值,将在最终的批次短于 n 时引发 ValueError

循环处理输入可迭代对象并将数据积累为长度至多为 n 的元组。 输入将被惰性地消耗,能填满一个批次即可。 结果将在批次填满或输入可迭代对象被耗尽时产生:

>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
>>> unflattened = list(batched(flattened_data, 2))
>>> unflattened
[('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]

大致等價於:

def batched(iterable, n, *, strict=False):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        if strict and len(batch) != n:
            raise ValueError('batched(): incomplete batch')
        yield batch

Added in version 3.12.

在 3.13 版的變更: 增加了 strict 选项。

itertools.chain(*iterables)

创建一个迭代器,它首先返回第一个可迭代对象中所有元素,接着返回下一个可迭代对象中所有元素,直到耗尽所有可迭代对象中的元素。可将多个序列处理为单个序列。大致相当于:

def chain(*iterables):
    # chain('ABC', 'DEF') → A B C D E F
    for iterable in iterables:
        yield from iterable
classmethod chain.from_iterable(iterable)

构建类似 chain() 迭代器的另一个选择。从一个单独的可迭代参数中得到链式输入,该参数是延迟计算的。大致相当于:

def from_iterable(iterables):
    # chain.from_iterable(['ABC', 'DEF']) → A B C D E F
    for iterable in iterables:
        yield from iterable
itertools.combinations(iterable, r)

返回由输入 iterable 中元素组成长度为 r 的子序列。

The output is a subsequence of product() keeping only entries that are subsequences of the iterable. The length of the output is given by math.comb() which computes n! / r! / (n - r)! when 0 r n or zero when r > n.

The combination tuples are emitted in lexicographic order according to the order of the input iterable. If the input iterable is sorted, the output tuples will be produced in sorted order.

Elements are treated as unique based on their position, not on their value. If the input elements are unique, there will be no repeated values within each combination.

大致等價於:

def combinations(iterable, r):
    # combinations('ABCD', 2) → AB AC AD BC BD CD
    # combinations(range(4), 3) → 012 013 023 123

    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)
itertools.combinations_with_replacement(iterable, r)

返回由输入 iterable 中元素组成的长度为 r 的子序列,允许每个元素可重复出现。

The output is a subsequence of product() that keeps only entries that are subsequences (with possible repeated elements) of the iterable. The number of subsequence returned is (n + r - 1)! / r! / (n - 1)! when n > 0.

The combination tuples are emitted in lexicographic order according to the order of the input iterable. if the input iterable is sorted, the output tuples will be produced in sorted order.

Elements are treated as unique based on their position, not on their value. If the input elements are unique, the generated combinations will also be unique.

大致等價於:

def combinations_with_replacement(iterable, r):
    # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC

    pool = tuple(iterable)
    n = len(pool)
    if not n and r:
        return
    indices = [0] * r

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != n - 1:
                break
        else:
            return
        indices[i:] = [indices[i] + 1] * (r - i)
        yield tuple(pool[i] for i in indices)

Added in version 3.1.

itertools.compress(data, selectors)

Make an iterator that returns elements from data where the corresponding element in selectors is true. Stops when either the data or selectors iterables have been exhausted. Roughly equivalent to:

def compress(data, selectors):
    # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F
    return (datum for datum, selector in zip(data, selectors) if selector)

Added in version 3.1.

itertools.count(start=0, step=1)

Make an iterator that returns evenly spaced values beginning with start. Can be used with map() to generate consecutive data points or with zip() to add sequence numbers. Roughly equivalent to:

def count(start=0, step=1):
    # count(10) → 10 11 12 13 14 ...
    # count(2.5, 0.5) → 2.5 3.0 3.5 ...
    n = start
    while True:
        yield n
        n += step

当对浮点数计数时,替换为乘法代码有时精度会更好,例如: (start + step * i for i in count())

在 3.1 版的變更: 新增 step 引數並允許非整數引數。

itertools.cycle(iterable)

Make an iterator returning elements from the iterable and saving a copy of each. When the iterable is exhausted, return elements from the saved copy. Repeats indefinitely. Roughly equivalent to:

def cycle(iterable):
    # cycle('ABCD') → A B C D A B C D A B C D ...
    saved = []
    for element in iterable:
        yield element
        saved.append(element)
    while saved:
        for element in saved:
            yield element

This itertool may require significant auxiliary storage (depending on the length of the iterable).

itertools.dropwhile(predicate, iterable)

Make an iterator that drops elements from the iterable while the predicate is true and afterwards returns every element. Roughly equivalent to:

def dropwhile(predicate, iterable):
    # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8

    iterator = iter(iterable)
    for x in iterator:
        if not predicate(x):
            yield x
            break

    for x in iterator:
        yield x

Note this does not produce any output until the predicate first becomes false, so this itertool may have a lengthy start-up time.

itertools.filterfalse(predicate, iterable)

Make an iterator that filters elements from the iterable returning only those for which the predicate returns a false value. If predicate is None, returns the items that are false. Roughly equivalent to:

def filterfalse(predicate, iterable):
    # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8
    if predicate is None:
        predicate = bool
    for x in iterable:
        if not predicate(x):
            yield x
itertools.groupby(iterable, key=None)

创建一个迭代器,返回 iterable 中连续的键和组。key 是一个计算元素键值函数。如果未指定或为 Nonekey 缺省为恒等函数(identity function),返回元素不变。一般来说,iterable 需用同一个键值函数预先排序。

groupby() 操作类似于Unix中的 uniq。当每次 key 函数产生的键值改变时,迭代器会分组或生成一个新组(这就是为什么通常需要使用同一个键值函数先对数据进行排序)。这种行为与SQL的GROUP BY操作不同,SQL的操作会忽略输入的顺序将相同键值的元素分在同组中。

返回的组本身也是一个迭代器,它与 groupby() 共享底层的可迭代对象。因为源是共享的,当 groupby() 对象向后迭代时,前一个组将消失。因此如果稍后还需要返回结果,可保存为列表:

groups = []
uniquekeys = []
data = sorted(data, key=keyfunc)
for k, g in groupby(data, keyfunc):
    groups.append(list(g))      # Store group iterator as a list
    uniquekeys.append(k)

groupby() 大致等價於:

def groupby(iterable, key=None):
    # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B
    # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D

    keyfunc = (lambda x: x) if key is None else key
    iterator = iter(iterable)
    exhausted = False

    def _grouper(target_key):
        nonlocal curr_value, curr_key, exhausted
        yield curr_value
        for curr_value in iterator:
            curr_key = keyfunc(curr_value)
            if curr_key != target_key:
                return
            yield curr_value
        exhausted = True

    try:
        curr_value = next(iterator)
    except StopIteration:
        return
    curr_key = keyfunc(curr_value)

    while not exhausted:
        target_key = curr_key
        curr_group = _grouper(target_key)
        yield curr_key, curr_group
        if curr_key == target_key:
            for _ in curr_group:
                pass
itertools.islice(iterable, stop)
itertools.islice(iterable, start, stop[, step])

Make an iterator that returns selected elements from the iterable. Works like sequence slicing but does not support negative values for start, stop, or step.

If start is zero or None, iteration starts at zero. Otherwise, elements from the iterable are skipped until start is reached.

If stop is None, iteration continues until the iterator is exhausted, if at all. Otherwise, it stops at the specified position.

If step is None, the step defaults to one. Elements are returned consecutively unless step is set higher than one which results in items being skipped.

大致等價於:

def islice(iterable, *args):
    # islice('ABCDEFG', 2) → A B
    # islice('ABCDEFG', 2, 4) → C D
    # islice('ABCDEFG', 2, None) → C D E F G
    # islice('ABCDEFG', 0, None, 2) → A C E G

    s = slice(*args)
    start = 0 if s.start is None else s.start
    stop = s.stop
    step = 1 if s.step is None else s.step
    if start < 0 or (stop is not None and stop < 0) or step <= 0:
        raise ValueError

    indices = count() if stop is None else range(max(start, stop))
    next_i = start
    for i, element in zip(indices, iterable):
        if i == next_i:
            yield element
            next_i += step
itertools.pairwise(iterable)

返回从输入 iterable 中获取的连续重叠对。

输出迭代器中 2 元组的数量将比输入的数量少一个。 如果输入可迭代对象中少于两个值则它将为空。

大致等價於:

def pairwise(iterable):
    # pairwise('ABCDEFG') → AB BC CD DE EF FG
    iterator = iter(iterable)
    a = next(iterator, None)
    for b in iterator:
        yield a, b
        a = b

Added in version 3.10.

itertools.permutations(iterable, r=None)

Return successive r length permutations of elements from the iterable.

如果 r 未指定或为 Noner 默认设置为 iterable 的长度,这种情况下,生成所有全长排列。

The output is a subsequence of product() where entries with repeated elements have been filtered out. The length of the output is given by math.perm() which computes n! / (n - r)! when 0 r n or zero when r > n.

The permutation tuples are emitted in lexicographic order according to the order of the input iterable. If the input iterable is sorted, the output tuples will be produced in sorted order.

Elements are treated as unique based on their position, not on their value. If the input elements are unique, there will be no repeated values within a permutation.

大致等價於:

def permutations(iterable, r=None):
    # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) → 012 021 102 120 201 210

    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return

    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])

    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return
itertools.product(*iterables, repeat=1)

可迭代对象输入的笛卡儿积。

大致相当于生成器表达式中的嵌套循环。例如, product(A, B)((x,y) for x in A for y in B) 返回结果一样。

嵌套循环像里程表那样循环变动,每次迭代时将最右侧的元素向后迭代。这种模式形成了一种字典序,因此如果输入的可迭代对象是已排序的,笛卡尔积元组依次序发出。

要计算可迭代对象自身的笛卡尔积,将可选参数 repeat 设定为要重复的次数。例如,product(A, repeat=4)product(A, A, A, A) 是一样的。

该函数大致相当于下面的代码,只不过实际实现方案不会在内存中创建中间结果。

def product(*iterables, repeat=1):
    # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111

    pools = [tuple(pool) for pool in iterables] * repeat

    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]

    for prod in result:
        yield tuple(prod)

product() 运行之前,它会完全耗尽输入的可迭代对象,在内存中保留值的临时池以生成结果积。 相应地,它只适用于有限的输入。

itertools.repeat(object[, times])

创建一个持续地返回 object 的迭代器。 将会无限期地运行除非指定了 times 参数。

大致等價於:

def repeat(object, times=None):
    # repeat(10, 3) → 10 10 10
    if times is None:
        while True:
            yield object
    else:
        for i in range(times):
            yield object

repeat 的一个常见用途是向 mapzip 提供一个常量值的流:

>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
itertools.starmap(function, iterable)

Make an iterator that computes the function using arguments obtained from the iterable. Used instead of map() when argument parameters have already been "pre-zipped" into tuples.

map()starmap() 之间的区别类似于 function(a,b)function(*c) 之间的差异。 大致相当于:

def starmap(function, iterable):
    # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000
    for args in iterable:
        yield function(*args)
itertools.takewhile(predicate, iterable)

Make an iterator that returns elements from the iterable as long as the predicate is true. Roughly equivalent to:

def takewhile(predicate, iterable):
    # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4
    for x in iterable:
        if not predicate(x):
            break
        yield x

Note, the element that first fails the predicate condition is consumed from the input iterator and there is no way to access it. This could be an issue if an application wants to further consume the input iterator after takewhile has been run to exhaustion. To work around this problem, consider using more-iterools before_and_after() instead.

itertools.tee(iterable, n=2)

从一个可迭代对象中返回 n 个独立的迭代器。

大致等價於:

def tee(iterable, n=2):
    iterator = iter(iterable)
    shared_link = [None, None]
    return tuple(_tee(iterator, shared_link) for _ in range(n))

def _tee(iterator, link):
    try:
        while True:
            if link[1] is None:
                link[0] = next(iterator)
                link[1] = [None, None]
            value, link = link
            yield value
    except StopIteration:
        return

一旦 tee() 已被创建,原有的 iterable 就不应在任何其他地方使用;否则,iterable 可能会被向下执行而不通知 tee 对象。

tee 迭代器不是线程安全的。 当同时使用由同一个 tee() 调用所返回的迭代器时可能引发 RuntimeError,即使原本的 iterable 是线程安全的。is threadsafe.

该迭代工具可能需要相当大的辅助存储空间(这取决于要保存多少临时数据)。通常,如果一个迭代器在另一个迭代器开始之前就要使用大部份或全部数据,使用 list() 会比 tee() 更快。

itertools.zip_longest(*iterables, fillvalue=None)

Make an iterator that aggregates elements from each of the iterables.

If the iterables are of uneven length, missing values are filled-in with fillvalue. If not specified, fillvalue defaults to None.

Iteration continues until the longest iterable is exhausted.

大致等價於:

def zip_longest(*iterables, fillvalue=None):
    # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D-

    iterators = list(map(iter, iterables))
    num_active = len(iterators)
    if not num_active:
        return

    while True:
        values = []
        for i, iterator in enumerate(iterators):
            try:
                value = next(iterator)
            except StopIteration:
                num_active -= 1
                if not num_active:
                    return
                iterators[i] = repeat(fillvalue)
                value = fillvalue
            values.append(value)
        yield tuple(values)

If one of the iterables is potentially infinite, then the zip_longest() function should be wrapped with something that limits the number of calls (for example islice() or takewhile()).

itertools 配方

本节将展示如何使用现有的 itertools 作为基础构件来创建扩展的工具集。

这些 itertools 专题的主要目的是教学。 各个专题显示了对单个工具的各种思维方式 — 例如,chain.from_iterable 被关联到展平的概念。 这些专题还给出了有关这些工具的组合方式的想法 — 例如,starmap()repeat() 应当如何一起工作。 这些专题还显示了 itertools 与 operatorcollections 模块以及内置迭代工具如 map(), filter(), reversed()enumerate() 相互配合的使用模式。

这些例程的次要目的是作为一个孵化器使用。 accumulate(), compress()pairwise() 等迭代工具最初就是作为例程引入的。 目前,sliding_window(), iter_index()sieve() 例程正在被测试以确定它们是否堪当大任。

基本上所有这些配方和许许多多其他配方都可以通过 Python Package Index 上的 more-itertools 项目来安装:

python -m pip install more-itertools

许多例程提供了与底层工具集相当的高性能。 更好的内存效率是通过每次只处理一个元素而不是将整个可迭代对象放入内存来保证的。 代码量的精简是通过以 函数式风格 来链接工具来实现的。 运行的早速度是通过选择使用“矢量化”构件来取代会导致较大解释器开销的 for 循环和 生成器 来达成的。

import collections
import contextlib
import functools
import math
import operator
import random

def take(n, iterable):
    "Return first n items of the iterable as a list."
    return list(islice(iterable, n))

def prepend(value, iterable):
    "Prepend a single value in front of an iterable."
    # prepend(1, [2, 3, 4]) → 1 2 3 4
    return chain([value], iterable)

def tabulate(function, start=0):
    "Return function(0), function(1), ..."
    return map(function, count(start))

def repeatfunc(func, times=None, *args):
    "Repeat calls to func with specified arguments."
    if times is None:
        return starmap(func, repeat(args))
    return starmap(func, repeat(args, times))

def flatten(list_of_lists):
    "Flatten one level of nesting."
    return chain.from_iterable(list_of_lists)

def ncycles(iterable, n):
    "Returns the sequence elements n times."
    return chain.from_iterable(repeat(tuple(iterable), n))

def tail(n, iterable):
    "Return an iterator over the last n items."
    # tail(3, 'ABCDEFG') → E F G
    return iter(collections.deque(iterable, maxlen=n))

def consume(iterator, n=None):
    "Advance the iterator n-steps ahead. If n is None, consume entirely."
    # Use functions that consume iterators at C speed.
    if n is None:
        collections.deque(iterator, maxlen=0)
    else:
        next(islice(iterator, n, n), None)

def nth(iterable, n, default=None):
    "Returns the nth item or a default value."
    return next(islice(iterable, n, None), default)

def quantify(iterable, predicate=bool):
    "Given a predicate that returns True or False, count the True results."
    return sum(map(predicate, iterable))

def first_true(iterable, default=False, predicate=None):
    "Returns the first true value or the *default* if there is no true value."
    # first_true([a,b,c], x) → a or b or c or x
    # first_true([a,b], x, f) → a if f(a) else b if f(b) else x
    return next(filter(predicate, iterable), default)

def all_equal(iterable, key=None):
    "Returns True if all the elements are equal to each other."
    # all_equal('4٤௪౪໔', key=int) → True
    return len(take(2, groupby(iterable, key))) <= 1

def unique_justseen(iterable, key=None):
    "Yield unique elements, preserving order. Remember only the element just seen."
    # unique_justseen('AAAABBBCCDAABBB') → A B C D A B
    # unique_justseen('ABBcCAD', str.casefold) → A B c A D
    if key is None:
        return map(operator.itemgetter(0), groupby(iterable))
    return map(next, map(operator.itemgetter(1), groupby(iterable, key)))

def unique_everseen(iterable, key=None):
    "Yield unique elements, preserving order. Remember all elements ever seen."
    # unique_everseen('AAAABBBCCDAABBB') → A B C D
    # unique_everseen('ABBcCAD', str.casefold) → A B c D
    seen = set()
    if key is None:
        for element in filterfalse(seen.__contains__, iterable):
            seen.add(element)
            yield element
    else:
        for element in iterable:
            k = key(element)
            if k not in seen:
                seen.add(k)
                yield element

def unique(iterable, key=None, reverse=False):
   "Yield unique elements in sorted order. Supports unhashable inputs."
   # unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
   return unique_justseen(sorted(iterable, key=key, reverse=reverse), key=key)

def sliding_window(iterable, n):
    "Collect data into overlapping fixed-length chunks or blocks."
    # sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
    iterator = iter(iterable)
    window = collections.deque(islice(iterator, n - 1), maxlen=n)
    for x in iterator:
        window.append(x)
        yield tuple(window)

def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks."
    # grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
    # grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
    # grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
    iterators = [iter(iterable)] * n
    match incomplete:
        case 'fill':
            return zip_longest(*iterators, fillvalue=fillvalue)
        case 'strict':
            return zip(*iterators, strict=True)
        case 'ignore':
            return zip(*iterators)
        case _:
            raise ValueError('Expected fill, strict, or ignore')

def roundrobin(*iterables):
    "Visit input iterables in a cycle until each is exhausted."
    # roundrobin('ABC', 'D', 'EF') → A D E B F C
    # Algorithm credited to George Sakkis
    iterators = map(iter, iterables)
    for num_active in range(len(iterables), 0, -1):
        iterators = cycle(islice(iterators, num_active))
        yield from map(next, iterators)

def partition(predicate, iterable):
    """Partition entries into false entries and true entries.

    If *predicate* is slow, consider wrapping it with functools.lru_cache().
    """
    # partition(is_odd, range(10)) → 0 2 4 6 8   and  1 3 5 7 9
    t1, t2 = tee(iterable)
    return filterfalse(predicate, t1), filter(predicate, t2)

def subslices(seq):
    "Return all contiguous non-empty subslices of a sequence."
    # subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
    slices = starmap(slice, combinations(range(len(seq) + 1), 2))
    return map(operator.getitem, repeat(seq), slices)

def iter_index(iterable, value, start=0, stop=None):
    "Return indices where a value occurs in a sequence or iterable."
    # iter_index('AABCADEAF', 'A') → 0 1 4 7
    seq_index = getattr(iterable, 'index', None)
    if seq_index is None:
        iterator = islice(iterable, start, stop)
        for i, element in enumerate(iterator, start):
            if element is value or element == value:
                yield i
    else:
        stop = len(iterable) if stop is None else stop
        i = start
        with contextlib.suppress(ValueError):
            while True:
                yield (i := seq_index(value, i, stop))
                i += 1

def iter_except(func, exception, first=None):
    "Convert a call-until-exception interface to an iterator interface."
    # iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
    with contextlib.suppress(exception):
        if first is not None:
            yield first()
        while True:
            yield func()

下面的例程具有更数学化的风格:

def powerset(iterable):
    "powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def sum_of_squares(iterable):
    "Add up the squares of the input values."
    # sum_of_squares([10, 20, 30]) → 1400
    return math.sumprod(*tee(iterable))

def reshape(matrix, cols):
    "Reshape a 2-D matrix to have a given number of columns."
    # reshape([(0, 1), (2, 3), (4, 5)], 3) →  (0, 1, 2), (3, 4, 5)
    return batched(chain.from_iterable(matrix), cols, strict=True)

def transpose(matrix):
    "Swap the rows and columns of a 2-D matrix."
    # transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
    return zip(*matrix, strict=True)

def matmul(m1, m2):
    "Multiply two matrices."
    # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
    n = len(m2[0])
    return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)

def convolve(signal, kernel):
    """Discrete linear convolution of two iterables.
    Equivalent to polynomial multiplication.

    Convolutions are mathematically commutative; however, the inputs are
    evaluated differently.  The signal is consumed lazily and can be
    infinite. The kernel is fully consumed before the calculations begin.

    Article:  https://betterexplained.com/articles/intuitive-convolution/
    Video:    https://www.youtube.com/watch?v=KuXjwB4LzSA
    """
    # convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
    # convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
    # convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
    # convolve(data, [1, -2, 1]) → 2nd derivative estimate
    kernel = tuple(kernel)[::-1]
    n = len(kernel)
    padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
    windowed_signal = sliding_window(padded_signal, n)
    return map(math.sumprod, repeat(kernel), windowed_signal)

def polynomial_from_roots(roots):
    """Compute a polynomial's coefficients from its roots.

       (x - 5) (x + 4) (x - 3)  expands to:   x³ -4x² -17x + 60
    """
    # polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
    factors = zip(repeat(1), map(operator.neg, roots))
    return list(functools.reduce(convolve, factors, [1]))

def polynomial_eval(coefficients, x):
    """Evaluate a polynomial at a specific value.

    Computes with better numeric stability than Horner's method.
    """
    # Evaluate x³ -4x² -17x + 60 at x = 5
    # polynomial_eval([1, -4, -17, 60], x=5) → 0
    n = len(coefficients)
    if not n:
        return type(x)(0)
    powers = map(pow, repeat(x), reversed(range(n)))
    return math.sumprod(coefficients, powers)

def polynomial_derivative(coefficients):
    """Compute the first derivative of a polynomial.

       f(x)  =  x³ -4x² -17x + 60
       f'(x) = 3x² -8x  -17
    """
    # polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
    n = len(coefficients)
    powers = reversed(range(1, n))
    return list(map(operator.mul, coefficients, powers))

def sieve(n):
    "Primes less than n."
    # sieve(30) → 2 3 5 7 11 13 17 19 23 29
    if n > 2:
        yield 2
    data = bytearray((0, 1)) * (n // 2)
    for p in iter_index(data, 1, start=3, stop=math.isqrt(n) + 1):
        data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
    yield from iter_index(data, 1, start=3)

def factor(n):
    "Prime factors of n."
    # factor(99) → 3 3 11
    # factor(1_000_000_000_000_007) → 47 59 360620266859
    # factor(1_000_000_000_000_403) → 1000000000000403
    for prime in sieve(math.isqrt(n) + 1):
        while not n % prime:
            yield prime
            n //= prime
            if n == 1:
                return
    if n > 1:
        yield n

def totient(n):
    "Count of natural numbers up to n that are coprime to n."
    # https://mathworld.wolfram.com/TotientFunction.html
    # totient(12) → 4 because len([1, 5, 7, 11]) == 4
    for prime in set(factor(n)):
        n -= n // prime
    return n