How to (properly) chunkify a list in Python
Lets start with the obvious solution:
def chunkify(iterable, chunk_size):
chunk = []
while True:
try:
item = next(iterable)
except StopIteration:
break
chunk.append(item)
if len(chunk) == chunk_size:
yield chunk
chunk.clear()
if chunk:
yield chunk
for chunk in chunkify(iter(range(10)), 3):
print(chunk)
# <<< [0, 1, 2]
# ... [3, 4, 5]
# ... [6, 7, 8]
# ... [9]
Detour
I purposefully included an error because I thought it would be interesting. Can you spot it?
How about now?
a, b, c, d = chunkify(iter(range(10)), 3)
a, b, c, d
# <<< ([9], [9], [9], [9])
a is b, a is c, a is d
# <<< (True, True, True)
So, the issue is that the chunk we are yielding is always the same variable.
When we used .clear()
, we didn’t create a new list but mutated the existing
one so if you don’t consume each chunk before getting the next one, it will be
replaced.
We can fix this with:
def chunkify(iterable, chunk_size):
chunk = []
while True:
try:
item = next(iterable)
except StopIteration:
break
chunk.append(item)
if len(chunk) == chunk_size:
yield chunk
- chunk.clear()
+ chunk = []
if chunk:
yield chunk
a, b, c, d = chunkify(iter(range(10)), 3)
a, b, c, d
# <<< ([0, 1, 2], [3, 4, 5], [6, 7, 8], [9])
Detour over
So, what’s wrong with this implementation? Nothing really, it will work fine. Can we improve it however?
The hypothetical scenario is that
-
The iterable we use for input does not require any RAM. Lets assume for example it is data streamed from a file or network resource
-
The chunk size is big
With the implementation we have so far, we will require storing each chunk in RAM, and maybe this is something that we can avoid.
To get a feel for what we are going after, lets look at how itertools.groupby
works:
from itertools import groupby
for div, group in groupby(range(10), lambda i: i // 3):
print(div, list(group))
# <<< 0 [0, 1, 2]
# ... 1 [3, 4, 5]
# ... 2 [6, 7, 8]
# ... 3 [9]
So, how does this save on RAM? The answer is that each group in the iteration above is not a list, but a generator:
list(groupby(range(10), lambda i: i // 3))
# <<< [
# ... (0, <itertools._grouper object at 0x7f35b3838070>),
# ... (1, <itertools._grouper object at 0x7f35c0c5a920>),
# ... (2, <itertools._grouper object at 0x7f35b37bc160>),
# ... (3, <itertools._grouper object at 0x7f35b37bfe50>)
# ... ]
And, because each group is not a list, it is populated directly from the initial iterable as we iterate over it. This means, the we must consume each group before we fetch the next group. Otherwise, by that point the previous group will have already been consumed so that the initial iterable can reach the new group.
groups = groupby(range(10), lambda i: i // 3)
_, group_1 = next(groups)
_, group_2 = next(groups)
list(group_1)
# <<< []
list(group_2)
# <<< [3, 4, 5]
This is the reason why you can’t (or rather shouldn’t) iterate over the same group twice:
for _, group in groupby(range(10), lambda i: i // 3):
print(list(group))
print(list(group))
break
# <<< [0, 1, 2]
# ... []
Side note: if you did want to iterate over the group twice, it would be easy to do:
for _, group in groupby(range(10), lambda i: i // 3): group = list(group) print(group) print(group) break # <<< [0, 1, 2] # ... [0, 1, 2]
This is also the reason why the iterable fed into groupby
must be already
sorted. Every time the return value of the lambda function changes, a new group
is being created, regardless of whether the same return value has been
encountered before:
for mod, group in groupby(range(10), lambda i: i % 2):
print(mod, list(group))
# <<< 0 [0]
# ... 1 [1]
# ... 0 [2]
# ... 1 [3]
# ... 0 [4]
# ... 1 [5]
# ... 0 [6]
# ... 1 [7]
# ... 0 [8]
# ... 1 [9]
So, now we have all the ingredients we need to make our zero-RAM chunkifier. In
fact, since groupby
works so well, we are going to take advantage of it:
from itertools import groupby
def chunkify(iterable, chunk_size):
for div, enumerated_chunk in groupby(
enumerate(iterable), lambda i: i[0] // chunk_size
):
yield (item for index, item in enumerated_chunk)
In case you haven’t encountered it before,
enumerate
does this:list(enumerate("abcd")) # <<< [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'd')]
enumerate
is also “zero-RAM”, ie it consumes its argument lazily
And lets run it through our, by now usual, tests to see how it behaves:
for chunk in chunkify(range(10), 3):
print(list(chunk))
# <<< [0, 1, 2]
# ... [3, 4, 5]
# ... [6, 7, 8]
# ... [9]
chunks = chunkify(range(10), 3)
a = next(chunks)
b = next(chunks)
list(a)
# <<< []
list(b)
# <<< [3, 4, 5]
for chunk in chunkify(range(10), 3):
print(list(chunk))
print(list(chunk))
break
# <<< [0, 1, 2]
# <<< []