diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 68add6ea3c1a..92ddda428bee 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -27,6 +27,7 @@ import logging import typing +import more_itertools from google.protobuf import wrappers_pb2 import apache_beam as beam @@ -56,6 +57,8 @@ __all__ = ['BundleBasedDirectRunner', 'DirectRunner', 'SwitchingDirectRunner'] _LOGGER = logging.getLogger(__name__) +K = typing.TypeVar('K') +V = typing.TypeVar('V') class SwitchingDirectRunner(PipelineRunner): @@ -263,6 +266,29 @@ def expand(self, pcoll): return PCollection.from_(pcoll) +@typehints.with_input_types(typing.Tuple[K, V]) +@typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) +class _GroupIntoBatches(PTransform): + """ + Non-timer based implementation of GroupIntoBatches. + """ + def __init__(self, batch_size: int): + if batch_size <= 0: + raise ValueError("batch_size must be a positive integer.") + self.batch_size = batch_size + + def expand(self, pcoll): + return ( + pcoll | beam.GroupByKey() | "BatchGroupedValues" >> beam.FlatMap( + self._batch_elements, batch_size=self.batch_size) + | beam.Reshuffle()) + + @staticmethod + def _batch_elements(key_values: tuple, batch_size: int): + k, values = key_values + return ((k, batch) for batch in more_itertools.batched(values, batch_size)) + + @typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]]) @typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]]) class _GroupAlsoByWindow(ParDo): @@ -472,10 +498,25 @@ def get_replacement_transform_for_applied_ptransform( self, applied_ptransform): return _GroupByKey() + class GroupIntoBatchesOverride(PTransformOverride): + """A ``PTransformOverride`` for ``GroupIntoBatches``. + + This replaces the Beam implementation as a primitive. + """ + def matches(self, applied_ptransform): + # Imported here to avoid circular dependencies. + # pylint: disable=wrong-import-order, wrong-import-position + from apache_beam.transforms.util import GroupIntoBatches + return isinstance(applied_ptransform.transform, GroupIntoBatches) + + def get_replacement_transform(self, ptransform): + return _GroupIntoBatches(ptransform.params.batch_size) + overrides = [ # This needs to be the first and the last override. Other overrides depend # on the GroupByKey implementation to be composed of _GroupByKeyOnly and # _GroupAlsoByWindow. + GroupIntoBatchesOverride(), GroupByKeyPTransformOverride(), SplittableParDoOverride(), ProcessKeyedElementsViaKeyedWorkItemsOverride(), diff --git a/sdks/python/apache_beam/runners/direct/direct_runner_test.py b/sdks/python/apache_beam/runners/direct/direct_runner_test.py index 008a1bd47215..c8d21db95e9b 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner_test.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner_test.py @@ -35,6 +35,7 @@ from apache_beam.runners import DirectRunner from apache_beam.runners import TestDirectRunner from apache_beam.runners import create_runner +from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner from apache_beam.runners.direct.evaluation_context import _ExecutionContext from apache_beam.runners.direct.transform_evaluator import _GroupByKeyOnlyEvaluator from apache_beam.runners.direct.transform_evaluator import _TransformEvaluator @@ -166,6 +167,14 @@ def test_impulse(self): with test_pipeline.TestPipeline(runner='BundleBasedDirectRunner') as p: assert_that(p | beam.Impulse(), equal_to([b''])) + def test_groupintobatches(self): + with beam.Pipeline(runner=BundleBasedDirectRunner()) as p: + groups = ( + p + | "Create input" >> beam.Create([(0, 0)]) + | "Batch in groups" >> beam.GroupIntoBatches(5)) + assert_that(groups, equal_to([(0, (0, ))])) + class DirectRunnerRetryTests(unittest.TestCase): def test_retry_fork_graph(self):