itertools --- 建立產生高效率迴圈之疊代器的函式


這個模組實作了許多 疊代器 (iterator) 構建塊 (building block),其靈感來自 APL、Haskell 和 SML 的結構。每個構建塊都以適合 Python 的形式來重新設計。

這個模組標準化了快速且高效率利用記憶體的核心工具集,這些工具本身或組合使用都很有用。它們共同構成了一個「疊代器代數 (iterator algebra)」,使得在純 Python 中簡潔且高效地建構專用工具成為可能。

例如,SML 提供了一個造表工具:tabulate(f),它產生一個序列 f(0), f(1), ...。在 Python 中,可以透過結合 map()count() 組成 map(f, count()) 以達到同樣的效果。

無限疊代器:

疊代器

引數

結果

範例

count()

[start[, step]]

start, start+step, start+2*step, ...

count(10) 10 11 12 13 14 ...

cycle()

p

p0, p1, ... plast, p0, p1, ...

cycle('ABCD') A B C D A B C D ...

repeat()

elem [,n]

elem, elem, elem,... 重複無限次或 n 次

repeat(10, 3) 10 10 10

在最短輸入序列 (shortest input sequence) 處終止的疊代器:

疊代器

引數

結果

範例

accumulate()

p [,func]

p0, p0+p1, p0+p1+p2, ...

accumulate([1,2,3,4,5]) 1 3 6 10 15

batched()

p, n

(p0, p1, ..., p_n-1), ...

batched('ABCDEFG', n=3) ABC DEF G

chain()

p, q, ...

p0, p1, ... plast, q0, q1, ...

chain('ABC', 'DEF') A B C D E F

chain.from_iterable()

可疊代物件

p0, p1, ... plast, q0, q1, ...

chain.from_iterable(['ABC', 'DEF']) A B C D E F

compress()

data, selectors

(d[0] if s[0]), (d[1] if s[1]), ...

compress('ABCDEF', [1,0,1,0,1,1]) A C E F

dropwhile()

predicate, seq

seq[n], seq[n+1],當 predicate 失敗時開始

dropwhile(lambda x: x<5, [1,4,6,3,8]) 6 3 8

filterfalse()

predicate, seq

當 predicate(elem) 失敗時 seq 的元素

filterfalse(lambda x: x<5, [1,4,6,3,8]) 6 8

groupby()

iterable[, key]

根據 key(v) 的值分組的子疊代器

groupby(['A','B','DEF'], len) (1, A B) (3, DEF)

islice()

seq, [start,] stop [, step]

seq[start:stop:step] 的元素

islice('ABCDEFG', 2, None) C D E F G

pairwise()

可疊代物件

(p[0], p[1]), (p[1], p[2])

pairwise('ABCDEFG') AB BC CD DE EF FG

starmap()

func, seq

func(*seq[0]), func(*seq[1]), ...

starmap(pow, [(2,5), (3,2), (10,3)]) 32 9 1000

takewhile()

predicate, seq

seq[0], seq[1],直到 predicate 失敗

takewhile(lambda x: x<5, [1,4,6,3,8]) 1 4

tee()

it, n

it1, it2, ... itn,將一個疊代器分成 n 個

tee('ABC', 2) A B C, A B C

zip_longest()

p, q, ...

(p[0], q[0]), (p[1], q[1]), ...

zip_longest('ABCD', 'xy', fillvalue='-') Ax By C- D-

組合疊代器:

疊代器

引數

結果

product()

p, q, ... [repeat=1]

笛卡爾乘積 (cartesian product),相當於巢狀的 for 迴圈

permutations()

p[, r]

長度為 r 的元組,所有可能的定序,無重複元素

combinations()

p, r

長度為 r 的元組,按照排序過後的定序,無重複元素

combinations_with_replacement()

p, r

長度為 r 的元組,按照排序過後的定序,有重複元素

範例

結果

product('ABCD', repeat=2)

AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD

permutations('ABCD', 2)

AB AC AD BA BC BD CA CB CD DA DB DC

combinations('ABCD', 2)

AB AC AD BC BD CD

combinations_with_replacement('ABCD', 2)

AA AB AC AD BB BC BD CC CD DD

Itertool 函式

以下的函式都會建構並回傳疊代器。一些函式提供無限長度的串流 (stream),因此應僅由截斷串流的函式或迴圈來存取它們。

itertools.accumulate(iterable[, function, *, initial=None])

建立一個回傳累積和的疊代器,或其他二進位函式的累積結果。

function 預設為加法。function 應接受兩個引數,即累積總和和來自 iterable 的值。

如果提供了 initial 值,則累積將從該值開始,並且輸出的元素數將比輸入的可疊代物件多一個。

大致等價於:

def accumulate(iterable, function=operator.add, *, initial=None):
    'Return running totals'
    # accumulate([1,2,3,4,5]) → 1 3 6 10 15
    # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115
    # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120

    iterator = iter(iterable)
    total = initial
    if initial is None:
        try:
            total = next(iterator)
        except StopIteration:
            return

    yield total
    for element in iterator:
        total = function(total, element)
        yield total

function 引數可以被設定為 min() 以得到連續的最小值,設定為 max() 以得到連續的最大值,或者設定為 operator.mul() 以得到連續的乘積。也可以透過累積利息和付款來建立攤銷表 (Amortization tables)

>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, max))              # running maximum
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
>>> list(accumulate(data, operator.mul))     # running product
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]

# Amortize a 5% loan of 1000 with 10 annual payments of 90
>>> update = lambda balance, payment: round(balance * 1.05) - payment
>>> list(accumulate(repeat(90, 10), update, initial=1_000))
[1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]

可參見 functools.reduce(),其是個類似的函式,但僅回傳最終的累積值。

在 3.2 版被加入.

在 3.3 版的變更: 新增可選的 function 參數。

在 3.8 版的變更: 新增可選的 initial 參數。

itertools.batched(iterable, n, *, strict=False)

將來自 iterable 的資料分批為長度為 n 的元組。最後一個批次可能比 n 短。

If strict is true, will raise a ValueError if the final batch is shorter than n.

對輸入的可疊代物件進行迴圈,並將資料累積到大小為 n 的元組中。輸入是惰性地被消耗 (consumed lazily) 的,會剛好足夠填充一批的資料。一旦批次填滿或輸入的可疊代物件耗盡,就會 yield 出結果:

>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
>>> unflattened = list(batched(flattened_data, 2))
>>> unflattened
[('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]

大致等價於:

def batched(iterable, n, *, strict=False):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        if strict and len(batch) != n:
            raise ValueError('batched(): incomplete batch')
        yield batch

在 3.12 版被加入.

在 3.13 版的變更: 新增 strict 選項。

itertools.chain(*iterables)

建立一個疊代器,從第一個可疊代物件回傳元素直到其耗盡,然後繼續處理下一個可疊代物件,直到所有可疊代物件都耗盡。這將多個資料來源結合為單一個疊代器。大致等價於:

def chain(*iterables):
    # chain('ABC', 'DEF') → A B C D E F
    for iterable in iterables:
        yield from iterable
classmethod chain.from_iterable(iterable)

chain() 的另一個建構函式。從單個可疊代的引數中得到鏈接的輸入,該引數是惰性計算的。大致等價於:

def from_iterable(iterables):
    # chain.from_iterable(['ABC', 'DEF']) → A B C D E F
    for iterable in iterables:
        yield from iterable
itertools.combinations(iterable, r)

從輸入 iterable 中回傳長度為 r 的元素的子序列。

輸出是 product() 的子序列,僅保留作為 iterable 子序列的條目。輸出的長度由 math.comb() 給定,當 0 r n 時,長度為 n! / r! / (n - r)!,當 r > n 時為零。

根據輸入值 iterable 的順序,組合的元組會按照字典順序輸出。如果輸入的 iterable 已經排序,則輸出的元組也將按排序的順序產生。

元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,則每個組合內將不會有重複的值。

大致等價於:

def combinations(iterable, r):
    # combinations('ABCD', 2) → AB AC AD BC BD CD
    # combinations(range(4), 3) → 012 013 023 123

    pool = tuple(iterable)
    n = len(pool)
    if r > n:
        return
    indices = list(range(r))

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != i + n - r:
                break
        else:
            return
        indices[i] += 1
        for j in range(i+1, r):
            indices[j] = indices[j-1] + 1
        yield tuple(pool[i] for i in indices)
itertools.combinations_with_replacement(iterable, r)

回傳來自輸入 iterable 的長度為 r 的子序列,且允許個別元素重複多次。

其輸出是一個 product() 的子序列,僅保留作為 iterable 子序列(可能有重複元素)的條目。當 n > 0 時,回傳的子序列數量為 (n + r - 1)! / r! / (n - 1)!

根據輸入值 iterable 的順序,組合的元組會按照字典順序輸出。如果輸入的 iterable 已經排序,則輸出的元組也將按排序的順序產生。

元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,生成的組合也將是獨特的。

大致等價於:

def combinations_with_replacement(iterable, r):
    # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC

    pool = tuple(iterable)
    n = len(pool)
    if not n and r:
        return
    indices = [0] * r

    yield tuple(pool[i] for i in indices)
    while True:
        for i in reversed(range(r)):
            if indices[i] != n - 1:
                break
        else:
            return
        indices[i:] = [indices[i] + 1] * (r - i)
        yield tuple(pool[i] for i in indices)

在 3.1 版被加入.

itertools.compress(data, selectors)

建立一個疊代器,回傳 data 中對應 selectors 的元素為 true 的元素。當 dataselectors 可疊代物件耗盡時停止。大致等價於:

def compress(data, selectors):
    # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F
    return (datum for datum, selector in zip(data, selectors) if selector)

在 3.1 版被加入.

itertools.count(start=0, step=1)

建立一個疊代器,回傳從 start 開始的等差的值。可以與 map() 一起使用來產生連續的資料點,或與 zip() 一起使用來增加序列號。大致等價於:

def count(start=0, step=1):
    # count(10) → 10 11 12 13 14 ...
    # count(2.5, 0.5) → 2.5 3.0 3.5 ...
    n = start
    while True:
        yield n
        n += step

當用浮點數計數時,將上述程式碼替換為乘法有時可以獲得更好的精確度,例如:(start + step * i for i in count())

在 3.1 版的變更: 新增 step 引數並允許非整數引數。

itertools.cycle(iterable)

建立一個疊代器,回傳 iterable 中的元素並保存每個元素的副本。當可疊代物件耗盡時,從保存的副本中回傳元素。會無限次的重複。大致等價於:

def cycle(iterable):
    # cycle('ABCD') → A B C D A B C D A B C D ...

    saved = []
    for element in iterable:
        yield element
        saved.append(element)

    while saved:
        for element in saved:
            yield element

此 itertool 可能需要大量的輔助儲存空間(取決於可疊代物件的長度)。

itertools.dropwhile(predicate, iterable)

建立一個疊代器,在 predicate 為 true 時丟棄 iterable 中的元素,之後回傳每個元素。大致等價於:

def dropwhile(predicate, iterable):
    # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8

    iterator = iter(iterable)
    for x in iterator:
        if not predicate(x):
            yield x
            break

    for x in iterator:
        yield x

注意,在 predicate 首次變為 False 之前,這不會產生任何輸出,所以此 itertool 可能會有較長的啟動時間。

itertools.filterfalse(predicate, iterable)

建立一個疊代器,過濾 iterable 中的元素,僅回傳 predicate 為 False 值的元素。如果 predicateNone,則回傳為 False 的項目。大致等價於:

def filterfalse(predicate, iterable):
    # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8

    if predicate is None:
        predicate = bool

    for x in iterable:
        if not predicate(x):
            yield x
itertools.groupby(iterable, key=None)

建立一個疊代器,回傳 iterable 中連續的鍵和群組。key 是一個為每個元素計算鍵值的函式。如果其未指定或為 None,則 key 預設為一個識別性函式 (identity function),並回傳未被更改的元素。一般來說,可疊代物件需要已經用相同的鍵函式進行排序。

groupby() 的操作類似於 Unix 中的 uniq 過濾器。每當鍵函式的值發生變化時,它會產生一個 break 或新的群組(這就是為什麼通常需要使用相同的鍵函式對資料進行排序)。這種行為不同於 SQL 的 GROUP BY,其無論輸入順序如何都會聚合相同的元素。

回傳的群組本身是一個與 groupby() 共享底層可疊代物件的疊代器。由於來源是共享的,當 groupby() 物件前進時,前一個群組將不再可見。因此,如果之後需要該資料,應將其儲存為串列:

groups = []
uniquekeys = []
data = sorted(data, key=keyfunc)
for k, g in groupby(data, keyfunc):
    groups.append(list(g))      # Store group iterator as a list
    uniquekeys.append(k)

groupby() 大致等價於:

def groupby(iterable, key=None):
    # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B
    # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D

    keyfunc = (lambda x: x) if key is None else key
    iterator = iter(iterable)
    exhausted = False

    def _grouper(target_key):
        nonlocal curr_value, curr_key, exhausted
        yield curr_value
        for curr_value in iterator:
            curr_key = keyfunc(curr_value)
            if curr_key != target_key:
                return
            yield curr_value
        exhausted = True

    try:
        curr_value = next(iterator)
    except StopIteration:
        return
    curr_key = keyfunc(curr_value)

    while not exhausted:
        target_key = curr_key
        curr_group = _grouper(target_key)
        yield curr_key, curr_group
        if curr_key == target_key:
            for _ in curr_group:
                pass
itertools.islice(iterable, stop)
itertools.islice(iterable, start, stop[, step])

建立一個疊代器,回傳從 iterable 中選取的元素。其作用類似於序列切片 (sequence slicing),但不支援負數的 startstopstep 的值。

如果 start 為零或 None,則從零開始疊代。否則在達到 start 之前,會跳過 iterable 中的元素。

如果 stopNone,則疊代將繼續前進直到輸入耗盡。如果指定了 stop,則在達到指定位置時停止。

如果 stepNone,則步長 (step) 預設為一。元素會連續回傳,除非將 step 設定為大於一,這會導致一些項目被跳過。

大致等價於:

def islice(iterable, *args):
    # islice('ABCDEFG', 2) → A B
    # islice('ABCDEFG', 2, 4) → C D
    # islice('ABCDEFG', 2, None) → C D E F G
    # islice('ABCDEFG', 0, None, 2) → A C E G

    s = slice(*args)
    start = 0 if s.start is None else s.start
    stop = s.stop
    step = 1 if s.step is None else s.step
    if start < 0 or (stop is not None and stop < 0) or step <= 0:
        raise ValueError

    indices = count() if stop is None else range(max(start, stop))
    next_i = start
    for i, element in zip(indices, iterable):
        if i == next_i:
            yield element
            next_i += step

If the input is an iterator, then fully consuming the islice advances the input iterator by max(start, stop) steps regardless of the step value.

itertools.pairwise(iterable)

回傳從輸入的 iterable 中提取的連續重疊對。

輸出疊代器中的 2 元組數量將比輸入少一個。如果輸入的可疊代物件中的值少於兩個,則輸出將為空值。

大致等價於:

def pairwise(iterable):
    # pairwise('ABCDEFG') → AB BC CD DE EF FG

    iterator = iter(iterable)
    a = next(iterator, None)

    for b in iterator:
        yield a, b
        a = b

在 3.10 版被加入.

itertools.permutations(iterable, r=None)

回傳 iterable 中連續且長度為 r元素排列

如果未指定 r 或其值為 None,則 r 預設為 iterable 的長度,並產生所有可能的完整長度的排列。

輸出是 product() 的子序列,其中重複元素的條目已被濾除。輸出的長度由 math.perm() 給定,當 0 r n 時,長度為 n! / (n - r)!,當 r > n 時為零。

根據輸入值 iterable 的順序,排列的元組會按照字典順序輸出。如果輸入的 iterable 已排序,則輸出的元組也將按排序的順序產生。

元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,則排列中將不會有重複的值。

大致等價於:

def permutations(iterable, r=None):
    # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC
    # permutations(range(3)) → 012 021 102 120 201 210

    pool = tuple(iterable)
    n = len(pool)
    r = n if r is None else r
    if r > n:
        return

    indices = list(range(n))
    cycles = list(range(n, n-r, -1))
    yield tuple(pool[i] for i in indices[:r])

    while n:
        for i in reversed(range(r)):
            cycles[i] -= 1
            if cycles[i] == 0:
                indices[i:] = indices[i+1:] + indices[i:i+1]
                cycles[i] = n - i
            else:
                j = cycles[i]
                indices[i], indices[-j] = indices[-j], indices[i]
                yield tuple(pool[i] for i in indices[:r])
                break
        else:
            return
itertools.product(*iterables, repeat=1)

Cartesian product of the input iterables.

大致等價於產生器運算式中的巢狀 for 迴圈。例如,product(A, B) 的回傳結果與 ((x,y) for x in A for y in B) 相同。

巢狀迴圈的循環類似於里程表,最右邊的元素在每次疊代時前進。這種模式會建立字典順序,因此如果輸入的 iterables 已排序,則輸出的乘積元組也將按排序的順序產生。

要計算可疊代物件自身的乘積,可以使用可選的 repeat 關鍵字引數來指定重複次數。例如,product(A, repeat=4)product(A, A, A, A) 相同。

此函式大致等價於以下的程式碼,不同之處在於真正的實作不會在記憶體中建立中間結果:

def product(*iterables, repeat=1):
    # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy
    # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111

    if repeat < 0:
        raise ValueError('repeat argument cannot be negative')
    pools = [tuple(pool) for pool in iterables] * repeat

    result = [[]]
    for pool in pools:
        result = [x+[y] for x in result for y in pool]

    for prod in result:
        yield tuple(prod)

product() 執行之前,它會完全消耗輸入的 iterables,並將值的池 (pools of values) 保存在記憶體中以產生乘積。因此,它僅對有限的輸入有用。

itertools.repeat(object[, times])

建立一個疊代器,反覆回傳 object。除非指定了 times 引數,否則會執行無限次。

大致等價於:

def repeat(object, times=None):
    # repeat(10, 3) → 10 10 10
    if times is None:
        while True:
            yield object
    else:
        for i in range(times):
            yield object

repeat 的常見用途是為 mapzip 提供定值的串流:

>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
itertools.starmap(function, iterable)

建立一個疊代器,使用從 iterable 獲取的引數計算 function 。當引數參數已經被「預先壓縮 (pre-zipped)」成元組時,使用此方法代替 map()

map()starmap() 之間的區別類似於 function(a,b)function(*c) 之間的區別。大致等價於:

def starmap(function, iterable):
    # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000
    for args in iterable:
        yield function(*args)
itertools.takewhile(predicate, iterable)

建立一個疊代器,只在 predicate 為 true 時回傳 iterable 中的元素。大致等價於:

def takewhile(predicate, iterable):
    # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4
    for x in iterable:
        if not predicate(x):
            break
        yield x

注意,第一個不符合條件判斷的元素將從輸入疊代器中被消耗,且無法再存取它。如果應用程式希望在 takewhile 耗盡後進一步消耗輸入疊代器,這可能會是個問題。為了解決這個問題,可以考慮使用 more-itertools 中的 before_and_after() 作為替代。

itertools.tee(iterable, n=2)

從一個 iterable 中回傳 n 個獨立的疊代器。

大致等價於:

def tee(iterable, n=2):
    if n < 0:
        raise ValueError
    if n == 0:
        return ()
    iterator = _tee(iterable)
    result = [iterator]
    for _ in range(n - 1):
        result.append(_tee(iterator))
    return tuple(result)

class _tee:

    def __init__(self, iterable):
        it = iter(iterable)
        if isinstance(it, _tee):
            self.iterator = it.iterator
            self.link = it.link
        else:
            self.iterator = it
            self.link = [None, None]

    def __iter__(self):
        return self

    def __next__(self):
        link = self.link
        if link[1] is None:
            link[0] = next(self.iterator)
            link[1] = [None, None]
        value, self.link = link
        return value

When the input iterable is already a tee iterator object, all members of the return tuple are constructed as if they had been produced by the upstream tee() call. This "flattening step" allows nested tee() calls to share the same underlying data chain and to have a single update step rather than a chain of calls.

The flattening property makes tee iterators efficiently peekable:

def lookahead(tee_iterator):
     "Return the next value without moving the input forward"
     [forked_iterator] = tee(tee_iterator, 1)
     return next(forked_iterator)
>>> iterator = iter('abcdef')
>>> [iterator] = tee(iterator, 1)   # Make the input peekable
>>> next(iterator)                  # Move the iterator forward
'a'
>>> lookahead(iterator)             # Check next value
'b'
>>> next(iterator)                  # Continue moving forward
'b'

tee 疊代器不是執行緒安全 (threadsafe) 的。當同時使用由同一個 tee() 呼叫所回傳的疊代器時,即使原始的 iterable 是執行緒安全的,也可能引發 RuntimeError

此 itertool 可能需要大量的輔助儲存空間(取決於需要儲存多少臨時資料)。通常如果一個疊代器在另一個疊代器開始之前使用了大部分或全部的資料,使用 list() 會比 tee() 更快。

itertools.zip_longest(*iterables, fillvalue=None)

建立一個疊代器,聚合來自每個 iterables 中的元素。

如果 iterables 的長度不一,則使用 fillvalue 填充缺少的值。如果未指定,fillvalue 會預設為 None

疊代將持續直到最長的可疊代物件耗盡為止。

大致等價於:

def zip_longest(*iterables, fillvalue=None):
    # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D-

    iterators = list(map(iter, iterables))
    num_active = len(iterators)
    if not num_active:
        return

    while True:
        values = []
        for i, iterator in enumerate(iterators):
            try:
                value = next(iterator)
            except StopIteration:
                num_active -= 1
                if not num_active:
                    return
                iterators[i] = repeat(fillvalue)
                value = fillvalue
            values.append(value)
        yield tuple(values)

如果其中一個 iterables 可能是無限的,那麼應該用別的可以限制呼叫次數的方法來包裝 zip_longest() 函式(例如 islice()takewhile())。

Itertools 應用技巧

此段落展示了使用現有的 itertools 作為構建塊來建立擴展工具集的應用技巧。

itertools 應用技巧的主要目的是教學。這些應用技巧展示了對單個工具進行思考的各種方式 —— 例如,chain.from_iterable 與攤平 (flattening) 的概念相關。這些應用技巧還提供了組合使用工具的想法 —— 例如,starmap()repeat() 如何一起工作。另外還展示了將 itertools 與 operatorcollections 模組一同使用以及與內建 itertools(如 map()filter()reversed()enumerate())一同使用的模式。

應用技巧的次要目的是作為 itertools 的孵化器。accumulate(), compress()pairwise() itertools 最初都是作為應用技巧出現的。目前,sliding_window()iter_index()sieve() 的應用技巧正在被測試,以確定它們是否有價值被收錄到內建的 itertools 中。

幾乎所有這些應用技巧以及許多其他應用技巧都可以從 Python Package Index 上的 more-itertools 專案中安裝:

python -m pip install more-itertools

許多應用技巧提供了與底層工具集相同的高性能。透過一次處理一個元素而不是將整個可疊代物件一次性引入記憶體,能保持優異的記憶體性能。以函式風格 (functional style) 將工具連接在一起,能將程式碼的數量維持在較少的情況。透過優先使用「向量化 (vectorized)」的構建塊而不是使用會造成直譯器負擔的 for 迴圈和產生器,則能保持高速度。

from collections import deque
from contextlib import suppress
from functools import reduce
from math import sumprod, isqrt
from operator import itemgetter, getitem, mul, neg

def take(n, iterable):
    "Return first n items of the iterable as a list."
    return list(islice(iterable, n))

def prepend(value, iterable):
    "Prepend a single value in front of an iterable."
    # prepend(1, [2, 3, 4]) → 1 2 3 4
    return chain([value], iterable)

def tabulate(function, start=0):
    "Return function(0), function(1), ..."
    return map(function, count(start))

def repeatfunc(function, times=None, *args):
    "Repeat calls to a function with specified arguments."
    if times is None:
        return starmap(function, repeat(args))
    return starmap(function, repeat(args, times))

def flatten(list_of_lists):
    "Flatten one level of nesting."
    return chain.from_iterable(list_of_lists)

def ncycles(iterable, n):
    "Returns the sequence elements n times."
    return chain.from_iterable(repeat(tuple(iterable), n))

def loops(n):
    "Loop n times. Like range(n) but without creating integers."
    # for _ in loops(100): ...
    return repeat(None, n)

def tail(n, iterable):
    "Return an iterator over the last n items."
    # tail(3, 'ABCDEFG') → E F G
    return iter(deque(iterable, maxlen=n))

def consume(iterator, n=None):
    "Advance the iterator n-steps ahead. If n is None, consume entirely."
    # Use functions that consume iterators at C speed.
    if n is None:
        deque(iterator, maxlen=0)
    else:
        next(islice(iterator, n, n), None)

def nth(iterable, n, default=None):
    "Returns the nth item or a default value."
    return next(islice(iterable, n, None), default)

def quantify(iterable, predicate=bool):
    "Given a predicate that returns True or False, count the True results."
    return sum(map(predicate, iterable))

def first_true(iterable, default=False, predicate=None):
    "Returns the first true value or the *default* if there is no true value."
    # first_true([a,b,c], x) → a or b or c or x
    # first_true([a,b], x, f) → a if f(a) else b if f(b) else x
    return next(filter(predicate, iterable), default)

def all_equal(iterable, key=None):
    "Returns True if all the elements are equal to each other."
    # all_equal('4٤௪౪໔', key=int) → True
    return len(take(2, groupby(iterable, key))) <= 1

def unique_justseen(iterable, key=None):
    "Yield unique elements, preserving order. Remember only the element just seen."
    # unique_justseen('AAAABBBCCDAABBB') → A B C D A B
    # unique_justseen('ABBcCAD', str.casefold) → A B c A D
    if key is None:
        return map(itemgetter(0), groupby(iterable))
    return map(next, map(itemgetter(1), groupby(iterable, key)))

def unique_everseen(iterable, key=None):
    "Yield unique elements, preserving order. Remember all elements ever seen."
    # unique_everseen('AAAABBBCCDAABBB') → A B C D
    # unique_everseen('ABBcCAD', str.casefold) → A B c D
    seen = set()
    if key is None:
        for element in filterfalse(seen.__contains__, iterable):
            seen.add(element)
            yield element
    else:
        for element in iterable:
            k = key(element)
            if k not in seen:
                seen.add(k)
                yield element

def unique(iterable, key=None, reverse=False):
   "Yield unique elements in sorted order. Supports unhashable inputs."
   # unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
   sequenced = sorted(iterable, key=key, reverse=reverse)
   return unique_justseen(sequenced, key=key)

def sliding_window(iterable, n):
    "Collect data into overlapping fixed-length chunks or blocks."
    # sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
    iterator = iter(iterable)
    window = deque(islice(iterator, n - 1), maxlen=n)
    for x in iterator:
        window.append(x)
        yield tuple(window)

def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
    "Collect data into non-overlapping fixed-length chunks or blocks."
    # grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
    # grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
    # grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
    iterators = [iter(iterable)] * n
    match incomplete:
        case 'fill':
            return zip_longest(*iterators, fillvalue=fillvalue)
        case 'strict':
            return zip(*iterators, strict=True)
        case 'ignore':
            return zip(*iterators)
        case _:
            raise ValueError('Expected fill, strict, or ignore')

def roundrobin(*iterables):
    "Visit input iterables in a cycle until each is exhausted."
    # roundrobin('ABC', 'D', 'EF') → A D E B F C
    # Algorithm credited to George Sakkis
    iterators = map(iter, iterables)
    for num_active in range(len(iterables), 0, -1):
        iterators = cycle(islice(iterators, num_active))
        yield from map(next, iterators)

def subslices(seq):
    "Return all contiguous non-empty subslices of a sequence."
    # subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
    slices = starmap(slice, combinations(range(len(seq) + 1), 2))
    return map(getitem, repeat(seq), slices)

def iter_index(iterable, value, start=0, stop=None):
    "Return indices where a value occurs in a sequence or iterable."
    # iter_index('AABCADEAF', 'A') → 0 1 4 7
    seq_index = getattr(iterable, 'index', None)
    if seq_index is None:
        iterator = islice(iterable, start, stop)
        for i, element in enumerate(iterator, start):
            if element is value or element == value:
                yield i
    else:
        stop = len(iterable) if stop is None else stop
        i = start
        with suppress(ValueError):
            while True:
                yield (i := seq_index(value, i, stop))
                i += 1

def iter_except(function, exception, first=None):
    "Convert a call-until-exception interface to an iterator interface."
    # iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
    with suppress(exception):
        if first is not None:
            yield first()
        while True:
            yield function()

以下的應用技巧具有更多的數學風格:

def powerset(iterable):
    "Subsequences of the iterable from shortest to longest."
    # powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def sum_of_squares(iterable):
    "Add up the squares of the input values."
    # sum_of_squares([10, 20, 30]) → 1400
    return sumprod(*tee(iterable))

def reshape(matrix, columns):
    "Reshape a 2-D matrix to have a given number of columns."
    # reshape([(0, 1), (2, 3), (4, 5)], 3) →  (0, 1, 2), (3, 4, 5)
    return batched(chain.from_iterable(matrix), columns, strict=True)

def transpose(matrix):
    "Swap the rows and columns of a 2-D matrix."
    # transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
    return zip(*matrix, strict=True)

def matmul(m1, m2):
    "Multiply two matrices."
    # matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
    n = len(m2[0])
    return batched(starmap(sumprod, product(m1, transpose(m2))), n)

def convolve(signal, kernel):
    """Discrete linear convolution of two iterables.
    Equivalent to polynomial multiplication.

    Convolutions are mathematically commutative; however, the inputs are
    evaluated differently.  The signal is consumed lazily and can be
    infinite. The kernel is fully consumed before the calculations begin.

    Article:  https://betterexplained.com/articles/intuitive-convolution/
    Video:    https://www.youtube.com/watch?v=KuXjwB4LzSA
    """
    # convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
    # convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
    # convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
    # convolve(data, [1, -2, 1]) → 2nd derivative estimate
    kernel = tuple(kernel)[::-1]
    n = len(kernel)
    padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
    windowed_signal = sliding_window(padded_signal, n)
    return map(sumprod, repeat(kernel), windowed_signal)

def polynomial_from_roots(roots):
    """Compute a polynomial's coefficients from its roots.

       (x - 5) (x + 4) (x - 3)  expands to:   x³ -4x² -17x + 60
    """
    # polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
    factors = zip(repeat(1), map(neg, roots))
    return list(reduce(convolve, factors, [1]))

def polynomial_eval(coefficients, x):
    """Evaluate a polynomial at a specific value.

    Computes with better numeric stability than Horner's method.
    """
    # Evaluate x³ -4x² -17x + 60 at x = 5
    # polynomial_eval([1, -4, -17, 60], x=5) → 0
    n = len(coefficients)
    if not n:
        return type(x)(0)
    powers = map(pow, repeat(x), reversed(range(n)))
    return sumprod(coefficients, powers)

def polynomial_derivative(coefficients):
    """Compute the first derivative of a polynomial.

       f(x)  =  x³ -4x² -17x + 60
       f'(x) = 3x² -8x  -17
    """
    # polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
    n = len(coefficients)
    powers = reversed(range(1, n))
    return list(map(mul, coefficients, powers))

def sieve(n):
    "Primes less than n."
    # sieve(30) → 2 3 5 7 11 13 17 19 23 29
    if n > 2:
        yield 2
    data = bytearray((0, 1)) * (n // 2)
    for p in iter_index(data, 1, start=3, stop=isqrt(n) + 1):
        data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
    yield from iter_index(data, 1, start=3)

def factor(n):
    "Prime factors of n."
    # factor(99) → 3 3 11
    # factor(1_000_000_000_000_007) → 47 59 360620266859
    # factor(1_000_000_000_000_403) → 1000000000000403
    for prime in sieve(isqrt(n) + 1):
        while not n % prime:
            yield prime
            n //= prime
            if n == 1:
                return
    if n > 1:
        yield n

def is_prime(n):
    "Return True if n is prime."
    # is_prime(1_000_000_000_000_403) → True
    return n > 1 and next(factor(n)) == n

def totient(n):
    "Count of natural numbers up to n that are coprime to n."
    # https://mathworld.wolfram.com/TotientFunction.html
    # totient(12) → 4 because len([1, 5, 7, 11]) == 4
    for prime in set(factor(n)):
        n -= n // prime
    return n