Skip to content

Commit f24fa2c

Browse files
committed
Move planning hints from global coordinator state to per-operation
Instead of setting hints globally on the Coordinator (which requires coordinating hint state across operations), hints are now passed with each copy/pull submission. This makes the API stateless and allows different operations to use different routing strategies. Key changes: - Add hints parameter to SubmitCopyRequest and SubmitPullRequest - Add RoutingHint with serializable topology path - Update Client.copy() and Client.pull() to accept hints - Remove global add_hint/clear_hints from Coordinator pybind - Add Topology serialization support for hint transport - Extract Handler, Executor, Gateway, DispatchManager, and Types from monolithic Coordinator into separate files - DispatchManager encapsulates CopyKey as private implementation detail - Introduce ShardSubmission struct to reduce SubmitShard parameter count - Merge operation_hints_ and operation_fingerprints_ into single map - Add FinalizeAggregation to move post-completion bookkeeping into DispatchManager - Replace CancelPendingIf with CancelPendingByTensors public API - Assert on unknown copy_op_id in RecordResponse - Use std::visit for SubmitResult handling instead of std::get_if
1 parent 0833775 commit f24fa2c

24 files changed

Lines changed: 1512 additions & 971 deletions

csrc/setu/client/Client.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ std::optional<TensorShardRef> Client::RegisterTensorShard(
154154
return response.shard_ref;
155155
}
156156

157-
std::optional<CopyOperationId> Client::SubmitCopy(const CopySpec& copy_spec) {
157+
std::optional<CopyOperationId> Client::SubmitCopy(
158+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints) {
158159
// Find all shards owned by this client that are involved in the copy
159160
// (either as source or destination)
160161
std::vector<ShardId> involved_shards;
@@ -175,10 +176,14 @@ std::optional<CopyOperationId> Client::SubmitCopy(const CopySpec& copy_spec) {
175176
"Client has no shards for src {} or dst {}",
176177
copy_spec.src_name, copy_spec.dst_name);
177178

179+
// Compute fingerprint once for all shard submissions
180+
const auto fingerprint = setu::planner::hints::Fingerprint(hints);
181+
178182
// Submit a request for each involved shard
179183
std::optional<CopyOperationId> copy_op_id;
180184
for (const auto& shard_id : involved_shards) {
181-
ClientRequest request = SubmitCopyRequest(shard_id, copy_spec);
185+
ClientRequest request =
186+
SubmitCopyRequest(shard_id, copy_spec, hints, fingerprint);
182187
Comm::Send(request_socket_, request);
183188

184189
auto response = Comm::Recv<SubmitCopyResponse>(request_socket_);
@@ -196,18 +201,23 @@ std::optional<CopyOperationId> Client::SubmitCopy(const CopySpec& copy_spec) {
196201
return copy_op_id;
197202
}
198203

199-
std::optional<CopyOperationId> Client::SubmitPull(const CopySpec& copy_spec) {
204+
std::optional<CopyOperationId> Client::SubmitPull(
205+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints) {
200206
// For Pull: only destination shards submit (one-sided operation)
201207
auto it = tensor_shards_.find(copy_spec.dst_name);
202208
ASSERT_VALID_RUNTIME(it != tensor_shards_.end(),
203209
"Client has no shards for dst {}", copy_spec.dst_name);
204210

211+
// Compute fingerprint once for all shard submissions
212+
const auto fingerprint = setu::planner::hints::Fingerprint(hints);
213+
205214
// Submit a request for each destination shard
206215
std::optional<CopyOperationId> copy_op_id;
207216
for (const auto& shard_ref : it->second) {
208217
const auto shard_id = shard_ref->shard_id;
209218

210-
ClientRequest request = SubmitPullRequest(shard_id, copy_spec);
219+
ClientRequest request =
220+
SubmitPullRequest(shard_id, copy_spec, hints, fingerprint);
211221
Comm::Send(request_socket_, request);
212222

213223
auto response = Comm::Recv<SubmitCopyResponse>(request_socket_);

csrc/setu/client/Client.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "commons/utils/TorchTensorIPC.h"
2727
#include "commons/utils/ZmqHelper.h"
2828
#include "messaging/GetTensorHandleResponse.h"
29+
#include "planner/hints/Hint.h"
2930

3031
namespace setu::client {
3132
using setu::commons::CopyOperationId;
@@ -41,6 +42,7 @@ using setu::commons::messages::GetTensorHandleResponse;
4142
using setu::commons::utils::TensorIPCSpec;
4243
using setu::commons::utils::ZmqContextPtr;
4344
using setu::commons::utils::ZmqSocketPtr;
45+
using setu::planner::hints::CompilerHint;
4446

4547
class Client {
4648
public:
@@ -58,9 +60,11 @@ class Client {
5860
std::optional<TensorShardRef> RegisterTensorShard(
5961
const TensorShardSpec& shard_spec);
6062

61-
std::optional<CopyOperationId> SubmitCopy(const CopySpec& copy_spec);
63+
std::optional<CopyOperationId> SubmitCopy(
64+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints = {});
6265

63-
std::optional<CopyOperationId> SubmitPull(const CopySpec& copy_spec);
66+
std::optional<CopyOperationId> SubmitPull(
67+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints = {});
6468

6569
void WaitForCopy(CopyOperationId copy_op_id);
6670

csrc/setu/client/Pybind.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@
1616
//==============================================================================
1717
#include "commons/utils/Pybind.h"
1818

19+
#include <pybind11/stl.h>
20+
1921
#include "client/Client.h"
2022
#include "commons/Logging.h"
2123
#include "commons/StdCommon.h"
2224
#include "commons/TorchCommon.h"
2325
#include "commons/datatypes/CopySpec.h"
2426
#include "commons/datatypes/TensorShardSpec.h"
2527
#include "commons/enums/Enums.h"
28+
#include "planner/hints/Hint.h"
2629
//==============================================================================
2730
namespace setu::client {
2831
//==============================================================================
2932
using setu::commons::CopyOperationId;
3033
using setu::commons::datatypes::CopySpec;
3134
using setu::commons::datatypes::TensorShardSpec;
3235
using setu::commons::enums::ErrorCode;
36+
using setu::planner::hints::CompilerHint;
37+
using setu::planner::hints::RoutingHint;
3338
//==============================================================================
3439
void InitClientPybindClass(py::module_& m) {
3540
py::class_<Client, std::shared_ptr<Client>>(m, "Client")
@@ -46,8 +51,10 @@ void InitClientPybindClass(py::module_& m) {
4651
py::arg("shard_spec"),
4752
"Register a tensor shard and return a reference to it")
4853
.def("submit_copy", &Client::SubmitCopy, py::arg("copy_spec"),
54+
py::arg("hints") = std::vector<CompilerHint>{},
4955
"Submit a copy operation and return an operation ID")
5056
.def("submit_pull", &Client::SubmitPull, py::arg("copy_spec"),
57+
py::arg("hints") = std::vector<CompilerHint>{},
5158
"Submit a pull operation and return an operation ID")
5259
.def("wait_for_copy", &Client::WaitForCopy, py::arg("copy_op_id"),
5360
"Wait for a copy operation to complete")

csrc/setu/commons/datatypes/Device.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ struct Device {
9898
return static_cast<std::int16_t>(torch_device.index());
9999
}
100100

101-
torch::Device torch_device; ///< PyTorch device (type + local index)
101+
torch::Device torch_device{
102+
torch::kCUDA}; ///< PyTorch device (type + local index)
102103
};
103104
//==============================================================================
104105
} // namespace setu::commons::datatypes

csrc/setu/commons/utils/ShardAggregator.h

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,14 @@ struct CompletedGroup {
4747
/// validation, and participant info. When all expected shards have arrived, the
4848
/// completed group is returned and the internal state is cleaned up.
4949
///
50-
/// @tparam KeyType The group key type (must support operator<).
50+
/// @tparam KeyType The group key type (must be hashable via KeyHash).
5151
/// @tparam PayloadType The payload type stored per group. Must support
5252
/// operator== for validation of consistency across submissions.
53-
template <typename KeyType, typename PayloadType>
53+
/// @tparam KeyHash Hash function object for KeyType.
54+
/// @tparam KeyEqual Equality function object for KeyType.
55+
template <typename KeyType, typename PayloadType,
56+
typename KeyHash = boost::hash<KeyType>,
57+
typename KeyEqual = std::equal_to<KeyType>>
5458
class ShardAggregator {
5559
public:
5660
/// @brief Submit a shard for aggregation.
@@ -63,7 +67,8 @@ class ShardAggregator {
6367
/// @param participant [in] The identity and request_id of the submitter.
6468
/// @param expected_count [in] Total number of shards expected for this group.
6569
/// @param validate_fn [in] Callable(const PayloadType& stored, const
66-
/// PayloadType& incoming) that asserts payload consistency.
70+
/// PayloadType& incoming) → bool. Returns true if payloads are consistent,
71+
/// false to reject and cancel the group.
6772
/// @return CompletedGroup if this submission completes the group, nullopt
6873
/// otherwise.
6974
template <typename ValidateFn>
@@ -82,8 +87,8 @@ class ShardAggregator {
8287
// Store or validate the payload
8388
if (!group.payload.has_value()) {
8489
group.payload.emplace(payload);
85-
} else {
86-
validate_fn(group.payload.value(), payload);
90+
} else if (!validate_fn(group.payload.value(), payload)) {
91+
return std::nullopt;
8792
}
8893

8994
group.shards_received.insert(shard_id);
@@ -100,6 +105,21 @@ class ShardAggregator {
100105
return std::nullopt;
101106
}
102107

108+
/// @brief Cancel and remove the group for a specific key.
109+
///
110+
/// @param key [in] The group key to cancel.
111+
/// @return All participants from the cancelled group.
112+
[[nodiscard]] std::vector<AggregationParticipant> Cancel(
113+
const KeyType& key /*[in]*/) {
114+
std::vector<AggregationParticipant> cancelled_participants;
115+
auto it = groups_.find(key);
116+
if (it != groups_.end()) {
117+
cancelled_participants = std::move(it->second.participants);
118+
groups_.erase(it);
119+
}
120+
return cancelled_participants;
121+
}
122+
103123
/// @brief Cancel and remove all groups whose key matches the predicate.
104124
///
105125
/// This is used to clean up partially-aggregated groups when the shards
@@ -134,7 +154,7 @@ class ShardAggregator {
134154
std::vector<AggregationParticipant> participants;
135155
};
136156

137-
std::map<KeyType, PendingGroup> groups_;
157+
std::unordered_map<KeyType, PendingGroup, KeyHash, KeyEqual> groups_;
138158
};
139159

140160
//==============================================================================

0 commit comments

Comments
 (0)