Skip to content

Commit 400b114

Browse files
authored
Add Staged Artifact validations for RunnerV2 (#37974)
1 parent 1db67c7 commit 400b114

10 files changed

Lines changed: 316 additions & 25 deletions

File tree

sdks/go/container/boot.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ func main() {
149149
log.Fatalf("Endpoint not set: %v", err)
150150
}
151151
logger := &tools.Logger{Endpoint: *loggingEndpoint}
152+
log.SetOutput(tools.NewBufferedLoggerWithFlushInterval(ctx, logger, 0))
152153
logger.Printf(ctx, "Initializing Go harness: %v", strings.Join(os.Args, " "))
153154

154155
// (1) Obtain the pipeline options
@@ -158,6 +159,9 @@ func main() {
158159
logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err)
159160
}
160161

162+
// Inject artifact validation enabled state into context
163+
ctx = artifact.WithArtifactValidation(ctx, !artifact.HasExperiment(info.GetPipelineOptions(), "disable_staged_file_integrity_checks"))
164+
161165
// (2) Retrieve the staged files.
162166
//
163167
// The Go SDK harness downloads the worker binary and invokes

sdks/go/pkg/beam/artifact/materialize.go

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ const (
5151
NoArtifactsStaged = "__no_artifacts_staged__"
5252
)
5353

54+
type validationKey string
55+
56+
const artifactValidationKey validationKey = "artifact_validation_enabled"
57+
58+
// WithArtifactValidation returns a new context carrying the artifact validation enabled state.
59+
func WithArtifactValidation(ctx context.Context, enabled bool) context.Context {
60+
return context.WithValue(ctx, artifactValidationKey, enabled)
61+
}
62+
63+
// isArtifactValidationEnabled parses pipeline options to check if "disable_integrity_checks" is enabled.
64+
func isArtifactValidationEnabled(ctx context.Context) bool {
65+
if val, ok := ctx.Value(artifactValidationKey).(bool); ok {
66+
return val
67+
}
68+
return true
69+
}
70+
5471
// Materialize is a convenience helper for ensuring that all artifacts are
5572
// present and uncorrupted. It interprets each artifact name as a relative
5673
// path under the dest directory. It does not retrieve valid artifacts already
@@ -131,6 +148,7 @@ func newMaterializeWithClient(ctx context.Context, client jobpb.ArtifactRetrieva
131148
RoleUrn: URNStagingTo,
132149
RolePayload: rolePayload,
133150
},
151+
expectedSha256: filePayload.Sha256,
134152
})
135153
}
136154

@@ -183,8 +201,9 @@ func MustExtractFilePayload(artifact *pipepb.ArtifactInformation) (string, strin
183201
}
184202

185203
type artifact struct {
186-
client jobpb.ArtifactRetrievalServiceClient
187-
dep *pipepb.ArtifactInformation
204+
client jobpb.ArtifactRetrievalServiceClient
205+
dep *pipepb.ArtifactInformation
206+
expectedSha256 string
188207
}
189208

190209
func (a artifact) retrieve(ctx context.Context, dest string) error {
@@ -231,7 +250,19 @@ func (a artifact) retrieve(ctx context.Context, dest string) error {
231250
stat, _ := fd.Stat()
232251
log.Printf("Downloaded: %v (sha256: %v, size: %v)", filename, sha256Hash, stat.Size())
233252

234-
return fd.Close()
253+
if err := fd.Close(); err != nil {
254+
return err
255+
}
256+
257+
if isArtifactValidationEnabled(ctx) {
258+
if a.expectedSha256 == "" {
259+
log.Printf("WARN: Artifact validation skipped for file: %v", filename)
260+
} else if sha256Hash != a.expectedSha256 {
261+
return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.expectedSha256)
262+
}
263+
}
264+
265+
return nil
235266
}
236267

237268
func writeChunks(stream jobpb.ArtifactRetrievalService_GetArtifactClient, w io.Writer) (string, error) {
@@ -442,8 +473,12 @@ func retrieve(ctx context.Context, client jobpb.LegacyArtifactRetrievalServiceCl
442473
}
443474

444475
// Artifact Sha256 hash is an optional field in metadata so we should only validate when its present.
445-
if a.Sha256 != "" && sha256Hash != a.Sha256 {
446-
return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.Sha256)
476+
if isArtifactValidationEnabled(ctx) {
477+
if a.Sha256 == "" {
478+
log.Printf("WARN: Artifact validation skipped for file: %v", filename)
479+
} else if sha256Hash != a.Sha256 {
480+
return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.Sha256)
481+
}
447482
}
448483
return nil
449484
}

sdks/go/pkg/beam/artifact/materialize_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,52 @@ func TestMultiRetrieve(t *testing.T) {
8282
}
8383
}
8484

85+
func TestRetrieveWithBadShaFails(t *testing.T) {
86+
cc := startServer(t)
87+
defer cc.Close()
88+
89+
ctx := grpcx.WriteWorkerID(context.Background(), "idA")
90+
keys := []string{"foo"}
91+
st := "whatever"
92+
rt, artifacts := populate(ctx, cc, t, keys, 300, st)
93+
94+
dst := makeTempDir(t)
95+
defer os.RemoveAll(dst)
96+
97+
client := jobpb.NewLegacyArtifactRetrievalServiceClient(cc)
98+
for _, a := range artifacts {
99+
a.Sha256 = "badhash" // mutate hash
100+
if err := Retrieve(ctx, client, a, rt, dst); err == nil {
101+
t.Errorf("expected materialization to fail due to bad sha256 mismatch")
102+
}
103+
}
104+
}
105+
106+
func TestRetrieveWithBadShaAndExperimentSucceeds(t *testing.T) {
107+
cc := startServer(t)
108+
defer cc.Close()
109+
110+
ctx := WithArtifactValidation(grpcx.WriteWorkerID(context.Background(), "idA"), false)
111+
keys := []string{"foo"}
112+
st := "whatever"
113+
rt, artifacts := populate(ctx, cc, t, keys, 300, st)
114+
115+
dst := makeTempDir(t)
116+
defer os.RemoveAll(dst)
117+
118+
client := jobpb.NewLegacyArtifactRetrievalServiceClient(cc)
119+
for _, a := range artifacts {
120+
originalHash := a.Sha256
121+
a.Sha256 = "badhash" // mutate hash
122+
filename := makeFilename(dst, a.Name)
123+
if err := Retrieve(ctx, client, a, rt, dst); err != nil {
124+
t.Errorf("materialize failed but should have succeeded because validation was disabled via experiment: %v", err)
125+
continue
126+
}
127+
verifySHA256(t, filename, originalHash)
128+
}
129+
}
130+
85131
// populate stages a set of artifacts with the given keys, each with
86132
// slightly different sizes and chucksizes.
87133
func populate(ctx context.Context, cc *grpc.ClientConn, t *testing.T, keys []string, size int, st string) (string, []*jobpb.ArtifactMetadata) {
@@ -266,6 +312,55 @@ func TestNewRetrieveWithResolution(t *testing.T) {
266312
checkStagedFiles(mds, dest, expected, t)
267313
}
268314

315+
func TestIsArtifactValidationEnabled(t *testing.T) {
316+
ctx := context.Background()
317+
if !isArtifactValidationEnabled(ctx) {
318+
t.Errorf("empty context should have validation enabled")
319+
}
320+
321+
ctx2 := WithArtifactValidation(ctx, false)
322+
if isArtifactValidationEnabled(ctx2) {
323+
t.Errorf("context with validation disabled should have validation disabled")
324+
}
325+
}
326+
327+
func TestNewRetrieveWithBadShaFails(t *testing.T) {
328+
expected := map[string]string{"a.txt": "a"}
329+
client := &fakeRetrievalService{artifacts: expected}
330+
dest := makeTempDir(t)
331+
defer os.RemoveAll(dest)
332+
ctx := grpcx.WriteWorkerID(context.Background(), "worker")
333+
334+
_, err := newMaterializeWithClient(ctx, client, client.fileArtifactsWithBadSha(), dest)
335+
if err == nil {
336+
t.Fatalf("expected materialization to fail due to bad sha256 mismatch")
337+
}
338+
}
339+
340+
func TestNewRetrieveWithBadShaAndExperimentSucceeds(t *testing.T) {
341+
expected := map[string]string{"a.txt": "a"}
342+
client := &fakeRetrievalService{artifacts: expected}
343+
dest := makeTempDir(t)
344+
defer os.RemoveAll(dest)
345+
346+
ctx := WithArtifactValidation(grpcx.WriteWorkerID(context.Background(), "worker"), false)
347+
348+
mds, err := newMaterializeWithClient(ctx, client, client.fileArtifactsWithBadSha(), dest)
349+
if err != nil {
350+
t.Fatalf("materialize failed but should have succeeded because validation was disabled via experiment: %v", err)
351+
}
352+
353+
generated := make(map[string]string)
354+
for _, md := range mds {
355+
name, _ := MustExtractFilePayload(md)
356+
payload, _ := proto.Marshal(&pipepb.ArtifactStagingToRolePayload{
357+
StagedName: name})
358+
generated[name] = string(payload)
359+
}
360+
361+
checkStagedFiles(mds, dest, generated, t)
362+
}
363+
269364
func checkStagedFiles(mds []*pipepb.ArtifactInformation, dest string, expected map[string]string, t *testing.T) {
270365
if len(mds) != len(expected) {
271366
t.Errorf("wrong number of artifacts staged %v vs %v", len(mds), len(expected))
@@ -323,6 +418,21 @@ func (fake *fakeRetrievalService) fileArtifactsWithoutStagingTo() []*pipepb.Arti
323418
return artifacts
324419
}
325420

421+
func (fake *fakeRetrievalService) fileArtifactsWithBadSha() []*pipepb.ArtifactInformation {
422+
var artifacts []*pipepb.ArtifactInformation
423+
for name := range fake.artifacts {
424+
payload, _ := proto.Marshal(&pipepb.ArtifactFilePayload{
425+
Path: filepath.Join("/tmp", name),
426+
Sha256: "badhash",
427+
})
428+
artifacts = append(artifacts, &pipepb.ArtifactInformation{
429+
TypeUrn: URNFileArtifact,
430+
TypePayload: payload,
431+
})
432+
}
433+
return artifacts
434+
}
435+
326436
func (fake *fakeRetrievalService) urlArtifactsWithoutStagingTo() []*pipepb.ArtifactInformation {
327437
var artifacts []*pipepb.ArtifactInformation
328438
for name := range fake.artifacts {
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one or more
2+
// contributor license agreements. See the NOTICE file distributed with
3+
// this work for additional information regarding copyright ownership.
4+
// The ASF licenses this file to You under the Apache License, Version 2.0
5+
// (the "License"); you may not use this file except in compliance with
6+
// the License. You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
package artifact
17+
18+
import (
19+
structpb "google.golang.org/protobuf/types/known/structpb"
20+
)
21+
22+
// GetExperiments extracts a list of experiments from the pipeline options.
23+
func GetExperiments(options *structpb.Struct) []string {
24+
if options == nil {
25+
return nil
26+
}
27+
28+
var exps []string
29+
// Try legacy style
30+
for _, v := range options.GetFields()["options"].GetStructValue().GetFields()["experiments"].GetListValue().GetValues() {
31+
exps = append(exps, v.GetStringValue())
32+
}
33+
// Try URN style
34+
for _, v := range options.GetFields()["beam:option:experiments:v1"].GetListValue().GetValues() {
35+
exps = append(exps, v.GetStringValue())
36+
}
37+
return exps
38+
}
39+
40+
// HasExperiment checks if a specific experiment is enabled in the pipeline options.
41+
func HasExperiment(options *structpb.Struct, experiment string) bool {
42+
for _, exp := range GetExperiments(options) {
43+
if exp == experiment {
44+
return true
45+
}
46+
}
47+
return false
48+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one or more
2+
// contributor license agreements. See the NOTICE file distributed with
3+
// this work for additional information regarding copyright ownership.
4+
// The ASF licenses this file to You under the Apache License, Version 2.0
5+
// (the "License"); you may not use this file except in compliance with
6+
// the License. You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
16+
package artifact
17+
18+
import (
19+
"testing"
20+
21+
structpb "google.golang.org/protobuf/types/known/structpb"
22+
)
23+
24+
func TestGetExperiments_Nil(t *testing.T) {
25+
if got := GetExperiments(nil); got != nil {
26+
t.Errorf("GetExperiments(nil) = %v, want nil", got)
27+
}
28+
}
29+
30+
func TestGetExperiments_Legacy(t *testing.T) {
31+
options, _ := structpb.NewStruct(map[string]interface{}{
32+
"options": map[string]interface{}{
33+
"experiments": []interface{}{"exp1", "exp2"},
34+
},
35+
})
36+
exps := GetExperiments(options)
37+
if len(exps) != 2 || exps[0] != "exp1" || exps[1] != "exp2" {
38+
t.Errorf("GetExperiments() = %v, want [exp1 exp2]", exps)
39+
}
40+
}
41+
42+
func TestGetExperiments_URN(t *testing.T) {
43+
urnOptions, _ := structpb.NewStruct(map[string]interface{}{
44+
"beam:option:experiments:v1": []interface{}{"expA", "expB"},
45+
})
46+
expsURN := GetExperiments(urnOptions)
47+
if len(expsURN) != 2 || expsURN[0] != "expA" || expsURN[1] != "expB" {
48+
t.Errorf("GetExperiments() = %v, want [expA expB]", expsURN)
49+
}
50+
}
51+
52+
func TestHasExperiment(t *testing.T) {
53+
options, _ := structpb.NewStruct(map[string]interface{}{
54+
"options": map[string]interface{}{
55+
"experiments": []interface{}{"exp1", "exp2"},
56+
},
57+
})
58+
59+
if !HasExperiment(options, "exp1") {
60+
t.Errorf("HasExperiment(exp1) = false, want true")
61+
}
62+
if HasExperiment(options, "exp3") {
63+
t.Errorf("HasExperiment(exp3) = true, want false")
64+
}
65+
}
66+
67+
func TestGetExperiments_Combined(t *testing.T) {
68+
options, _ := structpb.NewStruct(map[string]interface{}{
69+
"options": map[string]interface{}{
70+
"experiments": []interface{}{"exp1", "exp2"},
71+
},
72+
"beam:option:experiments:v1": []interface{}{"expA", "expB"},
73+
})
74+
exps := GetExperiments(options)
75+
if len(exps) != 4 || exps[0] != "exp1" || exps[1] != "exp2" || exps[2] != "expA" || exps[3] != "expB" {
76+
t.Errorf("GetExperiments() = %v, want [exp1 exp2 expA expB]", exps)
77+
}
78+
}

sdks/java/container/boot.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ func main() {
105105
logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err)
106106
}
107107

108+
// Inject artifact validation enabled state into context
109+
ctx = artifact.WithArtifactValidation(ctx, !artifact.HasExperiment(info.GetPipelineOptions(), "disable_staged_file_integrity_checks"))
110+
108111
// (2) Retrieve the staged user jars. We ignore any disk limit,
109112
// because the staged jars are mandatory.
110113

sdks/python/apache_beam/runners/dataflow/internal/apiclient.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -600,8 +600,9 @@ def _stage_resources(self, pipeline, options):
600600
else:
601601
remote_name = os.path.basename(type_payload.path)
602602
is_staged_role = False
603-
604-
if self._enable_caching and not type_payload.sha256:
603+
# compute sha256 even if caching is disabled.
604+
# This is used to check the payload integrity along with caching.
605+
if not type_payload.sha256:
605606
type_payload.sha256 = self._compute_sha256(type_payload.path)
606607

607608
if type_payload.sha256 and type_payload.sha256 in staged_hashes:

0 commit comments

Comments
 (0)