Skip to content

Commit d4d8fd8

Browse files
authored
Update async_dofn.py (#35924)
1 parent 3e59ea9 commit d4d8fd8

1 file changed

Lines changed: 13 additions & 6 deletions

File tree

sdks/python/apache_beam/transforms/async_dofn.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
3434
from apache_beam.transforms.userstate import TimerSpec
3535
from apache_beam.transforms.userstate import on_timer
36+
from apache_beam.utils.shared import Shared
3637
from apache_beam.utils.timestamp import Duration
3738
from apache_beam.utils.timestamp import Timestamp
3839

@@ -114,11 +115,17 @@ def __init__(
114115
self.timer_frequency_ = callback_frequency
115116
self.parallelism_ = parallelism
116117
self._next_time_to_fire = Timestamp.now() + Duration(seconds=5)
118+
self._shared_handle = Shared()
119+
120+
@staticmethod
121+
def initialize_pool(parallelism):
122+
return lambda: ThreadPoolExecutor(max_workers=parallelism)
117123

118124
@staticmethod
119125
def reset_state():
120126
for pool in AsyncWrapper._pool.values():
121-
pool.shutdown(wait=True, cancel_futures=True)
127+
pool.acquire(AsyncWrapper.initialize_pool(1)).shutdown(
128+
wait=True, cancel_futures=True)
122129
with AsyncWrapper._lock:
123130
AsyncWrapper._pool = {}
124131
AsyncWrapper._processing_elements = {}
@@ -129,8 +136,7 @@ def setup(self):
129136
self._sync_fn.setup()
130137
with AsyncWrapper._lock:
131138
if not self._uuid in AsyncWrapper._pool:
132-
AsyncWrapper._pool[self._uuid] = ThreadPoolExecutor(
133-
max_workers=self._parallelism)
139+
AsyncWrapper._pool[self._uuid] = Shared()
134140
AsyncWrapper._processing_elements[self._uuid] = {}
135141
AsyncWrapper._items_in_buffer[self._uuid] = 0
136142

@@ -202,9 +208,10 @@ def schedule_if_room(self, element, ignore_buffer=False, *args, **kwargs):
202208
logging.info('item %s already in processing elements', element)
203209
return True
204210
if self.accepting_items() or ignore_buffer:
205-
result = AsyncWrapper._pool[self._uuid].submit(
206-
lambda: self.sync_fn_process(element, *args, **kwargs),
207-
)
211+
result = AsyncWrapper._pool[self._uuid].acquire(
212+
AsyncWrapper.initialize_pool(self._parallelism)).submit(
213+
lambda: self.sync_fn_process(element, *args, **kwargs),
214+
)
208215
result.add_done_callback(self.decrement_items_in_buffer)
209216
AsyncWrapper._processing_elements[self._uuid][element] = result
210217
AsyncWrapper._items_in_buffer[self._uuid] += 1

0 commit comments

Comments
 (0)