Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,15 @@
import typing
import unittest
import uuid
from typing import Any
from typing import Dict
from typing import Iterator
from typing import List
from typing import Tuple
from typing import no_type_check
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Tuple, no_type_check

import hamcrest # pylint: disable=ungrouped-imports
import numpy as np
import pytest
from hamcrest.core.matcher import Matcher
from hamcrest.core.string_description import StringDescription
from tenacity import retry
from tenacity import stop_after_attempt
from tenacity import retry, stop_after_attempt

import apache_beam as beam
from apache_beam.coders import coders
Expand All @@ -53,31 +48,22 @@
from apache_beam.metrics import monitoring_infos
from apache_beam.metrics.execution import MetricKey
from apache_beam.metrics.metricbase import MetricName
from apache_beam.options.pipeline_options import DebugOptions
from apache_beam.options.pipeline_options import DirectOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import (
DebugOptions, DirectOptions, PipelineOptions, StandardOptions)
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.portability import python_urns
from apache_beam.runners.portability import fn_api_runner
from apache_beam.runners.portability.fn_api_runner import fn_runner
from apache_beam.runners.sdf_utils import RestrictionTrackerView
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import statesampler
from apache_beam.runners.worker import data_plane, statesampler
from apache_beam.runners.worker.operations import InefficientExecutionWarning
from apache_beam.testing.synthetic_pipeline import SyntheticSDFAsSource
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.testing.util import has_at_least_one
from apache_beam.testing.util import assert_that, equal_to, has_at_least_one
from apache_beam.tools import utils
from apache_beam.transforms import environments
from apache_beam.transforms import trigger
from apache_beam.transforms import userstate
from apache_beam.transforms import window
from apache_beam.transforms import environments, trigger, userstate, window
from apache_beam.transforms.periodicsequence import PeriodicImpulse
from apache_beam.utils import timestamp
from apache_beam.utils import windowed_value
from apache_beam.utils import timestamp, windowed_value

if statesampler.FAST_SAMPLER:
DEFAULT_SAMPLING_PERIOD_MS = statesampler.DEFAULT_SAMPLING_PERIOD_MS
Expand All @@ -87,6 +73,46 @@
_LOGGER = logging.getLogger(__name__)


@contextmanager
def patch_portable_runner_for_test():
captured = {}

orig_excepthook = getattr(threading, "excepthook", None)

def _capture_excepthook(args):
captured.setdefault("exc", args.exc_value)

if orig_excepthook is not None:
threading.excepthook = _capture_excepthook

orig_pipeline_run = beam.Pipeline.run

def wrapped_pipeline_run(pipeline_self, *a, **kw):
result = orig_pipeline_run(pipeline_self, *a, **kw)
if hasattr(result, "wait_until_finish"):
orig_wait = result.wait_until_finish

def wrapped_wait(*wa, **wk):
try:
return orig_wait(*wa, **wk)
finally:
exc = captured.get("exc")
if exc:
raise exc

result.wait_until_finish = wrapped_wait
return result

beam.Pipeline.run = wrapped_pipeline_run

try:
yield
finally:
beam.Pipeline.run = orig_pipeline_run
if orig_excepthook is not None:
threading.excepthook = orig_excepthook


def _matcher_or_equal_to(value_or_matcher):
"""Pass-thru for matchers, and wraps value inputs in an equal_to matcher."""
if value_or_matcher is None:
Expand All @@ -112,9 +138,10 @@ def create_pipeline(self, is_drain=False):
def test_assert_that(self):
# TODO: figure out a way for fn_api_runner to parse and raise the
# underlying exception.
with self.assertRaisesRegex(Exception, 'Failed assert'):
with self.create_pipeline() as p:
assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))
with patch_portable_runner_for_test(): # pylint: disable=not-context-manager
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - this is a good find.

With that said, I think that fixing the test is probably the wrong move here. If this test is flaky, it indicates that the underlying runner is not doing the right thing (the runner should fail this test everytime without any patches). So I think we need to address the behavior on the underlying runner, not in the test.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree that the issue isn't the test.

I changed PortableRunner.wait_until_finish() so it re-raises worker errors instead of just logging them on the message stream. Before, failures could get buried in logs and the timing made the run look “successful” sometimes, which caused the flake. Now the error bubbles up, the job fails deterministically, and the test passes without any test changes.

The changes are in portable_runner.py, plus a tiny companion tweak in local_job_service.py.

with self.assertRaisesRegex(Exception, 'Failed assert'):
with self.create_pipeline() as p:
assert_that(p | beam.Create(['a', 'b']), equal_to(['a']))

def test_create(self):
with self.create_pipeline() as p:
Expand Down
Loading