Befriending the itertools module for efficient iteration in Python

Prerequisites

You will need a working knowledge of Python to get the best out of this article. Each code sample provided is a complete Python program you can run by pasting the contents in a python file and running the file as python name_of_your_file.py on your terminal. For each of the provided code examples, there is an accompanying link to a repl on replit where you can run the code. However, you will need a replit account. You can sign up here

Exploring itertools

The itertools module has various utilities that we can use to perform efficient operations on iterables. Some allow you to combine elements from different iterables, some allow you to group items in an iterable using different keys, some allow you to create iterators that run infinitely but efficiently. Getting familiar with these utilities will save you from reimplementing similar functionality. If you find yourself working with large lists, datasets, or just want to indulge your curiosity, the itertools module is definitely worth a look.

I call them utilities because although they quack and walk like functions, they are actually implemented as classes, except tee, which is a builtin function.

Let's use the inspect module to have a high level overview of the itertools module. We ignore names beginning with _ since these are generally not meant for public use.

import inspect
import itertools
import pprint
import types
 
from collections import defaultdict
 
report = defaultdict(list)
 
for name, obj in inspect.getmembers(itertools):
    if name.startswith('_'):
        continue
 
    if inspect.isclass(obj):
        report['classes'].append(name)
    elif inspect.isfunction(obj):
        report['functions'].append(name)
    elif inspect.isgeneratorfunction(obj):
        report['generator_functions'].append(name)
    elif inspect.isbuiltin(obj) and isinstance(obj, types.BuiltinFunctionType):
            report['builtin_functions'].append(name)
 
for key, values in report.items():
    if values:
        pprint.pprint(f"{key}: {', '.join(values)}")

You can run this example on replit.

You should see the following:

('classes: accumulate, chain, combinations, combinations_with_replacement, '
 'compress, count, cycle, dropwhile, filterfalse, groupby, islice, pairwise, '
 'permutations, product, repeat, starmap, takewhile, zip_longest')
'builtin_functions: tee'

itertools.accumulate

accumulate(iterable, function, initial)

Makes an iterator that yields results of applying the function provided cummulatively to the items of the iterable from left to right. The function accepts two arguments; an accumulated total and the next value from the iterable. If function is not provided, accumulate returns an iterator that yields cummulative sums of the elements of the iterable. For example, accumulate([1,2,3]) would return an iterator that yields the items 1, 3, 6, the result of performing (1), (1 + 2), (1 + 2 + 3).

If initial is provided, it becomes the first value in the output iterator and provides a starting point for the accumulation.

Let's look at this example that tracks daily temperatures and records the maximum and minimum temperatures encountered so far.

from itertools import accumulate
 
daily_temps = [20, 23, 18, 25, 24, 22, 27]  # celcius
 
max_temps = list(accumulate(daily_temps, max))
min_temps = list(accumulate(daily_temps, min))
 
print("Daily temperatures:", daily_temps)
print("Highest temp. so far:", max_temps)
print("Lowest temp. so far:", min_temps)

You can run this example on replit.

You should see this output:

Daily temperatures: [20, 23, 18, 25, 24, 22, 27]
Highest temp so far: [20, 23, 23, 25, 25, 25, 27]
Lowest temp so far: [20, 20, 18, 18, 18, 18, 18]

itertools.count

count(start=0, step=1)

Makes an iterator that infinitely yields numbers starting at start and incremented by step. If start is not provided, the counting begins at 0. If step is not provided, the increment happens by 1.

SQL uses the AUTO_INCREMENT keyword to create auto-incrementing ids, starting at 1 by default; although you can provide a starting id. Here is how we can leverage itertools.count to implement SQL-like autoincrement functionality.

import pprint
from collections import namedtuple
from itertools import count
 
Person = namedtuple('Person', ['person_id', 'first_name', 'last_name', 'age' ])
 
class Table:
    def __init__(self, start=1):
        self.idfactory = count(start=start)
        self.records = []
 
    def add_record(self, first_name, last_name, age):
        person_id=next(self.idfactory)
        record = Person(person_id, first_name, last_name, age)
        self.records.append(record)
        return record
 
tb = Table()
tb.add_record('James', 'Waiyaki', 70)
tb.add_record('Martin', 'Mungai', 69)
pprint.pprint(tb.records)

You can run the code above here

We can see that the auto-increment is working.

[Person(person_id=1, first_name='James', last_name='Waiyaki', age=70),
 Person(person_id=2, first_name='Martin', last_name='Mungai', age=69)]

itertools.cycle

cycle(iterable)

Makes an iterator that consumes and yields elements from an iterable, one by one, until all items are exhasusted, then it repeats again. This happens indefinitely. Can be useful for situations where you want to "cycle" through the same elements over and over again.

Consider the traffic light simmulator below. Notice how color goes from Red to Yellow to Green and back to Red over and over.

import time
from itertools import cycle
 
 
def traffic_light_simulator():
    colors = cycle(['Red', 'Yellow', 'Green'])
    durations = {'Red': 20, 'Yellow': 5, 'Green': 30}
 
    for color in colors:
        print(f"Traffic light is now {color}")
        time.sleep(durations[color])
 
traffic_light_simulator()

You can run this example on replit.

itertools.repeat

repeat(object, times)

Makes an iterator that returns object repeatedely, up to times times. If times is not provided, yield object infinitely.

We can use itertools.repeat together with the csv module to initialize a CSV template with a 1000 rows:

import csv
from itertools import repeat
 
 
def create_spreadsheet(filename: str, headers: list[str], num_rows: int):
    with open(filename, 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(headers)
        writer.writerows(repeat([''] * len(headers), num_rows))
 
create_spreadsheet('template.csv', ['Name', 'Age', 'City'], 1000)

You can run this example on replit.

itertools.islice

islice(iterable, stop) islice(iterable, start, stop) islice(iterable, start, stop, step)

Works similarly to good 'ol slicing but returns an iterator while slicing returns a list or a tuple or whatever type you are slicing.

Let's create a list of numbers, from 0-100:

numbers = list(range(101))

To get multiples of 10 between 10 and 50 (including 10 & 50), we can create the following slice:

multiples = numbers[10:51:10] # [10, 20, 30, 40, 50]

This is the same as writing

multiples = numbers[slice(10, 51, 10)] # [10, 20, 30, 40, 50]

Notice how slicing a list returns a list. This can be expensive memory-wise if we are creating large slices. How can we effectively create a slice without putting the whole thing in memory? Enter itertools.islice. We can create an iterator that yields the same elements although efficiently.

import itertools
 
multiples_iter = itertools.islice(numbers, 10, 50, 10)

Both multiples and multiples_iter can be iterated over to get the same elements. However, multiples_iter takes up lesser memory compared to multiples. Let's see:

import sys
 
print(f'size of multiples is: {sys.getsizeof(multiples)} bytes')
print(f'size of multiples_iter is: {sys.getsizeof(multiples_iter)} bytes')

Full code for reference:

import sys
from itertools import islice
 
numbers = list(range(101))
 
multiples = numbers[10:51:10]
 
multiples = numbers[slice(10, 51, 10)]
 
multiples_iter = islice(numbers, 10, 50, 10)
 
print(f'size of multiples is: {sys.getsizeof(multiples)} bytes')
print(f'size of multiples_iter is: {sys.getsizeof(multiples_iter)} bytes')

I get the output below when I run the comparison:

size of multiples is: 96 bytes
size of multiples_iter is: 72 bytes

You can run this example on replit.

The difference here might not be that life-changing but extrapolated over larger slices can be quite significant.

A few differences between slice and islice:

  • islice does not support negative values for start, step, stop.
  • iterators returned by islice are not subscriptable. For example, we can not do multiples_iter[0]

itertools.filterfalse

filterfalse(predicate, iterable)

Makes an iterator that "picks" those items for which predicate returns False or a falsy value. predicate can be None or a function that returns a boolean value, given an element from the iterable. If no predicate function is provided, returns items that are False or falsy.

Consider this example where we want to filter odd numbers from a list of integers between 1 and 10.

import itertools
 
def is_even(num):
    return num % 2 == 0
 
nums = list(range(1, 11))
odd_nums = list(itertools.filterfalse(is_even, nums))
 
print(f'Numbers: {nums}')
print(f'Odd numbers: {odd_nums}')

You can run this example on replit. Running this gives the following output:

Numbers: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
Odd numbers: [1, 3, 5, 7, 9]

itertools.chain

chain(*iterables)

Creates an iterator that returns element from the first iterable, then elements from the next iterable and so on. Useful for iterating over multiple iterables as one iterable without creating a new list.

Let's see an example where we combine multiple shopping lists into a cart.

from collections import namedtuple
from itertools import chain
 
Item = namedtuple('Item', ['name', 'price'])
 
produce = [
    Item('apples', 152.99),
    Item('bananas', 121.59),
    Item('cherries', 64.99)
]
dairy = [Item('milk', 73.49), Item('cheese', 555.99)]
bakery = [Item('bread', 92.49), Item('mandazi pack', 55.99)]
 
cart = chain(produce, dairy, bakery)
 
total_cost = 0
total_items = 0
 
print("Cart:")
print("-" * 30)
for item in cart:
    print(f"{item.name:<15} KES {item.price:,.2f}")
    total_cost += item.price
    total_items += 1
 
print("-" * 30)
print(f"Total:          KES {total_cost:,.2f}")
 
print(f"\nNumber of items: {total_items}")

You can run this example on replit. You should see the output below:

Cart:
------------------------------
apples          KES 152.99
bananas         KES 121.59
cherries        KES 64.99
milk            KES 73.49
cheese          KES 555.99
bread           KES 92.49
mandazi pack    KES 55.99
------------------------------
Total:          KES 1,117.53

Number of items: 7

itertools.batched

batched(iterable, n)

Introduced in Python 3.12, splits the iterable into "batches" (tuples) of size n, without loading the entire iterable in memory.

Imagine you have a service where people upload image files and you want to compress them in batches of size n.

import itertools
import io
 
 
def compress_batch(batch):
    for image in batch:
        # image compression logic
        print('compressing image')
 
    print(f'compressed {len(batch)} images')
 
 
def compress(images, batch_size=100):
    for batch in itertools.batched(images, batch_size):
        compress_batch(batch)
 
 
images = (io.BytesIO(b'\x00' * 1024) for _ in range(100000))  # using a generator expression to make 100,000 fake files lazily
compress(images)

replit does not explicitly support Python 3.12 so you will have to try this example locally.

batched can also be used to batch requests to external APIs that implement rate limiting.

itertools.chain.from_iterable

chain.from_iterable(iterable)

This is useful for flattening nested sequences. Iterable here can be list of lists, list of tuples.

from itertools import chain
 
nested = [[1, 2, 3], [4, 5, 6]]
flat = list(chain.from_iterable(nested))
 
print(f'nested: {nested}')
print(f'flat: {flat}')

You can run this code here. The output will be as below:

nested: [[1, 2, 3], [4, 5, 6]]
flat: [1, 2, 3, 4, 5, 6]

itertools.compress

compress(data, selectors)

Makes an iterator that returns item in data whose corresponding index in selectors is True or Truthy.

The length of selectors should be shorter than or equal to the length of data. If the length of selectors is shorter than data, compress only processes items in data up to the corresponding length in `selectors.

Let's sat we have two lists. We can iterate over and print the names of registered voters as follows:

from itertools import compress
 
voters = {
  'names': ['Jane', 'Doe', 'Muthoni'],
  'registered': [True, False, True]
}
 
for voter in compress(voters['names'], voters['registered']):
  print(voter)

You can test this here. Expeected output will be as follows:

Jane
Muthoni

itertools.dropwhile

dropwhile(predicate, iterable)

Works by skipping items in the beginning of an iterable as long as the predicate function returns True for each element. Once it encounters an element for which the predicate function returns False, return that item and all other subsequent items, regardless of what the predicate function returns.

Imagine a mixed weather Formula 1 race; starts out raining but rain stops at some point and the rest of the race is dry. The drivers pit for slicks (dry weather tires) and their lap times start improving significantly. Let's assume that a dry lap time is expected to be 105 seconds (1 min 45 seconds) or less. We can use dropwhile to skip lap times higher than 105, effectively finding the start of dry weather laps. Once a laptime lower than 105 is found, that lap and all subsequent laps are included, even if a driver makes a mistake afterwards and has a laptime higher than 105.

The code below makes use of dropwhile in the detect_transition function to find the transition point and filter for dry laps from the provided lap times. I used, as examples my top three favorite drivers in the 2024 grid 😉.

from itertools import dropwhile
 
drivers = {
    "Hamilition": [122, 119, 114, 108, 103, 99, 97, 96, 95, 94],
    "Leclerc": [120, 118, 115, 110, 105, 100, 98, 97, 96, 95],
    "Piastri": [121, 117, 113, 109, 106, 101, 99, 98, 97, 96]
}
 
def seconds_to_min_sec(seconds):
    minutes = int(seconds // 60)
    remaining_seconds = round(seconds % 60, 3)
    return f"{minutes}:{remaining_seconds:06.3f}"
 
def detect_transition(lap_times, threshold):
    dry_laps = list(dropwhile(lambda x: x > threshold, lap_times))
    transition_lap = len(lap_times) - len(dry_laps)
    return transition_lap, dry_laps
 
def analyze_drivers(drivers_data, threshold):
    results = {}
    for driver, laps in drivers_data.items():
        transition_lap, dry_laps = detect_transition(laps, threshold)
        results[driver] = {
            "transition_lap": transition_lap,
            "dry_laps": dry_laps,
            "avg_dry_time": sum(dry_laps) / len(dry_laps) if dry_laps else None
        }
    return results
 
analysis = analyze_drivers(drivers, 105)
for driver, data in analysis.items():
    print(f"{driver}: Transition at lap {data['transition_lap']}, Avg dry time: {seconds_to_min_sec(data['avg_dry_time'])}")

Here is the associated repl for this example. Fork it, run it and you should see the output below:

Hamilition: Transition at lap 4, Avg dry time: 1:37.333
Leclerc: Transition at lap 4, Avg dry time: 1:38.500
Piastri: Transition at lap 5, Avg dry time: 1:38.200

itertools.groupby

groupby(iterable, key=None)

Accepts an iterable and a key function and makes an iteraror that groups consecutive elements based on the key function. It's like creating subgroups within the iterable where each subgroup contains elements that are considered equal according to the key function. If a key function is not provided, consecutive identical elements are grouped together. If the same element appeaars later but not consecutively, it starts a new group. This is why it might be important to sort the data, especially if you might want all equal elements grouped together.

Scenario 1: Grouping when key is None

from itertools import groupby
 
data = "AAAABBBCCAADDDD"
 
for key, group in groupby(data):
    group = list(group)
    print(f"Key: {key}, Group: {''.join(group)}, Count: {len(group)}")

See this repl. This code creates 5 groups as seen in the output below:

Key: A, Group: AAAA, Count: 4
Key: B, Group: BBB, Count: 3
Key: C, Group: CC, Count: 2
Key: A, Group: AA, Count: 2
Key: D, Group: DDDD, Count: 4

Notice how we have two groups of A. However, if we sort the data first, similar letters will be next to each other and thus grouped together.

for key, group in groupby(sorted(data)):
    group = list(group)
    print(f"Key: {key}, Group: {''.join(group)}, Count: {len(group)}")

Let's run the code again, we can now see that all similar characters are grouped together. Test it out on this repl

Key: A, Group: AAAAAA, Count: 6
Key: B, Group: BBB, Count: 3
Key: C, Group: CC, Count: 2
Key: D, Group: DDDD, Count: 4

Scenario 2: Grouping using a key function

Let's say you have a list of cars and would like to group them by the make.

from itertools import groupby
from operator import itemgetter
 
cars = [
    (2022, "Mercedes", "E53"),
    (2021, "Mercedes", "GLE 53"),
    (2023, "Audi", "RS6"),
    (2023, "Audi", "A5"),
    (2023, "Audi", "Q5"),
    (2018, "Audi", "SQ5"),
    (2022, "Porsche", "911"),
    (2021, "Porsche", "Cayenne"),
    (2024, "Porsche", "Cayenne S"),
    (2024, "Porsche", "Cayenne GTS"),
    (2023, "Volkswagen", "Golf R"),
    (2023, "Volkswagen", "Tiguan R")
]
 
make_key = itemgetter(1) # same as lambda x: x[1]
 
cars.sort(key=make_key) # rearrange the list to put similar items next to each other so they can be in the same group
 
grouped_by_make = groupby(cars, key=make_key)
 
for make, make_group in grouped_by_make:
    print(f"\n{make} cars:")
    for car in make_group:
        print(f"- {car}")

Running the code above, we see that our cars have indeed been grouped by make. You can confirm by forking and running the code in this repl.

Audi cars:
- (2023, 'Audi', 'RS6')
- (2023, 'Audi', 'A5')
- (2023, 'Audi', 'Q5')
- (2018, 'Audi', 'SQ5')

Mercedes cars:
- (2022, 'Mercedes', 'E53')
- (2021, 'Mercedes', 'GLE 53')

Porsche cars:
- (2022, 'Porsche', '911')
- (2021, 'Porsche', 'Cayenne')
- (2024, 'Porsche', 'Cayenne S')
- (2024, 'Porsche', 'Cayenne GTS')

Volkswagen cars:
- (2023, 'Volkswagen', 'Golf R')
- (2023, 'Volkswagen', 'Tiguan R')

If you would like to group by year, just change the sort key and groupby key from itemgetter(1) to itemgetter(0).

itertools.pairwise

pairwise(iterable)

Makes an iterator that returns successive overlapping pairs from an iterable. Can be useful for time series data analysis.

Consider a scenario where you have the stock price for a hypothetical company over a 5-day period and you would like to analyze the day-to-day price movements.

from itertools import pairwise
 
stock_prices = [
    ("Monday", 150.25),
    ("Tuesday", 152.50),
    ("Wednesday", 151.75),
    ("Thursday", 153.00),
    ("Friday", 155.50)
]
 
print("\nDay-to-day stock price changes:\n")
for (day1, price1), (day2, price2) in pairwise(stock_prices):
    change = price2 - price1
    percentage_change = (change / price1) * 100
    direction = "\033[32m\u25B2\033[0m" if change > 0 else "\033[31m\u25BC\033[0m" if change < 0 else "no change"
 
    print(f"{day1} to {day2}: {direction:} {change:.2f} ({percentage_change:.2f}%)")

You can try out this example by forking this repl. Run the code and you should see the output below:

Day-to-day stock price changes:

Monday to Tuesday: ▲ 2.25 (1.50%)
Tuesday to Wednesday: ▼ -0.75 (-0.49%)
Wednesday to Thursday: ▲ 1.25 (0.82%)
Thursday to Friday: ▲ 2.50 (1.63%)

*Note that pairwise is only available in Python 3.10 and up.

itertools.starmap

starmap(function, iterable)

Makes an iterator that calls function using arguments obtained from the iterable. starmap unpacks the items of the nested iterable and passes them to the provided function. So if you provide an iterable that looks like [("Alice", 28, "New York")], the function you provide will be called with the arguments "Alice", 28, "New York".

Imagine you run a newsletter service and you have a list names and email addresses represented as a list of tuple items. Each tuple contains the name and email address of your subscribers. We can generate emails for each like so:

from itertools import starmap
 
subscribers = [
    ("Alice", "alice@example.com"),
    ("Bob", "bob@example.com"),
    ("Charlie", "charlie68@example.com"),
    ("Diana", "dee@example.com")
]
 
 
def email_greeting(name: str, email: str) -> str:
    resp = f"""
        mailto:{email}
 
        Dear {name},
 
        You might be thinking, "Oh great, another automated newsletter." Here's a mind-blowing fact: I,
        personally hand-write every single one of these emails.
 
        So, the next time you're tempted to fire off a reply asking if I'm a real person, remember
        this: I'm so real, I stub my toe on furniture and cry when I chop onions.
 
        Bye for now, be on the lookout for my next email.
 
        Partially yours,
        Alex
    """
    return resp
 
emails = starmap(email_greeting, subscribers)
print(next(emails))

We call next to get the next item from emails. You can also loop over emails using a for loop.

You can view, fork and run this example here.

Running the code above will produce the output below:

mailto:alice@example.com

Dear Alice,

You might be thinking, "Oh great, another automated newsletter." Here's a mind-blowing fact: I, personally hand-write every single one of these emails.

So, the next time you're tempted to fire off a reply asking if I'm a real person, remember this: I'm so real, I stub my toe on furniture and cry when I chop onions.

Bye for now, be on the lookout for my next email.

Partially yours,
Alex

You might notice that starmap's signature here is somewhat similar to that of the inbuilt map function. Both functions accept a function as the first argument and an iterable as well. However, there are two major differences:

  • map accepts one or more iterables
  • map does not unpack arguments whereas starmap unpacks arguments.

itertools.takewhile

takewhile(predicate, iterable)

Accepts a predicate function and an iterable and makes an iterator that yields elements from the iterable as long as the predicate function returns True. It stops as soon as the predicate function returns False for an element.

I picked some logs from the ruff linter I use on my editor. The example below uses takewhile to print logs before 30 August 20:51:49.

from datetime import datetime
from itertools import takewhile
 
logs = [
    "2024-08-30 20:51:46 [info] Name: Ruff",
    "2024-08-30 20:51:46 [info] Module: ruff",
    "2024-08-30 20:51:46 [info] Using interpreter:",
    "2024-08-30 20:51:48 [error] Unable to find any Python environment for the interpreter path:",
    "2024-08-30 20:53:52 [info] Using interpreter:",
    "2024-08-30 20:53:52 [error] Unable to find any Python environment for the interpreter path:",
]
 
def process_cutoff(log: str) -> bool:
    cutoff_time = datetime(2024, 8, 30, 20, 51, 49)
    timestamp = log[:19]
    return datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S") < cutoff_time
 
relevant_logs = (takewhile(process_cutoff, logs))
 
for log in relevant_logs:
    print(log)

Running the code above will lead to the output below. You can run this example here.

2024-08-30 20:51:46 [info] Name: Ruff
2024-08-30 20:51:46 [info] Module: ruff
2024-08-30 20:51:46 [info] Using interpreter:
2024-08-30 20:51:48 [error] Unable to find any Python environment for the interpreter path:

itertools.tee

tee(iterable, n=2)

Accepts an iterable and creates n independent iterators out it. tee allows you to "split" an iterable into several "copies" that can be consumed independently. Once a tee is created, the original iterable should not be used. However, tee iterators are not thread-safe. Note that tee may use significant memory for large iterables.

Here is an example where we split an iterable and use one iterator for computing squares and another for computing logarithms to base 10. The .2f in the f-string rounds the logs to 2 decimal places.

from itertools import tee
from math import log10
 
original = list(range(1,11))
 
original1, original2 = tee(original)
 
squares = [num * num for num in original1]
 
logs = [f'{log10(num):.2f}' for num in original2]
 
print(squares)
print(logs)

You can test this example by forking this repl here.

The expected output is as follows:

[1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
['0.00', '0.30', '0.48', '0.60', '0.70', '0.78', '0.85', '0.90', '0.95', '1.00']

itertools.zip_longest

zip_longest(*iterables, fill_value=None)

Makes an iterator that combines elements from multiple lists/ iterables. However, unlike the regular zip function, zip_longest doesn't stop when the shortest list is exhausted.

fill_value is used to "fill" in the value when the shoter list/iterable is out of items. Think of it like a default value. If fill_value is not provided, its value is set to None.

Let's have a look at an example that uses zip_longest to combine incomplete sales data from different branches of a supermarket.

from itertools import zip_longest
 
branch_a = [100, 150, 200, 180]
branch_b = [120, 110, 140]
branch_c = [90, 95, 85, 100, 110]
 
combined_sales = list(zip_longest(branch_a, branch_b, branch_c, fillvalue=0))
 
# Print combined sales report
print("Day | Branch A | Branch B | Branch C")
print("-" * 35)
for day, sales in enumerate(combined_sales, start=1):
    a, b, c = sales
    print(f"{day:3} | {a:7} | {b:7} | {c:7}")
 
totals = [sum(branch) for branch in zip(*combined_sales)]
print("-" * 35)
print(f"Total | {totals[0]:7} | {totals[1]:7} | {totals[2]:7}")

You can test this example by forking and running this repl. The expected output will be:

Day | Branch A | Branch B | Branch C
-----------------------------------
  1 |     100 |     120 |      90
  2 |     150 |     110 |      95
  3 |     200 |     140 |      85
  4 |     180 |       0 |     100
  5 |       0 |       0 |     110
-----------------------------------
Total |     630 |     370 |     480

itertools.product

product(*iterables, repeat=1)

Makes an iterator that returns the cartesian product of input iterables. Imagine you have two or more lists, and you want to combine every item from the first list with every item from the second list (and so on if you have more lists).

For example, product([1, 2], [3, 4, 5]), will create an iterator that yields the items (1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5).

Conside this code below that generates all possible combinations of tshirt sizes and colors.

from itertools import product
 
sizes = ['S', 'M', 'L', 'XL']
colors = ['Black', 'White']
 
tshirt_combinations = product(sizes, colors)
 
print("Available T-shirt combinations:")
for size, color in tshirt_combinations:
    print(f"{color} {size}")

You can fork this repl and run the example to get the output below:

Available T-shirt combinations:
Black S
White S
Black M
White M
Black L
White L
Black XL
White XL

itertools.combinations

combinations(iterable, r)

Makes an iterator that generates all possible combinations of length r from an iterable. Order doesn't matter, and a combination can not have repeated elements.

itertools.combinations_with_replacement

combinations_with_replacement(iterable, r)

Makes an iterator that yields combinations of length r from iterable, but a combination can have repeated elements. Order doesn't matter.

itertools.permutations

permutations(iterable, r)

Makes an iterable that yields all possible arrangements of a specified length from an iterable. Order matters, and elements are not repeated.

Let's look at an example that compares combinations, combinations_with_replacement and permutations.

from itertools import permutations, combinations, combinations_with_replacement
 
items = ['A', 'B', 'C']
r = 2
 
perms = list(permutations(items, r))
combis = list(combinations(items, r))
combis_w_replacement = list(combinations_with_replacement(items, r))
 
print(f'permutations: {perms}')
print(f'combinations: {combis}')
print(f'combinations_w_replacement: {combis_w_replacement}')

We should get the output below:

permutations: [('A', 'B'), ('A', 'C'), ('B', 'A'), ('B', 'C'), ('C', 'A'), ('C', 'B')]
combinations: [('A', 'B'), ('A', 'C'), ('B', 'C')]
combinations_w_replacement: [('A', 'A'), ('A', 'B'), ('A', 'C'), ('B', 'B'), ('B', 'C'), ('C', 'C')]

Apologies, I ran out out of repls. You'll have to run this one locally.

Conclusion

itertools is an amazing module. I hope you have learnt something.

P.S. If you find any typos, just pretend they're artisanal imperfections. Thanks for reading. You can share feedback with me at kiuraalex@gmail.com