encryptedKv : c.element().getValue()) {
byte[] iv = Arrays.copyOfRange(encryptedKv.getKey(), 0, 12);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(128, iv);
@@ -251,7 +252,8 @@ public void processElement(ProcessContext c) throws Exception {
byte[] decryptedKeyBytes = this.cipher.doFinal(encryptedKey);
K key = decode(this.keyCoder, decryptedKeyBytes);
- if (key != null) {
+ // If somehow the key was decoded to null, but the byte string is non-empty, throw.
+ if (key != null || decryptedKeyBytes == null || decryptedKeyBytes.length == 0) {
if (!decryptedKvs.containsKey(key)) {
decryptedKvs.put(key, new java.util.ArrayList<>());
}
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ValidateRunnerXlangTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ValidateRunnerXlangTest.java
index c41b2151d4cc..06288c07dbff 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ValidateRunnerXlangTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ValidateRunnerXlangTest.java
@@ -17,17 +17,29 @@
*/
package org.apache.beam.sdk.util.construction;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.equalTo;
+
+import com.google.cloud.secretmanager.v1.ProjectName;
+import com.google.cloud.secretmanager.v1.SecretManagerServiceClient;
+import com.google.cloud.secretmanager.v1.SecretName;
+import com.google.cloud.secretmanager.v1.SecretPayload;
+import com.google.protobuf.ByteString;
import java.io.IOException;
import java.io.Serializable;
+import java.security.SecureRandom;
import java.util.Arrays;
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
import org.apache.beam.sdk.Pipeline;
+import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.RowCoder;
+import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.SchemaTranslation;
import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.UsesJavaExpansionService;
import org.apache.beam.sdk.testing.UsesPythonExpansionService;
import org.apache.beam.sdk.testing.ValidatesRunner;
@@ -42,8 +54,13 @@
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
+import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -286,6 +303,118 @@ public void test() {
}
}
+ /**
+ * Motivation behind GroupByKeyWithGbekTest.
+ *
+ * Target transform – GroupByKey
+ * (https://beam.apache.org/documentation/programming-guide/#groupbykey) Test scenario – Grouping
+ * a collection of KV to a collection of KV> by key Boundary conditions
+ * checked – –> PCollection> to external transforms –> PCollection>>
+ * from external transforms while using GroupByEncryptedKey overrides
+ */
+ @RunWith(JUnit4.class)
+ public static class GroupByKeyWithGbekTest extends ValidateRunnerXlangTestBase {
+ @Rule public ExpectedException thrown = ExpectedException.none();
+ private static final String PROJECT_ID = "apache-beam-testing";
+ private static final String SECRET_ID = "gbek-test";
+ private static String gcpSecretVersionName;
+ private static String secretId;
+
+ @BeforeClass
+ public static void setUpClass() {
+ secretId = String.format("%s-%d", SECRET_ID, new SecureRandom().nextInt(10000));
+ try (SecretManagerServiceClient client = SecretManagerServiceClient.create()) {
+ ProjectName projectName = ProjectName.of(PROJECT_ID);
+ SecretName secretName = SecretName.of(PROJECT_ID, secretId);
+
+ try {
+ client.getSecret(secretName);
+ } catch (Exception e) {
+ com.google.cloud.secretmanager.v1.Secret secret =
+ com.google.cloud.secretmanager.v1.Secret.newBuilder()
+ .setReplication(
+ com.google.cloud.secretmanager.v1.Replication.newBuilder()
+ .setAutomatic(
+ com.google.cloud.secretmanager.v1.Replication.Automatic.newBuilder()
+ .build())
+ .build())
+ .build();
+ client.createSecret(projectName, secretId, secret);
+ byte[] secretBytes = new byte[32];
+ new SecureRandom().nextBytes(secretBytes);
+ client.addSecretVersion(
+ secretName,
+ SecretPayload.newBuilder()
+ .setData(
+ ByteString.copyFrom(java.util.Base64.getUrlEncoder().encode(secretBytes)))
+ .build());
+ }
+ gcpSecretVersionName = secretName.toString() + "/versions/latest";
+ } catch (IOException e) {
+ gcpSecretVersionName = null;
+ return;
+ }
+ expansionAddr =
+ String.format("localhost:%s", Integer.valueOf(System.getProperty("expansionPort")));
+ }
+
+ @AfterClass
+ public static void tearDownClass() {
+ if (gcpSecretVersionName != null) {
+ try (SecretManagerServiceClient client = SecretManagerServiceClient.create()) {
+ SecretName secretName = SecretName.of(PROJECT_ID, secretId);
+ client.deleteSecret(secretName);
+ } catch (IOException e) {
+ // Do nothing.
+ }
+ }
+ }
+
+ @After
+ @Override
+ public void tearDown() {
+ // Override tearDown since we're doing our own assertion instead of relying on base class
+ // assertions
+ }
+
+ @Test
+ @Category({
+ ValidatesRunner.class,
+ UsesJavaExpansionService.class,
+ UsesPythonExpansionService.class
+ })
+ public void test() {
+ if (gcpSecretVersionName == null) {
+ // Skip test if we couldn't set up secret manager
+ return;
+ }
+ PipelineOptions options = TestPipeline.testingPipelineOptions();
+ options.setGbek(String.format("type:gcpsecret;version_name:%s", gcpSecretVersionName));
+ Pipeline pipeline = Pipeline.create(options);
+ groupByKeyTest(pipeline);
+ PipelineResult pipelineResult = pipeline.run();
+ pipelineResult.waitUntilFinish();
+ assertThat(pipelineResult.getState(), equalTo(PipelineResult.State.DONE));
+ }
+
+ @Test
+ @Category({
+ ValidatesRunner.class,
+ UsesJavaExpansionService.class,
+ UsesPythonExpansionService.class
+ })
+ public void testFailure() {
+ thrown.expect(Exception.class);
+ PipelineOptions options = TestPipeline.testingPipelineOptions();
+ options.setGbek("version_name:fake_secret");
+ Pipeline pipeline = Pipeline.create(options);
+ groupByKeyTest(pipeline);
+ PipelineResult pipelineResult = pipeline.run();
+ pipelineResult.waitUntilFinish();
+ assertThat(pipelineResult.getState(), equalTo(PipelineResult.State.DONE));
+ }
+ }
+
/**
* Motivation behind coGroupByKeyTest.
*
diff --git a/sdks/python/apache_beam/options/pipeline_options.py b/sdks/python/apache_beam/options/pipeline_options.py
index 3fc5151156f1..247daebd2c6d 100644
--- a/sdks/python/apache_beam/options/pipeline_options.py
+++ b/sdks/python/apache_beam/options/pipeline_options.py
@@ -64,6 +64,11 @@
# that have a destination(dest) in parser.add_argument() different
# from the flag name and whose default value is `None`.
_FLAG_THAT_SETS_FALSE_VALUE = {'use_public_ips': 'no_use_public_ips'}
+# Set of options which should not be overriden when applying options from a
+# different language. This is relevant when using x-lang transforms where the
+# expansion service is started up with some pipeline options, and will
+# impact which options are passed in to expanded transforms' expand functions.
+_NON_OVERIDABLE_XLANG_OPTIONS = ['runner', 'experiments']
def _static_value_provider_of(value_type):
@@ -287,6 +292,10 @@ def _smart_split(self, values):
class PipelineOptions(HasDisplayData):
+ # Set of options which should not be overriden when pipeline options are
+ # being merged (see from_runner_api). This primarily comes up when expanding
+ # the Python expansion service
+
"""This class and subclasses are used as containers for command line options.
These classes are wrappers over the standard argparse Python module
@@ -592,15 +601,19 @@ def to_struct_value(o):
})
@classmethod
- def from_runner_api(cls, proto_options):
+ def from_runner_api(cls, proto_options, original_options=None):
def from_urn(key):
assert key.startswith('beam:option:')
assert key.endswith(':v1')
return key[12:-3]
- return cls(
- **{from_urn(key): value
- for (key, value) in proto_options.items()})
+ parsed = {from_urn(key): value for (key, value) in proto_options.items()}
+ if original_options is None:
+ return cls(**parsed)
+ for (key, value) in parsed.items():
+ if value and key not in _NON_OVERIDABLE_XLANG_OPTIONS:
+ original_options._all_options[key] = value
+ return original_options
def display_data(self):
return self.get_all_options(drop_default=True, retain_unknown_options=True)
diff --git a/sdks/python/apache_beam/runners/portability/expansion_service.py b/sdks/python/apache_beam/runners/portability/expansion_service.py
index 12e3ffb69702..4464d2f89b07 100644
--- a/sdks/python/apache_beam/runners/portability/expansion_service.py
+++ b/sdks/python/apache_beam/runners/portability/expansion_service.py
@@ -56,16 +56,8 @@ def __init__(self, options=None, loopback_address=None):
def Expand(self, request, context=None):
try:
options = copy.deepcopy(self._options)
- request_options = pipeline_options.PipelineOptions.from_runner_api(
- request.pipeline_options)
- # TODO(https://github.com/apache/beam/issues/20090): Figure out the
- # correct subset of options to apply to expansion.
- if request_options.view_as(
- pipeline_options.StreamingOptions).update_compatibility_version:
- options.view_as(
- pipeline_options.StreamingOptions
- ).update_compatibility_version = request_options.view_as(
- pipeline_options.StreamingOptions).update_compatibility_version
+ options = pipeline_options.PipelineOptions.from_runner_api(
+ request.pipeline_options, options)
pipeline = beam_pipeline.Pipeline(options=options)
def with_pipeline(component, pcoll_id=None):
diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py
index 5af9d904895a..4a5cf0794c45 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -546,13 +546,16 @@ def expand(self, pcoll):
pcoll.element_type).tuple_types)
kv_type_hint = typehints.KV[key_type, value_type]
if kv_type_hint and kv_type_hint != typehints.Any:
- coder = coders.registry.get_coder(kv_type_hint).as_deterministic_coder(
- f'GroupByEncryptedKey {self.label}'
- 'The key coder is not deterministic. This may result in incorrect '
- 'pipeline output. This can be fixed by adding a type hint to the '
- 'operation preceding the GroupByKey step, and for custom key '
- 'classes, by writing a deterministic custom Coder. Please see the '
- 'documentation for more details.')
+ coder = coders.registry.get_coder(kv_type_hint)
+ try:
+ coder = coder.as_deterministic_coder(self.label)
+ except ValueError:
+ logging.warning(f'GroupByEncryptedKey {self.label}: '
+ 'The key coder is not deterministic. This may result in incorrect '
+ 'pipeline output. This can be fixed by adding a type hint to the '
+ 'operation preceding the GroupByKey step, and for custom key '
+ 'classes, by writing a deterministic custom Coder. Please see the '
+ 'documentation for more details.')
if not coder.is_kv_coder():
raise ValueError(
'Input elements to the transform %s with stateful DoFn must be '
@@ -565,12 +568,15 @@ def expand(self, pcoll):
gbk = beam.GroupByKey()
gbk._inside_gbek = True
+ output_type = Tuple[key_type, Iterable[value_type]]
return (
pcoll
| beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder))
| gbk
- | beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder)))
+ | beam.ParDo(
+ _DecryptMessage(self._hmac_key, key_coder,
+ value_coder)).with_output_types(output_type))
class _BatchSizeEstimator(object):
diff --git a/sdks/python/container/run_validatescontainer.sh b/sdks/python/container/run_validatescontainer.sh
index 95130f7559bb..8266df3c03a7 100755
--- a/sdks/python/container/run_validatescontainer.sh
+++ b/sdks/python/container/run_validatescontainer.sh
@@ -142,6 +142,7 @@ pytest -o log_cli=True -o log_level=Info -o junit_suite_name=$IMAGE_NAME \
--output=$GCS_LOCATION/output \
--sdk_location=$SDK_LOCATION \
--num_workers=1 \
+ --gbek=type:GcpSecret;version_name:projects/apache-beam-testing/secrets/gbek_secret_tests_dannystest/versions/latest \
$MACHINE_TYPE_ARGS \
--docker_registry_push_url=$PREBUILD_SDK_CONTAINER_REGISTRY_PATH"
diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle
index 6a0777bd667c..39492da527a4 100644
--- a/sdks/python/test-suites/dataflow/common.gradle
+++ b/sdks/python/test-suites/dataflow/common.gradle
@@ -391,6 +391,7 @@ task validatesDistrolessContainer() {
"project" : "apache-beam-testing",
"region" : "us-central1",
"runner" : "TestDataflowRunner",
+ "gbek" : "type:GcpSecret;version_name:projects/apache-beam-testing/secrets/gbek_secret_tests_dannystest/versions/latest",
"sdk_container_image": "${imageURL}",
"sdk_location" : "container",
"staging_location" : "gs://temp-storage-for-end-to-end-tests/staging-it",
@@ -438,6 +439,7 @@ def tensorRTTests = tasks.create("tensorRTtests") {
"input": "gs://apache-beam-ml/testing/inputs/tensorrt_image_file_names.txt",
"output": "gs://apache-beam-ml/outputs/tensorrt_predictions.txt",
"engine_path": "gs://apache-beam-ml/models/ssd_mobilenet_v2_320x320_coco17_tpu-8.trt",
+ "gbek": "type:GcpSecret;version_name:projects/apache-beam-testing/secrets/gbek_secret_tests_dannystest/versions/latest",
"disk_size_gb": 75
]
def cmdArgs = mapToArgString(argMap)
@@ -466,6 +468,7 @@ def vllmTests = tasks.create("vllmTests") {
"region": "us-central1",
"model": "facebook/opt-125m",
"output": "gs://apache-beam-ml/outputs/vllm_predictions.txt",
+ "gbek": "type:GcpSecret;version_name:projects/apache-beam-testing/secrets/gbek_secret_tests_dannystest/versions/latest",
"disk_size_gb": 75
]
def cmdArgs = mapToArgString(argMap)
@@ -499,7 +502,8 @@ task vertexAIInferenceTest {
"suite": "VertexAITests-df-py${pythonVersionSuffix}",
"collect": "vertex_ai_postcommit" ,
"runner": "TestDataflowRunner",
- "requirements_file": "$requirementsFile"
+ "requirements_file": "$requirementsFile",
+ "gbek": "type:GcpSecret;version_name:projects/apache-beam-testing/secrets/gbek_secret_tests_dannystest/versions/latest"
]
def cmdArgs = mapToArgString(argMap)
exec {
@@ -527,6 +531,7 @@ task geminiInferenceTest {
"suite": "GeminiTests-df-py${pythonVersionSuffix}",
"collect": "gemini_postcommit" ,
"runner": "TestDataflowRunner",
+ "gbek": "type:GcpSecret;version_name:projects/apache-beam-testing/secrets/gbek_secret_tests_dannystest/versions/latest",
"requirements_file": "$requirementsFile"
]
def cmdArgs = mapToArgString(argMap)
@@ -629,7 +634,8 @@ project(":sdks:python:test-suites:xlang").ext.xlangTasks.each { taskMetadata ->
"--project=${gcpProject}",
"--region=${gcpRegion}",
"--sdk_container_image=gcr.io/apache-beam-testing/beam-sdk/beam_python${project.ext.pythonVersion}_sdk:latest",
- "--sdk_harness_container_image_overrides=.*java.*,gcr.io/apache-beam-testing/beam-sdk/beam_java11_sdk:latest"
+ "--sdk_harness_container_image_overrides=.*java.*,gcr.io/apache-beam-testing/beam-sdk/beam_java11_sdk:latest",
+ "--gbek=type:GcpSecret;version_name:projects/apache-beam-testing/secrets/gbek_secret_tests_dannystest/versions/latest"
],
pytestOptions: basicPytestOpts,
additionalDeps: taskMetadata.additionalDeps,