When multiple threads or processes operate concurrently, especially in complex machine learning workflows, they often need to access shared resources or coordinate their actions. Without proper coordination, you risk encountering race conditions, where the outcome depends unpredictably on the timing of operations, leading to corrupted data, incorrect model states, or inconsistent results. Python's threading
and multiprocessing
modules provide several synchronization primitives to manage concurrent access and ensure predictable behavior. Understanding these tools is essential for building reliable parallel ML systems.
We will focus on three fundamental primitives: Locks, Semaphores, and Events.
A Lock, often called a mutex (mutual exclusion), is the simplest synchronization primitive. It provides a mechanism to ensure that only one thread or process can execute a specific block of code, known as the critical section, at any given time.
Concept: A Lock has two primary states: locked and unlocked. It starts unlocked. The acquire()
method attempts to obtain the lock. If the lock is unlocked, the calling thread/process acquires it and the state changes to locked. If the lock is already held by another thread/process, acquire()
blocks until the lock is released using the release()
method.
Use Cases in ML:
Python Implementation:
Both threading.Lock
and multiprocessing.Lock
offer the same interface. Using the with
statement is the preferred way to handle locks, as it automatically acquires the lock upon entering the block and releases it upon exiting, even if errors occur.
import threading
import time
import random
# Example: Simulating shared resource update (e.g., aggregating results)
shared_results = {}
results_lock = threading.Lock()
num_workers = 5
num_items_per_worker = 10
def worker_task(worker_id):
"""Simulates a worker processing items and updating shared results."""
local_sum = 0
for i in range(num_items_per_worker):
# Simulate work
time.sleep(random.uniform(0.01, 0.05))
local_sum += i
# Critical Section: Update shared results
print(f"Worker {worker_id} waiting to acquire lock...")
with results_lock:
print(f"Worker {worker_id} acquired lock.")
# Simulate update logic that needs to be atomic
current_total = shared_results.get('total_sum', 0)
time.sleep(random.uniform(0.02, 0.06)) # Simulate time spent holding lock
shared_results['total_sum'] = current_total + local_sum
shared_results[f'worker_{worker_id}_sum'] = local_sum
print(f"Worker {worker_id} released lock. New total: {shared_results['total_sum']}")
print(f"Worker {worker_id} finished.")
threads = []
for i in range(num_workers):
thread = threading.Thread(target=worker_task, args=(i,))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
print("\nFinal Shared Results:")
print(shared_results)
# Calculate expected total sum
expected_sum = num_workers * sum(range(num_items_per_worker))
print(f"\nExpected Total Sum: {expected_sum}")
print(f"Actual Total Sum: {shared_results.get('total_sum', 'Not Found')}")
# Verify individual worker sums add up
calculated_total = sum(v for k, v in shared_results.items() if k.startswith('worker_'))
print(f"Sum of individual worker sums: {calculated_total}")
assert calculated_total == shared_results.get('total_sum'), "Mismatch in sums!"
Reentrant Locks (RLock
): A standard Lock
cannot be acquired more than once by the same thread/process, even if it already holds the lock. Attempting to do so results in a deadlock. A threading.RLock
or multiprocessing.RLock
(Reentrant Lock) can be acquired multiple times by the same thread/process. It keeps track of the acquisition count and is only fully released when release()
has been called the same number of times as acquire()
. This is useful in complex functions that might call other functions which also need to acquire the same lock.
Deadlocks: A common issue with locks occurs when two or more threads/processes are blocked forever, each waiting for a lock held by the other. For example, Thread A holds Lock 1 and waits for Lock 2, while Thread B holds Lock 2 and waits for Lock 1. Careful design, such as always acquiring locks in a consistent global order, can prevent deadlocks.
A simple deadlock scenario where Thread A holds Lock 1 and needs Lock 2, while Thread B holds Lock 2 and needs Lock 1.
A Semaphore is a more generalized synchronization primitive than a Lock. It maintains an internal counter which is decremented by each acquire()
call and incremented by each release()
call. If the counter is zero, acquire()
blocks until another thread/process calls release()
.
Concept: Semaphores are typically initialized with a value greater than zero. This value represents the number of threads/processes that can simultaneously access a limited resource or execute a particular section of code. A Semaphore initialized with 1 behaves exactly like a Lock.
Use Cases in ML:
Python Implementation:
threading.Semaphore
and multiprocessing.Semaphore
take an optional initial value
argument.
import threading
import time
import random
# Example: Limiting concurrent downloads
max_concurrent_downloads = 3
download_semaphore = threading.Semaphore(value=max_concurrent_downloads)
num_files_to_download = 10
def download_file(file_id):
"""Simulates downloading a file, limited by the semaphore."""
print(f"File {file_id}: Waiting to acquire semaphore...")
with download_semaphore: # Acquire semaphore, block if count is 0
print(f"File {file_id}: Acquired semaphore. Starting download...")
# Simulate download time
download_duration = random.uniform(0.5, 1.5)
time.sleep(download_duration)
print(f"File {file_id}: Download finished in {download_duration:.2f}s. Releasing semaphore.")
# Semaphore automatically released upon exiting 'with' block
threads = []
print(f"Attempting to download {num_files_to_download} files with max {max_concurrent_downloads} concurrent downloads.")
for i in range(num_files_to_download):
thread = threading.Thread(target=download_file, args=(i,))
threads.append(thread)
thread.start()
time.sleep(0.1) # Stagger thread starts slightly
for thread in threads:
thread.join()
print("\nAll file download tasks completed.")
In this example, even though we start 10 threads quickly, only a maximum of 3 will execute the "download" part (the critical section protected by the semaphore) concurrently.
Bounded Semaphores: threading.BoundedSemaphore
is similar to Semaphore
, but it raises a ValueError
if release()
is called more times than acquire()
, potentially incrementing the counter beyond its initial value. This helps catch programming errors where resources are released incorrectly.
An Event object manages an internal flag that can be set (set()
) or cleared (clear()
). Threads/processes can wait for this flag to be set (wait()
). Events are a simple yet effective way to coordinate actions between concurrent tasks.
Concept: An Event starts with its internal flag cleared (False). Calling wait()
blocks until the flag becomes True (set by another thread/process calling set()
). If wait()
is called when the flag is already True, it returns immediately. is_set()
checks the flag's status without blocking. clear()
resets the flag to False.
Use Cases in ML:
Python Implementation:
threading.Event
and multiprocessing.Event
provide the core methods: set()
, clear()
, wait(timeout=None)
, is_set()
.
import threading
import time
import random
# Example: Coordinating data loading and processing
data_ready_event = threading.Event()
processing_finished_event = threading.Event()
shared_data = None
def data_loader():
"""Simulates loading data."""
global shared_data
print("Loader: Starting data load...")
load_time = random.uniform(1, 3)
time.sleep(load_time)
shared_data = list(range(10)) # Simulate loaded data
print(f"Loader: Data loaded in {load_time:.2f}s. Signaling data_ready.")
data_ready_event.set() # Signal that data is ready
def data_processor():
"""Waits for data and then processes it."""
print("Processor: Waiting for data...")
data_ready_event.wait() # Block until data_ready_event is set
print("Processor: Data received. Starting processing...")
if shared_data is not None:
# Simulate processing
processed_sum = sum(item * 2 for item in shared_data)
process_time = random.uniform(0.5, 1.5)
time.sleep(process_time)
print(f"Processor: Processing finished in {process_time:.2f}s. Result: {processed_sum}")
else:
print("Processor: Error! Shared data is None.")
print("Processor: Signaling processing finished.")
processing_finished_event.set()
loader_thread = threading.Thread(target=data_loader)
processor_thread = threading.Thread(target=data_processor)
loader_thread.start()
processor_thread.start()
# Main thread can wait for processing to complete
print("Main: Waiting for processing to finish...")
processing_finished_event.wait(timeout=10) # Wait with a timeout
if processing_finished_event.is_set():
print("Main: Processor has finished.")
else:
print("Main: Timeout waiting for processor.")
loader_thread.join()
processor_thread.join()
print("Main: All threads joined. Workflow complete.")
Here, the data_processor
thread waits efficiently using data_ready_event.wait()
instead of polling, only proceeding once the data_loader
signals completion via data_ready_event.set()
. The main thread similarly waits for the processor using another event.
threading.Lock
, multiprocessing.Lock
) when you need to ensure only one thread/process can access a critical section or shared resource at a time (mutual exclusion).threading.Semaphore
, multiprocessing.Semaphore
) when you need to limit the number of concurrent threads/processes accessing a resource or executing code, allowing up to N concurrent accesses (where N is the semaphore's initial value).threading.Event
, multiprocessing.Event
) when you need a simple mechanism for one or more threads/processes to wait for a signal or condition triggered by another thread/process.threading
vs. multiprocessing
: Remember that primitives from threading
work for coordinating threads within the same process, while primitives from multiprocessing
are needed to coordinate separate processes, as they handle the necessary inter-process communication mechanisms.Synchronization primitives are fundamental tools for managing shared state and coordinating tasks in concurrent Python applications. Applying them correctly is essential for harnessing the power of parallelism in machine learning pipelines while maintaining correctness and avoiding subtle bugs like race conditions and deadlocks.
© 2025 ApX Machine Learning