diff --git a/sdks/go/container/boot.go b/sdks/go/container/boot.go index ab2da3169319..469285821f7e 100644 --- a/sdks/go/container/boot.go +++ b/sdks/go/container/boot.go @@ -149,6 +149,7 @@ func main() { log.Fatalf("Endpoint not set: %v", err) } logger := &tools.Logger{Endpoint: *loggingEndpoint} + log.SetOutput(tools.NewBufferedLoggerWithFlushInterval(ctx, logger, 0)) logger.Printf(ctx, "Initializing Go harness: %v", strings.Join(os.Args, " ")) // (1) Obtain the pipeline options @@ -158,6 +159,9 @@ func main() { logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err) } + // Inject artifact validation enabled state into context + ctx = artifact.WithArtifactValidation(ctx, !artifact.HasExperiment(info.GetPipelineOptions(), "disable_staged_file_integrity_checks")) + // (2) Retrieve the staged files. // // The Go SDK harness downloads the worker binary and invokes diff --git a/sdks/go/pkg/beam/artifact/materialize.go b/sdks/go/pkg/beam/artifact/materialize.go index 624e30efcd2b..db624f3776af 100644 --- a/sdks/go/pkg/beam/artifact/materialize.go +++ b/sdks/go/pkg/beam/artifact/materialize.go @@ -51,6 +51,23 @@ const ( NoArtifactsStaged = "__no_artifacts_staged__" ) +type validationKey string + +const artifactValidationKey validationKey = "artifact_validation_enabled" + +// WithArtifactValidation returns a new context carrying the artifact validation enabled state. +func WithArtifactValidation(ctx context.Context, enabled bool) context.Context { + return context.WithValue(ctx, artifactValidationKey, enabled) +} + +// isArtifactValidationEnabled parses pipeline options to check if "disable_integrity_checks" is enabled. +func isArtifactValidationEnabled(ctx context.Context) bool { + if val, ok := ctx.Value(artifactValidationKey).(bool); ok { + return val + } + return true +} + // Materialize is a convenience helper for ensuring that all artifacts are // present and uncorrupted. It interprets each artifact name as a relative // path under the dest directory. It does not retrieve valid artifacts already @@ -131,6 +148,7 @@ func newMaterializeWithClient(ctx context.Context, client jobpb.ArtifactRetrieva RoleUrn: URNStagingTo, RolePayload: rolePayload, }, + expectedSha256: filePayload.Sha256, }) } @@ -183,8 +201,9 @@ func MustExtractFilePayload(artifact *pipepb.ArtifactInformation) (string, strin } type artifact struct { - client jobpb.ArtifactRetrievalServiceClient - dep *pipepb.ArtifactInformation + client jobpb.ArtifactRetrievalServiceClient + dep *pipepb.ArtifactInformation + expectedSha256 string } func (a artifact) retrieve(ctx context.Context, dest string) error { @@ -231,7 +250,19 @@ func (a artifact) retrieve(ctx context.Context, dest string) error { stat, _ := fd.Stat() log.Printf("Downloaded: %v (sha256: %v, size: %v)", filename, sha256Hash, stat.Size()) - return fd.Close() + if err := fd.Close(); err != nil { + return err + } + + if isArtifactValidationEnabled(ctx) { + if a.expectedSha256 == "" { + log.Printf("WARN: Artifact validation skipped for file: %v", filename) + } else if sha256Hash != a.expectedSha256 { + return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.expectedSha256) + } + } + + return nil } func writeChunks(stream jobpb.ArtifactRetrievalService_GetArtifactClient, w io.Writer) (string, error) { @@ -442,8 +473,12 @@ func retrieve(ctx context.Context, client jobpb.LegacyArtifactRetrievalServiceCl } // Artifact Sha256 hash is an optional field in metadata so we should only validate when its present. - if a.Sha256 != "" && sha256Hash != a.Sha256 { - return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.Sha256) + if isArtifactValidationEnabled(ctx) { + if a.Sha256 == "" { + log.Printf("WARN: Artifact validation skipped for file: %v", filename) + } else if sha256Hash != a.Sha256 { + return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.Sha256) + } } return nil } diff --git a/sdks/go/pkg/beam/artifact/materialize_test.go b/sdks/go/pkg/beam/artifact/materialize_test.go index 31890ed045cc..bf27e13e8a89 100644 --- a/sdks/go/pkg/beam/artifact/materialize_test.go +++ b/sdks/go/pkg/beam/artifact/materialize_test.go @@ -82,6 +82,52 @@ func TestMultiRetrieve(t *testing.T) { } } +func TestRetrieveWithBadShaFails(t *testing.T) { + cc := startServer(t) + defer cc.Close() + + ctx := grpcx.WriteWorkerID(context.Background(), "idA") + keys := []string{"foo"} + st := "whatever" + rt, artifacts := populate(ctx, cc, t, keys, 300, st) + + dst := makeTempDir(t) + defer os.RemoveAll(dst) + + client := jobpb.NewLegacyArtifactRetrievalServiceClient(cc) + for _, a := range artifacts { + a.Sha256 = "badhash" // mutate hash + if err := Retrieve(ctx, client, a, rt, dst); err == nil { + t.Errorf("expected materialization to fail due to bad sha256 mismatch") + } + } +} + +func TestRetrieveWithBadShaAndExperimentSucceeds(t *testing.T) { + cc := startServer(t) + defer cc.Close() + + ctx := WithArtifactValidation(grpcx.WriteWorkerID(context.Background(), "idA"), false) + keys := []string{"foo"} + st := "whatever" + rt, artifacts := populate(ctx, cc, t, keys, 300, st) + + dst := makeTempDir(t) + defer os.RemoveAll(dst) + + client := jobpb.NewLegacyArtifactRetrievalServiceClient(cc) + for _, a := range artifacts { + originalHash := a.Sha256 + a.Sha256 = "badhash" // mutate hash + filename := makeFilename(dst, a.Name) + if err := Retrieve(ctx, client, a, rt, dst); err != nil { + t.Errorf("materialize failed but should have succeeded because validation was disabled via experiment: %v", err) + continue + } + verifySHA256(t, filename, originalHash) + } +} + // populate stages a set of artifacts with the given keys, each with // slightly different sizes and chucksizes. func populate(ctx context.Context, cc *grpc.ClientConn, t *testing.T, keys []string, size int, st string) (string, []*jobpb.ArtifactMetadata) { @@ -266,6 +312,55 @@ func TestNewRetrieveWithResolution(t *testing.T) { checkStagedFiles(mds, dest, expected, t) } +func TestIsArtifactValidationEnabled(t *testing.T) { + ctx := context.Background() + if !isArtifactValidationEnabled(ctx) { + t.Errorf("empty context should have validation enabled") + } + + ctx2 := WithArtifactValidation(ctx, false) + if isArtifactValidationEnabled(ctx2) { + t.Errorf("context with validation disabled should have validation disabled") + } +} + +func TestNewRetrieveWithBadShaFails(t *testing.T) { + expected := map[string]string{"a.txt": "a"} + client := &fakeRetrievalService{artifacts: expected} + dest := makeTempDir(t) + defer os.RemoveAll(dest) + ctx := grpcx.WriteWorkerID(context.Background(), "worker") + + _, err := newMaterializeWithClient(ctx, client, client.fileArtifactsWithBadSha(), dest) + if err == nil { + t.Fatalf("expected materialization to fail due to bad sha256 mismatch") + } +} + +func TestNewRetrieveWithBadShaAndExperimentSucceeds(t *testing.T) { + expected := map[string]string{"a.txt": "a"} + client := &fakeRetrievalService{artifacts: expected} + dest := makeTempDir(t) + defer os.RemoveAll(dest) + + ctx := WithArtifactValidation(grpcx.WriteWorkerID(context.Background(), "worker"), false) + + mds, err := newMaterializeWithClient(ctx, client, client.fileArtifactsWithBadSha(), dest) + if err != nil { + t.Fatalf("materialize failed but should have succeeded because validation was disabled via experiment: %v", err) + } + + generated := make(map[string]string) + for _, md := range mds { + name, _ := MustExtractFilePayload(md) + payload, _ := proto.Marshal(&pipepb.ArtifactStagingToRolePayload{ + StagedName: name}) + generated[name] = string(payload) + } + + checkStagedFiles(mds, dest, generated, t) +} + func checkStagedFiles(mds []*pipepb.ArtifactInformation, dest string, expected map[string]string, t *testing.T) { if len(mds) != len(expected) { t.Errorf("wrong number of artifacts staged %v vs %v", len(mds), len(expected)) @@ -323,6 +418,21 @@ func (fake *fakeRetrievalService) fileArtifactsWithoutStagingTo() []*pipepb.Arti return artifacts } +func (fake *fakeRetrievalService) fileArtifactsWithBadSha() []*pipepb.ArtifactInformation { + var artifacts []*pipepb.ArtifactInformation + for name := range fake.artifacts { + payload, _ := proto.Marshal(&pipepb.ArtifactFilePayload{ + Path: filepath.Join("/tmp", name), + Sha256: "badhash", + }) + artifacts = append(artifacts, &pipepb.ArtifactInformation{ + TypeUrn: URNFileArtifact, + TypePayload: payload, + }) + } + return artifacts +} + func (fake *fakeRetrievalService) urlArtifactsWithoutStagingTo() []*pipepb.ArtifactInformation { var artifacts []*pipepb.ArtifactInformation for name := range fake.artifacts { diff --git a/sdks/go/pkg/beam/artifact/options.go b/sdks/go/pkg/beam/artifact/options.go new file mode 100644 index 000000000000..47356433161c --- /dev/null +++ b/sdks/go/pkg/beam/artifact/options.go @@ -0,0 +1,48 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package artifact + +import ( + structpb "google.golang.org/protobuf/types/known/structpb" +) + +// GetExperiments extracts a list of experiments from the pipeline options. +func GetExperiments(options *structpb.Struct) []string { + if options == nil { + return nil + } + + var exps []string + // Try legacy style + for _, v := range options.GetFields()["options"].GetStructValue().GetFields()["experiments"].GetListValue().GetValues() { + exps = append(exps, v.GetStringValue()) + } + // Try URN style + for _, v := range options.GetFields()["beam:option:experiments:v1"].GetListValue().GetValues() { + exps = append(exps, v.GetStringValue()) + } + return exps +} + +// HasExperiment checks if a specific experiment is enabled in the pipeline options. +func HasExperiment(options *structpb.Struct, experiment string) bool { + for _, exp := range GetExperiments(options) { + if exp == experiment { + return true + } + } + return false +} diff --git a/sdks/go/pkg/beam/artifact/options_test.go b/sdks/go/pkg/beam/artifact/options_test.go new file mode 100644 index 000000000000..a9f0e4bb7e35 --- /dev/null +++ b/sdks/go/pkg/beam/artifact/options_test.go @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package artifact + +import ( + "testing" + + structpb "google.golang.org/protobuf/types/known/structpb" +) + +func TestGetExperiments_Nil(t *testing.T) { + if got := GetExperiments(nil); got != nil { + t.Errorf("GetExperiments(nil) = %v, want nil", got) + } +} + +func TestGetExperiments_Legacy(t *testing.T) { + options, _ := structpb.NewStruct(map[string]interface{}{ + "options": map[string]interface{}{ + "experiments": []interface{}{"exp1", "exp2"}, + }, + }) + exps := GetExperiments(options) + if len(exps) != 2 || exps[0] != "exp1" || exps[1] != "exp2" { + t.Errorf("GetExperiments() = %v, want [exp1 exp2]", exps) + } +} + +func TestGetExperiments_URN(t *testing.T) { + urnOptions, _ := structpb.NewStruct(map[string]interface{}{ + "beam:option:experiments:v1": []interface{}{"expA", "expB"}, + }) + expsURN := GetExperiments(urnOptions) + if len(expsURN) != 2 || expsURN[0] != "expA" || expsURN[1] != "expB" { + t.Errorf("GetExperiments() = %v, want [expA expB]", expsURN) + } +} + +func TestHasExperiment(t *testing.T) { + options, _ := structpb.NewStruct(map[string]interface{}{ + "options": map[string]interface{}{ + "experiments": []interface{}{"exp1", "exp2"}, + }, + }) + + if !HasExperiment(options, "exp1") { + t.Errorf("HasExperiment(exp1) = false, want true") + } + if HasExperiment(options, "exp3") { + t.Errorf("HasExperiment(exp3) = true, want false") + } +} + +func TestGetExperiments_Combined(t *testing.T) { + options, _ := structpb.NewStruct(map[string]interface{}{ + "options": map[string]interface{}{ + "experiments": []interface{}{"exp1", "exp2"}, + }, + "beam:option:experiments:v1": []interface{}{"expA", "expB"}, + }) + exps := GetExperiments(options) + if len(exps) != 4 || exps[0] != "exp1" || exps[1] != "exp2" || exps[2] != "expA" || exps[3] != "expB" { + t.Errorf("GetExperiments() = %v, want [exp1 exp2 expA expB]", exps) + } +} diff --git a/sdks/java/container/boot.go b/sdks/java/container/boot.go index 8c918f231797..3ce79e4927eb 100644 --- a/sdks/java/container/boot.go +++ b/sdks/java/container/boot.go @@ -105,6 +105,9 @@ func main() { logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err) } + // Inject artifact validation enabled state into context + ctx = artifact.WithArtifactValidation(ctx, !artifact.HasExperiment(info.GetPipelineOptions(), "disable_staged_file_integrity_checks")) + // (2) Retrieve the staged user jars. We ignore any disk limit, // because the staged jars are mandatory. diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index f38a2ee34bbf..097523a5131c 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -600,8 +600,9 @@ def _stage_resources(self, pipeline, options): else: remote_name = os.path.basename(type_payload.path) is_staged_role = False - - if self._enable_caching and not type_payload.sha256: + # compute sha256 even if caching is disabled. + # This is used to check the payload integrity along with caching. + if not type_payload.sha256: type_payload.sha256 = self._compute_sha256(type_payload.path) if type_payload.sha256 and type_payload.sha256 in staged_hashes: diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 43f4c8a21513..43f51d0b39fd 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -1375,13 +1375,19 @@ def test_stage_resources(self): ]) })) client = apiclient.DataflowApplicationClient(pipeline_options) - with mock.patch.object(apiclient._LegacyDataflowStager, - 'stage_job_resources') as mock_stager: - client._stage_resources(pipeline, pipeline_options) + with mock.patch.object(apiclient.DataflowApplicationClient, + '_compute_sha256', + side_effect=lambda path: 'hash' + path): + with mock.patch.object(apiclient._LegacyDataflowStager, + 'stage_job_resources') as mock_stager: + client._stage_resources(pipeline, pipeline_options) mock_stager.assert_called_once_with( - [('/tmp/foo1', 'foo1', ''), ('/tmp/bar1', 'bar1', ''), - ('/tmp/baz', 'baz1', ''), ('/tmp/renamed1', 'renamed1', 'abcdefg'), - ('/tmp/foo2', 'foo2', ''), ('/tmp/bar2', 'bar2', '')], + [('/tmp/foo1', 'foo1', 'hash/tmp/foo1'), + ('/tmp/bar1', 'bar1', 'hash/tmp/bar1'), + ('/tmp/baz', 'baz1', 'hash/tmp/baz'), + ('/tmp/renamed1', 'renamed1', 'abcdefg'), + ('/tmp/foo2', 'foo2', 'hash/tmp/foo2'), + ('/tmp/bar2', 'bar2', 'hash/tmp/bar2')], staging_location='gs://test-location/staging') pipeline_expected = beam_runner_api_pb2.Pipeline( @@ -1392,8 +1398,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/foo1' - ).SerializeToString(), + url='gs://test-location/staging/foo1', + sha256='hash/tmp/foo1').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( @@ -1401,8 +1407,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/bar1'). - SerializeToString(), + url='gs://test-location/staging/bar1', + sha256='hash/tmp/bar1').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( @@ -1410,8 +1416,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/baz1'). - SerializeToString(), + url='gs://test-location/staging/baz1', + sha256='hash/tmp/baz').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( @@ -1431,8 +1437,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/foo2'). - SerializeToString(), + url='gs://test-location/staging/foo2', + sha256='hash/tmp/foo2').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( @@ -1440,8 +1446,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/bar2'). - SerializeToString(), + url='gs://test-location/staging/bar2', + sha256='hash/tmp/bar2').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( @@ -1449,8 +1455,8 @@ def test_stage_resources(self): beam_runner_api_pb2.ArtifactInformation( type_urn=common_urns.artifact_types.URL.urn, type_payload=beam_runner_api_pb2.ArtifactUrlPayload( - url='gs://test-location/staging/baz1'). - SerializeToString(), + url='gs://test-location/staging/baz1', + sha256='hash/tmp/baz').SerializeToString(), role_urn=common_urns.artifact_roles.STAGING_TO.urn, role_payload=beam_runner_api_pb2. ArtifactStagingToRolePayload( diff --git a/sdks/python/container/boot.go b/sdks/python/container/boot.go index a2655903a4b1..f5a37c9cca0a 100644 --- a/sdks/python/container/boot.go +++ b/sdks/python/container/boot.go @@ -184,6 +184,9 @@ func launchSDKProcess() error { logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err) } + // Inject artifact validation enabled state into context + ctx = artifact.WithArtifactValidation(ctx, !artifact.HasExperiment(info.GetPipelineOptions(), "disable_staged_file_integrity_checks")) + experiments := getExperiments(options) logger.Printf(ctx, "Experiments=%v", experiments) diff --git a/sdks/typescript/container/boot.go b/sdks/typescript/container/boot.go index 44f94f804330..95e26124facc 100644 --- a/sdks/typescript/container/boot.go +++ b/sdks/typescript/container/boot.go @@ -91,6 +91,9 @@ func main() { logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err) } + // Inject artifact validation enabled state into context + ctx = artifact.WithArtifactValidation(ctx, !artifact.HasExperiment(info.GetPipelineOptions(), "disable_staged_file_integrity_checks")) + // (2) Retrieve and install the staged packages. dir := filepath.Join(*semiPersistDir, *id, "staged")