diff --git a/.github/trigger_files/beam_PostCommit_XVR_JavaUsingPython_Dataflow.json b/.github/trigger_files/beam_PostCommit_XVR_JavaUsingPython_Dataflow.json
new file mode 100644
index 000000000000..6a55e29ae15d
--- /dev/null
+++ b/.github/trigger_files/beam_PostCommit_XVR_JavaUsingPython_Dataflow.json
@@ -0,0 +1,4 @@
+{
+ "comment": "Modify this file in a trivial way to cause this test suite to run.",
+ "modification": 1
+}
\ No newline at end of file
diff --git a/sdks/java/build-tools/src/main/resources/beam/checkstyle/suppressions.xml b/sdks/java/build-tools/src/main/resources/beam/checkstyle/suppressions.xml
index 52e8467b1624..53cd7b7ad4d0 100644
--- a/sdks/java/build-tools/src/main/resources/beam/checkstyle/suppressions.xml
+++ b/sdks/java/build-tools/src/main/resources/beam/checkstyle/suppressions.xml
@@ -60,6 +60,7 @@
+
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ValidateRunnerXlangTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ValidateRunnerXlangTest.java
index c41b2151d4cc..06288c07dbff 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ValidateRunnerXlangTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ValidateRunnerXlangTest.java
@@ -17,17 +17,29 @@
*/
package org.apache.beam.sdk.util.construction;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+
+import com.google.cloud.secretmanager.v1.ProjectName;
+import com.google.cloud.secretmanager.v1.SecretManagerServiceClient;
+import com.google.cloud.secretmanager.v1.SecretName;
+import com.google.cloud.secretmanager.v1.SecretPayload;
+import com.google.protobuf.ByteString;
import java.io.IOException;
import java.io.Serializable;
+import java.security.SecureRandom;
import java.util.Arrays;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.UsesJavaExpansionService;
import org.apache.beam.sdk.testing.UsesPythonExpansionService;
import org.apache.beam.sdk.testing.ValidatesRunner;
@@ -42,8 +54,13 @@
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -286,6 +303,118 @@ public void test() {
}
}
+ /**
+ * Motivation behind GroupByKeyWithGbekTest.
+ *
+ *
Target transform – GroupByKey
+ * (https://beam.apache.org/documentation/programming-guide/#groupbykey) Test scenario – Grouping
+ * a collection of KV to a collection of KV> by key Boundary conditions
+ * checked – –> PCollection> to external transforms –> PCollection>>
+ * from external transforms while using GroupByEncryptedKey overrides
+ */
+ @RunWith(JUnit4.class)
+ public static class GroupByKeyWithGbekTest extends ValidateRunnerXlangTestBase {
+ @Rule public ExpectedException thrown = ExpectedException.none();
+ private static final String PROJECT_ID = "apache-beam-testing";
+ private static final String SECRET_ID = "gbek-test";
+ private static String gcpSecretVersionName;
+ private static String secretId;
+
+ @BeforeClass
+ public static void setUpClass() {
+ secretId = String.format("%s-%d", SECRET_ID, new SecureRandom().nextInt(10000));
+ try (SecretManagerServiceClient client = SecretManagerServiceClient.create()) {
+ ProjectName projectName = ProjectName.of(PROJECT_ID);
+ SecretName secretName = SecretName.of(PROJECT_ID, secretId);
+
+ try {
+ client.getSecret(secretName);
+ } catch (Exception e) {
+ com.google.cloud.secretmanager.v1.Secret secret =
+ com.google.cloud.secretmanager.v1.Secret.newBuilder()
+ .setReplication(
+ com.google.cloud.secretmanager.v1.Replication.newBuilder()
+ .setAutomatic(
+ com.google.cloud.secretmanager.v1.Replication.Automatic.newBuilder()
+ .build())
+ .build())
+ .build();
+ client.createSecret(projectName, secretId, secret);
+ byte[] secretBytes = new byte[32];
+ new SecureRandom().nextBytes(secretBytes);
+ client.addSecretVersion(
+ secretName,
+ SecretPayload.newBuilder()
+ .setData(
+ ByteString.copyFrom(java.util.Base64.getUrlEncoder().encode(secretBytes)))
+ .build());
+ }
+ gcpSecretVersionName = secretName.toString() + "/versions/latest";
+ } catch (IOException e) {
+ gcpSecretVersionName = null;
+ return;
+ }
+ expansionAddr =
+ String.format("localhost:%s", Integer.valueOf(System.getProperty("expansionPort")));
+ }
+
+ @AfterClass
+ public static void tearDownClass() {
+ if (gcpSecretVersionName != null) {
+ try (SecretManagerServiceClient client = SecretManagerServiceClient.create()) {
+ SecretName secretName = SecretName.of(PROJECT_ID, secretId);
+ client.deleteSecret(secretName);
+ } catch (IOException e) {
+ // Do nothing.
+ }
+ }
+ }
+
+ @After
+ @Override
+ public void tearDown() {
+ // Override tearDown since we're doing our own assertion instead of relying on base class
+ // assertions
+ }
+
+ @Test
+ @Category({
+ ValidatesRunner.class,
+ UsesJavaExpansionService.class,
+ UsesPythonExpansionService.class
+ })
+ public void test() {
+ if (gcpSecretVersionName == null) {
+ // Skip test if we couldn't set up secret manager
+ return;
+ }
+ PipelineOptions options = TestPipeline.testingPipelineOptions();
+ options.setGbek(String.format("type:gcpsecret;version_name:%s", gcpSecretVersionName));
+ Pipeline pipeline = Pipeline.create(options);
+ groupByKeyTest(pipeline);
+ PipelineResult pipelineResult = pipeline.run();
+ pipelineResult.waitUntilFinish();
+ assertThat(pipelineResult.getState(), equalTo(PipelineResult.State.DONE));
+ }
+
+ @Test
+ @Category({
+ ValidatesRunner.class,
+ UsesJavaExpansionService.class,
+ UsesPythonExpansionService.class
+ })
+ public void testFailure() {
+ thrown.expect(Exception.class);
+ PipelineOptions options = TestPipeline.testingPipelineOptions();
+ options.setGbek("version_name:fake_secret");
+ Pipeline pipeline = Pipeline.create(options);
+ groupByKeyTest(pipeline);
+ PipelineResult pipelineResult = pipeline.run();
+ pipelineResult.waitUntilFinish();
+ assertThat(pipelineResult.getState(), equalTo(PipelineResult.State.DONE));
+ }
+ }
+
/**
* Motivation behind coGroupByKeyTest.
*
diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py
index 3fc5151156f1..247daebd2c6d 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -64,6 +64,11 @@
# that have a destination(dest) in parser.add_argument() different
# from the flag name and whose default value is `None`.
_FLAG_THAT_SETS_FALSE_VALUE = {'use_public_ips': 'no_use_public_ips'}
+# Set of options which should not be overriden when applying options from a
+# different language. This is relevant when using x-lang transforms where the
+# expansion service is started up with some pipeline options, and will
+# impact which options are passed in to expanded transforms' expand functions.
+_NON_OVERIDABLE_XLANG_OPTIONS = ['runner', 'experiments']
def _static_value_provider_of(value_type):
@@ -287,6 +292,10 @@ def _smart_split(self, values):
class PipelineOptions(HasDisplayData):
+ # Set of options which should not be overriden when pipeline options are
+ # being merged (see from_runner_api). This primarily comes up when expanding
+ # the Python expansion service
+
"""This class and subclasses are used as containers for command line options.
These classes are wrappers over the standard argparse Python module
@@ -592,15 +601,19 @@ def to_struct_value(o):
})
@classmethod
- def from_runner_api(cls, proto_options):
+ def from_runner_api(cls, proto_options, original_options=None):
def from_urn(key):
assert key.startswith('beam:option:')
assert key.endswith(':v1')
return key[12:-3]
- return cls(
- **{from_urn(key): value
- for (key, value) in proto_options.items()})
+ parsed = {from_urn(key): value for (key, value) in proto_options.items()}
+ if original_options is None:
+ return cls(**parsed)
+ for (key, value) in parsed.items():
+ if value and key not in _NON_OVERIDABLE_XLANG_OPTIONS:
+ original_options._all_options[key] = value
+ return original_options
def display_data(self):
return self.get_all_options(drop_default=True, retain_unknown_options=True)
diff --git a/sdks/python/apache_beam/options/pipeline_options_test.py b/sdks/python/apache_beam/options/pipeline_options_test.py
index cd6cce204b78..b9c2061744b8 100644
--- a/sdks/python/apache_beam/options/pipeline_options_test.py
+++ b/sdks/python/apache_beam/options/pipeline_options_test.py
@@ -34,6 +34,7 @@
from apache_beam.options.pipeline_options import JobServerOptions
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import ProfilingOptions
+from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.options.pipeline_options import WorkerOptions
from apache_beam.options.pipeline_options import _BeamArgumentParser
@@ -308,6 +309,26 @@ def _add_argparse_args(cls, parser):
self.assertEqual(result['test_arg_int'], 5)
self.assertEqual(result['test_arg_none'], None)
+ def test_merging_options(self):
+ opts = PipelineOptions(flags=['--num_workers', '5'])
+ actual_opts = PipelineOptions.from_runner_api(opts.to_runner_api())
+ actual = actual_opts.view_as(WorkerOptions).num_workers
+ self.assertEqual(5, actual)
+
+ def test_merging_options_with_overriden_options(self):
+ opts = PipelineOptions(flags=['--num_workers', '5'])
+ base = PipelineOptions(flags=['--num_workers', '2'])
+ actual_opts = PipelineOptions.from_runner_api(opts.to_runner_api(), base)
+ actual = actual_opts.view_as(WorkerOptions).num_workers
+ self.assertEqual(5, actual)
+
+ def test_merging_options_with_overriden_runner(self):
+ opts = PipelineOptions(flags=['--runner', 'FnApiRunner'])
+ base = PipelineOptions(flags=['--runner', 'Direct'])
+ actual_opts = PipelineOptions.from_runner_api(opts.to_runner_api(), base)
+ actual = actual_opts.view_as(StandardOptions).runner
+ self.assertEqual('Direct', actual)
+
def test_from_kwargs(self):
class MyOptions(PipelineOptions):
@classmethod
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service.py b/sdks/python/apache_beam/runners/portability/expansion_service.py
index 12e3ffb69702..4464d2f89b07 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service.py
@@ -56,16 +56,8 @@ def __init__(self, options=None, loopback_address=None):
def Expand(self, request, context=None):
try:
options = copy.deepcopy(self._options)
- request_options = pipeline_options.PipelineOptions.from_runner_api(
- request.pipeline_options)
- # TODO(https://github.com/apache/beam/issues/20090): Figure out the
- # correct subset of options to apply to expansion.
- if request_options.view_as(
- pipeline_options.StreamingOptions).update_compatibility_version:
- options.view_as(
- pipeline_options.StreamingOptions
- ).update_compatibility_version = request_options.view_as(
- pipeline_options.StreamingOptions).update_compatibility_version
+ options = pipeline_options.PipelineOptions.from_runner_api(
+ request.pipeline_options, options)
pipeline = beam_pipeline.Pipeline(options=options)
def with_pipeline(component, pcoll_id=None):
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 5af9d904895a..b3b2e9699b14 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -565,12 +565,15 @@ def expand(self, pcoll):
gbk = beam.GroupByKey()
gbk._inside_gbek = True
+ output_type = Tuple[key_type, Iterable[value_type]]
return (
pcoll
| beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder))
| gbk
- | beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder)))
+ | beam.ParDo(
+ _DecryptMessage(self._hmac_key, key_coder,
+ value_coder)).with_output_types(output_type))
class _BatchSizeEstimator(object):