Generator makes life so much easier when coding in Python, but there is a catch; raw generators are not thread-safe. Consider the example below:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# implement a simple generator and show that it is not thread-safe | |
import multiprocessing.pool as mp | |
def simple_generator(n): | |
result = 0 | |
while True: | |
if result >= n: | |
result = 0 | |
yield result | |
result += 1 | |
def task(gen): | |
for _ in range(10): | |
print next(gen) | |
if __name__ == '__main__': | |
# single thread | |
gen = simple_generator(100) | |
for _ in range(10): | |
task(gen) | |
# multi-threads | |
gen = simple_generator(100) | |
pool = mp.ThreadPool(4) | |
for _ in range(10): | |
pool.apply_async(task, (gen,)) | |
pool.close() | |
pool.join() |
We see that the generator does not produce correct output when multiple threads are accessing this at the same time.
One easy way to make it thread-safe is by creating a wrapper class that simply lets only one thread to execute the generator's next method at any given time with threading lock. This is shown below:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# implement a simple generator and add thread-safe support | |
import threading | |
import multiprocessing.pool as mp | |
class thread_safe_generator(object): | |
def __init__(self, gen): | |
self.gen = gen | |
self.lock = threading.Lock() | |
def next(self): | |
with self.lock: | |
return next(self.gen) | |
def simple_generator(n): | |
result = 0 | |
while True: | |
if result >= n: | |
result = 0 | |
yield result | |
result += 1 | |
def task(gen): | |
for _ in range(10): | |
print next(gen) | |
if __name__ == '__main__': | |
# single thread | |
gen = thread_safe_generator(simple_generator(100)) | |
for _ in range(10): | |
task(gen) | |
# multi-threads | |
gen = thread_safe_generator(simple_generator(100)) | |
pool = mp.ThreadPool(4) | |
for _ in range(10): | |
pool.apply_async(task, (gen,)) | |
pool.close() | |
pool.join() |
Note that the generator now is thread-safe but doesn't execute its next method in parallel. You can also use Python's decorator to make it look even easier, although it basically does the same thing.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# implement a simple generator and add thread-safe support | |
import threading | |
import multiprocessing.pool as mp | |
class thread_safe_generator(object): | |
def __init__(self, gen): | |
self.gen = gen | |
self.lock = threading.Lock() | |
def next(self): | |
with self.lock: | |
return next(self.gen) | |
def thread_safe(f): | |
def g(*a, **kw): | |
return thread_safe_generator(f(*a, **kw)) | |
return g | |
@thread_safe | |
def simple_generator(n): | |
result = 0 | |
while True: | |
if result >= n: | |
result = 0 | |
yield result | |
result += 1 | |
def task(gen): | |
for _ in range(10): | |
print next(gen) | |
if __name__ == '__main__': | |
# single thread | |
gen = simple_generator(100) | |
for _ in range(10): | |
task(gen) | |
# multi-threads | |
gen = simple_generator(100) | |
pool = mp.ThreadPool(4) | |
for _ in range(10): | |
pool.apply_async(task, (gen,)) | |
pool.close() | |
pool.join() |
This code runs in Python2, but not in Python3.
ReplyDeleteIn Python3, you get this error: TypeError: 'thread_safe_generator' object is not an iterator
I fixed it. Just replace thread_safe_generator.next with thread_safe_generator.__next__. Python3 needs the double underscore.
Delete