Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,29 @@ message StateKey {
bytes key = 4;
}

// Represents a request for all of the entries of a multimap associated with a
// specified user key and window for a PTransform. See
// https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
// details.
//
// Can only be used to perform StateGetRequests and StateClearRequests on the
// user state.
//
// The response data stream will be a concatenation of pairs, where the first
// component is the map key and the second component is a concatenation of
// values associated with that map key.
message MultimapEntriesUserState {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should add a capability for Multimap support similar to ordered list state?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! I added a capability to beam_runner_api.proto and specified in for the java sdk. I don't think python or go support it yet.

// (Required) The id of the PTransform containing user state.
string transform_id = 1;
// (Required) The id of the user state.
string user_state_id = 2;
// (Required) The window encoded in a nested context.
bytes window = 3;
// (Required) The key of the currently executing element encoded in a
// nested context.
bytes key = 4;
}

// Represents a request for the values of the map key associated with a
// specified user key and window for a PTransform. See
// https://s.apache.org/beam-fn-state-api-and-bundle-processing for further
Expand Down Expand Up @@ -1064,6 +1087,7 @@ message StateKey {
MultimapKeysSideInput multimap_keys_side_input = 5;
MultimapKeysValuesSideInput multimap_keys_values_side_input = 8;
MultimapKeysUserState multimap_keys_user_state = 6;
MultimapEntriesUserState multimap_entries_user_state = 10;
MultimapUserState multimap_user_state = 7;
OrderedListUserState ordered_list_user_state = 9;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ public PrefetchableIterator<T> createIterator() {
* constructed that ensures that {@link PrefetchableIterator#prefetch()} is a no-op and {@link
* PrefetchableIterator#isReady()} always returns true.
*/
private static <T> PrefetchableIterable<T> maybePrefetchable(Iterable<T> iterable) {
public static <T> PrefetchableIterable<T> maybePrefetchable(Iterable<T> iterable) {
if (iterable instanceof PrefetchableIterable) {
return (PrefetchableIterable<T>) iterable;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2917,6 +2917,73 @@ public void processElement(
pipeline.run();
}

@Test
@Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesMultimapState.class})
public void testMultimapStateEntries() {
final String stateId = "foo:";
final String countStateId = "count";
DoFn<KV<String, KV<String, Integer>>, KV<String, Integer>> fn =
new DoFn<KV<String, KV<String, Integer>>, KV<String, Integer>>() {

@StateId(stateId)
private final StateSpec<MultimapState<String, Integer>> multimapState =
StateSpecs.multimap(StringUtf8Coder.of(), VarIntCoder.of());

@StateId(countStateId)
private final StateSpec<CombiningState<Integer, int[], Integer>> countState =
StateSpecs.combiningFromInputInternal(VarIntCoder.of(), Sum.ofIntegers());

@ProcessElement
public void processElement(
ProcessContext c,
@Element KV<String, KV<String, Integer>> element,
@StateId(stateId) MultimapState<String, Integer> state,
@StateId(countStateId) CombiningState<Integer, int[], Integer> count,
OutputReceiver<KV<String, Integer>> r) {
// Empty before we process any elements.
if (count.read() == 0) {
assertThat(state.entries().read(), emptyIterable());
}
assertEquals(count.read().intValue(), Iterables.size(state.entries().read()));

KV<String, Integer> value = element.getValue();
state.put(value.getKey(), value.getValue());
count.add(1);

if (count.read() >= 4) {
// This should be evaluated only when ReadableState.read is called.
ReadableState<Iterable<Entry<String, Integer>>> entriesView = state.entries();
Comment thread
acrites marked this conversation as resolved.

// This is evaluated immediately.
Iterable<Entry<String, Integer>> entries = state.entries().read();

state.remove("b");
assertEquals(4, Iterables.size(entries));
state.put("a", 2);
state.put("a", 3);

assertEquals(5, Iterables.size(entriesView.read()));
// Note we output the view of state before the modifications in this if statement.
for (Entry<String, Integer> entry : entries) {
r.output(KV.of(entry.getKey(), entry.getValue()));
}
}
}
};
PCollection<KV<String, Integer>> output =
pipeline
.apply(
Create.of(
KV.of("hello", KV.of("a", 97)), KV.of("hello", KV.of("a", 97)),
KV.of("hello", KV.of("a", 98)), KV.of("hello", KV.of("b", 33))))
.apply(ParDo.of(fn));
PAssert.that(output)
.containsInAnyOrder(
KV.of("a", 97), KV.of("a", 97),
KV.of("a", 98), KV.of("b", 33));
pipeline.run();
}

@Test
@Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesMultimapState.class})
public void testMultimapStateRemove() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public static class Factory<K> {

public Factory(
PipelineOptions pipelineOptions,
Set<String> runnerCapabilites,
Set<String> runnerCapabilities,
String ptransformId,
Supplier<String> processBundleInstructionId,
Supplier<List<CacheToken>> cacheTokens,
Expand All @@ -128,7 +128,7 @@ public Factory(
Coder<K> keyCoder,
Coder<BoundedWindow> windowCoder) {
this.pipelineOptions = pipelineOptions;
this.runnerCapabilities = runnerCapabilites;
this.runnerCapabilities = runnerCapabilities;
this.ptransformId = ptransformId;
this.processBundleInstructionId = processBundleInstructionId;
this.cacheTokens = cacheTokens;
Expand Down Expand Up @@ -240,7 +240,7 @@ public FnApiStateAccessor<K> create() {
}

private final PipelineOptions pipelineOptions;
private final Set<String> runnerCapabilites;
private final Set<String> runnerCapabilities;
private final Map<StateKey, Object> stateKeyObjectCache;
private final Map<TupleTag<?>, SideInputSpec> sideInputSpecMap;
private final BeamFnStateClient beamFnStateClient;
Expand All @@ -259,7 +259,7 @@ public FnApiStateAccessor<K> create() {

public FnApiStateAccessor(
PipelineOptions pipelineOptions,
Set<String> runnerCapabilites,
Set<String> runnerCapabilities,
String ptransformId,
Supplier<String> processBundleInstructionId,
Supplier<List<CacheToken>> cacheTokens,
Expand All @@ -270,7 +270,7 @@ public FnApiStateAccessor(
Coder<K> keyCoder,
Coder<BoundedWindow> windowCoder) {
this.pipelineOptions = pipelineOptions;
this.runnerCapabilites = runnerCapabilites;
this.runnerCapabilities = runnerCapabilities;
this.stateKeyObjectCache = Maps.newHashMap();
this.sideInputSpecMap = sideInputSpecMap;
this.beamFnStateClient = beamFnStateClient;
Expand Down Expand Up @@ -414,7 +414,7 @@ public <T> T get(PCollectionView<T> view, BoundedWindow window) {
key,
((KvCoder) sideInputSpec.getCoder()).getKeyCoder(),
((KvCoder) sideInputSpec.getCoder()).getValueCoder(),
runnerCapabilites.contains(
runnerCapabilities.contains(
BeamUrns.getUrn(
RunnerApi.StandardRunnerProtocols.Enum
.MULTIMAP_KEYS_VALUES_SIDE_INPUT))));
Expand Down Expand Up @@ -762,8 +762,113 @@ public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
StateSpec<MultimapState<KeyT, ValueT>> spec,
Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
// TODO(https://github.com/apache/beam/issues/23616)
throw new UnsupportedOperationException("Multimap is not currently supported with Fn API.");
return (MultimapState<KeyT, ValueT>)
stateKeyObjectCache.computeIfAbsent(
createMultimapKeysUserStateKey(id),
new Function<StateKey, Object>() {
@Override
public Object apply(StateKey stateKey) {
return new MultimapState<KeyT, ValueT>() {
private final MultimapUserState<KeyT, ValueT> impl =
createMultimapUserState(stateKey, keyCoder, valueCoder);

@Override
public void put(KeyT key, ValueT value) {
impl.put(key, value);
}

@Override
public ReadableState<Iterable<ValueT>> get(KeyT key) {
return new ReadableState<Iterable<ValueT>>() {
@Override
public Iterable<ValueT> read() {
return impl.get(key);
}

@Override
public ReadableState<Iterable<ValueT>> readLater() {
impl.get(key).prefetch();
return this;
}
};
}

@Override
public void remove(KeyT key) {
impl.remove(key);
}

@Override
public ReadableState<Iterable<KeyT>> keys() {
return new ReadableState<Iterable<KeyT>>() {
@Override
public Iterable<KeyT> read() {
return impl.keys();
}

@Override
public ReadableState<Iterable<KeyT>> readLater() {
impl.keys().prefetch();
return this;
}
};
}

@Override
public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> entries() {
return new ReadableState<Iterable<Map.Entry<KeyT, ValueT>>>() {
@Override
public Iterable<Map.Entry<KeyT, ValueT>> read() {
return impl.entries();
}

@Override
public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> readLater() {
impl.entries().prefetch();
return this;
}
};
}

@Override
public ReadableState<Boolean> containsKey(KeyT key) {
return new ReadableState<Boolean>() {
@Override
public Boolean read() {
return !Iterables.isEmpty(impl.get(key));
}

@Override
public ReadableState<Boolean> readLater() {
impl.get(key).prefetch();
return this;
}
};
}

@Override
public ReadableState<Boolean> isEmpty() {
return new ReadableState<Boolean>() {
@Override
public Boolean read() {
return Iterables.isEmpty(impl.keys());
}

@Override
public ReadableState<Boolean> readLater() {
impl.keys().prefetch();
return this;
}
};
}

@Override
public void clear() {
impl.clear();
}
};
}
});
}

@Override
Expand Down
Loading
Loading