itertools --- 建立產生高效率迴圈之疊代器的函式


這個模組實作了許多 疊代器 (iterator) 構建塊 (building block),其靈感來自 APL、Haskell 和 SML 的結構。每個構建塊都以適合 Python 的形式來重新設計。

這個模組標準化了快速且高效率利用記憶體的核心工具集,這些工具本身或組合使用都很有用。它們共同構成了一個「疊代器代數 (iterator algebra)」,使得在純 Python 中簡潔且高效地建構專用工具成為可能。

例如,SML 提供了一個造表工具:tabulate(f),它產生一個序列 f(0), f(1), ...。在 Python 中,可以透過結合 map()count() 組成 map(f, count()) 以達到同樣的效果。

這些工具及其內建的對等部分 (counterpart) 也可以很好地與 operator 模組中的高速函式配合使用。例如,乘法運算子可以對映到兩個向量上以組成高效率的內積:sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))

無限疊代器:

疊代器

引數

結果

範例

count()

[start[, step]]

start, start+step, start+2*step, ...

count(10) 10 11 12 13 14 ...

cycle()

p

p0, p1, ... plast, p0, p1, ...

cycle('ABCD') A B C D A B C D ...

repeat()

elem [,n]

elem, elem, elem,... 重複無限次或 n 次

repeat(10, 3) 10 10 10

在最短輸入序列 (shortest input sequence) 處終止的疊代器:

疊代器

引數

結果

範例

accumulate()

p [,func]

p0, p0+p1, p0+p1+p2, ...

accumulate([1,2,3,4,5]) 1 3 6 10 15

batched()

p, n

(p0, p1, ..., p_n-1), ...

batched('ABCDEFG', n=3) ABC DEF G

chain()

p, q, ...

p0, p1, ... plast, q0, q1, ...

chain('ABC', 'DEF') A B C D E F

chain.from_iterable()

可疊代物件

p0, p1, ... plast, q0, q1, ...

chain.from_iterable(['ABC', 'DEF']) A B C D E F

compress()

data, selectors

(d[0] if s[0]), (d[1] if s[1]), ...

compress('ABCDEF', [1,0,1,0,1,1]) A C E F

dropwhile()

predicate, seq

seq[n], seq[n+1],當 predicate 失敗時開始

dropwhile(lambda x: x<5, [1,4,6,3,8]) 6 3 8

filterfalse()

predicate, seq

當 predicate(elem) 失敗時 seq 的元素

filterfalse(lambda x: x<5, [1,4,6,3,8]) 6 8

groupby()

iterable[, key]

根據 key(v) 的值分組的子疊代器

islice()

seq, [start,] stop [, step]

seq[start:stop:step] 的元素

islice('ABCDEFG', 2, None) C D E F G

pairwise()

可疊代物件

(p[0], p[1]), (p[1], p[2])

pairwise('ABCDEFG') AB BC CD DE EF FG

starmap()

func, seq

func(*seq[0]), func(*seq[1]), ...

starmap(pow, [(2,5), (3,2), (10,3)]) 32 9 1000

takewhile()

predicate, seq

seq[0], seq[1],直到 predicate 失敗

takewhile(lambda x: x<5, [1,4,6,3,8]) 1 4

tee()

it, n

it1, it2, ... itn,將一個疊代器分成 n 個

zip_longest()

p, q, ...

(p[0], q[0]), (p[1], q[1]), ...

zip_longest('ABCD', 'xy', fillvalue='-') Ax By C- D-

組合疊代器:

疊代器

引數

結果

product()

p, q, ... [repeat=1]

笛卡爾乘積 (cartesian product),相當於巢狀的 for 迴圈

permutations()

p[, r]

長度為 r 的元組,所有可能的定序,無重複元素

combinations()

p, r

長度為 r 的元組,按照排序過後的定序,無重複元素

combinations_with_replacement()

p, r

長度為 r 的元組,按照排序過後的定序,有重複元素

範例

結果

product('ABCD', repeat=2)

AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD

permutations('ABCD', 2)

AB AC AD BA BC BD CA CB CD DA DB DC

combinations('ABCD', 2)

AB AC AD BC BD CD

combinations_with_replacement('ABCD', 2)

AA AB AC AD BB BC BD CC CD DD

Itertool 函式

以下的模組函式都會建構並回傳疊代器。一些函式提供無限長度的串流 (stream),因此應僅由截斷串流的函式或迴圈來存取它們。

itertools.accumulate(iterable[, function, *, initial=None])

建立一個回傳累積和的疊代器,或其他二進位函式的累積結果。

function 預設為加法。function 應接受兩個引數,即累積總和和來自 iterable 的值。

如果提供了 initial 值,則累積將從該值開始,並且輸出的元素數將比輸入的可疊代物件多一個。

大致等價於:

def accumulate(iterable, function=operator.add, *, initial=None):
    'Return running totals'
    # accumulate([1,2,3,4,5]) → 1 3 6 10 15
    # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115
    # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120

    iterator = iter(iterable)
    total = initial
    if initial is None:
        try:
            total = next(iterator)
        except StopIteration:
            return

    yield total
    for element in iterator:
        total = function(total, element)
        yield total

function 引數可以被設定為 min() 以得到連續的最小值,設定為 max() 以得到連續的最大值,或者設定為 operator.mul() 以得到連續的乘積。也可以透過累積利息和付款來建立攤銷表 (Amortization tables)

>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, max))              # running maximum
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
>>> list(accumulate(data, operator.mul))     # running product
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]

# Amortize a 5% loan of 1000 with 10 annual payments of 90
>>> update = lambda balance, payment: round(balance * 1.05) - payment
>>> list(accumulate(repeat(90, 10), update, initial=1_000))
[1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]

可參見 functools.reduce(),其是個類似的函式,但僅回傳最終的累積值。

在 3.2 版被加入.

在 3.3 版的變更: 新增可選的 function 參數。

在 3.8 版的變更: 新增可選的 initial 參數。

itertools.batched(iterable, n)

將來自 iterable 的資料分批為長度為 n 的元組。最後一個批次可能比 n 短。

對輸入的可疊代物件進行迴圈,並將資料累積到大小為 n 的元組中。輸入是惰性地被消耗 (consumed lazily) 的,會剛好足夠填充一批的資料。一旦批次填滿或輸入的可疊代物件耗盡,就會 yield 出結果:

>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
>>> unflattened = list(batched(flattened_data, 2))
>>> unflattened
[('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]

大致等價於:

def batched(iterable, n):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        yield batch

在 3.12 版被加入.

itertools.chain(*iterables)

建立一個疊代器,從第一個可疊代物件回傳元素直到其耗盡,然後繼續處理下一個可疊代物件,直到所有可疊代物件都耗盡。用於將連續的序列做為單一序列處理。大致等價於:

def chain(*iterables):
    # chain('ABC', 'DEF') → A B C D E F
    for iterable in iterables:
        yield from iterable
classmethod chain.from_iterable(iterable)

chain() 的另一個建構函式。從單個可疊代的引數中得到鏈接的輸入,該引數是惰性計算的。大致等價於:

def from_iterable(iterables):
    # chain.from_iterable(['ABC', 'DEF']) → A B C D E F
    for iterable in iterables:
        yield from iterable
itertools.combinations(iterable, r)

從輸入 iterable 中回傳長度為 r 的元素的子序列。

輸出是 product() 的子序列,僅保留作為 iterable 子序列的條目。輸出的長度由 math.comb() 給定,當 0 r n 時,計算 n! / r! / (n - r)!,當 r > n 時為零。

根據輸入值 iterable 的順序,組合的元組會按照字典順序輸出。如果輸入的 iterable 已經排序,則輸出的元組也將按排序的順序產生。

元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,則每個組合內將不會有重複的值。

大致等價於:

def combinations(iterable, r):
    # combinations('ABCD', 2) → AB AC AD BC BD CD
    # combinations(range(4), 3) → 012 013 023 123

    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)
itertools.combinations_with_replacement(iterable, r)

回傳來自輸入 iterable 的長度為 r 的子序列,且允許個別元素重複多次。

其輸出是一個 product() 的子序列,僅保留作為 iterable 子序列(可能有重複元素)的條目。當 n > 0 時,回傳的子序列數量為 (n + r - 1)! / r! / (n - 1)!

根據輸入值 iterable 的順序,組合的元組會按照字典順序輸出。如果輸入的 iterable 已經排序,則輸出的元組也將按排序的順序產生。

元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,生成的組合也將是獨特的。

大致等價於:

def combinations_with_replacement(iterable, r):
    # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC

    pool = tuple(iterable)
    n = len(pool)
    if not n and r:
        return
    indices = [0] * r

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != n - 1:
                break
        else:
            return
        indices[i:] = [indices[i] + 1] * (r - i)
        yield tuple(pool[i] for i in indices)

在 3.1 版被加入.

itertools.compress(data, selectors)

建立一個疊代器,回傳 data 中對應 selectors 的元素為真的元素。當 dataselectors 可疊代物件耗盡時停止。大致等價於:

def compress(data, selectors):
    # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F
    return (datum for datum, selector in zip(data, selectors) if selector)

在 3.1 版被加入.

itertools.count(start=0, step=1)

建立一個疊代器,回傳從 start 開始的等差的值。可以與 map() 一起使用來產生連續的資料點,或與 zip() 一起使用來增加序列號。大致等價於:

def count(start=0, step=1):
    # count(10) → 10 11 12 13 14 ...
    # count(2.5, 0.5) → 2.5 3.0 3.5 ...
    n = start
    while True:
        yield n
        n += step

When counting with floating-point numbers, better accuracy can sometimes be achieved by substituting multiplicative code such as: (start + step * i for i in count()).

在 3.1 版的變更: 新增 step 引數並允許非整數引數。

itertools.cycle(iterable)

建立一個疊代器,回傳 iterable 中的元素並保存每個元素的副本。當可疊代物件耗盡時,從保存的副本中回傳元素。會無限次的重複。大致等價於:

def cycle(iterable):
    # cycle('ABCD') → A B C D A B C D A B C D ...
    saved = []
    for element in iterable:
        yield element
        saved.append(element)
    while saved:
        for element in saved:
            yield element

此 itertool 可能需要大量的輔助儲存空間(取決於可疊代物件的長度)。

itertools.dropwhile(predicate, iterable)

建立一個疊代器,在 predicate 為真時丟棄 iterable 中的元素,之後回傳每個元素。大致等價於:

def dropwhile(predicate, iterable):
    # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8

    iterator = iter(iterable)
    for x in iterator:
        if not predicate(x):
            yield x
            break

    for x in iterator:
        yield x

注意,在 predicate 首次變為 False 之前,這不會產生任何輸出,所以此 itertool 可能會有較長的啟動時間。

itertools.filterfalse(predicate, iterable)

建立一個疊代器,過濾 iterable 中的元素,僅回傳 predicate 為 False 值的元素。如果 predicateNone,則回傳為 False 的項目。大致等價於:

def filterfalse(predicate, iterable):
    # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8
    if predicate is None:
        predicate = bool
    for x in iterable:
        if not predicate(x):
            yield x
itertools.groupby(iterable, key=None)

建立一個疊代器,回傳 iterable 中連續的鍵和群組。key 是一個為每個元素計算鍵值的函式。如果其未指定或為 None,則 key 預設為一個識別性函式 (identity function),並回傳未被更改的元素。一般來說,可疊代物件需要已經用相同的鍵函式進行排序。

groupby() 的操作類似於 Unix 中的 uniq 過濾器。每當鍵函式的值發生變化時,它會產生一個 break 或新的群組(這就是為什麼通常需要使用相同的鍵函式對資料進行排序)。這種行為不同於 SQL 的 GROUP BY,其無論輸入順序如何都會聚合相同的元素。

回傳的群組本身是一個與 groupby() 共享底層可疊代物件的疊代器。由於來源是共享的,當 groupby() 物件前進時,前一個群組將不再可見。因此,如果之後需要該資料,應將其儲存為串列:

groups = []
uniquekeys = []
data = sorted(data, key=keyfunc)
for k, g in groupby(data, keyfunc):
    groups.append(list(g))      # Store group iterator as a list
    uniquekeys.append(k)

groupby() 大致等價於:

def groupby(iterable, key=None):
    # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B
    # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D

    keyfunc = (lambda x: x) if key is None else key
    iterator = iter(iterable)
    exhausted = False

    def _grouper(target_key):
        nonlocal curr_value, curr_key, exhausted
        yield curr_value
        for curr_value in iterator:
            curr_key = keyfunc(curr_value)
            if curr_key != target_key:
                return
            yield curr_value
        exhausted = True

    try:
        curr_value = next(iterator)
    except StopIteration:
        return
    curr_key = keyfunc(curr_value)

    while not exhausted:
        target_key = curr_key
        curr_group = _grouper(target_key)
        yield curr_key, curr_group
        if curr_key == target_key:
            for _ in curr_group:
                pass
itertools.islice(iterable, stop)
itertools.islice(iterable, start, stop[, step])

Make an iterator that returns selected elements from the iterable. Works like sequence slicing but does not support negative values for start, stop, or step.

If start is zero or None, iteration starts at zero. Otherwise, elements from the iterable are skipped until start is reached.

If stop is None, iteration continues until the iterator is exhausted, if at all. Otherwise, it stops at the specified position.

If step is None, the step defaults to one. Elements are returned consecutively unless step is set higher than one which results in items being skipped.

大致等價於:

def islice(iterable, *args):
    # islice('ABCDEFG', 2) → A B
    # islice('ABCDEFG', 2, 4) → C D
    # islice('ABCDEFG', 2, None) → C D E F G
    # islice('ABCDEFG', 0, None, 2) → A C E G

    s = slice(*args)
    start = 0 if s.start is None else s.start
    stop = s.stop
    step = 1 if s.step is None else s.step
    if start < 0 or (stop is not None and stop < 0) or step <= 0:
        raise ValueError

    indices = count() if stop is None else range(max(start, stop))
    next_i = start
    for i, element in zip(indices, iterable):
        if i == next_i:
            yield element
            next_i += step
itertools.pairwise(iterable)

Return successive overlapping pairs taken from the input iterable.

The number of 2-tuples in the output iterator will be one fewer than the number of inputs. It will be empty if the input iterable has fewer than two values.

大致等價於:

def pairwise(iterable):
    # pairwise('ABCDEFG') → AB BC CD DE EF FG
    iterator = iter(iterable)
    a = next(iterator, None)
    for b in iterator:
        yield a, b
        a = b

在 3.10 版被加入.

itertools.permutations(iterable, r=None)

Return successive r length permutations of elements from the iterable.

If r is not specified or is None, then r defaults to the length of the iterable and all possible full-length permutations are generated.

The output is a subsequence of product() where entries with repeated elements have been filtered out. The length of the output is given by math.perm() which computes n! / (n - r)! when 0 r n or zero when r > n.

The permutation tuples are emitted in lexicographic order according to the order of the input iterable. If the input iterable is sorted, the output tuples will be produced in sorted order.

Elements are treated as unique based on their position, not on their value. If the input elements are unique, there will be no repeated values within a permutation.

大致等價於:

def permutations(iterable, r=None):
    # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) → 012 021 102 120 201 210

    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return

    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])

    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return
itertools.product(*iterables, repeat=1)

Cartesian product of input iterables.

Roughly equivalent to nested for-loops in a generator expression. For example, product(A, B) returns the same as ((x,y) for x in A for y in B).

The nested loops cycle like an odometer with the rightmost element advancing on every iteration. This pattern creates a lexicographic ordering so that if the input's iterables are sorted, the product tuples are emitted in sorted order.

To compute the product of an iterable with itself, specify the number of repetitions with the optional repeat keyword argument. For example, product(A, repeat=4) means the same as product(A, A, A, A).

This function is roughly equivalent to the following code, except that the actual implementation does not build up intermediate results in memory:

def product(*iterables, repeat=1):
    # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111

    pools = [tuple(pool) for pool in iterables] * repeat

    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]

    for prod in result:
        yield tuple(prod)

Before product() runs, it completely consumes the input iterables, keeping pools of values in memory to generate the products. Accordingly, it is only useful with finite inputs.

itertools.repeat(object[, times])

Make an iterator that returns object over and over again. Runs indefinitely unless the times argument is specified.

大致等價於:

def repeat(object, times=None):
    # repeat(10, 3) → 10 10 10
    if times is None:
        while True:
            yield object
    else:
        for i in range(times):
            yield object

A common use for repeat is to supply a stream of constant values to map or zip:

>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
itertools.starmap(function, iterable)

Make an iterator that computes the function using arguments obtained from the iterable. Used instead of map() when argument parameters have already been "pre-zipped" into tuples.

The difference between map() and starmap() parallels the distinction between function(a,b) and function(*c). Roughly equivalent to:

def starmap(function, iterable):
    # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000
    for args in iterable:
        yield function(*args)
itertools.takewhile(predicate, iterable)

Make an iterator that returns elements from the iterable as long as the predicate is true. Roughly equivalent to:

def takewhile(predicate, iterable):
    # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4
    for x in iterable:
        if not predicate(x):
            break
        yield x

Note, the element that first fails the predicate condition is consumed from the input iterator and there is no way to access it. This could be an issue if an application wants to further consume the input iterator after takewhile has been run to exhaustion. To work around this problem, consider using more-iterools before_and_after() instead.

itertools.tee(iterable, n=2)

Return n independent iterators from a single iterable.

大致等價於:

def tee(iterable, n=2):
    iterator = iter(iterable)
    shared_link = [None, None]
    return tuple(_tee(iterator, shared_link) for _ in range(n))

def _tee(iterator, link):
    try:
        while True:
            if link[1] is None:
                link[0] = next(iterator)
                link[1] = [None, None]
            value, link = link
            yield value
    except StopIteration:
        return

Once a tee() has been created, the original iterable should not be used anywhere else; otherwise, the iterable could get advanced without the tee objects being informed.

tee iterators are not threadsafe. A RuntimeError may be raised when simultaneously using iterators returned by the same tee() call, even if the original iterable is threadsafe.

This itertool may require significant auxiliary storage (depending on how much temporary data needs to be stored). In general, if one iterator uses most or all of the data before another iterator starts, it is faster to use list() instead of tee().

itertools.zip_longest(*iterables, fillvalue=None)

Make an iterator that aggregates elements from each of the iterables.

If the iterables are of uneven length, missing values are filled-in with fillvalue. If not specified, fillvalue defaults to None.

Iteration continues until the longest iterable is exhausted.

大致等價於:

def zip_longest(*iterables, fillvalue=None):
    # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D-

    iterators = list(map(iter, iterables))
    num_active = len(iterators)
    if not num_active:
        return

    while True:
        values = []
        for i, iterator in enumerate(iterators):
            try:
                value = next(iterator)
            except StopIteration:
                num_active -= 1
                if not num_active:
                    return
                iterators[i] = repeat(fillvalue)
                value = fillvalue
            values.append(value)
        yield tuple(values)

If one of the iterables is potentially infinite, then the zip_longest() function should be wrapped with something that limits the number of calls (for example islice() or takewhile()).

Itertools Recipes

This section shows recipes for creating an extended toolset using the existing itertools as building blocks.

The primary purpose of the itertools recipes is educational. The recipes show various ways of thinking about individual tools — for example, that chain.from_iterable is related to the concept of flattening. The recipes also give ideas about ways that the tools can be combined — for example, how starmap() and repeat() can work together. The recipes also show patterns for using itertools with the operator and collections modules as well as with the built-in itertools such as map(), filter(), reversed(), and enumerate().

A secondary purpose of the recipes is to serve as an incubator. The accumulate(), compress(), and pairwise() itertools started out as recipes. Currently, the sliding_window(), iter_index(), and sieve() recipes are being tested to see whether they prove their worth.

Substantially all of these recipes and many, many others can be installed from the more-itertools project found on the Python Package Index:

python -m pip install more-itertools

Many of the recipes offer the same high performance as the underlying toolset. Superior memory performance is kept by processing elements one at a time rather than bringing the whole iterable into memory all at once. Code volume is kept small by linking the tools together in a functional style. High speed is retained by preferring "vectorized" building blocks over the use of for-loops and generators which incur interpreter overhead.

import collections
import contextlib
import functools
import math
import operator
import random

def take(n, iterable):
    "Return first n items of the iterable as a list."
    return list(islice(iterable, n))

def prepend(value, iterable):
    "Prepend a single value in front of an iterable."
    # prepend(1, [2, 3, 4]) → 1 2 3 4
    return chain([value], iterable)

def tabulate(function, start=0):
    "Return function(0), function(1), ..."
    return map(function, count(start))

def repeatfunc(func, times=None, *args):
    "Repeat calls to func with specified arguments."
    if times is None:
        return starmap(func, repeat(args))
    return starmap(func, repeat(args, times))

def flatten(list_of_lists):
    "Flatten one level of nesting."
    return chain.from_iterable(list_of_lists)

def ncycles(iterable, n):
    "Returns the sequence elements n times."
    return chain.from_iterable(repeat(tuple(iterable), n))

def tail(n, iterable):
    "Return an iterator over the last n items."
    # tail(3, 'ABCDEFG') → E F G
    return iter(collections.deque(iterable, maxlen=n))

def consume(iterator, n=None):
    "Advance the iterator n-steps ahead. If n is None, consume entirely."
    # Use functions that consume iterators at C speed.
    if n is None:
        collections.deque(iterator, maxlen=0)
    else:
        next(islice(iterator, n, n), None)

def nth(iterable, n, default=None):
    "Returns the nth item or a default value."
    return next(islice(iterable, n, None), default)

def quantify(iterable, predicate=bool):
    "Given a predicate that returns True or False, count the True results."
    return sum(map(predicate, iterable))

def first_true(iterable, default=False, predicate=None):
    "Returns the first true value or the *default* if there is no true value."
    # first_true([a,b,c], x) → a or b or c or x
    # first_true([a,b], x, f) → a if f(a) else b if f(b) else x
    return next(filter(predicate, iterable), default)

def all_equal(iterable, key=None):
    "Returns True if all the elements are equal to each other."
    # all_equal('4٤௪౪໔', key=int) → True
    return len(take(2, groupby(iterable, key))) <= 1

def unique_justseen(iterable, key=None):
    "Yield unique elements, preserving order. Remember only the element just seen."
    # unique_justseen('AAAABBBCCDAABBB') → A B C D A B
    # unique_justseen('ABBcCAD', str.casefold) → A B c A D
    if key is None:
        return map(operator.itemgetter(0), groupby(iterable))
    return map(next, map(operator.itemgetter(1), groupby(iterable, key)))

def unique_everseen(iterable, key=None):
    "Yield unique elements, preserving order. Remember all elements ever seen."
    # unique_everseen('AAAABBBCCDAABBB') → A B C D
    # unique_everseen('ABBcCAD', str.casefold) → A B c D
    seen = set()
    if key is None:
        for element in filterfalse(seen.__contains__, iterable):
            seen.add(element)
            yield element
    else:
        for element in iterable:
            k = key(element)
            if k not in seen:
                seen.add(k)
                yield element

def unique(iterable, key=None, reverse=False):
   "Yield unique elements in sorted order. Supports unhashable inputs."
   # unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
   return unique_justseen(sorted(iterable, key=key, reverse=reverse), key=key)

def sliding_window(iterable, n):
    "Collect data into overlapping fixed-length chunks or blocks."
    # sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
    iterator = iter(iterable)
    window = collections.deque(islice(iterator, n - 1), maxlen=n)
    for x in iterator:
        window.append(x)
        yield tuple(window)

def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks."
    # grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
    # grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
    # grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
    iterators = [iter(iterable)] * n
    match incomplete:
        case 'fill':
            return zip_longest(*iterators, fillvalue=fillvalue)
        case 'strict':
            return zip(*iterators, strict=True)
        case 'ignore':
            return zip(*iterators)
        case _:
            raise ValueError('Expected fill, strict, or ignore')

def roundrobin(*iterables):
    "Visit input iterables in a cycle until each is exhausted."
    # roundrobin('ABC', 'D', 'EF') → A D E B F C
    # Algorithm credited to George Sakkis
    iterators = map(iter, iterables)
    for num_active in range(len(iterables), 0, -1):
        iterators = cycle(islice(iterators, num_active))
        yield from map(next, iterators)

def partition(predicate, iterable):
    """Partition entries into false entries and true entries.

    If *predicate* is slow, consider wrapping it with functools.lru_cache().
    """
    # partition(is_odd, range(10)) → 0 2 4 6 8   and  1 3 5 7 9
    t1, t2 = tee(iterable)
    return filterfalse(predicate, t1), filter(predicate, t2)

def subslices(seq):
    "Return all contiguous non-empty subslices of a sequence."
    # subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
    slices = starmap(slice, combinations(range(len(seq) + 1), 2))
    return map(operator.getitem, repeat(seq), slices)

def iter_index(iterable, value, start=0, stop=None):
    "Return indices where a value occurs in a sequence or iterable."
    # iter_index('AABCADEAF', 'A') → 0 1 4 7
    seq_index = getattr(iterable, 'index', None)
    if seq_index is None:
        iterator = islice(iterable, start, stop)
        for i, element in enumerate(iterator, start):
            if element is value or element == value:
                yield i
    else:
        stop = len(iterable) if stop is None else stop
        i = start
        with contextlib.suppress(ValueError):
            while True:
                yield (i := seq_index(value, i, stop))
                i += 1

def iter_except(func, exception, first=None):
    "Convert a call-until-exception interface to an iterator interface."
    # iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
    with contextlib.suppress(exception):
        if first is not None:
            yield first()
        while True:
            yield func()

The following recipes have a more mathematical flavor:

def powerset(iterable):
    "powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def sum_of_squares(iterable):
    "Add up the squares of the input values."
    # sum_of_squares([10, 20, 30]) → 1400
    return math.sumprod(*tee(iterable))

def reshape(matrix, cols):
    "Reshape a 2-D matrix to have a given number of columns."
    # reshape([(0, 1), (2, 3), (4, 5)], 3) →  (0, 1, 2), (3, 4, 5)
    return batched(chain.from_iterable(matrix), cols)

def transpose(matrix):
    "Swap the rows and columns of a 2-D matrix."
    # transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
    return zip(*matrix, strict=True)

def matmul(m1, m2):
    "Multiply two matrices."
    # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
    n = len(m2[0])
    return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)

def convolve(signal, kernel):
    """Discrete linear convolution of two iterables.
    Equivalent to polynomial multiplication.

    Convolutions are mathematically commutative; however, the inputs are
    evaluated differently.  The signal is consumed lazily and can be
    infinite. The kernel is fully consumed before the calculations begin.

    Article:  https://betterexplained.com/articles/intuitive-convolution/
    Video:    https://www.youtube.com/watch?v=KuXjwB4LzSA
    """
    # convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
    # convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
    # convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
    # convolve(data, [1, -2, 1]) → 2nd derivative estimate
    kernel = tuple(kernel)[::-1]
    n = len(kernel)
    padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
    windowed_signal = sliding_window(padded_signal, n)
    return map(math.sumprod, repeat(kernel), windowed_signal)

def polynomial_from_roots(roots):
    """Compute a polynomial's coefficients from its roots.

       (x - 5) (x + 4) (x - 3)  expands to:   x³ -4x² -17x + 60
    """
    # polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
    factors = zip(repeat(1), map(operator.neg, roots))
    return list(functools.reduce(convolve, factors, [1]))

def polynomial_eval(coefficients, x):
    """Evaluate a polynomial at a specific value.

    Computes with better numeric stability than Horner's method.
    """
    # Evaluate x³ -4x² -17x + 60 at x = 5
    # polynomial_eval([1, -4, -17, 60], x=5) → 0
    n = len(coefficients)
    if not n:
        return type(x)(0)
    powers = map(pow, repeat(x), reversed(range(n)))
    return math.sumprod(coefficients, powers)

def polynomial_derivative(coefficients):
    """Compute the first derivative of a polynomial.

       f(x)  =  x³ -4x² -17x + 60
       f'(x) = 3x² -8x  -17
    """
    # polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
    n = len(coefficients)
    powers = reversed(range(1, n))
    return list(map(operator.mul, coefficients, powers))

def sieve(n):
    "Primes less than n."
    # sieve(30) → 2 3 5 7 11 13 17 19 23 29
    if n > 2:
        yield 2
    data = bytearray((0, 1)) * (n // 2)
    for p in iter_index(data, 1, start=3, stop=math.isqrt(n) + 1):
        data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
    yield from iter_index(data, 1, start=3)

def factor(n):
    "Prime factors of n."
    # factor(99) → 3 3 11
    # factor(1_000_000_000_000_007) → 47 59 360620266859
    # factor(1_000_000_000_000_403) → 1000000000000403
    for prime in sieve(math.isqrt(n) + 1):
        while not n % prime:
            yield prime
            n //= prime
            if n == 1:
                return
    if n > 1:
        yield n

def totient(n):
    "Count of natural numbers up to n that are coprime to n."
    # https://mathworld.wolfram.com/TotientFunction.html
    # totient(12) → 4 because len([1, 5, 7, 11]) == 4
    for prime in set(factor(n)):
        n -= n // prime
    return n