Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions sdks/python/apache_beam/runners/direct/direct_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import logging
import typing

import more_itertools
from google.protobuf import wrappers_pb2

import apache_beam as beam
Expand Down Expand Up @@ -56,6 +57,8 @@
__all__ = ['BundleBasedDirectRunner', 'DirectRunner', 'SwitchingDirectRunner']

_LOGGER = logging.getLogger(__name__)
K = typing.TypeVar('K')
V = typing.TypeVar('V')


class SwitchingDirectRunner(PipelineRunner):
Expand Down Expand Up @@ -263,6 +266,29 @@ def expand(self, pcoll):
return PCollection.from_(pcoll)


@typehints.with_input_types(typing.Tuple[K, V])
@typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]])
class _GroupIntoBatches(PTransform):
"""
Non-timer based implementation of GroupIntoBatches.
"""
def __init__(self, batch_size: int):
if batch_size <= 0:
raise ValueError("batch_size must be a positive integer.")
self.batch_size = batch_size

def expand(self, pcoll):
return (
pcoll | beam.GroupByKey() | "BatchGroupedValues" >> beam.FlatMap(
self._batch_elements, batch_size=self.batch_size)
| beam.Reshuffle())

@staticmethod
def _batch_elements(key_values: tuple, batch_size: int):
k, values = key_values
return ((k, batch) for batch in more_itertools.batched(values, batch_size))


@typehints.with_input_types(typing.Tuple[K, typing.Iterable[V]])
@typehints.with_output_types(typing.Tuple[K, typing.Iterable[V]])
class _GroupAlsoByWindow(ParDo):
Expand Down Expand Up @@ -472,10 +498,25 @@ def get_replacement_transform_for_applied_ptransform(
self, applied_ptransform):
return _GroupByKey()

class GroupIntoBatchesOverride(PTransformOverride):
"""A ``PTransformOverride`` for ``GroupIntoBatches``.

This replaces the Beam implementation as a primitive.
"""
def matches(self, applied_ptransform):
# Imported here to avoid circular dependencies.
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.transforms.util import GroupIntoBatches
return isinstance(applied_ptransform.transform, GroupIntoBatches)

def get_replacement_transform(self, ptransform):
return _GroupIntoBatches(ptransform.params.batch_size)

overrides = [
# This needs to be the first and the last override. Other overrides depend
# on the GroupByKey implementation to be composed of _GroupByKeyOnly and
# _GroupAlsoByWindow.
GroupIntoBatchesOverride(),
GroupByKeyPTransformOverride(),
SplittableParDoOverride(),
ProcessKeyedElementsViaKeyedWorkItemsOverride(),
Expand Down
9 changes: 9 additions & 0 deletions sdks/python/apache_beam/runners/direct/direct_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from apache_beam.runners import DirectRunner
from apache_beam.runners import TestDirectRunner
from apache_beam.runners import create_runner
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.direct.evaluation_context import _ExecutionContext
from apache_beam.runners.direct.transform_evaluator import _GroupByKeyOnlyEvaluator
from apache_beam.runners.direct.transform_evaluator import _TransformEvaluator
Expand Down Expand Up @@ -166,6 +167,14 @@ def test_impulse(self):
with test_pipeline.TestPipeline(runner='BundleBasedDirectRunner') as p:
assert_that(p | beam.Impulse(), equal_to([b'']))

def test_groupintobatches(self):
with beam.Pipeline(runner=BundleBasedDirectRunner()) as p:
groups = (
p
| "Create input" >> beam.Create([(0, 0)])
| "Batch in groups" >> beam.GroupIntoBatches(5))
assert_that(groups, equal_to([(0, (0, ))]))


class DirectRunnerRetryTests(unittest.TestCase):
def test_retry_fork_graph(self):
Expand Down
Loading