Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,9 @@ public DataflowPipelineJob run(Pipeline pipeline) {
options.getStager().stageToFile(serializedProtoPipeline, PIPELINE_FILE_NAME);
dataflowOptions.setPipelineUrl(stagedPipeline.getLocation());

String pipelineProtoHash = Hashing.sha256().hashBytes(serializedProtoPipeline).toString();
options.as(SdkHarnessOptions.class).setPipelineProtoHash(pipelineProtoHash);

if (useUnifiedWorker(options)) {
LOG.info("Skipping v1 transform replacements since job will run on v2.");
} else {
Expand Down
8 changes: 7 additions & 1 deletion sdks/go/pkg/beam/runners/dataflow/dataflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ package dataflow

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
Expand All @@ -40,6 +42,7 @@ import (
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/pipelinex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/hooks"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
"github.com/apache/beam/sdks/v2/go/pkg/beam/log"
"github.com/apache/beam/sdks/v2/go/pkg/beam/options/gcpopts"
Expand Down Expand Up @@ -239,7 +242,10 @@ func Execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, error)
log.Info(ctx, "Dry-run: not submitting job!")

log.Info(ctx, model.String())
job, err := dataflowlib.Translate(ctx, model, opts, workerURL, modelURL)
modelBytes := protox.MustEncode(model)
hash := sha256.Sum256(modelBytes)
pipelineProtoHash := hex.EncodeToString(hash[:])
job, err := dataflowlib.Translate(ctx, model, opts, workerURL, modelURL, pipelineProtoHash)
if err != nil {
return nil, err
}
Expand Down
9 changes: 7 additions & 2 deletions sdks/go/pkg/beam/runners/dataflow/dataflowlib/execute.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package dataflowlib

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"os"
"strings"
Expand Down Expand Up @@ -83,14 +85,17 @@ func Execute(ctx context.Context, raw *pipepb.Pipeline, opts *JobOptions, worker
// (2) Upload model to GCS
log.Info(ctx, raw.String())

if err := StageModel(ctx, opts.Project, modelURL, protox.MustEncode(raw)); err != nil {
modelBytes := protox.MustEncode(raw)
modelHash := sha256.Sum256(modelBytes)
pipelineProtoHash := hex.EncodeToString(modelHash[:])
if err := StageModel(ctx, opts.Project, modelURL, modelBytes); err != nil {
return presult, err
}
log.Infof(ctx, "Staged model pipeline: %v", modelURL)

// (3) Translate to v1b3 and submit

job, err := Translate(ctx, raw, opts, workerURL, modelURL)
job, err := Translate(ctx, raw, opts, workerURL, modelURL, pipelineProtoHash)
if err != nil {
return presult, err
}
Expand Down
12 changes: 7 additions & 5 deletions sdks/go/pkg/beam/runners/dataflow/dataflowlib/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func containerImages(p *pipepb.Pipeline) ([]*df.SdkHarnessContainerImage, []stri
}

// Translate translates a pipeline to a Dataflow job.
func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, workerURL, modelURL string) (*df.Job, error) {
func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, workerURL, modelURL string, pipelineProtoHash string) (*df.Job, error) {
// (1) Translate pipeline to v1b3 speak.

jobType := "JOB_TYPE_BATCH"
Expand Down Expand Up @@ -181,10 +181,11 @@ func Translate(ctx context.Context, p *pipepb.Pipeline, opts *JobOptions, worker
SdkPipelineOptions: newMsg(pipelineOptions{
DisplayData: printOptions(opts, images),
Options: dataflowOptions{
PipelineURL: modelURL,
Region: opts.Region,
Experiments: opts.Experiments,
TempLocation: opts.TempLocation,
PipelineURL: modelURL,
PipelineProtoHash: pipelineProtoHash,
Region: opts.Region,
Experiments: opts.Experiments,
TempLocation: opts.TempLocation,
},
GoOptions: opts.Options,
}),
Expand Down Expand Up @@ -359,6 +360,7 @@ func GetMetrics(ctx context.Context, client *df.Service, project, region, jobID
type dataflowOptions struct {
Experiments []string `json:"experiments,omitempty"`
PipelineURL string `json:"pipelineUrl"`
PipelineProtoHash string `json:"pipelineProtoHash,omitempty"`
Region string `json:"region"`
TempLocation string `json:"tempLocation"`
DiskProvisionedIops int64 `json:"diskProvisionedIops"`
Expand Down
51 changes: 50 additions & 1 deletion sdks/go/pkg/beam/runners/dataflow/dataflowlib/job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ import (
"reflect"
"testing"

"encoding/json"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/graphx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/protox"
pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
Expand Down Expand Up @@ -293,7 +296,7 @@ func TestTranslate(t *testing.T) {
workerURL := "gs://any-location/temp"
modelURL := "gs://any-location/temp"

job, err := Translate(ctx, p, opts, workerURL, modelURL)
job, err := Translate(ctx, p, opts, workerURL, modelURL, "dummy-hash-12345")
if err != nil {
t.Fatalf("Translate(...) error = %v, want nil", err)
}
Expand All @@ -310,3 +313,49 @@ func TestTranslate(t *testing.T) {
t.Errorf("DiskProvisionedThroughputMibps = %v, want 200", wp.DiskProvisionedThroughputMibps)
}
}

func TestTranslateWithPipelineHash(t *testing.T) {
p := &pipepb.Pipeline{
Components: &pipepb.Components{
Environments: map[string]*pipepb.Environment{
"env1": {
Payload: protox.MustEncode(&pipepb.DockerPayload{
ContainerImage: "dummy_image",
}),
},
},
},
}
opts := &JobOptions{
Name: "test-job",
Project: "test-project",
Region: "test-region",
Options: runtime.RawOptions{
Options: make(map[string]string),
},
}

expectedHashStr := "dummy-hash-12345"

job, err := Translate(context.Background(), p, opts, "worker-url", "model-url", expectedHashStr)
if err != nil {
t.Fatalf("Translate failed: %v", err)
}

// Verify PipelineProtoHash
var recoveredOptions struct {
Options struct {
PipelineURL string `json:"pipelineUrl"`
PipelineProtoHash string `json:"pipelineProtoHash"`
} `json:"options"`
}

rawOpts := job.Environment.SdkPipelineOptions
if err := json.Unmarshal(rawOpts, &recoveredOptions); err != nil {
t.Fatalf("Failed to unmarshal SdkPipelineOptions: %v", err)
}

if recoveredOptions.Options.PipelineProtoHash != expectedHashStr {
t.Errorf("Expected PipelineProtoHash %v, got %v", expectedHashStr, recoveredOptions.Options.PipelineProtoHash)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -481,4 +481,11 @@ public OpenTelemetry create(PipelineOptions options) {
return GlobalOpenTelemetry.get();
}
}

/** The hex-encoded SHA256 hash of the staged portable pipeline proto. */
@Description("The hex-encoded SHA256 hash of the staged portable pipeline proto")
String getPipelineProtoHash();

void setPipelineProtoHash(String hash);
}

13 changes: 10 additions & 3 deletions sdks/python/apache_beam/runners/dataflow/internal/apiclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __init__(
options,
environment_version,
proto_pipeline_staged_url,
proto_pipeline=None):
proto_pipeline=None,
pipeline_proto_hash=None):
self.standard_options = options.view_as(StandardOptions)
self.google_cloud_options = options.view_as(GoogleCloudOptions)
self.worker_options = options.view_as(WorkerOptions)
Expand Down Expand Up @@ -279,6 +280,8 @@ def __init__(
for k, v in sdk_pipeline_options.items() if v is not None
}
options_dict["pipelineUrl"] = proto_pipeline_staged_url
if pipeline_proto_hash:
options_dict["pipelineProtoHash"] = pipeline_proto_hash
# Don't pass impersonate_service_account through to the harness.
# Though impersonation should start a job, the workers should
# not try to modify their credentials.
Expand Down Expand Up @@ -831,10 +834,13 @@ def create_job_description(self, job):
resources = self._stage_resources(job.proto_pipeline, job.options)

# Stage proto pipeline.
serialized_pipeline = job.proto_pipeline.SerializeToString()
pipeline_proto_hash = hashlib.sha256(serialized_pipeline).hexdigest()

self.stage_file_with_retry(
job.google_cloud_options.staging_location,
shared_names.STAGED_PIPELINE_FILENAME,
io.BytesIO(job.proto_pipeline.SerializeToString()))
io.BytesIO(serialized_pipeline))

job.proto.environment = Environment(
proto_pipeline_staged_url=FileSystems.join(
Expand All @@ -843,7 +849,8 @@ def create_job_description(self, job):
packages=resources,
options=job.options,
environment_version=self.environment_version,
proto_pipeline=job.proto_pipeline).proto
proto_pipeline=job.proto_pipeline,
pipeline_proto_hash=pipeline_proto_hash).proto
_LOGGER.debug('JOB: %s', job)

@retry.with_exponential_backoff(num_retries=3, initial_delay_secs=3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# pytype: skip-file

import hashlib
import io
import itertools
import json
Expand Down Expand Up @@ -97,6 +98,40 @@ def test_pipeline_url(self):

self.assertEqual(pipeline_url.string_value, FAKE_PIPELINE_URL)

def test_pipeline_proto_hash(self):
pipeline_options = PipelineOptions(
['--temp_location', 'gs://any-location/temp'])
proto_pipeline = beam_runner_api_pb2.Pipeline()
proto_pipeline.components.transforms['dummy'].unique_name = 'dummy'

expected_hash = hashlib.sha256(
proto_pipeline.SerializeToString()).hexdigest()

env = apiclient.Environment([],
pipeline_options,
'2.0.0',
FAKE_PIPELINE_URL,
proto_pipeline,
pipeline_proto_hash=expected_hash)

recovered_options = None
for additionalProperty in env.proto.sdkPipelineOptions.additionalProperties:
if additionalProperty.key == 'options':
recovered_options = additionalProperty.value
break
else:
self.fail('No pipeline options found')

pipeline_proto_hash = None
for property in recovered_options.object_value.properties:
if property.key == 'pipelineProtoHash':
pipeline_proto_hash = property.value
break
else:
self.fail('No pipelineProtoHash found')

self.assertEqual(pipeline_proto_hash.string_value, expected_hash)

def test_set_network(self):
pipeline_options = PipelineOptions([
'--network',
Expand Down
Loading