Skip to content

Commit 026a6d3

Browse files
committed
Move planning hints from global coordinator state to per-operation
Instead of setting hints globally on the Coordinator (which requires coordinating hint state across operations), hints are now passed with each copy/pull submission. This makes the API stateless and allows different operations to use different routing strategies. Key changes: - Add hints parameter to SubmitCopyRequest and SubmitPullRequest - Add RoutingHint with serializable topology path - Update Client.copy() and Client.pull() to accept hints - Remove global add_hint/clear_hints from Coordinator pybind - Add Topology serialization support for hint transport - Introduce PendingDispatch to encapsulate all per-CopyKey state (shard aggregation, hints, fingerprints) behind SubmitShard and CancelIf, eliminating duplicated cleanup logic across HandleShardSubmission and HandleDeregisterShardsRequest
1 parent 0833775 commit 026a6d3

15 files changed

Lines changed: 322 additions & 104 deletions

File tree

csrc/setu/client/Client.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ std::optional<TensorShardRef> Client::RegisterTensorShard(
154154
return response.shard_ref;
155155
}
156156

157-
std::optional<CopyOperationId> Client::SubmitCopy(const CopySpec& copy_spec) {
157+
std::optional<CopyOperationId> Client::SubmitCopy(
158+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints) {
158159
// Find all shards owned by this client that are involved in the copy
159160
// (either as source or destination)
160161
std::vector<ShardId> involved_shards;
@@ -175,10 +176,14 @@ std::optional<CopyOperationId> Client::SubmitCopy(const CopySpec& copy_spec) {
175176
"Client has no shards for src {} or dst {}",
176177
copy_spec.src_name, copy_spec.dst_name);
177178

179+
// Compute fingerprint once for all shard submissions
180+
const auto fingerprint = setu::planner::hints::Fingerprint(hints);
181+
178182
// Submit a request for each involved shard
179183
std::optional<CopyOperationId> copy_op_id;
180184
for (const auto& shard_id : involved_shards) {
181-
ClientRequest request = SubmitCopyRequest(shard_id, copy_spec);
185+
ClientRequest request =
186+
SubmitCopyRequest(shard_id, copy_spec, hints, fingerprint);
182187
Comm::Send(request_socket_, request);
183188

184189
auto response = Comm::Recv<SubmitCopyResponse>(request_socket_);
@@ -196,18 +201,23 @@ std::optional<CopyOperationId> Client::SubmitCopy(const CopySpec& copy_spec) {
196201
return copy_op_id;
197202
}
198203

199-
std::optional<CopyOperationId> Client::SubmitPull(const CopySpec& copy_spec) {
204+
std::optional<CopyOperationId> Client::SubmitPull(
205+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints) {
200206
// For Pull: only destination shards submit (one-sided operation)
201207
auto it = tensor_shards_.find(copy_spec.dst_name);
202208
ASSERT_VALID_RUNTIME(it != tensor_shards_.end(),
203209
"Client has no shards for dst {}", copy_spec.dst_name);
204210

211+
// Compute fingerprint once for all shard submissions
212+
const auto fingerprint = setu::planner::hints::Fingerprint(hints);
213+
205214
// Submit a request for each destination shard
206215
std::optional<CopyOperationId> copy_op_id;
207216
for (const auto& shard_ref : it->second) {
208217
const auto shard_id = shard_ref->shard_id;
209218

210-
ClientRequest request = SubmitPullRequest(shard_id, copy_spec);
219+
ClientRequest request =
220+
SubmitPullRequest(shard_id, copy_spec, hints, fingerprint);
211221
Comm::Send(request_socket_, request);
212222

213223
auto response = Comm::Recv<SubmitCopyResponse>(request_socket_);

csrc/setu/client/Client.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "commons/utils/TorchTensorIPC.h"
2727
#include "commons/utils/ZmqHelper.h"
2828
#include "messaging/GetTensorHandleResponse.h"
29+
#include "planner/hints/Hint.h"
2930

3031
namespace setu::client {
3132
using setu::commons::CopyOperationId;
@@ -41,6 +42,7 @@ using setu::commons::messages::GetTensorHandleResponse;
4142
using setu::commons::utils::TensorIPCSpec;
4243
using setu::commons::utils::ZmqContextPtr;
4344
using setu::commons::utils::ZmqSocketPtr;
45+
using setu::planner::hints::CompilerHint;
4446

4547
class Client {
4648
public:
@@ -58,9 +60,11 @@ class Client {
5860
std::optional<TensorShardRef> RegisterTensorShard(
5961
const TensorShardSpec& shard_spec);
6062

61-
std::optional<CopyOperationId> SubmitCopy(const CopySpec& copy_spec);
63+
std::optional<CopyOperationId> SubmitCopy(
64+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints = {});
6265

63-
std::optional<CopyOperationId> SubmitPull(const CopySpec& copy_spec);
66+
std::optional<CopyOperationId> SubmitPull(
67+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints = {});
6468

6569
void WaitForCopy(CopyOperationId copy_op_id);
6670

csrc/setu/client/Pybind.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@
1616
//==============================================================================
1717
#include "commons/utils/Pybind.h"
1818

19+
#include <pybind11/stl.h>
20+
1921
#include "client/Client.h"
2022
#include "commons/Logging.h"
2123
#include "commons/StdCommon.h"
2224
#include "commons/TorchCommon.h"
2325
#include "commons/datatypes/CopySpec.h"
2426
#include "commons/datatypes/TensorShardSpec.h"
2527
#include "commons/enums/Enums.h"
28+
#include "planner/hints/Hint.h"
2629
//==============================================================================
2730
namespace setu::client {
2831
//==============================================================================
2932
using setu::commons::CopyOperationId;
3033
using setu::commons::datatypes::CopySpec;
3134
using setu::commons::datatypes::TensorShardSpec;
3235
using setu::commons::enums::ErrorCode;
36+
using setu::planner::hints::CompilerHint;
37+
using setu::planner::hints::RoutingHint;
3338
//==============================================================================
3439
void InitClientPybindClass(py::module_& m) {
3540
py::class_<Client, std::shared_ptr<Client>>(m, "Client")
@@ -46,8 +51,10 @@ void InitClientPybindClass(py::module_& m) {
4651
py::arg("shard_spec"),
4752
"Register a tensor shard and return a reference to it")
4853
.def("submit_copy", &Client::SubmitCopy, py::arg("copy_spec"),
54+
py::arg("hints") = std::vector<CompilerHint>{},
4955
"Submit a copy operation and return an operation ID")
5056
.def("submit_pull", &Client::SubmitPull, py::arg("copy_spec"),
57+
py::arg("hints") = std::vector<CompilerHint>{},
5158
"Submit a pull operation and return an operation ID")
5259
.def("wait_for_copy", &Client::WaitForCopy, py::arg("copy_op_id"),
5360
"Wait for a copy operation to complete")

csrc/setu/commons/datatypes/Device.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ struct Device {
9898
return static_cast<std::int16_t>(torch_device.index());
9999
}
100100

101-
torch::Device torch_device; ///< PyTorch device (type + local index)
101+
torch::Device torch_device{
102+
torch::kCUDA}; ///< PyTorch device (type + local index)
102103
};
103104
//==============================================================================
104105
} // namespace setu::commons::datatypes

csrc/setu/coordinator/Coordinator.cpp

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ Coordinator::Coordinator(std::size_t port, PlannerPtr planner)
6060

6161
handler_ = std::make_unique<Handler>(inbox_queue_, outbox_queue_, metastore_,
6262
planner_queue_, outbox_notify);
63-
executor_ =
64-
std::make_unique<Executor>(planner_queue_, outbox_queue_, metastore_,
65-
*planner_, hint_store_, outbox_notify);
63+
executor_ = std::make_unique<Executor>(planner_queue_, outbox_queue_,
64+
metastore_, *planner_, outbox_notify);
6665
}
6766

6867
Coordinator::~Coordinator() {
@@ -108,12 +107,6 @@ std::optional<CopyOperationId> Coordinator::SubmitCopy(
108107
return std::nullopt;
109108
}
110109

111-
void Coordinator::AddHint(setu::planner::hints::CompilerHint hint) {
112-
hint_store_.AddHint(std::move(hint));
113-
}
114-
115-
void Coordinator::ClearHints() { hint_store_.Clear(); }
116-
117110
void Coordinator::PlanExecuted(CopyOperationId copy_op_id) {
118111
LOG_DEBUG("Plan executed for copy operation ID: {}", copy_op_id);
119112

@@ -387,7 +380,8 @@ void Coordinator::Handler::HandleSubmitCopyRequest(
387380
metastore_.GetNumShardsForTensor(request.copy_spec.dst_name);
388381

389382
HandleShardSubmission(node_agent_identity, request.request_id,
390-
request.shard_id, request.copy_spec, expected_shards);
383+
request.shard_id, request.copy_spec, expected_shards,
384+
std::vector(request.hints), request.hints_fingerprint);
391385
}
392386

393387
void Coordinator::Handler::HandleSubmitPullRequest(
@@ -413,18 +407,21 @@ void Coordinator::Handler::HandleSubmitPullRequest(
413407
metastore_.GetNumShardsForTensor(request.copy_spec.dst_name);
414408

415409
HandleShardSubmission(node_agent_identity, request.request_id,
416-
request.shard_id, request.copy_spec, expected_shards);
410+
request.shard_id, request.copy_spec, expected_shards,
411+
std::vector(request.hints), request.hints_fingerprint);
417412
}
418413

419414
void Coordinator::Handler::HandleShardSubmission(
420415
const Identity& node_agent_identity, const RequestId& request_id,
421416
const ShardId& shard_id, const CopySpec& copy_spec,
422-
std::size_t expected_shards) {
417+
std::size_t expected_shards,
418+
std::vector<setu::planner::hints::CompilerHint> hints,
419+
std::uint64_t hints_fingerprint) {
423420
using setu::commons::utils::AggregationParticipant;
424421

425422
CopyKey copy_key{copy_spec.src_name, copy_spec.dst_name};
426423

427-
auto result = shard_aggregator_.Submit(
424+
auto result = pending_dispatch_.SubmitShard(
428425
copy_key, shard_id, copy_spec,
429426
AggregationParticipant{node_agent_identity, request_id}, expected_shards,
430427
[](const CopySpec& stored, const CopySpec& incoming) {
@@ -437,7 +434,8 @@ void Coordinator::Handler::HandleShardSubmission(
437434
*incoming.dst_selection == *stored.dst_selection,
438435
"Shard submission {} -> {}: destination selection mismatch",
439436
incoming.src_name, incoming.dst_name);
440-
});
437+
},
438+
std::move(hints), hints_fingerprint);
441439

442440
if (!result.has_value()) {
443441
return;
@@ -459,14 +457,15 @@ void Coordinator::Handler::HandleShardSubmission(
459457
}
460458

461459
// Create shared state with submitter identities
462-
auto state = std::make_shared<CopyOperationState>(result->payload,
463-
std::move(submitters));
460+
auto state =
461+
std::make_shared<CopyOperationState>(result->spec, std::move(submitters));
464462

465463
// Store the shared state (will be accessed by HandleExecuteResponse)
466464
copy_operations_.emplace(copy_op_id, state);
467465

468-
// Add to planner queue with copy_op_id and shared state
469-
planner_queue_.push(PlannerTask{copy_op_id, result->payload, state});
466+
// Add to planner queue with copy_op_id, shared state, and per-op hints
467+
planner_queue_.push(PlannerTask{copy_op_id, result->spec, state,
468+
HintStore(std::move(result->hints))});
470469

471470
// Send responses to all waiting participants with copy_op_id
472471
for (const auto& participant : result->participants) {
@@ -588,15 +587,25 @@ void Coordinator::Handler::HandleDeregisterShardsRequest(
588587
metastore_.MarkTensorDeregistered(name);
589588
}
590589

591-
// Cancel partial entries in the shard aggregator for these tensors.
590+
// Cancel partial entries in the pending dispatch for these tensors.
592591
// This cleans up groups that will never complete because the shards are
593592
// going away.
594593
auto cancelled_participants =
595-
shard_aggregator_.CancelIf([&tensor_names](const CopyKey& key) {
594+
pending_dispatch_.CancelIf([&tensor_names](const CopyKey& key) {
596595
return tensor_names.contains(key.src_name) ||
597596
tensor_names.contains(key.dst_name);
598597
});
599598

599+
// Clean up per-operation hint tracking for cancelled operations
600+
std::erase_if(operation_hints_, [&tensor_names](const auto& entry) {
601+
return tensor_names.contains(entry.first.src_name) ||
602+
tensor_names.contains(entry.first.dst_name);
603+
});
604+
std::erase_if(operation_fingerprints_, [&tensor_names](const auto& entry) {
605+
return tensor_names.contains(entry.first.src_name) ||
606+
tensor_names.contains(entry.first.dst_name);
607+
});
608+
600609
// Send error responses to cancelled participants
601610
for (const auto& participant : cancelled_participants) {
602611
LOG_INFO(
@@ -651,13 +660,11 @@ void Coordinator::Handler::HandleDeregisterShardsRequest(
651660
Coordinator::Executor::Executor(Queue<PlannerTask>& planner_queue,
652661
Queue<OutboxMessage>& outbox_queue,
653662
MetaStore& metastore, Planner& planner,
654-
HintStore& hint_store,
655663
OutboxNotifyFn outbox_notify)
656664
: planner_queue_(planner_queue),
657665
outbox_queue_(outbox_queue),
658666
metastore_(metastore),
659667
planner_(planner),
660-
hint_store_(hint_store),
661668
outbox_notify_(std::move(outbox_notify)) {}
662669

663670
void Coordinator::Executor::PushOutbox(OutboxMessage msg) {
@@ -690,9 +697,8 @@ void Coordinator::Executor::Loop() {
690697

691698
LOG_DEBUG("Executor received task for copy_op_id: {}", task.copy_op_id);
692699

693-
auto hints = hint_store_.Snapshot();
694700
auto t_compile_start = std::chrono::steady_clock::now();
695-
Plan plan = planner_.Compile(task.copy_spec, metastore_, hints);
701+
Plan plan = planner_.Compile(task.copy_spec, metastore_, task.hints);
696702
auto t_compile_end = std::chrono::steady_clock::now();
697703

698704
LOG_DEBUG("Compiled plan:\n{}", plan);

0 commit comments

Comments
 (0)