diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index b375de661885..ecc231ab825e 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -1368,6 +1368,9 @@ public DataflowPipelineJob run(Pipeline pipeline) { options.getStager().stageToFile(serializedProtoPipeline, PIPELINE_FILE_NAME); dataflowOptions.setPipelineUrl(stagedPipeline.getLocation()); + String pipelineProtoHash = Hashing.sha256().hashBytes(serializedProtoPipeline).toString(); + options.as(SdkHarnessOptions.class).setPipelineProtoHash(pipelineProtoHash); + if (useUnifiedWorker(options)) { LOG.info("Skipping v1 transform replacements since job will run on v2."); } else { diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflow.go b/sdks/go/pkg/beam/runners/dataflow/dataflow.go index ecbfe53939ec..e968911fcca1 100644 --- a/sdks/go/pkg/beam/runners/dataflow/dataflow.go +++ b/sdks/go/pkg/beam/runners/dataflow/dataflow.go @@ -24,6 +24,8 @@ package dataflow import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "flag" "fmt" @@ -40,6 +42,7 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/hooks" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox" "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" "github.com/apache/beam/sdks/v2/go/pkg/beam/log" "github.com/apache/beam/sdks/v2/go/pkg/beam/options/gcpopts" @@ -239,7 +242,10 @@ func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error) log.Info(ctx, "Dry-run: not submitting job!") log.Info(ctx, model.String()) - job, err := dataflowlib.Translate(ctx, model, opts, workerURL, modelURL) + modelBytes := protox.MustEncode(model) + hash := sha256.Sum256(modelBytes) + pipelineProtoHash := hex.EncodeToString(hash[:]) + job, err := dataflowlib.Translate(ctx, model, opts, workerURL, modelURL, pipelineProtoHash) if err != nil { return nil, err } diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/execute.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/execute.go index 806b8940ae99..396eefab7318 100644 --- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/execute.go +++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/execute.go @@ -19,6 +19,8 @@ package dataflowlib import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "os" "strings" @@ -83,14 +85,17 @@ func Execute(ctx context.Context, raw *pipepb.Pipeline, opts *JobOptions, worker // (2) Upload model to GCS log.Info(ctx, raw.String()) - if err := StageModel(ctx, opts.Project, modelURL, protox.MustEncode(raw)); err != nil { + modelBytes := protox.MustEncode(raw) + modelHash := sha256.Sum256(modelBytes) + pipelineProtoHash := hex.EncodeToString(modelHash[:]) + if err := StageModel(ctx, opts.Project, modelURL, modelBytes); err != nil { return presult, err } log.Infof(ctx, "Staged model pipeline: %v", modelURL) // (3) Translate to v1b3 and submit - job, err := Translate(ctx, raw, opts, workerURL, modelURL) + job, err := Translate(ctx, raw, opts, workerURL, modelURL, pipelineProtoHash) if err != nil { return presult, err } diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go index 2f8057b6d506..f0adb21cf714 100644 --- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go +++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go @@ -117,7 +117,7 @@ func containerImages(p *pipepb.Pipeline) ([]*df.SdkHarnessContainerImage, []stri } // Translate translates a pipeline to a Dataflow job. -func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, workerURL, modelURL string) (*df.Job, error) { +func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, workerURL, modelURL string, pipelineProtoHash string) (*df.Job, error) { // (1) Translate pipeline to v1b3 speak. jobType := "JOB_TYPE_BATCH" @@ -181,10 +181,11 @@ func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, worker SdkPipelineOptions: newMsg(pipelineOptions{ DisplayData: printOptions(opts, images), Options: dataflowOptions{ - PipelineURL: modelURL, - Region: opts.Region, - Experiments: opts.Experiments, - TempLocation: opts.TempLocation, + PipelineURL: modelURL, + PipelineProtoHash: pipelineProtoHash, + Region: opts.Region, + Experiments: opts.Experiments, + TempLocation: opts.TempLocation, }, GoOptions: opts.Options, }), @@ -359,6 +360,7 @@ func GetMetrics(ctx context.Context, client *df.Service, project, region, jobID type dataflowOptions struct { Experiments []string `json:"experiments,omitempty"` PipelineURL string `json:"pipelineUrl"` + PipelineProtoHash string `json:"pipelineProtoHash,omitempty"` Region string `json:"region"` TempLocation string `json:"tempLocation"` DiskProvisionedIops int64 `json:"diskProvisionedIops"` diff --git a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job_test.go b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job_test.go index 303fcb776bff..901adb6c7b72 100644 --- a/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job_test.go +++ b/sdks/go/pkg/beam/runners/dataflow/dataflowlib/job_test.go @@ -21,6 +21,9 @@ import ( "reflect" "testing" + "encoding/json" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox" pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1" @@ -293,7 +296,7 @@ func TestTranslate(t *testing.T) { workerURL := "gs://any-location/temp" modelURL := "gs://any-location/temp" - job, err := Translate(ctx, p, opts, workerURL, modelURL) + job, err := Translate(ctx, p, opts, workerURL, modelURL, "dummy-hash-12345") if err != nil { t.Fatalf("Translate(...) error = %v, want nil", err) } @@ -310,3 +313,49 @@ func TestTranslate(t *testing.T) { t.Errorf("DiskProvisionedThroughputMibps = %v, want 200", wp.DiskProvisionedThroughputMibps) } } + +func TestTranslateWithPipelineHash(t *testing.T) { + p := &pipepb.Pipeline{ + Components: &pipepb.Components{ + Environments: map[string]*pipepb.Environment{ + "env1": { + Payload: protox.MustEncode(&pipepb.DockerPayload{ + ContainerImage: "dummy_image", + }), + }, + }, + }, + } + opts := &JobOptions{ + Name: "test-job", + Project: "test-project", + Region: "test-region", + Options: runtime.RawOptions{ + Options: make(map[string]string), + }, + } + + expectedHashStr := "dummy-hash-12345" + + job, err := Translate(context.Background(), p, opts, "worker-url", "model-url", expectedHashStr) + if err != nil { + t.Fatalf("Translate failed: %v", err) + } + + // Verify PipelineProtoHash + var recoveredOptions struct { + Options struct { + PipelineURL string `json:"pipelineUrl"` + PipelineProtoHash string `json:"pipelineProtoHash"` + } `json:"options"` + } + + rawOpts := job.Environment.SdkPipelineOptions + if err := json.Unmarshal(rawOpts, &recoveredOptions); err != nil { + t.Fatalf("Failed to unmarshal SdkPipelineOptions: %v", err) + } + + if recoveredOptions.Options.PipelineProtoHash != expectedHashStr { + t.Errorf("Expected PipelineProtoHash %v, got %v", expectedHashStr, recoveredOptions.Options.PipelineProtoHash) + } +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java index 831dd69ec95f..7267dda9ed0b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/options/SdkHarnessOptions.java @@ -481,4 +481,10 @@ public OpenTelemetry create(PipelineOptions options) { return GlobalOpenTelemetry.get(); } } + + /** The hex-encoded SHA256 hash of the staged portable pipeline proto. */ + @Description("The hex-encoded SHA256 hash of the staged portable pipeline proto") + String getPipelineProtoHash(); + + void setPipelineProtoHash(String hash); } diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py index 29cb36071488..f38a2ee34bbf 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient.py @@ -97,7 +97,8 @@ def __init__( options, environment_version, proto_pipeline_staged_url, - proto_pipeline=None): + proto_pipeline=None, + pipeline_proto_hash=None): self.standard_options = options.view_as(StandardOptions) self.google_cloud_options = options.view_as(GoogleCloudOptions) self.worker_options = options.view_as(WorkerOptions) @@ -279,6 +280,8 @@ def __init__( for k, v in sdk_pipeline_options.items() if v is not None } options_dict["pipelineUrl"] = proto_pipeline_staged_url + if pipeline_proto_hash: + options_dict["pipelineProtoHash"] = pipeline_proto_hash # Don't pass impersonate_service_account through to the harness. # Though impersonation should start a job, the workers should # not try to modify their credentials. @@ -831,10 +834,13 @@ def create_job_description(self, job): resources = self._stage_resources(job.proto_pipeline, job.options) # Stage proto pipeline. + serialized_pipeline = job.proto_pipeline.SerializeToString() + pipeline_proto_hash = hashlib.sha256(serialized_pipeline).hexdigest() + self.stage_file_with_retry( job.google_cloud_options.staging_location, shared_names.STAGED_PIPELINE_FILENAME, - io.BytesIO(job.proto_pipeline.SerializeToString())) + io.BytesIO(serialized_pipeline)) job.proto.environment = Environment( proto_pipeline_staged_url=FileSystems.join( @@ -843,7 +849,8 @@ def create_job_description(self, job): packages=resources, options=job.options, environment_version=self.environment_version, - proto_pipeline=job.proto_pipeline).proto + proto_pipeline=job.proto_pipeline, + pipeline_proto_hash=pipeline_proto_hash).proto _LOGGER.debug('JOB: %s', job) @retry.with_exponential_backoff(num_retries=3, initial_delay_secs=3) diff --git a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py index 66b1c8e1e5bb..43f4c8a21513 100644 --- a/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py +++ b/sdks/python/apache_beam/runners/dataflow/internal/apiclient_test.py @@ -19,6 +19,7 @@ # pytype: skip-file +import hashlib import io import itertools import json @@ -97,6 +98,40 @@ def test_pipeline_url(self): self.assertEqual(pipeline_url.string_value, FAKE_PIPELINE_URL) + def test_pipeline_proto_hash(self): + pipeline_options = PipelineOptions( + ['--temp_location', 'gs://any-location/temp']) + proto_pipeline = beam_runner_api_pb2.Pipeline() + proto_pipeline.components.transforms['dummy'].unique_name = 'dummy' + + expected_hash = hashlib.sha256( + proto_pipeline.SerializeToString()).hexdigest() + + env = apiclient.Environment([], + pipeline_options, + '2.0.0', + FAKE_PIPELINE_URL, + proto_pipeline, + pipeline_proto_hash=expected_hash) + + recovered_options = None + for additionalProperty in env.proto.sdkPipelineOptions.additionalProperties: + if additionalProperty.key == 'options': + recovered_options = additionalProperty.value + break + else: + self.fail('No pipeline options found') + + pipeline_proto_hash = None + for property in recovered_options.object_value.properties: + if property.key == 'pipelineProtoHash': + pipeline_proto_hash = property.value + break + else: + self.fail('No pipelineProtoHash found') + + self.assertEqual(pipeline_proto_hash.string_value, expected_hash) + def test_set_network(self): pipeline_options = PipelineOptions([ '--network',