Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdks/go/container/boot.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ func main() {
log.Fatalf("Endpoint not set: %v", err)
}
logger := &tools.Logger{Endpoint: *loggingEndpoint}
log.SetOutput(tools.NewBufferedLoggerWithFlushInterval(ctx, logger, 0))
logger.Printf(ctx, "Initializing Go harness: %v", strings.Join(os.Args, " "))

// (1) Obtain the pipeline options
Expand All @@ -158,6 +159,9 @@ func main() {
logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err)
}

// Inject artifact validation enabled state into context
ctx = artifact.WithArtifactValidation(ctx, !artifact.HasExperiment(info.GetPipelineOptions(), "disable_staged_file_integrity_checks"))

// (2) Retrieve the staged files.
//
// The Go SDK harness downloads the worker binary and invokes
Expand Down
45 changes: 40 additions & 5 deletions sdks/go/pkg/beam/artifact/materialize.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ const (
NoArtifactsStaged = "__no_artifacts_staged__"
)

type validationKey string

const artifactValidationKey validationKey = "artifact_validation_enabled"

// WithArtifactValidation returns a new context carrying the artifact validation enabled state.
func WithArtifactValidation(ctx context.Context, enabled bool) context.Context {
return context.WithValue(ctx, artifactValidationKey, enabled)
}

// isArtifactValidationEnabled parses pipeline options to check if "disable_integrity_checks" is enabled.
func isArtifactValidationEnabled(ctx context.Context) bool {
if val, ok := ctx.Value(artifactValidationKey).(bool); ok {
return val
}
return true
}

// Materialize is a convenience helper for ensuring that all artifacts are
// present and uncorrupted. It interprets each artifact name as a relative
// path under the dest directory. It does not retrieve valid artifacts already
Expand Down Expand Up @@ -131,6 +148,7 @@ func newMaterializeWithClient(ctx context.Context, client jobpb.ArtifactRetrieva
RoleUrn: URNStagingTo,
RolePayload: rolePayload,
},
expectedSha256: filePayload.Sha256,
})
}

Expand Down Expand Up @@ -183,8 +201,9 @@ func MustExtractFilePayload(artifact *pipepb.ArtifactInformation) (string, strin
}

type artifact struct {
client jobpb.ArtifactRetrievalServiceClient
dep *pipepb.ArtifactInformation
client jobpb.ArtifactRetrievalServiceClient
dep *pipepb.ArtifactInformation
expectedSha256 string
}

func (a artifact) retrieve(ctx context.Context, dest string) error {
Expand Down Expand Up @@ -231,7 +250,19 @@ func (a artifact) retrieve(ctx context.Context, dest string) error {
stat, _ := fd.Stat()
log.Printf("Downloaded: %v (sha256: %v, size: %v)", filename, sha256Hash, stat.Size())

return fd.Close()
if err := fd.Close(); err != nil {
return err
}

if isArtifactValidationEnabled(ctx) {
if a.expectedSha256 == "" {
log.Printf("WARN: Artifact validation skipped for file: %v", filename)
} else if sha256Hash != a.expectedSha256 {
return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.expectedSha256)
}
}

return nil
}

func writeChunks(stream jobpb.ArtifactRetrievalService_GetArtifactClient, w io.Writer) (string, error) {
Expand Down Expand Up @@ -442,8 +473,12 @@ func retrieve(ctx context.Context, client jobpb.LegacyArtifactRetrievalServiceCl
}

// Artifact Sha256 hash is an optional field in metadata so we should only validate when its present.
if a.Sha256 != "" && sha256Hash != a.Sha256 {
return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.Sha256)
if isArtifactValidationEnabled(ctx) {
if a.Sha256 == "" {
log.Printf("WARN: Artifact validation skipped for file: %v", filename)
} else if sha256Hash != a.Sha256 {
return errors.Errorf("bad SHA256 for %v: %v, want %v", filename, sha256Hash, a.Sha256)
}
}
return nil
}
Expand Down
110 changes: 110 additions & 0 deletions sdks/go/pkg/beam/artifact/materialize_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,52 @@ func TestMultiRetrieve(t *testing.T) {
}
}

func TestRetrieveWithBadShaFails(t *testing.T) {
cc := startServer(t)
defer cc.Close()

ctx := grpcx.WriteWorkerID(context.Background(), "idA")
keys := []string{"foo"}
st := "whatever"
rt, artifacts := populate(ctx, cc, t, keys, 300, st)

dst := makeTempDir(t)
defer os.RemoveAll(dst)

client := jobpb.NewLegacyArtifactRetrievalServiceClient(cc)
for _, a := range artifacts {
a.Sha256 = "badhash" // mutate hash
if err := Retrieve(ctx, client, a, rt, dst); err == nil {
t.Errorf("expected materialization to fail due to bad sha256 mismatch")
}
}
}

func TestRetrieveWithBadShaAndExperimentSucceeds(t *testing.T) {
cc := startServer(t)
defer cc.Close()

ctx := WithArtifactValidation(grpcx.WriteWorkerID(context.Background(), "idA"), false)
keys := []string{"foo"}
st := "whatever"
rt, artifacts := populate(ctx, cc, t, keys, 300, st)

dst := makeTempDir(t)
defer os.RemoveAll(dst)

client := jobpb.NewLegacyArtifactRetrievalServiceClient(cc)
for _, a := range artifacts {
originalHash := a.Sha256
a.Sha256 = "badhash" // mutate hash
filename := makeFilename(dst, a.Name)
if err := Retrieve(ctx, client, a, rt, dst); err != nil {
t.Errorf("materialize failed but should have succeeded because validation was disabled via experiment: %v", err)
continue
}
verifySHA256(t, filename, originalHash)
}
}

// populate stages a set of artifacts with the given keys, each with
// slightly different sizes and chucksizes.
func populate(ctx context.Context, cc *grpc.ClientConn, t *testing.T, keys []string, size int, st string) (string, []*jobpb.ArtifactMetadata) {
Expand Down Expand Up @@ -266,6 +312,55 @@ func TestNewRetrieveWithResolution(t *testing.T) {
checkStagedFiles(mds, dest, expected, t)
}

func TestIsArtifactValidationEnabled(t *testing.T) {
ctx := context.Background()
if !isArtifactValidationEnabled(ctx) {
t.Errorf("empty context should have validation enabled")
}

ctx2 := WithArtifactValidation(ctx, false)
if isArtifactValidationEnabled(ctx2) {
t.Errorf("context with validation disabled should have validation disabled")
}
}

func TestNewRetrieveWithBadShaFails(t *testing.T) {
expected := map[string]string{"a.txt": "a"}
client := &fakeRetrievalService{artifacts: expected}
dest := makeTempDir(t)
defer os.RemoveAll(dest)
ctx := grpcx.WriteWorkerID(context.Background(), "worker")

_, err := newMaterializeWithClient(ctx, client, client.fileArtifactsWithBadSha(), dest)
if err == nil {
t.Fatalf("expected materialization to fail due to bad sha256 mismatch")
}
}

func TestNewRetrieveWithBadShaAndExperimentSucceeds(t *testing.T) {
expected := map[string]string{"a.txt": "a"}
client := &fakeRetrievalService{artifacts: expected}
dest := makeTempDir(t)
defer os.RemoveAll(dest)

ctx := WithArtifactValidation(grpcx.WriteWorkerID(context.Background(), "worker"), false)

mds, err := newMaterializeWithClient(ctx, client, client.fileArtifactsWithBadSha(), dest)
if err != nil {
t.Fatalf("materialize failed but should have succeeded because validation was disabled via experiment: %v", err)
}

generated := make(map[string]string)
for _, md := range mds {
name, _ := MustExtractFilePayload(md)
payload, _ := proto.Marshal(&pipepb.ArtifactStagingToRolePayload{
StagedName: name})
generated[name] = string(payload)
}

checkStagedFiles(mds, dest, generated, t)
}

func checkStagedFiles(mds []*pipepb.ArtifactInformation, dest string, expected map[string]string, t *testing.T) {
if len(mds) != len(expected) {
t.Errorf("wrong number of artifacts staged %v vs %v", len(mds), len(expected))
Expand Down Expand Up @@ -323,6 +418,21 @@ func (fake *fakeRetrievalService) fileArtifactsWithoutStagingTo() []*pipepb.Arti
return artifacts
}

func (fake *fakeRetrievalService) fileArtifactsWithBadSha() []*pipepb.ArtifactInformation {
var artifacts []*pipepb.ArtifactInformation
for name := range fake.artifacts {
payload, _ := proto.Marshal(&pipepb.ArtifactFilePayload{
Path: filepath.Join("/tmp", name),
Sha256: "badhash",
})
artifacts = append(artifacts, &pipepb.ArtifactInformation{
TypeUrn: URNFileArtifact,
TypePayload: payload,
})
}
return artifacts
}

func (fake *fakeRetrievalService) urlArtifactsWithoutStagingTo() []*pipepb.ArtifactInformation {
var artifacts []*pipepb.ArtifactInformation
for name := range fake.artifacts {
Expand Down
48 changes: 48 additions & 0 deletions sdks/go/pkg/beam/artifact/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package artifact

import (
structpb "google.golang.org/protobuf/types/known/structpb"
)

// GetExperiments extracts a list of experiments from the pipeline options.
func GetExperiments(options *structpb.Struct) []string {
if options == nil {
return nil
}

var exps []string
// Try legacy style
for _, v := range options.GetFields()["options"].GetStructValue().GetFields()["experiments"].GetListValue().GetValues() {
exps = append(exps, v.GetStringValue())
}
// Try URN style
for _, v := range options.GetFields()["beam:option:experiments:v1"].GetListValue().GetValues() {
exps = append(exps, v.GetStringValue())
}
return exps
Comment thread
tarun-google marked this conversation as resolved.
}

// HasExperiment checks if a specific experiment is enabled in the pipeline options.
func HasExperiment(options *structpb.Struct, experiment string) bool {
for _, exp := range GetExperiments(options) {
if exp == experiment {
return true
}
}
return false
Comment thread
tarun-google marked this conversation as resolved.
}
78 changes: 78 additions & 0 deletions sdks/go/pkg/beam/artifact/options_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package artifact

import (
"testing"

structpb "google.golang.org/protobuf/types/known/structpb"
)

func TestGetExperiments_Nil(t *testing.T) {
if got := GetExperiments(nil); got != nil {
t.Errorf("GetExperiments(nil) = %v, want nil", got)
}
}

func TestGetExperiments_Legacy(t *testing.T) {
options, _ := structpb.NewStruct(map[string]interface{}{
"options": map[string]interface{}{
"experiments": []interface{}{"exp1", "exp2"},
},
})
exps := GetExperiments(options)
if len(exps) != 2 || exps[0] != "exp1" || exps[1] != "exp2" {
t.Errorf("GetExperiments() = %v, want [exp1 exp2]", exps)
}
}

func TestGetExperiments_URN(t *testing.T) {
urnOptions, _ := structpb.NewStruct(map[string]interface{}{
"beam:option:experiments:v1": []interface{}{"expA", "expB"},
})
expsURN := GetExperiments(urnOptions)
if len(expsURN) != 2 || expsURN[0] != "expA" || expsURN[1] != "expB" {
t.Errorf("GetExperiments() = %v, want [expA expB]", expsURN)
}
}

func TestHasExperiment(t *testing.T) {
options, _ := structpb.NewStruct(map[string]interface{}{
"options": map[string]interface{}{
"experiments": []interface{}{"exp1", "exp2"},
},
})

if !HasExperiment(options, "exp1") {
t.Errorf("HasExperiment(exp1) = false, want true")
}
if HasExperiment(options, "exp3") {
t.Errorf("HasExperiment(exp3) = true, want false")
}
}

func TestGetExperiments_Combined(t *testing.T) {
options, _ := structpb.NewStruct(map[string]interface{}{
"options": map[string]interface{}{
"experiments": []interface{}{"exp1", "exp2"},
},
"beam:option:experiments:v1": []interface{}{"expA", "expB"},
})
exps := GetExperiments(options)
if len(exps) != 4 || exps[0] != "exp1" || exps[1] != "exp2" || exps[2] != "expA" || exps[3] != "expB" {
t.Errorf("GetExperiments() = %v, want [exp1 exp2 expA expB]", exps)
}
}
3 changes: 3 additions & 0 deletions sdks/java/container/boot.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ func main() {
logger.Fatalf(ctx, "Failed to convert pipeline options: %v", err)
}

// Inject artifact validation enabled state into context
ctx = artifact.WithArtifactValidation(ctx, !artifact.HasExperiment(info.GetPipelineOptions(), "disable_staged_file_integrity_checks"))

// (2) Retrieve the staged user jars. We ignore any disk limit,
// because the staged jars are mandatory.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,8 +600,9 @@ def _stage_resources(self, pipeline, options):
else:
remote_name = os.path.basename(type_payload.path)
is_staged_role = False

if self._enable_caching and not type_payload.sha256:
# compute sha256 even if caching is disabled.
# This is used to check the payload integrity along with caching.
if not type_payload.sha256:
Comment thread
Abacn marked this conversation as resolved.
type_payload.sha256 = self._compute_sha256(type_payload.path)

if type_payload.sha256 and type_payload.sha256 in staged_hashes:
Expand Down
Loading
Loading