Monday, October 16, 2017

Thread-Safe Generators in Python

In this post, we will go over how to create thread-safe generators in Python. This post is heavily based on this excellent article.

Generator makes life so much easier when coding in Python, but there is a catch; raw generators are not thread-safe. Consider the example below:
# implement a simple generator and show that it is not thread-safe
import multiprocessing.pool as mp
def simple_generator(n):
result = 0
while True:
if result >= n:
result = 0
yield result
result += 1
def task(gen):
for _ in range(10):
print next(gen)
if __name__ == '__main__':
# single thread
gen = simple_generator(100)
for _ in range(10):
task(gen)
# multi-threads
gen = simple_generator(100)
pool = mp.ThreadPool(4)
for _ in range(10):
pool.apply_async(task, (gen,))
pool.close()
pool.join()
view raw generator1.py hosted with ❤ by GitHub

We see that the generator does not produce correct output when multiple threads are accessing this at the same time.

One easy way to make it thread-safe is by creating a wrapper class that simply lets only one thread to execute the generator's next method at any given time with threading lock. This is shown below:
# implement a simple generator and add thread-safe support
import threading
import multiprocessing.pool as mp
class thread_safe_generator(object):
def __init__(self, gen):
self.gen = gen
self.lock = threading.Lock()
def next(self):
with self.lock:
return next(self.gen)
def simple_generator(n):
result = 0
while True:
if result >= n:
result = 0
yield result
result += 1
def task(gen):
for _ in range(10):
print next(gen)
if __name__ == '__main__':
# single thread
gen = thread_safe_generator(simple_generator(100))
for _ in range(10):
task(gen)
# multi-threads
gen = thread_safe_generator(simple_generator(100))
pool = mp.ThreadPool(4)
for _ in range(10):
pool.apply_async(task, (gen,))
pool.close()
pool.join()
view raw generator2.py hosted with ❤ by GitHub

Note that the generator now is thread-safe but doesn't execute its next method in parallel. You can also use Python's decorator to make it look even easier, although it basically does the same thing.
# implement a simple generator and add thread-safe support
import threading
import multiprocessing.pool as mp
class thread_safe_generator(object):
def __init__(self, gen):
self.gen = gen
self.lock = threading.Lock()
def next(self):
with self.lock:
return next(self.gen)
def thread_safe(f):
def g(*a, **kw):
return thread_safe_generator(f(*a, **kw))
return g
@thread_safe
def simple_generator(n):
result = 0
while True:
if result >= n:
result = 0
yield result
result += 1
def task(gen):
for _ in range(10):
print next(gen)
if __name__ == '__main__':
# single thread
gen = simple_generator(100)
for _ in range(10):
task(gen)
# multi-threads
gen = simple_generator(100)
pool = mp.ThreadPool(4)
for _ in range(10):
pool.apply_async(task, (gen,))
pool.close()
pool.join()
view raw generator3.py hosted with ❤ by GitHub

2 comments:

  1. This code runs in Python2, but not in Python3.
    In Python3, you get this error: TypeError: 'thread_safe_generator' object is not an iterator

    ReplyDelete
    Replies
    1. I fixed it. Just replace thread_safe_generator.next with thread_safe_generator.__next__. Python3 needs the double underscore.

      Delete