# `itertools` --- 建立產生高效率迴圈之疊代器的函式¶

`count()`

[start[, step]]

start, start+step, start+2*step, ...

`count(10) → 10 11 12 13 14 ...`

`cycle()`

p

p0, p1, ... plast, p0, p1, ...

`cycle('ABCD') → A B C D A B C D ...`

`repeat()`

elem [,n]

elem, elem, elem，... 重複無限次或 n 次

`repeat(10, 3) → 10 10 10`

`accumulate()`

p [,func]

p0, p0+p1, p0+p1+p2, ...

`accumulate([1,2,3,4,5]) → 1 3 6 10 15`

`batched()`

p, n

(p0, p1, ..., p_n-1), ...

`batched('ABCDEFG', n=3) → ABC DEF G`

`chain()`

p, q, ...

p0, p1, ... plast, q0, q1, ...

`chain('ABC', 'DEF') → A B C D E F`

`chain.from_iterable()`

p0, p1, ... plast, q0, q1, ...

`chain.from_iterable(['ABC', 'DEF']) → A B C D E F`

`compress()`

data, selectors

(d[0] if s[0]), (d[1] if s[1]), ...

`compress('ABCDEF', [1,0,1,0,1,1]) → A C E F`

`dropwhile()`

predicate, seq

seq[n], seq[n+1]，當 predicate 失敗時開始

`dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8`

`filterfalse()`

predicate, seq

`filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8`

`groupby()`

iterable[, key]

`islice()`

seq, [start,] stop [, step]

seq[start:stop:step] 的元素

`islice('ABCDEFG', 2, None) → C D E F G`

`pairwise()`

(p[0], p[1]), (p[1], p[2])

`pairwise('ABCDEFG') → AB BC CD DE EF FG`

`starmap()`

func, seq

func(*seq[0]), func(*seq[1]), ...

`starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000`

`takewhile()`

predicate, seq

seq[0], seq[1]，直到 predicate 失敗

`takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4`

`tee()`

it, n

it1, it2, ... itn，將一個疊代器分成 n 個

`zip_longest()`

p, q, ...

(p[0], q[0]), (p[1], q[1]), ...

`zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D-`

`product()`

p, q, ... [repeat=1]

`permutations()`

p[, r]

`combinations()`

p, r

`combinations_with_replacement()`

p, r

`product('ABCD', repeat=2)`

`AA AB AC AD BA BB BC BD CA CB CC CD DA DB DC DD`

`permutations('ABCD', 2)`

`AB AC AD BA BC BD CA CB CD DA DB DC`

`combinations('ABCD', 2)`

`AB AC AD BC BD CD`

`combinations_with_replacement('ABCD', 2)`

`AA AB AC AD BB BC BD CC CD DD`

## Itertool 函式¶

itertools.accumulate(iterable[, function, *, initial=None])

function 預設為加法。function 應接受兩個引數，即累積總和和來自 iterable 的值。

```def accumulate(iterable, function=operator.add, *, initial=None):
'Return running totals'
# accumulate([1,2,3,4,5]) → 1 3 6 10 15
# accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115
# accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120

iterator = iter(iterable)
total = initial
if initial is None:
try:
total = next(iterator)
except StopIteration:
return

yield total
for element in iterator:
total = function(total, element)
yield total
```

function 引數可以被設定為 `min()` 以得到連續的最小值，設定為 `max()` 以得到連續的最大值，或者設定為 `operator.mul()` 以得到連續的乘積。也可以透過累積利息和付款來建立攤銷表 (Amortization tables)

```>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8]
>>> list(accumulate(data, max))              # running maximum
[3, 4, 6, 6, 6, 9, 9, 9, 9, 9]
>>> list(accumulate(data, operator.mul))     # running product
[3, 12, 72, 144, 144, 1296, 0, 0, 0, 0]

# Amortize a 5% loan of 1000 with 10 annual payments of 90
>>> update = lambda balance, payment: round(balance * 1.05) - payment
>>> list(accumulate(repeat(90, 10), update, initial=1_000))
[1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]
```

itertools.batched(iterable, n)

```>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet']
>>> unflattened = list(batched(flattened_data, 2))
>>> unflattened
[('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]
```

```def batched(iterable, n):
# batched('ABCDEFG', 3) → ABC DEF G
if n < 1:
raise ValueError('n must be at least one')
iterator = iter(iterable)
while batch := tuple(islice(iterator, n)):
yield batch
```

itertools.chain(*iterables)

```def chain(*iterables):
# chain('ABC', 'DEF') → A B C D E F
for iterable in iterables:
yield from iterable
```
classmethod chain.from_iterable(iterable)

`chain()` 的另一個建構函式。從單個可疊代的引數中得到鏈接的輸入，該引數是惰性計算的。大致等價於：

```def from_iterable(iterables):
# chain.from_iterable(['ABC', 'DEF']) → A B C D E F
for iterable in iterables:
yield from iterable
```
itertools.combinations(iterable, r)

```def combinations(iterable, r):
# combinations('ABCD', 2) → AB AC AD BC BD CD
# combinations(range(4), 3) → 012 013 023 123

pool = tuple(iterable)
n = len(pool)
if r > n:
return
indices = list(range(r))

yield tuple(pool[i] for i in indices)
while True:
for i in reversed(range(r)):
if indices[i] != i + n - r:
break
else:
return
indices[i] += 1
for j in range(i+1, r):
indices[j] = indices[j-1] + 1
yield tuple(pool[i] for i in indices)
```
itertools.combinations_with_replacement(iterable, r)

```def combinations_with_replacement(iterable, r):
# combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC

pool = tuple(iterable)
n = len(pool)
if not n and r:
return
indices = [0] * r

yield tuple(pool[i] for i in indices)
while True:
for i in reversed(range(r)):
if indices[i] != n - 1:
break
else:
return
indices[i:] = [indices[i] + 1] * (r - i)
yield tuple(pool[i] for i in indices)
```

itertools.compress(data, selectors)

```def compress(data, selectors):
# compress('ABCDEF', [1,0,1,0,1,1]) → A C E F
return (datum for datum, selector in zip(data, selectors) if selector)
```

itertools.count(start=0, step=1)

```def count(start=0, step=1):
# count(10) → 10 11 12 13 14 ...
# count(2.5, 0.5) → 2.5 3.0 3.5 ...
n = start
while True:
yield n
n += step
```

itertools.cycle(iterable)

```def cycle(iterable):
# cycle('ABCD') → A B C D A B C D A B C D ...
saved = []
for element in iterable:
yield element
saved.append(element)
while saved:
for element in saved:
yield element
```

itertools.dropwhile(predicate, iterable)

```def dropwhile(predicate, iterable):
# dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8

iterator = iter(iterable)
for x in iterator:
if not predicate(x):
yield x
break

for x in iterator:
yield x
```

itertools.filterfalse(predicate, iterable)

```def filterfalse(predicate, iterable):
# filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8
if predicate is None:
predicate = bool
for x in iterable:
if not predicate(x):
yield x
```
itertools.groupby(iterable, key=None)

`groupby()` 的操作類似於 Unix 中的 `uniq` 過濾器。每當鍵函式的值發生變化時，它會產生一個 break 或新的群組（這就是為什麼通常需要使用相同的鍵函式對資料進行排序）。這種行為不同於 SQL 的 GROUP BY，其無論輸入順序如何都會聚合相同的元素。

```groups = []
uniquekeys = []
data = sorted(data, key=keyfunc)
for k, g in groupby(data, keyfunc):
groups.append(list(g))      # Store group iterator as a list
uniquekeys.append(k)
```

`groupby()` 大致等價於：

```def groupby(iterable, key=None):
# [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B
# [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D

keyfunc = (lambda x: x) if key is None else key
iterator = iter(iterable)
exhausted = False

def _grouper(target_key):
nonlocal curr_value, curr_key, exhausted
yield curr_value
for curr_value in iterator:
curr_key = keyfunc(curr_value)
if curr_key != target_key:
return
yield curr_value
exhausted = True

try:
curr_value = next(iterator)
except StopIteration:
return
curr_key = keyfunc(curr_value)

while not exhausted:
target_key = curr_key
curr_group = _grouper(target_key)
yield curr_key, curr_group
if curr_key == target_key:
for _ in curr_group:
pass
```
itertools.islice(iterable, stop)
itertools.islice(iterable, start, stop[, step])

```def islice(iterable, *args):
# islice('ABCDEFG', 2) → A B
# islice('ABCDEFG', 2, 4) → C D
# islice('ABCDEFG', 2, None) → C D E F G
# islice('ABCDEFG', 0, None, 2) → A C E G

s = slice(*args)
start = 0 if s.start is None else s.start
stop = s.stop
step = 1 if s.step is None else s.step
if start < 0 or (stop is not None and stop < 0) or step <= 0:
raise ValueError

indices = count() if stop is None else range(max(start, stop))
next_i = start
for i, element in zip(indices, iterable):
if i == next_i:
yield element
next_i += step
```
itertools.pairwise(iterable)

```def pairwise(iterable):
# pairwise('ABCDEFG') → AB BC CD DE EF FG
iterator = iter(iterable)
a = next(iterator, None)
for b in iterator:
yield a, b
a = b
```

itertools.permutations(iterable, r=None)

```def permutations(iterable, r=None):
# permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC
# permutations(range(3)) → 012 021 102 120 201 210

pool = tuple(iterable)
n = len(pool)
r = n if r is None else r
if r > n:
return

indices = list(range(n))
cycles = list(range(n, n-r, -1))
yield tuple(pool[i] for i in indices[:r])

while n:
for i in reversed(range(r)):
cycles[i] -= 1
if cycles[i] == 0:
indices[i:] = indices[i+1:] + indices[i:i+1]
cycles[i] = n - i
else:
j = cycles[i]
indices[i], indices[-j] = indices[-j], indices[i]
yield tuple(pool[i] for i in indices[:r])
break
else:
return
```
itertools.product(*iterables, repeat=1)

```def product(*iterables, repeat=1):
# product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy
# product(range(2), repeat=3) → 000 001 010 011 100 101 110 111

pools = [tuple(pool) for pool in iterables] * repeat

result = [[]]
for pool in pools:
result = [x+[y] for x in result for y in pool]

for prod in result:
yield tuple(prod)
```

`product()` 執行之前，它會完全消耗輸入的 iterables，並將值的池 (pools of values) 保存在記憶體中以產生乘積。因此，它僅對有限的輸入有用。

itertools.repeat(object[, times])

```def repeat(object, times=None):
# repeat(10, 3) → 10 10 10
if times is None:
while True:
yield object
else:
for i in range(times):
yield object
```

repeat 的常見用途是為 mapzip 提供定值的串流：

```>>> list(map(pow, range(10), repeat(2)))
[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
```
itertools.starmap(function, iterable)

`map()``starmap()` 之間的區別類似於 `function(a,b)``function(*c)` 之間的區別。大致等價於：

```def starmap(function, iterable):
# starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000
for args in iterable:
yield function(*args)
```
itertools.takewhile(predicate, iterable)

```def takewhile(predicate, iterable):
# takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4
for x in iterable:
if not predicate(x):
break
yield x
```

itertools.tee(iterable, n=2)

```def tee(iterable, n=2):
iterator = iter(iterable)
return tuple(_tee(iterator, shared_link) for _ in range(n))

try:
while True:
yield value
except StopIteration:
return
```

`tee` 疊代器不是執行緒安全 (threadsafe) 的。當同時使用由同一個 `tee()` 呼叫所回傳的疊代器時，即使原始的 iterable 是執行緒安全的，也可能引發 `RuntimeError`

itertools.zip_longest(*iterables, fillvalue=None)

```def zip_longest(*iterables, fillvalue=None):
# zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D-

iterators = list(map(iter, iterables))
num_active = len(iterators)
if not num_active:
return

while True:
values = []
for i, iterator in enumerate(iterators):
try:
value = next(iterator)
except StopIteration:
num_active -= 1
if not num_active:
return
iterators[i] = repeat(fillvalue)
value = fillvalue
values.append(value)
yield tuple(values)
```

## Itertools 應用技巧¶

itertools 應用技巧的主要目的是教學。這些應用技巧展示了對單個工具進行思考的各種方式 —— 例如，`chain.from_iterable` 與攤平 (flattening) 的概念相關。這些應用技巧還提供了組合使用工具的想法 —— 例如，`starmap()``repeat()` 如何一起工作。另外還展示了將 itertools 與 `operator``collections` 模組一同使用以及與內建 itertools（如 `map()``filter()``reversed()``enumerate()`）一同使用的模式。

```python -m pip install more-itertools
```

```import collections
import contextlib
import functools
import math
import operator
import random

def take(n, iterable):
"Return first n items of the iterable as a list."
return list(islice(iterable, n))

def prepend(value, iterable):
"Prepend a single value in front of an iterable."
# prepend(1, [2, 3, 4]) → 1 2 3 4
return chain([value], iterable)

def tabulate(function, start=0):
"Return function(0), function(1), ..."
return map(function, count(start))

def repeatfunc(func, times=None, *args):
"Repeat calls to func with specified arguments."
if times is None:
return starmap(func, repeat(args))
return starmap(func, repeat(args, times))

def flatten(list_of_lists):
"Flatten one level of nesting."
return chain.from_iterable(list_of_lists)

def ncycles(iterable, n):
"Returns the sequence elements n times."
return chain.from_iterable(repeat(tuple(iterable), n))

def tail(n, iterable):
"Return an iterator over the last n items."
# tail(3, 'ABCDEFG') → E F G
return iter(collections.deque(iterable, maxlen=n))

def consume(iterator, n=None):
# Use functions that consume iterators at C speed.
if n is None:
collections.deque(iterator, maxlen=0)
else:
next(islice(iterator, n, n), None)

def nth(iterable, n, default=None):
"Returns the nth item or a default value."
return next(islice(iterable, n, None), default)

def quantify(iterable, predicate=bool):
"Given a predicate that returns True or False, count the True results."
return sum(map(predicate, iterable))

def first_true(iterable, default=False, predicate=None):
"Returns the first true value or the *default* if there is no true value."
# first_true([a,b,c], x) → a or b or c or x
# first_true([a,b], x, f) → a if f(a) else b if f(b) else x
return next(filter(predicate, iterable), default)

def all_equal(iterable, key=None):
"Returns True if all the elements are equal to each other."
# all_equal('4٤௪౪໔', key=int) → True
return len(take(2, groupby(iterable, key))) <= 1

def unique_justseen(iterable, key=None):
"Yield unique elements, preserving order. Remember only the element just seen."
# unique_justseen('AAAABBBCCDAABBB') → A B C D A B
# unique_justseen('ABBcCAD', str.casefold) → A B c A D
if key is None:
return map(operator.itemgetter(0), groupby(iterable))
return map(next, map(operator.itemgetter(1), groupby(iterable, key)))

def unique_everseen(iterable, key=None):
"Yield unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') → A B C D
# unique_everseen('ABBcCAD', str.casefold) → A B c D
seen = set()
if key is None:
for element in filterfalse(seen.__contains__, iterable):
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
yield element

def unique(iterable, key=None, reverse=False):
"Yield unique elements in sorted order. Supports unhashable inputs."
# unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
return unique_justseen(sorted(iterable, key=key, reverse=reverse), key=key)

def sliding_window(iterable, n):
"Collect data into overlapping fixed-length chunks or blocks."
# sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
iterator = iter(iterable)
window = collections.deque(islice(iterator, n - 1), maxlen=n)
for x in iterator:
window.append(x)
yield tuple(window)

def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
"Collect data into non-overlapping fixed-length chunks or blocks."
# grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
# grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
iterators = [iter(iterable)] * n
match incomplete:
case 'fill':
return zip_longest(*iterators, fillvalue=fillvalue)
case 'strict':
return zip(*iterators, strict=True)
case 'ignore':
return zip(*iterators)
case _:
raise ValueError('Expected fill, strict, or ignore')

def roundrobin(*iterables):
"Visit input iterables in a cycle until each is exhausted."
# roundrobin('ABC', 'D', 'EF') → A D E B F C
# Algorithm credited to George Sakkis
iterators = map(iter, iterables)
for num_active in range(len(iterables), 0, -1):
iterators = cycle(islice(iterators, num_active))
yield from map(next, iterators)

def partition(predicate, iterable):
"""Partition entries into false entries and true entries.

If *predicate* is slow, consider wrapping it with functools.lru_cache().
"""
# partition(is_odd, range(10)) → 0 2 4 6 8   and  1 3 5 7 9
t1, t2 = tee(iterable)
return filterfalse(predicate, t1), filter(predicate, t2)

def subslices(seq):
"Return all contiguous non-empty subslices of a sequence."
# subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
return map(operator.getitem, repeat(seq), slices)

def iter_index(iterable, value, start=0, stop=None):
"Return indices where a value occurs in a sequence or iterable."
# iter_index('AABCADEAF', 'A') → 0 1 4 7
seq_index = getattr(iterable, 'index', None)
if seq_index is None:
iterator = islice(iterable, start, stop)
for i, element in enumerate(iterator, start):
if element is value or element == value:
yield i
else:
stop = len(iterable) if stop is None else stop
i = start
with contextlib.suppress(ValueError):
while True:
yield (i := seq_index(value, i, stop))
i += 1

def iter_except(func, exception, first=None):
"Convert a call-until-exception interface to an iterator interface."
# iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
with contextlib.suppress(exception):
if first is not None:
yield first()
while True:
yield func()
```

```def powerset(iterable):
"powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def sum_of_squares(iterable):
"Add up the squares of the input values."
# sum_of_squares([10, 20, 30]) → 1400
return math.sumprod(*tee(iterable))

def reshape(matrix, cols):
"Reshape a 2-D matrix to have a given number of columns."
# reshape([(0, 1), (2, 3), (4, 5)], 3) →  (0, 1, 2), (3, 4, 5)
return batched(chain.from_iterable(matrix), cols)

def transpose(matrix):
"Swap the rows and columns of a 2-D matrix."
# transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
return zip(*matrix, strict=True)

def matmul(m1, m2):
"Multiply two matrices."
# matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
n = len(m2[0])
return batched(starmap(math.sumprod, product(m1, transpose(m2))), n)

def convolve(signal, kernel):
"""Discrete linear convolution of two iterables.
Equivalent to polynomial multiplication.

Convolutions are mathematically commutative; however, the inputs are
evaluated differently.  The signal is consumed lazily and can be
infinite. The kernel is fully consumed before the calculations begin.

Article:  https://betterexplained.com/articles/intuitive-convolution/
"""
# convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
# convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
# convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
# convolve(data, [1, -2, 1]) → 2nd derivative estimate
kernel = tuple(kernel)[::-1]
n = len(kernel)
padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
return map(math.sumprod, repeat(kernel), windowed_signal)

def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.

(x - 5) (x + 4) (x - 3)  expands to:   x³ -4x² -17x + 60
"""
# polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
factors = zip(repeat(1), map(operator.neg, roots))
return list(functools.reduce(convolve, factors, [1]))

def polynomial_eval(coefficients, x):
"""Evaluate a polynomial at a specific value.

Computes with better numeric stability than Horner's method.
"""
# Evaluate x³ -4x² -17x + 60 at x = 5
# polynomial_eval([1, -4, -17, 60], x=5) → 0
n = len(coefficients)
if not n:
return type(x)(0)
powers = map(pow, repeat(x), reversed(range(n)))
return math.sumprod(coefficients, powers)

def polynomial_derivative(coefficients):
"""Compute the first derivative of a polynomial.

f(x)  =  x³ -4x² -17x + 60
f'(x) = 3x² -8x  -17
"""
# polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
n = len(coefficients)
powers = reversed(range(1, n))
return list(map(operator.mul, coefficients, powers))

def sieve(n):
"Primes less than n."
# sieve(30) → 2 3 5 7 11 13 17 19 23 29
if n > 2:
yield 2
data = bytearray((0, 1)) * (n // 2)
for p in iter_index(data, 1, start=3, stop=math.isqrt(n) + 1):
data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
yield from iter_index(data, 1, start=3)

def factor(n):
"Prime factors of n."
# factor(99) → 3 3 11
# factor(1_000_000_000_000_007) → 47 59 360620266859
# factor(1_000_000_000_000_403) → 1000000000000403
for prime in sieve(math.isqrt(n) + 1):
while not n % prime:
yield prime
n //= prime
if n == 1:
return
if n > 1:
yield n

def totient(n):
"Count of natural numbers up to n that are coprime to n."
# https://mathworld.wolfram.com/TotientFunction.html
# totient(12) → 4 because len([1, 5, 7, 11]) == 4
for prime in set(factor(n)):
n -= n // prime
return n
```