diff --git a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto index 9360522ab409..fde3657b667d 100644 --- a/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto +++ b/model/fn-execution/src/main/proto/org/apache/beam/model/fn_execution/v1/beam_fn_api.proto @@ -1009,6 +1009,29 @@ message StateKey { bytes key = 4; } + // Represents a request for all of the entries of a multimap associated with a + // specified user key and window for a PTransform. See + // https://s.apache.org/beam-fn-state-api-and-bundle-processing for further + // details. + // + // Can only be used to perform StateGetRequests and StateClearRequests on the + // user state. + // + // The response data stream will be a concatenation of pairs, where the first + // component is the map key and the second component is a concatenation of + // values associated with that map key. + message MultimapEntriesUserState { + // (Required) The id of the PTransform containing user state. + string transform_id = 1; + // (Required) The id of the user state. + string user_state_id = 2; + // (Required) The window encoded in a nested context. + bytes window = 3; + // (Required) The key of the currently executing element encoded in a + // nested context. + bytes key = 4; + } + // Represents a request for the values of the map key associated with a // specified user key and window for a PTransform. See // https://s.apache.org/beam-fn-state-api-and-bundle-processing for further @@ -1064,6 +1087,7 @@ message StateKey { MultimapKeysSideInput multimap_keys_side_input = 5; MultimapKeysValuesSideInput multimap_keys_values_side_input = 8; MultimapKeysUserState multimap_keys_user_state = 6; + MultimapEntriesUserState multimap_entries_user_state = 10; MultimapUserState multimap_user_state = 7; OrderedListUserState ordered_list_user_state = 9; } diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto index c615b2a5279a..0bdc4f69aab6 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto @@ -1621,13 +1621,13 @@ message AnyOfEnvironmentPayload { // environment understands. message StandardProtocols { enum Enum { - // Indicates suport for progress reporting via the legacy Metrics proto. + // Indicates support for progress reporting via the legacy Metrics proto. LEGACY_PROGRESS_REPORTING = 0 [(beam_urn) = "beam:protocol:progress_reporting:v0"]; - // Indicates suport for progress reporting via the new MonitoringInfo proto. + // Indicates support for progress reporting via the new MonitoringInfo proto. PROGRESS_REPORTING = 1 [(beam_urn) = "beam:protocol:progress_reporting:v1"]; - // Indicates suport for worker status protocol defined at + // Indicates support for worker status protocol defined at // https://s.apache.org/beam-fn-api-harness-status. WORKER_STATUS = 2 [(beam_urn) = "beam:protocol:worker_status:v1"]; @@ -1681,6 +1681,10 @@ message StandardProtocols { // Indicates support for reading, writing and propagating Element's metadata ELEMENT_METADATA = 11 [(beam_urn) = "beam:protocol:element_metadata:v1"]; + + // Indicates whether the SDK supports multimap state. + MULTIMAP_STATE = 12 + [(beam_urn) = "beam:protocol:multimap_state:v1"]; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java index dd7ec6b0f65a..1f7451e72a21 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/fn/stream/PrefetchableIterables.java @@ -94,7 +94,7 @@ public PrefetchableIterator createIterator() { * constructed that ensures that {@link PrefetchableIterator#prefetch()} is a no-op and {@link * PrefetchableIterator#isReady()} always returns true. */ - private static PrefetchableIterable maybePrefetchable(Iterable iterable) { + public static PrefetchableIterable maybePrefetchable(Iterable iterable) { if (iterable instanceof PrefetchableIterable) { return (PrefetchableIterable) iterable; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java index 05ecb21fd956..7a7725cc26d2 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/Environments.java @@ -495,6 +495,7 @@ public static Set getJavaCapabilities() { capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.DATA_SAMPLING)); capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.SDK_CONSUMING_RECEIVED_DATA)); capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.ORDERED_LIST_STATE)); + capabilities.add(BeamUrns.getUrn(StandardProtocols.Enum.MULTIMAP_STATE)); return capabilities.build(); } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 8409133772eb..8a273127b4fc 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -2917,6 +2917,73 @@ public void processElement( pipeline.run(); } + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesMultimapState.class}) + public void testMultimapStateEntries() { + final String stateId = "foo:"; + final String countStateId = "count"; + DoFn>, KV> fn = + new DoFn>, KV>() { + + @StateId(stateId) + private final StateSpec> multimapState = + StateSpecs.multimap(StringUtf8Coder.of(), VarIntCoder.of()); + + @StateId(countStateId) + private final StateSpec> countState = + StateSpecs.combiningFromInputInternal(VarIntCoder.of(), Sum.ofIntegers()); + + @ProcessElement + public void processElement( + ProcessContext c, + @Element KV> element, + @StateId(stateId) MultimapState state, + @StateId(countStateId) CombiningState count, + OutputReceiver> r) { + // Empty before we process any elements. + if (count.read() == 0) { + assertThat(state.entries().read(), emptyIterable()); + } + assertEquals(count.read().intValue(), Iterables.size(state.entries().read())); + + KV value = element.getValue(); + state.put(value.getKey(), value.getValue()); + count.add(1); + + if (count.read() >= 4) { + // This should be evaluated only when ReadableState.read is called. + ReadableState>> entriesView = state.entries(); + + // This is evaluated immediately. + Iterable> entries = state.entries().read(); + + state.remove("b"); + assertEquals(4, Iterables.size(entries)); + state.put("a", 2); + state.put("a", 3); + + assertEquals(5, Iterables.size(entriesView.read())); + // Note we output the view of state before the modifications in this if statement. + for (Entry entry : entries) { + r.output(KV.of(entry.getKey(), entry.getValue())); + } + } + } + }; + PCollection> output = + pipeline + .apply( + Create.of( + KV.of("hello", KV.of("a", 97)), KV.of("hello", KV.of("a", 97)), + KV.of("hello", KV.of("a", 98)), KV.of("hello", KV.of("b", 33)))) + .apply(ParDo.of(fn)); + PAssert.that(output) + .containsInAnyOrder( + KV.of("a", 97), KV.of("a", 97), + KV.of("a", 98), KV.of("b", 33)); + pipeline.run(); + } + @Test @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesMultimapState.class}) public void testMultimapStateRemove() { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java index 410b52cba23b..46aaeebd1cdf 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/EnvironmentsTest.java @@ -219,6 +219,9 @@ public void testCapabilities() { assertThat( Environments.getJavaCapabilities(), hasItem(BeamUrns.getUrn(RunnerApi.StandardProtocols.Enum.ORDERED_LIST_STATE))); + assertThat( + Environments.getJavaCapabilities(), + hasItem(BeamUrns.getUrn(RunnerApi.StandardProtocols.Enum.MULTIMAP_STATE))); // Check that SDF truncation is supported assertThat( Environments.getJavaCapabilities(), diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java index e06a82c8e25f..6913c75a5f2d 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java @@ -117,7 +117,7 @@ public static class Factory { public Factory( PipelineOptions pipelineOptions, - Set runnerCapabilites, + Set runnerCapabilities, String ptransformId, Supplier processBundleInstructionId, Supplier> cacheTokens, @@ -128,7 +128,7 @@ public Factory( Coder keyCoder, Coder windowCoder) { this.pipelineOptions = pipelineOptions; - this.runnerCapabilities = runnerCapabilites; + this.runnerCapabilities = runnerCapabilities; this.ptransformId = ptransformId; this.processBundleInstructionId = processBundleInstructionId; this.cacheTokens = cacheTokens; @@ -240,7 +240,7 @@ public FnApiStateAccessor create() { } private final PipelineOptions pipelineOptions; - private final Set runnerCapabilites; + private final Set runnerCapabilities; private final Map stateKeyObjectCache; private final Map, SideInputSpec> sideInputSpecMap; private final BeamFnStateClient beamFnStateClient; @@ -259,7 +259,7 @@ public FnApiStateAccessor create() { public FnApiStateAccessor( PipelineOptions pipelineOptions, - Set runnerCapabilites, + Set runnerCapabilities, String ptransformId, Supplier processBundleInstructionId, Supplier> cacheTokens, @@ -270,7 +270,7 @@ public FnApiStateAccessor( Coder keyCoder, Coder windowCoder) { this.pipelineOptions = pipelineOptions; - this.runnerCapabilites = runnerCapabilites; + this.runnerCapabilities = runnerCapabilities; this.stateKeyObjectCache = Maps.newHashMap(); this.sideInputSpecMap = sideInputSpecMap; this.beamFnStateClient = beamFnStateClient; @@ -414,7 +414,7 @@ public T get(PCollectionView view, BoundedWindow window) { key, ((KvCoder) sideInputSpec.getCoder()).getKeyCoder(), ((KvCoder) sideInputSpec.getCoder()).getValueCoder(), - runnerCapabilites.contains( + runnerCapabilities.contains( BeamUrns.getUrn( RunnerApi.StandardRunnerProtocols.Enum .MULTIMAP_KEYS_VALUES_SIDE_INPUT)))); @@ -762,8 +762,113 @@ public MultimapState bindMultimap( StateSpec> spec, Coder keyCoder, Coder valueCoder) { - // TODO(https://github.com/apache/beam/issues/23616) - throw new UnsupportedOperationException("Multimap is not currently supported with Fn API."); + return (MultimapState) + stateKeyObjectCache.computeIfAbsent( + createMultimapKeysUserStateKey(id), + new Function() { + @Override + public Object apply(StateKey stateKey) { + return new MultimapState() { + private final MultimapUserState impl = + createMultimapUserState(stateKey, keyCoder, valueCoder); + + @Override + public void put(KeyT key, ValueT value) { + impl.put(key, value); + } + + @Override + public ReadableState> get(KeyT key) { + return new ReadableState>() { + @Override + public Iterable read() { + return impl.get(key); + } + + @Override + public ReadableState> readLater() { + impl.get(key).prefetch(); + return this; + } + }; + } + + @Override + public void remove(KeyT key) { + impl.remove(key); + } + + @Override + public ReadableState> keys() { + return new ReadableState>() { + @Override + public Iterable read() { + return impl.keys(); + } + + @Override + public ReadableState> readLater() { + impl.keys().prefetch(); + return this; + } + }; + } + + @Override + public ReadableState>> entries() { + return new ReadableState>>() { + @Override + public Iterable> read() { + return impl.entries(); + } + + @Override + public ReadableState>> readLater() { + impl.entries().prefetch(); + return this; + } + }; + } + + @Override + public ReadableState containsKey(KeyT key) { + return new ReadableState() { + @Override + public Boolean read() { + return !Iterables.isEmpty(impl.get(key)); + } + + @Override + public ReadableState readLater() { + impl.get(key).prefetch(); + return this; + } + }; + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + return Iterables.isEmpty(impl.keys()); + } + + @Override + public ReadableState readLater() { + impl.keys().prefetch(); + return this; + } + }; + } + + @Override + public void clear() { + impl.clear(); + } + }; + } + }); } @Override diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java index 617faba87cc0..8e3d76f5fc8f 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java @@ -29,6 +29,7 @@ import java.util.List; import java.util.Map; import java.util.NoSuchElementException; +import java.util.Objects; import java.util.Set; import org.apache.beam.fn.harness.Cache; import org.apache.beam.fn.harness.Caches; @@ -38,13 +39,19 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.fn.stream.PrefetchableIterable; import org.apache.beam.sdk.fn.stream.PrefetchableIterables; import org.apache.beam.sdk.fn.stream.PrefetchableIterator; import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.values.KV; import org.apache.beam.vendor.grpc.v1p69p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; /** * An implementation of a multimap user state that utilizes the Beam Fn State API to fetch, clear @@ -52,9 +59,6 @@ * *

Calling {@link #asyncClose()} schedules any required persistence changes. This object should * no longer be used after it is closed. - * - *

TODO: Move to an async persist model where persistence is signalled based upon cache memory - * pressure and its need to flush. */ public class MultimapUserState { @@ -63,8 +67,10 @@ public class MultimapUserState { private final Coder mapKeyCoder; private final Coder valueCoder; private final StateRequest keysStateRequest; + private final StateRequest entriesStateRequest; private final StateRequest userStateRequest; private final CachingStateIterable persistedKeys; + private final CachingStateIterable>> persistedEntries; private boolean isClosed; private boolean isCleared; @@ -90,6 +96,8 @@ public MultimapUserState( this.mapKeyCoder = mapKeyCoder; this.valueCoder = valueCoder; + // Note: These StateRequest protos are constructed even if we never try to read the + // corresponding state type. Consider constructing them lazily, as needed. this.keysStateRequest = StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build(); this.persistedKeys = @@ -106,6 +114,23 @@ public MultimapUserState( .setWindow(stateKey.getMultimapKeysUserState().getWindow()) .setKey(stateKey.getMultimapKeysUserState().getKey()); this.userStateRequest = userStateRequestBuilder.build(); + + StateRequest.Builder entriesStateRequestBuilder = StateRequest.newBuilder(); + entriesStateRequestBuilder + .setInstructionId(instructionId) + .getStateKeyBuilder() + .getMultimapEntriesUserStateBuilder() + .setTransformId(stateKey.getMultimapKeysUserState().getTransformId()) + .setUserStateId(stateKey.getMultimapKeysUserState().getUserStateId()) + .setWindow(stateKey.getMultimapKeysUserState().getWindow()) + .setKey(stateKey.getMultimapKeysUserState().getKey()); + this.entriesStateRequest = entriesStateRequestBuilder.build(); + this.persistedEntries = + StateFetchingIterators.readAllAndDecodeStartingFrom( + Caches.subCache(this.cache, "AllEntries"), + beamFnStateClient, + entriesStateRequest, + KvCoder.of(mapKeyCoder, IterableCoder.of(valueCoder))); } public void clear() { @@ -200,7 +225,7 @@ public boolean hasNext() { nextKey = persistedKeysIterator.next(); Object nextKeyStructuralValue = mapKeyCoder.structuralValue(nextKey); if (!pendingRemovesNow.contains(nextKeyStructuralValue)) { - // Remove all keys that we will visit when passing over the persistedKeysIterator + // Remove all keys that we will visit when passing over the persistedKeysIterator, // so we do not revisit them when passing over the pendingAddsNowIterator if (pendingAddsNow.containsKey(nextKeyStructuralValue)) { pendingAddsNow.remove(nextKeyStructuralValue); @@ -235,6 +260,112 @@ public K next() { }; } + @SuppressWarnings({ + "nullness" // TODO(https://github.com/apache/beam/issues/21068) + }) + /* + * Returns an Iterable containing all entries in this multimap. + */ + public PrefetchableIterable> entries() { + checkState( + !isClosed, + "Multimap user state is no longer usable because it is closed for %s", + keysStateRequest.getStateKey()); + // Make a deep copy of pendingAdds so this iterator represents a snapshot of state at the time + // it was created. + Map>> pendingAddsNow = ImmutableMap.copyOf(pendingAdds); + if (isCleared) { + return PrefetchableIterables.maybePrefetchable( + Iterables.concat( + Iterables.transform( + pendingAddsNow.entrySet(), + entry -> + Iterables.transform( + entry.getValue().getValue(), + value -> Maps.immutableEntry(entry.getValue().getKey(), value))))); + } + + Set pendingRemovesNow = ImmutableSet.copyOf(pendingRemoves.keySet()); + return new PrefetchableIterables.Default>() { + @Override + public PrefetchableIterator> createIterator() { + return new PrefetchableIterator>() { + // We can get the same key multiple times from persistedEntries in the case that its + // values are paginated across multiple pages. Keep track of which keys we've seen, so we + // only add in pendingAdds once (with the first page). We'll also use it to return all + // keys not on the backend at the end of the iterator. + Set seenKeys = Sets.newHashSet(); + final PrefetchableIterator> allEntries = + PrefetchableIterables.concat( + Iterables.concat( + Iterables.filter( + Iterables.transform( + persistedEntries, + entry -> { + final Object structuralKey = + mapKeyCoder.structuralValue(entry.getKey()); + if (pendingRemovesNow.contains(structuralKey)) { + return null; + } + // add returns true if we haven't seen this key yet. + if (seenKeys.add(structuralKey) + && pendingAddsNow.containsKey(structuralKey)) { + return PrefetchableIterables.concat( + Iterables.transform( + pendingAddsNow.get(structuralKey).getValue(), + pendingAdd -> + Maps.immutableEntry(entry.getKey(), pendingAdd)), + Iterables.transform( + entry.getValue(), + value -> Maps.immutableEntry(entry.getKey(), value))); + } + return Iterables.transform( + entry.getValue(), + value -> Maps.immutableEntry(entry.getKey(), value)); + }), + Objects::nonNull)), + Iterables.concat( + Iterables.filter( + Iterables.transform( + pendingAddsNow.entrySet(), + entry -> { + if (seenKeys.contains(entry.getKey())) { + return null; + } + return Iterables.transform( + entry.getValue().getValue(), + value -> + Maps.immutableEntry(entry.getValue().getKey(), value)); + }), + Objects::nonNull))) + .iterator(); + + @Override + public boolean isReady() { + return allEntries.isReady(); + } + + @Override + public void prefetch() { + if (!isReady()) { + allEntries.prefetch(); + } + } + + @Override + public boolean hasNext() { + return allEntries.hasNext(); + } + + @Override + public Map.Entry next() { + return allEntries.next(); + } + }; + } + }; + } + /* * Store a key-value pair in the multimap. * Allows duplicate key-value pairs. diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java index 48c9ce43bdf0..679307321826 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java @@ -22,6 +22,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.collection.ArrayMatching.arrayContainingInAnyOrder; +import static org.hamcrest.collection.IsIterableContainingInAnyOrder.containsInAnyOrder; import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -34,11 +35,15 @@ import java.util.Collections; import java.util.Iterator; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import org.apache.beam.fn.harness.Cache; import org.apache.beam.fn.harness.Caches; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.IterableCoder; +import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.fn.stream.PrefetchableIterable; @@ -179,6 +184,81 @@ public void testKeys() throws Exception { assertThrows(IllegalStateException.class, () -> userState.keys()); } + @Test + public void testEntries() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapEntriesStateKey(), + KV.of( + KvCoder.of(ByteArrayCoder.of(), IterableCoder.of(StringUtf8Coder.of())), + asList(KV.of(A1, asList("V1", "V2")), KV.of(A2, asList("V3")))))); + MultimapUserState userState = + new MultimapUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createMultimapKeyStateKey(), + ByteArrayCoder.of(), + StringUtf8Coder.of()); + + assertArrayEquals(A1, userState.entries().iterator().next().getKey()); + assertThat( + StreamSupport.stream(userState.entries().spliterator(), false) + .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), entry.getValue())) + .collect(Collectors.toList()), + containsInAnyOrder( + KV.of(ByteString.copyFrom(A1), "V1"), + KV.of(ByteString.copyFrom(A1), "V2"), + KV.of(ByteString.copyFrom(A2), "V3"))); + + userState.put(A1, "V4"); + // Iterable is a snapshot of the entries at this time. + PrefetchableIterable> entriesBeforeOperations = userState.entries(); + + assertThat( + StreamSupport.stream(userState.entries().spliterator(), false) + .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), entry.getValue())) + .collect(Collectors.toList()), + containsInAnyOrder( + KV.of(ByteString.copyFrom(A1), "V1"), + KV.of(ByteString.copyFrom(A1), "V2"), + KV.of(ByteString.copyFrom(A2), "V3"), + KV.of(ByteString.copyFrom(A1), "V4"))); + + userState.remove(A1); + assertThat( + StreamSupport.stream(userState.entries().spliterator(), false) + .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), entry.getValue())) + .collect(Collectors.toList()), + containsInAnyOrder(KV.of(ByteString.copyFrom(A2), "V3"))); + + userState.put(A1, "V5"); + assertThat( + StreamSupport.stream(userState.entries().spliterator(), false) + .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), entry.getValue())) + .collect(Collectors.toList()), + containsInAnyOrder( + KV.of(ByteString.copyFrom(A2), "V3"), KV.of(ByteString.copyFrom(A1), "V5"))); + + userState.clear(); + assertThat(userState.entries(), emptyIterable()); + // Check that after applying all these operations, our original entries Iterable contains a + // snapshot of state from when it was created. + assertThat( + StreamSupport.stream(entriesBeforeOperations.spliterator(), false) + .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), entry.getValue())) + .collect(Collectors.toList()), + containsInAnyOrder( + KV.of(ByteString.copyFrom(A1), "V1"), + KV.of(ByteString.copyFrom(A1), "V2"), + KV.of(ByteString.copyFrom(A1), "V4"), + KV.of(ByteString.copyFrom(A2), "V3"))); + + userState.asyncClose(); + assertThrows(IllegalStateException.class, () -> userState.entries()); + } + @Test public void testPut() throws Exception { FakeBeamFnStateClient fakeClient = @@ -620,6 +700,44 @@ public void testRemoveKeysPrefetch() throws Exception { assertEquals(0, fakeClient.getCallCount()); } + @Test + public void testEntriesPrefetched() throws Exception { + // Use a really large chunk size so all elements get returned in a single page. This makes it + // easier to count how many get calls we should expect. + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapEntriesStateKey(), + KV.of( + KvCoder.of(ByteArrayCoder.of(), IterableCoder.of(StringUtf8Coder.of())), + asList(KV.of(A1, asList("V1", "V2")), KV.of(A2, asList("V3"))))), + 1000000); + MultimapUserState userState = + new MultimapUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createMultimapKeyStateKey(), + ByteArrayCoder.of(), + StringUtf8Coder.of()); + + userState.put(A1, "V4"); + PrefetchableIterable> entries = userState.entries(); + assertEquals(0, fakeClient.getCallCount()); + entries.prefetch(); + assertEquals(1, fakeClient.getCallCount()); + assertThat( + StreamSupport.stream(entries.spliterator(), false) + .map(entry -> KV.of(ByteString.copyFrom(entry.getKey()), entry.getValue())) + .collect(Collectors.toList()), + containsInAnyOrder( + KV.of(ByteString.copyFrom(A1), "V1"), + KV.of(ByteString.copyFrom(A1), "V2"), + KV.of(ByteString.copyFrom(A1), "V4"), + KV.of(ByteString.copyFrom(A2), "V3"))); + assertEquals(1, fakeClient.getCallCount()); + } + @Test public void testClearPrefetch() throws Exception { FakeBeamFnStateClient fakeClient = @@ -1053,6 +1171,17 @@ private StateKey createMultimapKeyStateKey() throws IOException { .build(); } + private StateKey createMultimapEntriesStateKey() throws IOException { + return StateKey.newBuilder() + .setMultimapEntriesUserState( + StateKey.MultimapEntriesUserState.newBuilder() + .setWindow(encode(encodedWindow)) + .setKey(encode(encodedKey)) + .setTransformId(pTransformId) + .setUserStateId(stateId)) + .build(); + } + private StateKey createMultimapValueStateKey(byte[] key) throws IOException { return StateKey.newBuilder() .setMultimapUserState(