Skip to content

Commit 7a9a4e6

Browse files
authored
x-lang GroupByEncryptedKey (Java to Python) (#36418)
* x-lang gbek tests * Add java test * missing import * Move towards standardizing on base64 * url encoded * More doc * yapf * test cleanup * progress, kick presubmits * use options * Additional pieces * Add pipeline options piece * Format * Move gbek into own test class * Remove python -> java tests (see #36457) * Simplify to get faster repro * Get it working, need to figure out actual issue though * Fix type hinting * Clean up * pipeline options tests * simplify/lint * resolve gemini comments (minor) * extra test * Lint: import ordering
1 parent 75eda20 commit 7a9a4e6

7 files changed

Lines changed: 178 additions & 15 deletions

File tree

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"comment": "Modify this file in a trivial way to cause this test suite to run.",
3+
"modification": 1
4+
}

sdks/java/build-tools/src/main/resources/beam/checkstyle/suppressions.xml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
<suppress id="ForbidNonVendoredGrpcProtobuf" files=".*sdk.*core.*GroupByEncryptedKeyTest.*" />
6161
<suppress id="ForbidNonVendoredGrpcProtobuf" files=".*sdk.*core.*GroupByKeyTest.*" />
6262
<suppress id="ForbidNonVendoredGrpcProtobuf" files=".*sdk.*core.*GroupByKeyIT.*" />
63+
<suppress id="ForbidNonVendoredGrpcProtobuf" files=".*sdk.*core.*ValidateRunnerXlangTest.*" />
6364
<suppress id="ForbidNonVendoredGrpcProtobuf" files=".*sdk.*extensions.*ml.*" />
6465
<suppress id="ForbidNonVendoredGrpcProtobuf" files=".*sdk.*io.*gcp.*" />
6566
<suppress id="ForbidNonVendoredGrpcProtobuf" files=".*sdk.*io.*googleads.*DummyRateLimitPolicy\.java" />

sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ValidateRunnerXlangTest.java

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,29 @@
1717
*/
1818
package org.apache.beam.sdk.util.construction;
1919

20+
import static org.hamcrest.MatcherAssert.assertThat;
21+
import static org.hamcrest.Matchers.equalTo;
22+
23+
import com.google.cloud.secretmanager.v1.ProjectName;
24+
import com.google.cloud.secretmanager.v1.SecretManagerServiceClient;
25+
import com.google.cloud.secretmanager.v1.SecretName;
26+
import com.google.cloud.secretmanager.v1.SecretPayload;
27+
import com.google.protobuf.ByteString;
2028
import java.io.IOException;
2129
import java.io.Serializable;
30+
import java.security.SecureRandom;
2231
import java.util.Arrays;
2332
import org.apache.beam.model.pipeline.v1.ExternalTransforms;
2433
import org.apache.beam.sdk.Pipeline;
34+
import org.apache.beam.sdk.PipelineResult;
2535
import org.apache.beam.sdk.coders.RowCoder;
36+
import org.apache.beam.sdk.options.PipelineOptions;
2637
import org.apache.beam.sdk.schemas.Schema;
2738
import org.apache.beam.sdk.schemas.Schema.Field;
2839
import org.apache.beam.sdk.schemas.Schema.FieldType;
2940
import org.apache.beam.sdk.schemas.SchemaTranslation;
3041
import org.apache.beam.sdk.testing.PAssert;
42+
import org.apache.beam.sdk.testing.TestPipeline;
3143
import org.apache.beam.sdk.testing.UsesJavaExpansionService;
3244
import org.apache.beam.sdk.testing.UsesPythonExpansionService;
3345
import org.apache.beam.sdk.testing.ValidatesRunner;
@@ -42,8 +54,13 @@
4254
import org.apache.beam.sdk.values.Row;
4355
import org.apache.beam.sdk.values.TypeDescriptors;
4456
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
57+
import org.junit.After;
58+
import org.junit.AfterClass;
59+
import org.junit.BeforeClass;
60+
import org.junit.Rule;
4561
import org.junit.Test;
4662
import org.junit.experimental.categories.Category;
63+
import org.junit.rules.ExpectedException;
4764
import org.junit.runner.RunWith;
4865
import org.junit.runners.JUnit4;
4966

@@ -286,6 +303,118 @@ public void test() {
286303
}
287304
}
288305

306+
/**
307+
* Motivation behind GroupByKeyWithGbekTest.
308+
*
309+
* <p>Target transform – GroupByKey
310+
* (https://beam.apache.org/documentation/programming-guide/#groupbykey) Test scenario – Grouping
311+
* a collection of KV<K,V> to a collection of KV<K, Iterable<V>> by key Boundary conditions
312+
* checked – –> PCollection<KV<?, ?>> to external transforms –> PCollection<KV<?, Iterable<?>>>
313+
* from external transforms while using GroupByEncryptedKey overrides
314+
*/
315+
@RunWith(JUnit4.class)
316+
public static class GroupByKeyWithGbekTest extends ValidateRunnerXlangTestBase {
317+
@Rule public ExpectedException thrown = ExpectedException.none();
318+
private static final String PROJECT_ID = "apache-beam-testing";
319+
private static final String SECRET_ID = "gbek-test";
320+
private static String gcpSecretVersionName;
321+
private static String secretId;
322+
323+
@BeforeClass
324+
public static void setUpClass() {
325+
secretId = String.format("%s-%d", SECRET_ID, new SecureRandom().nextInt(10000));
326+
try (SecretManagerServiceClient client = SecretManagerServiceClient.create()) {
327+
ProjectName projectName = ProjectName.of(PROJECT_ID);
328+
SecretName secretName = SecretName.of(PROJECT_ID, secretId);
329+
330+
try {
331+
client.getSecret(secretName);
332+
} catch (Exception e) {
333+
com.google.cloud.secretmanager.v1.Secret secret =
334+
com.google.cloud.secretmanager.v1.Secret.newBuilder()
335+
.setReplication(
336+
com.google.cloud.secretmanager.v1.Replication.newBuilder()
337+
.setAutomatic(
338+
com.google.cloud.secretmanager.v1.Replication.Automatic.newBuilder()
339+
.build())
340+
.build())
341+
.build();
342+
client.createSecret(projectName, secretId, secret);
343+
byte[] secretBytes = new byte[32];
344+
new SecureRandom().nextBytes(secretBytes);
345+
client.addSecretVersion(
346+
secretName,
347+
SecretPayload.newBuilder()
348+
.setData(
349+
ByteString.copyFrom(java.util.Base64.getUrlEncoder().encode(secretBytes)))
350+
.build());
351+
}
352+
gcpSecretVersionName = secretName.toString() + "/versions/latest";
353+
} catch (IOException e) {
354+
gcpSecretVersionName = null;
355+
return;
356+
}
357+
expansionAddr =
358+
String.format("localhost:%s", Integer.valueOf(System.getProperty("expansionPort")));
359+
}
360+
361+
@AfterClass
362+
public static void tearDownClass() {
363+
if (gcpSecretVersionName != null) {
364+
try (SecretManagerServiceClient client = SecretManagerServiceClient.create()) {
365+
SecretName secretName = SecretName.of(PROJECT_ID, secretId);
366+
client.deleteSecret(secretName);
367+
} catch (IOException e) {
368+
// Do nothing.
369+
}
370+
}
371+
}
372+
373+
@After
374+
@Override
375+
public void tearDown() {
376+
// Override tearDown since we're doing our own assertion instead of relying on base class
377+
// assertions
378+
}
379+
380+
@Test
381+
@Category({
382+
ValidatesRunner.class,
383+
UsesJavaExpansionService.class,
384+
UsesPythonExpansionService.class
385+
})
386+
public void test() {
387+
if (gcpSecretVersionName == null) {
388+
// Skip test if we couldn't set up secret manager
389+
return;
390+
}
391+
PipelineOptions options = TestPipeline.testingPipelineOptions();
392+
options.setGbek(String.format("type:gcpsecret;version_name:%s", gcpSecretVersionName));
393+
Pipeline pipeline = Pipeline.create(options);
394+
groupByKeyTest(pipeline);
395+
PipelineResult pipelineResult = pipeline.run();
396+
pipelineResult.waitUntilFinish();
397+
assertThat(pipelineResult.getState(), equalTo(PipelineResult.State.DONE));
398+
}
399+
400+
@Test
401+
@Category({
402+
ValidatesRunner.class,
403+
UsesJavaExpansionService.class,
404+
UsesPythonExpansionService.class
405+
})
406+
public void testFailure() {
407+
thrown.expect(Exception.class);
408+
PipelineOptions options = TestPipeline.testingPipelineOptions();
409+
options.setGbek("version_name:fake_secret");
410+
Pipeline pipeline = Pipeline.create(options);
411+
groupByKeyTest(pipeline);
412+
PipelineResult pipelineResult = pipeline.run();
413+
pipelineResult.waitUntilFinish();
414+
assertThat(pipelineResult.getState(), equalTo(PipelineResult.State.DONE));
415+
}
416+
}
417+
289418
/**
290419
* Motivation behind coGroupByKeyTest.
291420
*

sdks/python/apache_beam/options/pipeline_options.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@
6464
# that have a destination(dest) in parser.add_argument() different
6565
# from the flag name and whose default value is `None`.
6666
_FLAG_THAT_SETS_FALSE_VALUE = {'use_public_ips': 'no_use_public_ips'}
67+
# Set of options which should not be overriden when applying options from a
68+
# different language. This is relevant when using x-lang transforms where the
69+
# expansion service is started up with some pipeline options, and will
70+
# impact which options are passed in to expanded transforms' expand functions.
71+
_NON_OVERIDABLE_XLANG_OPTIONS = ['runner', 'experiments']
6772

6873

6974
def _static_value_provider_of(value_type):
@@ -287,6 +292,10 @@ def _smart_split(self, values):
287292

288293

289294
class PipelineOptions(HasDisplayData):
295+
# Set of options which should not be overriden when pipeline options are
296+
# being merged (see from_runner_api). This primarily comes up when expanding
297+
# the Python expansion service
298+
290299
"""This class and subclasses are used as containers for command line options.
291300
292301
These classes are wrappers over the standard argparse Python module
@@ -592,15 +601,19 @@ def to_struct_value(o):
592601
})
593602

594603
@classmethod
595-
def from_runner_api(cls, proto_options):
604+
def from_runner_api(cls, proto_options, original_options=None):
596605
def from_urn(key):
597606
assert key.startswith('beam:option:')
598607
assert key.endswith(':v1')
599608
return key[12:-3]
600609

601-
return cls(
602-
**{from_urn(key): value
603-
for (key, value) in proto_options.items()})
610+
parsed = {from_urn(key): value for (key, value) in proto_options.items()}
611+
if original_options is None:
612+
return cls(**parsed)
613+
for (key, value) in parsed.items():
614+
if value and key not in _NON_OVERIDABLE_XLANG_OPTIONS:
615+
original_options._all_options[key] = value
616+
return original_options
604617

605618
def display_data(self):
606619
return self.get_all_options(drop_default=True, retain_unknown_options=True)

sdks/python/apache_beam/options/pipeline_options_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from apache_beam.options.pipeline_options import JobServerOptions
3535
from apache_beam.options.pipeline_options import PipelineOptions
3636
from apache_beam.options.pipeline_options import ProfilingOptions
37+
from apache_beam.options.pipeline_options import StandardOptions
3738
from apache_beam.options.pipeline_options import TypeOptions
3839
from apache_beam.options.pipeline_options import WorkerOptions
3940
from apache_beam.options.pipeline_options import _BeamArgumentParser
@@ -308,6 +309,26 @@ def _add_argparse_args(cls, parser):
308309
self.assertEqual(result['test_arg_int'], 5)
309310
self.assertEqual(result['test_arg_none'], None)
310311

312+
def test_merging_options(self):
313+
opts = PipelineOptions(flags=['--num_workers', '5'])
314+
actual_opts = PipelineOptions.from_runner_api(opts.to_runner_api())
315+
actual = actual_opts.view_as(WorkerOptions).num_workers
316+
self.assertEqual(5, actual)
317+
318+
def test_merging_options_with_overriden_options(self):
319+
opts = PipelineOptions(flags=['--num_workers', '5'])
320+
base = PipelineOptions(flags=['--num_workers', '2'])
321+
actual_opts = PipelineOptions.from_runner_api(opts.to_runner_api(), base)
322+
actual = actual_opts.view_as(WorkerOptions).num_workers
323+
self.assertEqual(5, actual)
324+
325+
def test_merging_options_with_overriden_runner(self):
326+
opts = PipelineOptions(flags=['--runner', 'FnApiRunner'])
327+
base = PipelineOptions(flags=['--runner', 'Direct'])
328+
actual_opts = PipelineOptions.from_runner_api(opts.to_runner_api(), base)
329+
actual = actual_opts.view_as(StandardOptions).runner
330+
self.assertEqual('Direct', actual)
331+
311332
def test_from_kwargs(self):
312333
class MyOptions(PipelineOptions):
313334
@classmethod

sdks/python/apache_beam/runners/portability/expansion_service.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,8 @@ def __init__(self, options=None, loopback_address=None):
5656
def Expand(self, request, context=None):
5757
try:
5858
options = copy.deepcopy(self._options)
59-
request_options = pipeline_options.PipelineOptions.from_runner_api(
60-
request.pipeline_options)
61-
# TODO(https://github.com/apache/beam/issues/20090): Figure out the
62-
# correct subset of options to apply to expansion.
63-
if request_options.view_as(
64-
pipeline_options.StreamingOptions).update_compatibility_version:
65-
options.view_as(
66-
pipeline_options.StreamingOptions
67-
).update_compatibility_version = request_options.view_as(
68-
pipeline_options.StreamingOptions).update_compatibility_version
59+
options = pipeline_options.PipelineOptions.from_runner_api(
60+
request.pipeline_options, options)
6961
pipeline = beam_pipeline.Pipeline(options=options)
7062

7163
def with_pipeline(component, pcoll_id=None):

sdks/python/apache_beam/transforms/util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,12 +570,15 @@ def expand(self, pcoll):
570570

571571
gbk = beam.GroupByKey()
572572
gbk._inside_gbek = True
573+
output_type = Tuple[key_type, Iterable[value_type]]
573574

574575
return (
575576
pcoll
576577
| beam.ParDo(_EncryptMessage(self._hmac_key, key_coder, value_coder))
577578
| gbk
578-
| beam.ParDo(_DecryptMessage(self._hmac_key, key_coder, value_coder)))
579+
| beam.ParDo(
580+
_DecryptMessage(self._hmac_key, key_coder,
581+
value_coder)).with_output_types(output_type))
579582

580583

581584
class _BatchSizeEstimator(object):

0 commit comments

Comments
 (0)