itertools
--- 建立產生高效率迴圈之疊代器的函式¶
這個模組實作了許多 疊代器 (iterator) 構建塊 (building block),其靈感來自 APL、Haskell 和 SML 的結構。每個構建塊都以適合 Python 的形式來重新設計。
這個模組標準化了快速且高效率利用記憶體的核心工具集,這些工具本身或組合使用都很有用。它們共同構成了一個「疊代器代數 (iterator algebra)」,使得在純 Python 中簡潔且高效地建構專用工具成為可能。
例如,SML 提供了一個造表工具:tabulate(f)
,它產生一個序列 f(0), f(1), ...
。在 Python 中,可以透過結合 map()
和 count()
組成 map(f, count())
以達到同樣的效果。
無限疊代器:
疊代器 |
引數 |
結果 |
範例 |
---|---|---|---|
[start[, step]] |
start, start+step, start+2*step, ... |
|
|
p |
p0, p1, ... plast, p0, p1, ... |
|
|
elem [,n] |
elem, elem, elem,... 重複無限次或 n 次 |
|
在最短輸入序列 (shortest input sequence) 處終止的疊代器:
疊代器 |
引數 |
結果 |
範例 |
---|---|---|---|
p [,func] |
p0, p0+p1, p0+p1+p2, ... |
|
|
p, n |
(p0, p1, ..., p_n-1), ... |
|
|
p, q, ... |
p0, p1, ... plast, q0, q1, ... |
|
|
可疊代物件 |
p0, p1, ... plast, q0, q1, ... |
|
|
data, selectors |
(d[0] if s[0]), (d[1] if s[1]), ... |
|
|
predicate, seq |
seq[n], seq[n+1],當 predicate 失敗時開始 |
|
|
predicate, seq |
當 predicate(elem) 失敗時 seq 的元素 |
|
|
iterable[, key] |
根據 key(v) 的值分組的子疊代器 |
|
|
seq, [start,] stop [, step] |
seq[start:stop:step] 的元素 |
|
|
可疊代物件 |
(p[0], p[1]), (p[1], p[2]) |
|
|
func, seq |
func(*seq[0]), func(*seq[1]), ... |
|
|
predicate, seq |
seq[0], seq[1],直到 predicate 失敗 |
|
|
it, n |
it1, it2, ... itn,將一個疊代器分成 n 個 |
|
|
p, q, ... |
(p[0], q[0]), (p[1], q[1]), ... |
|
組合疊代器:
疊代器 |
引數 |
結果 |
---|---|---|
p, q, ... [repeat=1] |
笛卡爾乘積 (cartesian product),相當於巢狀的 for 迴圈 |
|
p[, r] |
長度為 r 的元組,所有可能的定序,無重複元素 |
|
p, r |
長度為 r 的元組,按照排序過後的定序,無重複元素 |
|
p, r |
長度為 r 的元組,按照排序過後的定序,有重複元素 |
範例 |
結果 |
---|---|
|
|
|
|
|
|
|
|
Itertool 函式¶
以下的函式都會建構並回傳疊代器。一些函式提供無限長度的串流 (stream),因此應僅由截斷串流的函式或迴圈來存取它們。
- itertools.accumulate(iterable[, function, *, initial=None])¶
建立一個回傳累積和的疊代器,或其他二進位函式的累積結果。
function 預設為加法。function 應接受兩個引數,即累積總和和來自 iterable 的值。
如果提供了 initial 值,則累積將從該值開始,並且輸出的元素數將比輸入的可疊代物件多一個。
大致等價於:
def accumulate(iterable, function=operator.add, *, initial=None): 'Return running totals' # accumulate([1,2,3,4,5]) → 1 3 6 10 15 # accumulate([1,2,3,4,5], initial=100) → 100 101 103 106 110 115 # accumulate([1,2,3,4,5], operator.mul) → 1 2 6 24 120 iterator = iter(iterable) total = initial if initial is None: try: total = next(iterator) except StopIteration: return yield total for element in iterator: total = function(total, element) yield total
function 引數可以被設定為
min()
以得到連續的最小值,設定為max()
以得到連續的最大值,或者設定為operator.mul()
以得到連續的乘積。也可以透過累積利息和付款來建立攤銷表 (Amortization tables) :>>> data = [3, 4, 6, 2, 1, 9, 0, 7, 5, 8] >>> list(accumulate(data, max)) # running maximum [3, 4, 6, 6, 6, 9, 9, 9, 9, 9] >>> list(accumulate(data, operator.mul)) # running product [3, 12, 72, 144, 144, 1296, 0, 0, 0, 0] # Amortize a 5% loan of 1000 with 10 annual payments of 90 >>> update = lambda balance, payment: round(balance * 1.05) - payment >>> list(accumulate(repeat(90, 10), update, initial=1_000)) [1000, 960, 918, 874, 828, 779, 728, 674, 618, 559, 497]
可參見
functools.reduce()
,其是個類似的函式,但僅回傳最終的累積值。在 3.2 版被加入.
在 3.3 版的變更: 新增可選的 function 參數。
在 3.8 版的變更: 新增可選的 initial 參數。
- itertools.batched(iterable, n, *, strict=False)¶
將來自 iterable 的資料分批為長度為 n 的元組。最後一個批次可能比 n 短。
If strict is true, will raise a
ValueError
if the final batch is shorter than n.對輸入的可疊代物件進行迴圈,並將資料累積到大小為 n 的元組中。輸入是惰性地被消耗 (consumed lazily) 的,會剛好足夠填充一批的資料。一旦批次填滿或輸入的可疊代物件耗盡,就會 yield 出結果:
>>> flattened_data = ['roses', 'red', 'violets', 'blue', 'sugar', 'sweet'] >>> unflattened = list(batched(flattened_data, 2)) >>> unflattened [('roses', 'red'), ('violets', 'blue'), ('sugar', 'sweet')]
大致等價於:
def batched(iterable, n, *, strict=False): # batched('ABCDEFG', 3) → ABC DEF G if n < 1: raise ValueError('n must be at least one') iterator = iter(iterable) while batch := tuple(islice(iterator, n)): if strict and len(batch) != n: raise ValueError('batched(): incomplete batch') yield batch
在 3.12 版被加入.
在 3.13 版的變更: 新增 strict 選項。
- itertools.chain(*iterables)¶
建立一個疊代器,從第一個可疊代物件回傳元素直到其耗盡,然後繼續處理下一個可疊代物件,直到所有可疊代物件都耗盡。這將多個資料來源結合為單一個疊代器。大致等價於:
def chain(*iterables): # chain('ABC', 'DEF') → A B C D E F for iterable in iterables: yield from iterable
- classmethod chain.from_iterable(iterable)¶
chain()
的另一個建構函式。從單個可疊代的引數中得到鏈接的輸入,該引數是惰性計算的。大致等價於:def from_iterable(iterables): # chain.from_iterable(['ABC', 'DEF']) → A B C D E F for iterable in iterables: yield from iterable
- itertools.combinations(iterable, r)¶
從輸入 iterable 中回傳長度為 r 的元素的子序列。
輸出是
product()
的子序列,僅保留作為 iterable 子序列的條目。輸出的長度由math.comb()
給定,當0 ≤ r ≤ n
時,長度為n! / r! / (n - r)!
,當r > n
時為零。根據輸入值 iterable 的順序,組合的元組會按照字典順序輸出。如果輸入的 iterable 已經排序,則輸出的元組也將按排序的順序產生。
元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,則每個組合內將不會有重複的值。
大致等價於:
def combinations(iterable, r): # combinations('ABCD', 2) → AB AC AD BC BD CD # combinations(range(4), 3) → 012 013 023 123 pool = tuple(iterable) n = len(pool) if r > n: return indices = list(range(r)) yield tuple(pool[i] for i in indices) while True: for i in reversed(range(r)): if indices[i] != i + n - r: break else: return indices[i] += 1 for j in range(i+1, r): indices[j] = indices[j-1] + 1 yield tuple(pool[i] for i in indices)
- itertools.combinations_with_replacement(iterable, r)¶
回傳來自輸入 iterable 的長度為 r 的子序列,且允許個別元素重複多次。
其輸出是一個
product()
的子序列,僅保留作為 iterable 子序列(可能有重複元素)的條目。當n > 0
時,回傳的子序列數量為(n + r - 1)! / r! / (n - 1)!
。根據輸入值 iterable 的順序,組合的元組會按照字典順序輸出。如果輸入的 iterable 已經排序,則輸出的元組也將按排序的順序產生。
元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,生成的組合也將是獨特的。
大致等價於:
def combinations_with_replacement(iterable, r): # combinations_with_replacement('ABC', 2) → AA AB AC BB BC CC pool = tuple(iterable) n = len(pool) if not n and r: return indices = [0] * r yield tuple(pool[i] for i in indices) while True: for i in reversed(range(r)): if indices[i] != n - 1: break else: return indices[i:] = [indices[i] + 1] * (r - i) yield tuple(pool[i] for i in indices)
在 3.1 版被加入.
- itertools.compress(data, selectors)¶
建立一個疊代器,回傳 data 中對應 selectors 的元素為 true 的元素。當 data 或 selectors 可疊代物件耗盡時停止。大致等價於:
def compress(data, selectors): # compress('ABCDEF', [1,0,1,0,1,1]) → A C E F return (datum for datum, selector in zip(data, selectors) if selector)
在 3.1 版被加入.
- itertools.count(start=0, step=1)¶
建立一個疊代器,回傳從 start 開始的等差的值。可以與
map()
一起使用來產生連續的資料點,或與zip()
一起使用來增加序列號。大致等價於:def count(start=0, step=1): # count(10) → 10 11 12 13 14 ... # count(2.5, 0.5) → 2.5 3.0 3.5 ... n = start while True: yield n n += step
當用浮點數計數時,將上述程式碼替換為乘法有時可以獲得更好的精確度,例如:
(start + step * i for i in count())
。在 3.1 版的變更: 新增 step 引數並允許非整數引數。
- itertools.cycle(iterable)¶
建立一個疊代器,回傳 iterable 中的元素並保存每個元素的副本。當可疊代物件耗盡時,從保存的副本中回傳元素。會無限次的重複。大致等價於:
def cycle(iterable): # cycle('ABCD') → A B C D A B C D A B C D ... saved = [] for element in iterable: yield element saved.append(element) while saved: for element in saved: yield element
此 itertool 可能需要大量的輔助儲存空間(取決於可疊代物件的長度)。
- itertools.dropwhile(predicate, iterable)¶
建立一個疊代器,在 predicate 為 true 時丟棄 iterable 中的元素,之後回傳每個元素。大致等價於:
def dropwhile(predicate, iterable): # dropwhile(lambda x: x<5, [1,4,6,3,8]) → 6 3 8 iterator = iter(iterable) for x in iterator: if not predicate(x): yield x break for x in iterator: yield x
注意,在 predicate 首次變為 False 之前,這不會產生任何輸出,所以此 itertool 可能會有較長的啟動時間。
- itertools.filterfalse(predicate, iterable)¶
建立一個疊代器,過濾 iterable 中的元素,僅回傳 predicate 為 False 值的元素。如果 predicate 是
None
,則回傳為 False 的項目。大致等價於:def filterfalse(predicate, iterable): # filterfalse(lambda x: x<5, [1,4,6,3,8]) → 6 8 if predicate is None: predicate = bool for x in iterable: if not predicate(x): yield x
- itertools.groupby(iterable, key=None)¶
建立一個疊代器,回傳 iterable 中連續的鍵和群組。key 是一個為每個元素計算鍵值的函式。如果其未指定或為
None
,則 key 預設為一個識別性函式 (identity function),並回傳未被更改的元素。一般來說,可疊代物件需要已經用相同的鍵函式進行排序。groupby()
的操作類似於 Unix 中的uniq
過濾器。每當鍵函式的值發生變化時,它會產生一個 break 或新的群組(這就是為什麼通常需要使用相同的鍵函式對資料進行排序)。這種行為不同於 SQL 的 GROUP BY,其無論輸入順序如何都會聚合相同的元素。回傳的群組本身是一個與
groupby()
共享底層可疊代物件的疊代器。由於來源是共享的,當groupby()
物件前進時,前一個群組將不再可見。因此,如果之後需要該資料,應將其儲存為串列:groups = [] uniquekeys = [] data = sorted(data, key=keyfunc) for k, g in groupby(data, keyfunc): groups.append(list(g)) # Store group iterator as a list uniquekeys.append(k)
groupby()
大致等價於:def groupby(iterable, key=None): # [k for k, g in groupby('AAAABBBCCDAABBB')] → A B C D A B # [list(g) for k, g in groupby('AAAABBBCCD')] → AAAA BBB CC D keyfunc = (lambda x: x) if key is None else key iterator = iter(iterable) exhausted = False def _grouper(target_key): nonlocal curr_value, curr_key, exhausted yield curr_value for curr_value in iterator: curr_key = keyfunc(curr_value) if curr_key != target_key: return yield curr_value exhausted = True try: curr_value = next(iterator) except StopIteration: return curr_key = keyfunc(curr_value) while not exhausted: target_key = curr_key curr_group = _grouper(target_key) yield curr_key, curr_group if curr_key == target_key: for _ in curr_group: pass
- itertools.islice(iterable, stop)¶
- itertools.islice(iterable, start, stop[, step])
建立一個疊代器,回傳從 iterable 中選取的元素。其作用類似於序列切片 (sequence slicing),但不支援負數的 start、stop 或 step 的值。
如果 start 為零或
None
,則從零開始疊代。否則在達到 start 之前,會跳過 iterable 中的元素。如果 stop 為
None
,則疊代將繼續前進直到輸入耗盡。如果指定了 stop,則在達到指定位置時停止。如果 step 為
None
,則步長 (step) 預設為一。元素會連續回傳,除非將 step 設定為大於一,這會導致一些項目被跳過。大致等價於:
def islice(iterable, *args): # islice('ABCDEFG', 2) → A B # islice('ABCDEFG', 2, 4) → C D # islice('ABCDEFG', 2, None) → C D E F G # islice('ABCDEFG', 0, None, 2) → A C E G s = slice(*args) start = 0 if s.start is None else s.start stop = s.stop step = 1 if s.step is None else s.step if start < 0 or (stop is not None and stop < 0) or step <= 0: raise ValueError indices = count() if stop is None else range(max(start, stop)) next_i = start for i, element in zip(indices, iterable): if i == next_i: yield element next_i += step
If the input is an iterator, then fully consuming the islice advances the input iterator by
max(start, stop)
steps regardless of the step value.
- itertools.pairwise(iterable)¶
回傳從輸入的 iterable 中提取的連續重疊對。
輸出疊代器中的 2 元組數量將比輸入少一個。如果輸入的可疊代物件中的值少於兩個,則輸出將為空值。
大致等價於:
def pairwise(iterable): # pairwise('ABCDEFG') → AB BC CD DE EF FG iterator = iter(iterable) a = next(iterator, None) for b in iterator: yield a, b a = b
在 3.10 版被加入.
- itertools.permutations(iterable, r=None)¶
回傳 iterable 中連續且長度為 r 的元素排列 。
如果未指定 r 或其值為
None
,則 r 預設為 iterable 的長度,並產生所有可能的完整長度的排列。輸出是
product()
的子序列,其中重複元素的條目已被濾除。輸出的長度由math.perm()
給定,當0 ≤ r ≤ n
時,長度為n! / (n - r)!
,當r > n
時為零。根據輸入值 iterable 的順序,排列的元組會按照字典順序輸出。如果輸入的 iterable 已排序,則輸出的元組也將按排序的順序產生。
元素是根據它們的位置(而非值)來決定其唯一性。如果輸入的元素都是獨特的,則排列中將不會有重複的值。
大致等價於:
def permutations(iterable, r=None): # permutations('ABCD', 2) → AB AC AD BA BC BD CA CB CD DA DB DC # permutations(range(3)) → 012 021 102 120 201 210 pool = tuple(iterable) n = len(pool) r = n if r is None else r if r > n: return indices = list(range(n)) cycles = list(range(n, n-r, -1)) yield tuple(pool[i] for i in indices[:r]) while n: for i in reversed(range(r)): cycles[i] -= 1 if cycles[i] == 0: indices[i:] = indices[i+1:] + indices[i:i+1] cycles[i] = n - i else: j = cycles[i] indices[i], indices[-j] = indices[-j], indices[i] yield tuple(pool[i] for i in indices[:r]) break else: return
- itertools.product(*iterables, repeat=1)¶
Cartesian product of the input iterables.
大致等價於產生器運算式中的巢狀 for 迴圈。例如,
product(A, B)
的回傳結果與((x,y) for x in A for y in B)
相同。巢狀迴圈的循環類似於里程表,最右邊的元素在每次疊代時前進。這種模式會建立字典順序,因此如果輸入的 iterables 已排序,則輸出的乘積元組也將按排序的順序產生。
要計算可疊代物件自身的乘積,可以使用可選的 repeat 關鍵字引數來指定重複次數。例如,
product(A, repeat=4)
與product(A, A, A, A)
相同。此函式大致等價於以下的程式碼,不同之處在於真正的實作不會在記憶體中建立中間結果:
def product(*iterables, repeat=1): # product('ABCD', 'xy') → Ax Ay Bx By Cx Cy Dx Dy # product(range(2), repeat=3) → 000 001 010 011 100 101 110 111 if repeat < 0: raise ValueError('repeat argument cannot be negative') pools = [tuple(pool) for pool in iterables] * repeat result = [[]] for pool in pools: result = [x+[y] for x in result for y in pool] for prod in result: yield tuple(prod)
在
product()
執行之前,它會完全消耗輸入的 iterables,並將值的池 (pools of values) 保存在記憶體中以產生乘積。因此,它僅對有限的輸入有用。
- itertools.repeat(object[, times])¶
建立一個疊代器,反覆回傳 object。除非指定了 times 引數,否則會執行無限次。
大致等價於:
def repeat(object, times=None): # repeat(10, 3) → 10 10 10 if times is None: while True: yield object else: for i in range(times): yield object
repeat 的常見用途是為 map 或 zip 提供定值的串流:
>>> list(map(pow, range(10), repeat(2))) [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]
- itertools.starmap(function, iterable)¶
建立一個疊代器,使用從 iterable 獲取的引數計算 function 。當引數參數已經被「預先壓縮 (pre-zipped)」成元組時,使用此方法代替
map()
。map()
和starmap()
之間的區別類似於function(a,b)
和function(*c)
之間的區別。大致等價於:def starmap(function, iterable): # starmap(pow, [(2,5), (3,2), (10,3)]) → 32 9 1000 for args in iterable: yield function(*args)
- itertools.takewhile(predicate, iterable)¶
建立一個疊代器,只在 predicate 為 true 時回傳 iterable 中的元素。大致等價於:
def takewhile(predicate, iterable): # takewhile(lambda x: x<5, [1,4,6,3,8]) → 1 4 for x in iterable: if not predicate(x): break yield x
注意,第一個不符合條件判斷的元素將從輸入疊代器中被消耗,且無法再存取它。如果應用程式希望在 takewhile 耗盡後進一步消耗輸入疊代器,這可能會是個問題。為了解決這個問題,可以考慮使用 more-itertools 中的 before_and_after() 作為替代。
- itertools.tee(iterable, n=2)¶
從一個 iterable 中回傳 n 個獨立的疊代器。
大致等價於:
def tee(iterable, n=2): if n < 0: raise ValueError if n == 0: return () iterator = _tee(iterable) result = [iterator] for _ in range(n - 1): result.append(_tee(iterator)) return tuple(result) class _tee: def __init__(self, iterable): it = iter(iterable) if isinstance(it, _tee): self.iterator = it.iterator self.link = it.link else: self.iterator = it self.link = [None, None] def __iter__(self): return self def __next__(self): link = self.link if link[1] is None: link[0] = next(self.iterator) link[1] = [None, None] value, self.link = link return value
When the input iterable is already a tee iterator object, all members of the return tuple are constructed as if they had been produced by the upstream
tee()
call. This "flattening step" allows nestedtee()
calls to share the same underlying data chain and to have a single update step rather than a chain of calls.The flattening property makes tee iterators efficiently peekable:
def lookahead(tee_iterator): "Return the next value without moving the input forward" [forked_iterator] = tee(tee_iterator, 1) return next(forked_iterator)
>>> iterator = iter('abcdef') >>> [iterator] = tee(iterator, 1) # Make the input peekable >>> next(iterator) # Move the iterator forward 'a' >>> lookahead(iterator) # Check next value 'b' >>> next(iterator) # Continue moving forward 'b'
tee
疊代器不是執行緒安全 (threadsafe) 的。當同時使用由同一個tee()
呼叫所回傳的疊代器時,即使原始的 iterable 是執行緒安全的,也可能引發RuntimeError
。此 itertool 可能需要大量的輔助儲存空間(取決於需要儲存多少臨時資料)。通常如果一個疊代器在另一個疊代器開始之前使用了大部分或全部的資料,使用
list()
會比tee()
更快。
- itertools.zip_longest(*iterables, fillvalue=None)¶
建立一個疊代器,聚合來自每個 iterables 中的元素。
如果 iterables 的長度不一,則使用 fillvalue 填充缺少的值。如果未指定,fillvalue 會預設為
None
。疊代將持續直到最長的可疊代物件耗盡為止。
大致等價於:
def zip_longest(*iterables, fillvalue=None): # zip_longest('ABCD', 'xy', fillvalue='-') → Ax By C- D- iterators = list(map(iter, iterables)) num_active = len(iterators) if not num_active: return while True: values = [] for i, iterator in enumerate(iterators): try: value = next(iterator) except StopIteration: num_active -= 1 if not num_active: return iterators[i] = repeat(fillvalue) value = fillvalue values.append(value) yield tuple(values)
如果其中一個 iterables 可能是無限的,那麼應該用別的可以限制呼叫次數的方法來包裝
zip_longest()
函式(例如islice()
或takewhile()
)。
Itertools 應用技巧¶
此段落展示了使用現有的 itertools 作為構建塊來建立擴展工具集的應用技巧。
itertools 應用技巧的主要目的是教學。這些應用技巧展示了對單個工具進行思考的各種方式 —— 例如,chain.from_iterable
與攤平 (flattening) 的概念相關。這些應用技巧還提供了組合使用工具的想法 —— 例如,starmap()
和 repeat()
如何一起工作。另外還展示了將 itertools 與 operator
和 collections
模組一同使用以及與內建 itertools(如 map()
、filter()
、reversed()
和 enumerate()
)一同使用的模式。
應用技巧的次要目的是作為 itertools 的孵化器。accumulate()
, compress()
和 pairwise()
itertools 最初都是作為應用技巧出現的。目前,sliding_window()
、iter_index()
和 sieve()
的應用技巧正在被測試,以確定它們是否有價值被收錄到內建的 itertools 中。
幾乎所有這些應用技巧以及許多其他應用技巧都可以從 Python Package Index 上的 more-itertools 專案中安裝:
python -m pip install more-itertools
許多應用技巧提供了與底層工具集相同的高性能。透過一次處理一個元素而不是將整個可疊代物件一次性引入記憶體,能保持優異的記憶體性能。以函式風格 (functional style) 將工具連接在一起,能將程式碼的數量維持在較少的情況。透過優先使用「向量化 (vectorized)」的構建塊而不是使用會造成直譯器負擔的 for 迴圈和產生器,則能保持高速度。
from collections import deque
from contextlib import suppress
from functools import reduce
from math import sumprod, isqrt
from operator import itemgetter, getitem, mul, neg
def take(n, iterable):
"Return first n items of the iterable as a list."
return list(islice(iterable, n))
def prepend(value, iterable):
"Prepend a single value in front of an iterable."
# prepend(1, [2, 3, 4]) → 1 2 3 4
return chain([value], iterable)
def tabulate(function, start=0):
"Return function(0), function(1), ..."
return map(function, count(start))
def repeatfunc(function, times=None, *args):
"Repeat calls to a function with specified arguments."
if times is None:
return starmap(function, repeat(args))
return starmap(function, repeat(args, times))
def flatten(list_of_lists):
"Flatten one level of nesting."
return chain.from_iterable(list_of_lists)
def ncycles(iterable, n):
"Returns the sequence elements n times."
return chain.from_iterable(repeat(tuple(iterable), n))
def loops(n):
"Loop n times. Like range(n) but without creating integers."
# for _ in loops(100): ...
return repeat(None, n)
def tail(n, iterable):
"Return an iterator over the last n items."
# tail(3, 'ABCDEFG') → E F G
return iter(deque(iterable, maxlen=n))
def consume(iterator, n=None):
"Advance the iterator n-steps ahead. If n is None, consume entirely."
# Use functions that consume iterators at C speed.
if n is None:
deque(iterator, maxlen=0)
else:
next(islice(iterator, n, n), None)
def nth(iterable, n, default=None):
"Returns the nth item or a default value."
return next(islice(iterable, n, None), default)
def quantify(iterable, predicate=bool):
"Given a predicate that returns True or False, count the True results."
return sum(map(predicate, iterable))
def first_true(iterable, default=False, predicate=None):
"Returns the first true value or the *default* if there is no true value."
# first_true([a,b,c], x) → a or b or c or x
# first_true([a,b], x, f) → a if f(a) else b if f(b) else x
return next(filter(predicate, iterable), default)
def all_equal(iterable, key=None):
"Returns True if all the elements are equal to each other."
# all_equal('4٤௪౪໔', key=int) → True
return len(take(2, groupby(iterable, key))) <= 1
def unique_justseen(iterable, key=None):
"Yield unique elements, preserving order. Remember only the element just seen."
# unique_justseen('AAAABBBCCDAABBB') → A B C D A B
# unique_justseen('ABBcCAD', str.casefold) → A B c A D
if key is None:
return map(itemgetter(0), groupby(iterable))
return map(next, map(itemgetter(1), groupby(iterable, key)))
def unique_everseen(iterable, key=None):
"Yield unique elements, preserving order. Remember all elements ever seen."
# unique_everseen('AAAABBBCCDAABBB') → A B C D
# unique_everseen('ABBcCAD', str.casefold) → A B c D
seen = set()
if key is None:
for element in filterfalse(seen.__contains__, iterable):
seen.add(element)
yield element
else:
for element in iterable:
k = key(element)
if k not in seen:
seen.add(k)
yield element
def unique(iterable, key=None, reverse=False):
"Yield unique elements in sorted order. Supports unhashable inputs."
# unique([[1, 2], [3, 4], [1, 2]]) → [1, 2] [3, 4]
sequenced = sorted(iterable, key=key, reverse=reverse)
return unique_justseen(sequenced, key=key)
def sliding_window(iterable, n):
"Collect data into overlapping fixed-length chunks or blocks."
# sliding_window('ABCDEFG', 4) → ABCD BCDE CDEF DEFG
iterator = iter(iterable)
window = deque(islice(iterator, n - 1), maxlen=n)
for x in iterator:
window.append(x)
yield tuple(window)
def grouper(iterable, n, *, incomplete='fill', fillvalue=None):
"Collect data into non-overlapping fixed-length chunks or blocks."
# grouper('ABCDEFG', 3, fillvalue='x') → ABC DEF Gxx
# grouper('ABCDEFG', 3, incomplete='strict') → ABC DEF ValueError
# grouper('ABCDEFG', 3, incomplete='ignore') → ABC DEF
iterators = [iter(iterable)] * n
match incomplete:
case 'fill':
return zip_longest(*iterators, fillvalue=fillvalue)
case 'strict':
return zip(*iterators, strict=True)
case 'ignore':
return zip(*iterators)
case _:
raise ValueError('Expected fill, strict, or ignore')
def roundrobin(*iterables):
"Visit input iterables in a cycle until each is exhausted."
# roundrobin('ABC', 'D', 'EF') → A D E B F C
# Algorithm credited to George Sakkis
iterators = map(iter, iterables)
for num_active in range(len(iterables), 0, -1):
iterators = cycle(islice(iterators, num_active))
yield from map(next, iterators)
def subslices(seq):
"Return all contiguous non-empty subslices of a sequence."
# subslices('ABCD') → A AB ABC ABCD B BC BCD C CD D
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
return map(getitem, repeat(seq), slices)
def iter_index(iterable, value, start=0, stop=None):
"Return indices where a value occurs in a sequence or iterable."
# iter_index('AABCADEAF', 'A') → 0 1 4 7
seq_index = getattr(iterable, 'index', None)
if seq_index is None:
iterator = islice(iterable, start, stop)
for i, element in enumerate(iterator, start):
if element is value or element == value:
yield i
else:
stop = len(iterable) if stop is None else stop
i = start
with suppress(ValueError):
while True:
yield (i := seq_index(value, i, stop))
i += 1
def iter_except(function, exception, first=None):
"Convert a call-until-exception interface to an iterator interface."
# iter_except(d.popitem, KeyError) → non-blocking dictionary iterator
with suppress(exception):
if first is not None:
yield first()
while True:
yield function()
以下的應用技巧具有更多的數學風格:
def powerset(iterable):
"Subsequences of the iterable from shortest to longest."
# powerset([1,2,3]) → () (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
s = list(iterable)
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
def sum_of_squares(iterable):
"Add up the squares of the input values."
# sum_of_squares([10, 20, 30]) → 1400
return sumprod(*tee(iterable))
def reshape(matrix, columns):
"Reshape a 2-D matrix to have a given number of columns."
# reshape([(0, 1), (2, 3), (4, 5)], 3) → (0, 1, 2), (3, 4, 5)
return batched(chain.from_iterable(matrix), columns, strict=True)
def transpose(matrix):
"Swap the rows and columns of a 2-D matrix."
# transpose([(1, 2, 3), (11, 22, 33)]) → (1, 11) (2, 22) (3, 33)
return zip(*matrix, strict=True)
def matmul(m1, m2):
"Multiply two matrices."
# matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]) → (49, 80), (41, 60)
n = len(m2[0])
return batched(starmap(sumprod, product(m1, transpose(m2))), n)
def convolve(signal, kernel):
"""Discrete linear convolution of two iterables.
Equivalent to polynomial multiplication.
Convolutions are mathematically commutative; however, the inputs are
evaluated differently. The signal is consumed lazily and can be
infinite. The kernel is fully consumed before the calculations begin.
Article: https://betterexplained.com/articles/intuitive-convolution/
Video: https://www.youtube.com/watch?v=KuXjwB4LzSA
"""
# convolve([1, -1, -20], [1, -3]) → 1 -4 -17 60
# convolve(data, [0.25, 0.25, 0.25, 0.25]) → Moving average (blur)
# convolve(data, [1/2, 0, -1/2]) → 1st derivative estimate
# convolve(data, [1, -2, 1]) → 2nd derivative estimate
kernel = tuple(kernel)[::-1]
n = len(kernel)
padded_signal = chain(repeat(0, n-1), signal, repeat(0, n-1))
windowed_signal = sliding_window(padded_signal, n)
return map(sumprod, repeat(kernel), windowed_signal)
def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.
(x - 5) (x + 4) (x - 3) expands to: x³ -4x² -17x + 60
"""
# polynomial_from_roots([5, -4, 3]) → [1, -4, -17, 60]
factors = zip(repeat(1), map(neg, roots))
return list(reduce(convolve, factors, [1]))
def polynomial_eval(coefficients, x):
"""Evaluate a polynomial at a specific value.
Computes with better numeric stability than Horner's method.
"""
# Evaluate x³ -4x² -17x + 60 at x = 5
# polynomial_eval([1, -4, -17, 60], x=5) → 0
n = len(coefficients)
if not n:
return type(x)(0)
powers = map(pow, repeat(x), reversed(range(n)))
return sumprod(coefficients, powers)
def polynomial_derivative(coefficients):
"""Compute the first derivative of a polynomial.
f(x) = x³ -4x² -17x + 60
f'(x) = 3x² -8x -17
"""
# polynomial_derivative([1, -4, -17, 60]) → [3, -8, -17]
n = len(coefficients)
powers = reversed(range(1, n))
return list(map(mul, coefficients, powers))
def sieve(n):
"Primes less than n."
# sieve(30) → 2 3 5 7 11 13 17 19 23 29
if n > 2:
yield 2
data = bytearray((0, 1)) * (n // 2)
for p in iter_index(data, 1, start=3, stop=isqrt(n) + 1):
data[p*p : n : p+p] = bytes(len(range(p*p, n, p+p)))
yield from iter_index(data, 1, start=3)
def factor(n):
"Prime factors of n."
# factor(99) → 3 3 11
# factor(1_000_000_000_000_007) → 47 59 360620266859
# factor(1_000_000_000_000_403) → 1000000000000403
for prime in sieve(isqrt(n) + 1):
while not n % prime:
yield prime
n //= prime
if n == 1:
return
if n > 1:
yield n
def is_prime(n):
"Return True if n is prime."
# is_prime(1_000_000_000_000_403) → True
return n > 1 and next(factor(n)) == n
def totient(n):
"Count of natural numbers up to n that are coprime to n."
# https://mathworld.wolfram.com/TotientFunction.html
# totient(12) → 4 because len([1, 5, 7, 11]) == 4
for prime in set(factor(n)):
n -= n // prime
return n