Skip to content

Commit db2020b

Browse files
committed
fix tableRowFromMessage
1 parent c8df4da commit db2020b

16 files changed

Lines changed: 1878 additions & 388 deletions
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.transforms;
19+
20+
import org.checkerframework.checker.nullness.qual.Nullable;
21+
22+
public class SerializableBiFunctions {
23+
public static <T, U, R> SerializableBiFunction<T, U, T> Select1st(
24+
SerializableBiFunction<@Nullable T, @Nullable U, R> biFunction) {
25+
return (t, u) -> t;
26+
}
27+
28+
public static <T, U, R> SerializableBiFunction<T, U, U> Select2nd(
29+
SerializableBiFunction<@Nullable T, @Nullable U, R> biFunction) {
30+
return (t, u) -> u;
31+
}
32+
33+
public static <T, U, R> SerializableFunction<U, R> fix1st(
34+
SerializableBiFunction<@Nullable T, @Nullable U, R> biFunction, @Nullable T value) {
35+
return u -> biFunction.apply(value, u);
36+
}
37+
38+
public static <T, U, R> SerializableFunction<T, R> fix2nd(
39+
SerializableBiFunction<@Nullable T, @Nullable U, R> biFunction, @Nullable U value) {
40+
return t -> biFunction.apply(t, value);
41+
}
42+
43+
public static <T, U, R> SerializableBiFunction<T, U, R> ignore1st(
44+
SerializableFunction<@Nullable U, R> function) {
45+
return (t, u) -> function.apply(u);
46+
}
47+
48+
public static <T, U, R> SerializableBiFunction<T, U, R> ignore2nd(
49+
SerializableFunction<@Nullable T, R> function) {
50+
return (t, u) -> function.apply(t);
51+
}
52+
}

sdks/java/extensions/protobuf/src/test/proto/proto3_schema_messages.proto

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,51 @@ import "proto3_schema_options.proto";
3333

3434
option java_package = "org.apache.beam.sdk.extensions.protobuf";
3535

36+
message PrimitiveEncodedFields {
37+
int64 encoded_timestamp = 1;
38+
int32 encoded_date = 2;
39+
bytes encoded_numeric = 3;
40+
bytes encoded_bignumeric = 4;
41+
int64 encoded_packed_datetime = 5;
42+
int64 encoded_packed_time = 6;
43+
}
44+
45+
message NestedEncodedFields {
46+
PrimitiveEncodedFields nested = 1;
47+
repeated PrimitiveEncodedFields nested_list = 2;
48+
}
49+
50+
message PrimitiveUnEncodedFields {
51+
string timestamp = 1;
52+
string date = 2;
53+
string numeric = 3;
54+
string bignumeric = 4;
55+
string datetime = 5;
56+
string time = 6;
57+
}
58+
59+
message NestedUnEncodedFields {
60+
PrimitiveUnEncodedFields nested = 1;
61+
repeated PrimitiveUnEncodedFields nested_list = 2;
62+
}
63+
64+
message WrapperUnEncodedFields {
65+
google.protobuf.FloatValue float = 1;
66+
google.protobuf.DoubleValue double = 2;
67+
google.protobuf.BoolValue bool = 3;
68+
google.protobuf.Int32Value int32 = 4;
69+
google.protobuf.Int64Value int64 = 5;
70+
google.protobuf.UInt32Value uint32 = 6;
71+
google.protobuf.UInt64Value uint64 = 7;
72+
google.protobuf.BytesValue bytes = 8;
73+
google.protobuf.Timestamp timestamp = 9;
74+
}
75+
76+
message NestedWrapperUnEncodedFields {
77+
WrapperUnEncodedFields nested = 1;
78+
repeated WrapperUnEncodedFields nested_list = 2;
79+
}
80+
3681
message Primitive {
3782
double primitive_double = 1;
3883
float primitive_float = 2;
@@ -287,4 +332,4 @@ message NoWrapPrimitive {
287332
optional bool bool = 13;
288333
optional string string = 14;
289334
optional bytes bytes = 15;
290-
}
335+
}

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AppendClientInfo.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ Descriptors.Descriptor getDescriptorIgnoreRequired() {
170170
public TableRow toTableRow(ByteString protoBytes, Predicate<String> includeField) {
171171
try {
172172
return TableRowToStorageApiProto.tableRowFromMessage(
173+
getSchemaInformation(),
173174
DynamicMessage.parseFrom(
174175
TableRowToStorageApiProto.wrapDescriptorProto(getDescriptor()), protoBytes),
175176
true,

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java

Lines changed: 84 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import com.google.cloud.bigquery.storage.v1.AppendRowsRequest;
4343
import com.google.cloud.bigquery.storage.v1.CreateReadSessionRequest;
4444
import com.google.cloud.bigquery.storage.v1.DataFormat;
45+
import com.google.cloud.bigquery.storage.v1.ProtoSchemaConverter;
4546
import com.google.cloud.bigquery.storage.v1.ReadSession;
4647
import com.google.cloud.bigquery.storage.v1.ReadStream;
4748
import com.google.gson.JsonArray;
@@ -119,6 +120,8 @@
119120
import org.apache.beam.sdk.transforms.PTransform;
120121
import org.apache.beam.sdk.transforms.ParDo;
121122
import org.apache.beam.sdk.transforms.Reshuffle;
123+
import org.apache.beam.sdk.transforms.SerializableBiFunction;
124+
import org.apache.beam.sdk.transforms.SerializableBiFunctions;
122125
import org.apache.beam.sdk.transforms.SerializableFunction;
123126
import org.apache.beam.sdk.transforms.SerializableFunctions;
124127
import org.apache.beam.sdk.transforms.SimpleFunction;
@@ -2297,10 +2300,61 @@ public static <T extends Message> Write<T> writeProtos(Class<T> protoMessageClas
22972300
if (DynamicMessage.class.equals(protoMessageClass)) {
22982301
throw new IllegalArgumentException("DynamicMessage is not supported.");
22992302
}
2300-
return BigQueryIO.<T>write()
2301-
.withFormatFunction(
2302-
m -> TableRowToStorageApiProto.tableRowFromMessage(m, false, Predicates.alwaysTrue()))
2303-
.withWriteProtosClass(protoMessageClass);
2303+
try {
2304+
return BigQueryIO.<T>write()
2305+
.toBuilder()
2306+
.setFormatFunction(FormatProto.fromClass(protoMessageClass))
2307+
.build()
2308+
.withWriteProtosClass(protoMessageClass);
2309+
} catch (Exception e) {
2310+
throw new RuntimeException(e);
2311+
}
2312+
}
2313+
2314+
private static class FormatProto<T extends Message>
2315+
implements SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, T, TableRow> {
2316+
transient TableRowToStorageApiProto.SchemaInformation inferredSchemaInformation;
2317+
final Class<T> protoMessageClass;
2318+
2319+
FormatProto(Class<T> protoMessageClass) {
2320+
this.protoMessageClass = protoMessageClass;
2321+
}
2322+
2323+
TableRowToStorageApiProto.SchemaInformation inferSchemaInformation() {
2324+
try {
2325+
if (inferredSchemaInformation == null) {
2326+
Descriptors.Descriptor descriptor =
2327+
(Descriptors.Descriptor)
2328+
org.apache.beam.sdk.util.Preconditions.checkStateNotNull(
2329+
protoMessageClass.getMethod("getDescriptor"))
2330+
.invoke(null);
2331+
Descriptors.Descriptor convertedDescriptor =
2332+
TableRowToStorageApiProto.wrapDescriptorProto(
2333+
ProtoSchemaConverter.convert(descriptor).getProtoDescriptor());
2334+
TableSchema tableSchema =
2335+
TableRowToStorageApiProto.protoSchemaToTableSchema(
2336+
TableRowToStorageApiProto.tableSchemaFromDescriptor(convertedDescriptor));
2337+
this.inferredSchemaInformation =
2338+
TableRowToStorageApiProto.SchemaInformation.fromTableSchema(tableSchema);
2339+
}
2340+
return inferredSchemaInformation;
2341+
} catch (Exception e) {
2342+
throw new RuntimeException(e);
2343+
}
2344+
}
2345+
2346+
static <T extends Message> FormatProto<T> fromClass(Class<T> protoMessageClass)
2347+
throws Exception {
2348+
return new FormatProto<>(protoMessageClass);
2349+
}
2350+
2351+
@Override
2352+
public TableRow apply(TableRowToStorageApiProto.SchemaInformation schemaInformation, T input) {
2353+
TableRowToStorageApiProto.SchemaInformation localSchemaInformation =
2354+
schemaInformation != null ? schemaInformation : inferSchemaInformation();
2355+
return TableRowToStorageApiProto.tableRowFromMessage(
2356+
localSchemaInformation, input, false, Predicates.alwaysTrue());
2357+
}
23042358
}
23052359

23062360
/** Implementation of {@link #write}. */
@@ -2354,9 +2408,13 @@ public enum Method {
23542408
abstract @Nullable SerializableFunction<ValueInSingleWindow<T>, TableDestination>
23552409
getTableFunction();
23562410

2357-
abstract @Nullable SerializableFunction<T, TableRow> getFormatFunction();
2411+
abstract @Nullable SerializableBiFunction<
2412+
TableRowToStorageApiProto.SchemaInformation, T, TableRow>
2413+
getFormatFunction();
23582414

2359-
abstract @Nullable SerializableFunction<T, TableRow> getFormatRecordOnFailureFunction();
2415+
abstract @Nullable SerializableBiFunction<
2416+
TableRowToStorageApiProto.SchemaInformation, T, TableRow>
2417+
getFormatRecordOnFailureFunction();
23602418

23612419
abstract RowWriterFactory.@Nullable AvroRowWriterFactory<T, ?, ?> getAvroRowWriterFactory();
23622420

@@ -2467,10 +2525,13 @@ abstract static class Builder<T> {
24672525
abstract Builder<T> setTableFunction(
24682526
SerializableFunction<ValueInSingleWindow<T>, TableDestination> tableFunction);
24692527

2470-
abstract Builder<T> setFormatFunction(SerializableFunction<T, TableRow> formatFunction);
2528+
abstract Builder<T> setFormatFunction(
2529+
SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, T, TableRow>
2530+
formatFunction);
24712531

24722532
abstract Builder<T> setFormatRecordOnFailureFunction(
2473-
SerializableFunction<T, TableRow> formatFunction);
2533+
SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, T, TableRow>
2534+
formatFunction);
24742535

24752536
abstract Builder<T> setAvroRowWriterFactory(
24762537
RowWriterFactory.AvroRowWriterFactory<T, ?, ?> avroRowWriterFactory);
@@ -2718,7 +2779,9 @@ public Write<T> to(DynamicDestinations<T, ?> dynamicDestinations) {
27182779

27192780
/** Formats the user's type into a {@link TableRow} to be written to BigQuery. */
27202781
public Write<T> withFormatFunction(SerializableFunction<T, TableRow> formatFunction) {
2721-
return toBuilder().setFormatFunction(formatFunction).build();
2782+
return toBuilder()
2783+
.setFormatFunction(SerializableBiFunctions.ignore1st(formatFunction))
2784+
.build();
27222785
}
27232786

27242787
/**
@@ -2733,7 +2796,9 @@ public Write<T> withFormatFunction(SerializableFunction<T, TableRow> formatFunct
27332796
*/
27342797
public Write<T> withFormatRecordOnFailureFunction(
27352798
SerializableFunction<T, TableRow> formatFunction) {
2736-
return toBuilder().setFormatRecordOnFailureFunction(formatFunction).build();
2799+
return toBuilder()
2800+
.setFormatRecordOnFailureFunction(SerializableBiFunctions.ignore1st(formatFunction))
2801+
.build();
27372802
}
27382803

27392804
/**
@@ -3599,9 +3664,10 @@ && getStorageApiTriggeringFrequency(bqOptions) != null) {
35993664
private <DestinationT> WriteResult expandTyped(
36003665
PCollection<T> input, DynamicDestinations<T, DestinationT> dynamicDestinations) {
36013666
boolean optimizeWrites = getOptimizeWrites();
3602-
SerializableFunction<T, TableRow> formatFunction = getFormatFunction();
3603-
SerializableFunction<T, TableRow> formatRecordOnFailureFunction =
3604-
getFormatRecordOnFailureFunction();
3667+
SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, T, TableRow>
3668+
formatFunction = getFormatFunction();
3669+
SerializableBiFunction<TableRowToStorageApiProto.SchemaInformation, T, TableRow>
3670+
formatRecordOnFailureFunction = getFormatRecordOnFailureFunction();
36053671
RowWriterFactory.AvroRowWriterFactory<T, ?, DestinationT> avroRowWriterFactory =
36063672
(RowWriterFactory.AvroRowWriterFactory<T, ?, DestinationT>) getAvroRowWriterFactory();
36073673

@@ -3623,7 +3689,8 @@ private <DestinationT> WriteResult expandTyped(
36233689
// If no format function set, then we will automatically convert the input type to a
36243690
// TableRow.
36253691
// TODO: it would be trivial to convert to avro records here instead.
3626-
formatFunction = BigQueryUtils.toTableRow(input.getToRowFunction());
3692+
formatFunction =
3693+
SerializableBiFunctions.ignore1st(BigQueryUtils.toTableRow(input.getToRowFunction()));
36273694
}
36283695
// Infer the TableSchema from the input Beam schema.
36293696
// TODO: If the user provided a schema, we should use that. There are things that can be
@@ -3769,8 +3836,9 @@ private <DestinationT> WriteResult continueExpandTyped(
37693836
getCreateDisposition(),
37703837
dynamicDestinations,
37713838
elementCoder,
3772-
tableRowWriterFactory.getToRowFn(),
3773-
tableRowWriterFactory.getToFailsafeRowFn())
3839+
SerializableBiFunctions.fix1st(tableRowWriterFactory.getToRowFn(), null),
3840+
SerializableBiFunctions.fix1st(
3841+
tableRowWriterFactory.getToFailsafeRowFn(), null))
37743842
.withInsertRetryPolicy(retryPolicy)
37753843
.withTestServices(getBigQueryServices())
37763844
.withExtendedErrorInfo(getExtendedErrorInfo())

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.apache.beam.sdk.schemas.Schema.FieldType;
5656
import org.apache.beam.sdk.schemas.logicaltypes.NanosDuration;
5757
import org.apache.beam.sdk.transforms.PTransform;
58+
import org.apache.beam.sdk.transforms.SerializableBiFunctions;
5859
import org.apache.beam.sdk.transforms.SerializableFunction;
5960
import org.apache.beam.sdk.transforms.errorhandling.BadRecord;
6061
import org.apache.beam.sdk.transforms.errorhandling.BadRecordRouter;
@@ -641,15 +642,17 @@ public Write<?> fromConfigRow(Row configRow, PipelineOptions options) {
641642
if (formatFunctionBytes != null) {
642643
builder =
643644
builder.setFormatFunction(
644-
(SerializableFunction<?, TableRow>) fromByteArray(formatFunctionBytes));
645+
SerializableBiFunctions.ignore1st(
646+
(SerializableFunction<?, TableRow>) fromByteArray(formatFunctionBytes)));
645647
}
646648
byte[] formatRecordOnFailureFunctionBytes =
647649
configRow.getBytes("format_record_on_failure_function");
648650
if (formatRecordOnFailureFunctionBytes != null) {
649651
builder =
650652
builder.setFormatRecordOnFailureFunction(
651-
(SerializableFunction<?, TableRow>)
652-
fromByteArray(formatRecordOnFailureFunctionBytes));
653+
SerializableBiFunctions.ignore1st(
654+
(SerializableFunction<?, TableRow>)
655+
fromByteArray(formatRecordOnFailureFunctionBytes)));
653656
}
654657
byte[] avroRowWriterFactoryBytes = configRow.getBytes("avro_row_writer_factory");
655658
if (avroRowWriterFactoryBytes != null) {

sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/RowWriterFactory.java

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import java.io.Serializable;
2323
import org.apache.avro.Schema;
2424
import org.apache.avro.io.DatumWriter;
25+
import org.apache.beam.sdk.transforms.SerializableBiFunction;
26+
import org.apache.beam.sdk.transforms.SerializableBiFunctions;
2527
import org.apache.beam.sdk.transforms.SerializableFunction;
2628
import org.checkerframework.checker.nullness.qual.Nullable;
2729

@@ -41,29 +43,45 @@ abstract BigQueryRowWriter<ElementT> createRowWriter(
4143
String tempFilePrefix, DestinationT destination) throws Exception;
4244

4345
static <ElementT, DestinationT> RowWriterFactory<ElementT, DestinationT> tableRows(
44-
SerializableFunction<ElementT, TableRow> toRow,
45-
SerializableFunction<ElementT, TableRow> toFailsafeRow) {
46+
SerializableBiFunction<
47+
TableRowToStorageApiProto.@Nullable SchemaInformation, ElementT, TableRow>
48+
toRow,
49+
SerializableBiFunction<
50+
TableRowToStorageApiProto.@Nullable SchemaInformation, ElementT, TableRow>
51+
toFailsafeRow) {
4652
return new TableRowWriterFactory<ElementT, DestinationT>(toRow, toFailsafeRow);
4753
}
4854

4955
static final class TableRowWriterFactory<ElementT, DestinationT>
5056
extends RowWriterFactory<ElementT, DestinationT> {
5157

52-
private final SerializableFunction<ElementT, TableRow> toRow;
53-
private final SerializableFunction<ElementT, TableRow> toFailsafeRow;
58+
private final SerializableBiFunction<
59+
TableRowToStorageApiProto.@Nullable SchemaInformation, ElementT, TableRow>
60+
toRow;
61+
private final SerializableBiFunction<
62+
TableRowToStorageApiProto.@Nullable SchemaInformation, ElementT, TableRow>
63+
toFailsafeRow;
5464

5565
private TableRowWriterFactory(
56-
SerializableFunction<ElementT, TableRow> toRow,
57-
SerializableFunction<ElementT, TableRow> toFailsafeRow) {
66+
SerializableBiFunction<
67+
TableRowToStorageApiProto.@Nullable SchemaInformation, ElementT, TableRow>
68+
toRow,
69+
SerializableBiFunction<
70+
TableRowToStorageApiProto.@Nullable SchemaInformation, ElementT, TableRow>
71+
toFailsafeRow) {
5872
this.toRow = toRow;
5973
this.toFailsafeRow = toFailsafeRow;
6074
}
6175

62-
public SerializableFunction<ElementT, TableRow> getToRowFn() {
76+
public SerializableBiFunction<
77+
TableRowToStorageApiProto.@Nullable SchemaInformation, ElementT, TableRow>
78+
getToRowFn() {
6379
return toRow;
6480
}
6581

66-
public SerializableFunction<ElementT, TableRow> getToFailsafeRowFn() {
82+
public SerializableBiFunction<
83+
TableRowToStorageApiProto.@Nullable SchemaInformation, ElementT, TableRow>
84+
getToFailsafeRowFn() {
6785
if (toFailsafeRow == null) {
6886
return toRow;
6987
}
@@ -76,9 +94,10 @@ public OutputType getOutputType() {
7694
}
7795

7896
@Override
97+
@SuppressWarnings("nullness")
7998
public BigQueryRowWriter<ElementT> createRowWriter(
8099
String tempFilePrefix, DestinationT destination) throws Exception {
81-
return new TableRowWriter<>(tempFilePrefix, toRow);
100+
return new TableRowWriter<>(tempFilePrefix, SerializableBiFunctions.fix1st(toRow, null));
82101
}
83102

84103
@Override

0 commit comments

Comments
 (0)