Skip to content

Commit b8e99db

Browse files
committed
Move planning hints from global coordinator state to per-operation
Instead of setting hints globally on the Coordinator (which requires coordinating hint state across operations), hints are now passed with each copy/pull submission. This makes the API stateless and allows different operations to use different routing strategies. Key changes: - Add hints parameter to SubmitCopyRequest and SubmitPullRequest - Add RoutingHint with serializable topology path - Update Client.copy() and Client.pull() to accept hints - Remove global add_hint/clear_hints from Coordinator pybind - Add Topology serialization support for hint transport
1 parent 93f6338 commit b8e99db

14 files changed

Lines changed: 248 additions & 94 deletions

File tree

csrc/setu/client/Client.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ std::optional<TensorShardRef> Client::RegisterTensorShard(
150150
return response.shard_ref;
151151
}
152152

153-
std::optional<CopyOperationId> Client::SubmitCopy(const CopySpec& copy_spec) {
153+
std::optional<CopyOperationId> Client::SubmitCopy(
154+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints) {
154155
// Find all shards owned by this client that are involved in the copy
155156
// (either as source or destination)
156157
std::vector<ShardId> involved_shards;
@@ -171,10 +172,14 @@ std::optional<CopyOperationId> Client::SubmitCopy(const CopySpec& copy_spec) {
171172
"Client has no shards for src {} or dst {}",
172173
copy_spec.src_name, copy_spec.dst_name);
173174

175+
// Compute fingerprint once for all shard submissions
176+
const auto fingerprint = setu::planner::hints::Fingerprint(hints);
177+
174178
// Submit a request for each involved shard
175179
std::optional<CopyOperationId> copy_op_id;
176180
for (const auto& shard_id : involved_shards) {
177-
ClientRequest request = SubmitCopyRequest(shard_id, copy_spec);
181+
ClientRequest request =
182+
SubmitCopyRequest(shard_id, copy_spec, hints, fingerprint);
178183
Comm::Send(request_socket_, request);
179184

180185
auto response = Comm::Recv<SubmitCopyResponse>(request_socket_);
@@ -192,18 +197,23 @@ std::optional<CopyOperationId> Client::SubmitCopy(const CopySpec& copy_spec) {
192197
return copy_op_id;
193198
}
194199

195-
std::optional<CopyOperationId> Client::SubmitPull(const CopySpec& copy_spec) {
200+
std::optional<CopyOperationId> Client::SubmitPull(
201+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints) {
196202
// For Pull: only destination shards submit (one-sided operation)
197203
auto it = tensor_shards_.find(copy_spec.dst_name);
198204
ASSERT_VALID_RUNTIME(it != tensor_shards_.end(),
199205
"Client has no shards for dst {}", copy_spec.dst_name);
200206

207+
// Compute fingerprint once for all shard submissions
208+
const auto fingerprint = setu::planner::hints::Fingerprint(hints);
209+
201210
// Submit a request for each destination shard
202211
std::optional<CopyOperationId> copy_op_id;
203212
for (const auto& shard_ref : it->second) {
204213
const auto shard_id = shard_ref->shard_id;
205214

206-
ClientRequest request = SubmitPullRequest(shard_id, copy_spec);
215+
ClientRequest request =
216+
SubmitPullRequest(shard_id, copy_spec, hints, fingerprint);
207217
Comm::Send(request_socket_, request);
208218

209219
auto response = Comm::Recv<SubmitCopyResponse>(request_socket_);

csrc/setu/client/Client.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "commons/utils/TorchTensorIPC.h"
2727
#include "commons/utils/ZmqHelper.h"
2828
#include "messaging/GetTensorHandleResponse.h"
29+
#include "planner/hints/Hint.h"
2930

3031
namespace setu::client {
3132
using setu::commons::CopyOperationId;
@@ -41,6 +42,7 @@ using setu::commons::messages::GetTensorHandleResponse;
4142
using setu::commons::utils::TensorIPCSpec;
4243
using setu::commons::utils::ZmqContextPtr;
4344
using setu::commons::utils::ZmqSocketPtr;
45+
using setu::planner::hints::CompilerHint;
4446

4547
class Client {
4648
public:
@@ -58,9 +60,11 @@ class Client {
5860
std::optional<TensorShardRef> RegisterTensorShard(
5961
const TensorShardSpec& shard_spec);
6062

61-
std::optional<CopyOperationId> SubmitCopy(const CopySpec& copy_spec);
63+
std::optional<CopyOperationId> SubmitCopy(
64+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints = {});
6265

63-
std::optional<CopyOperationId> SubmitPull(const CopySpec& copy_spec);
66+
std::optional<CopyOperationId> SubmitPull(
67+
const CopySpec& copy_spec, const std::vector<CompilerHint>& hints = {});
6468

6569
void WaitForCopy(CopyOperationId copy_op_id);
6670

csrc/setu/client/Pybind.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,16 @@
2323
#include "commons/datatypes/CopySpec.h"
2424
#include "commons/datatypes/TensorShardSpec.h"
2525
#include "commons/enums/Enums.h"
26+
#include "planner/hints/Hint.h"
2627
//==============================================================================
2728
namespace setu::client {
2829
//==============================================================================
2930
using setu::commons::CopyOperationId;
3031
using setu::commons::datatypes::CopySpec;
3132
using setu::commons::datatypes::TensorShardSpec;
3233
using setu::commons::enums::ErrorCode;
34+
using setu::planner::hints::CompilerHint;
35+
using setu::planner::hints::RoutingHint;
3336
//==============================================================================
3437
void InitClientPybindClass(py::module_& m) {
3538
py::class_<Client, std::shared_ptr<Client>>(m, "Client")
@@ -45,10 +48,26 @@ void InitClientPybindClass(py::module_& m) {
4548
.def("register_tensor_shard", &Client::RegisterTensorShard,
4649
py::arg("shard_spec"),
4750
"Register a tensor shard and return a reference to it")
48-
.def("submit_copy", &Client::SubmitCopy, py::arg("copy_spec"),
49-
"Submit a copy operation and return an operation ID")
50-
.def("submit_pull", &Client::SubmitPull, py::arg("copy_spec"),
51-
"Submit a pull operation and return an operation ID")
51+
.def(
52+
"submit_copy",
53+
[](Client& self, const CopySpec& copy_spec,
54+
const std::vector<RoutingHint>& hints) {
55+
std::vector<CompilerHint> compiler_hints(hints.begin(),
56+
hints.end());
57+
return self.SubmitCopy(copy_spec, compiler_hints);
58+
},
59+
py::arg("copy_spec"), py::arg("hints") = std::vector<RoutingHint>{},
60+
"Submit a copy operation and return an operation ID")
61+
.def(
62+
"submit_pull",
63+
[](Client& self, const CopySpec& copy_spec,
64+
const std::vector<RoutingHint>& hints) {
65+
std::vector<CompilerHint> compiler_hints(hints.begin(),
66+
hints.end());
67+
return self.SubmitPull(copy_spec, compiler_hints);
68+
},
69+
py::arg("copy_spec"), py::arg("hints") = std::vector<RoutingHint>{},
70+
"Submit a pull operation and return an operation ID")
5271
.def("wait_for_copy", &Client::WaitForCopy, py::arg("copy_op_id"),
5372
"Wait for a copy operation to complete")
5473
.def("wait_for_shard_allocation", &Client::WaitForShardAllocation,

csrc/setu/coordinator/Coordinator.cpp

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,8 @@ Coordinator::Coordinator(std::size_t port, PlannerPtr planner)
6060

6161
handler_ = std::make_unique<Handler>(inbox_queue_, outbox_queue_, metastore_,
6262
planner_queue_, outbox_notify);
63-
executor_ =
64-
std::make_unique<Executor>(planner_queue_, outbox_queue_, metastore_,
65-
*planner_, hint_store_, outbox_notify);
63+
executor_ = std::make_unique<Executor>(planner_queue_, outbox_queue_,
64+
metastore_, *planner_, outbox_notify);
6665
}
6766

6867
Coordinator::~Coordinator() {
@@ -108,12 +107,6 @@ std::optional<CopyOperationId> Coordinator::SubmitCopy(
108107
return std::nullopt;
109108
}
110109

111-
void Coordinator::AddHint(setu::planner::hints::CompilerHint hint) {
112-
hint_store_.AddHint(std::move(hint));
113-
}
114-
115-
void Coordinator::ClearHints() { hint_store_.Clear(); }
116-
117110
void Coordinator::PlanExecuted(CopyOperationId copy_op_id) {
118111
LOG_DEBUG("Plan executed for copy operation ID: {}", copy_op_id);
119112

@@ -387,7 +380,8 @@ void Coordinator::Handler::HandleSubmitCopyRequest(
387380
metastore_.GetNumShardsForTensor(request.copy_spec.dst_name);
388381

389382
HandleShardSubmission(node_agent_identity, request.request_id,
390-
request.shard_id, request.copy_spec, expected_shards);
383+
request.shard_id, request.copy_spec, expected_shards,
384+
std::vector(request.hints), request.hints_fingerprint);
391385
}
392386

393387
void Coordinator::Handler::HandleSubmitPullRequest(
@@ -413,17 +407,38 @@ void Coordinator::Handler::HandleSubmitPullRequest(
413407
metastore_.GetNumShardsForTensor(request.copy_spec.dst_name);
414408

415409
HandleShardSubmission(node_agent_identity, request.request_id,
416-
request.shard_id, request.copy_spec, expected_shards);
410+
request.shard_id, request.copy_spec, expected_shards,
411+
std::vector(request.hints), request.hints_fingerprint);
417412
}
418413

419414
void Coordinator::Handler::HandleShardSubmission(
420415
const Identity& node_agent_identity, const RequestId& request_id,
421416
const ShardId& shard_id, const CopySpec& copy_spec,
422-
std::size_t expected_shards) {
417+
std::size_t expected_shards,
418+
std::vector<setu::planner::hints::CompilerHint> hints,
419+
std::uint64_t hints_fingerprint) {
423420
using setu::commons::utils::AggregationParticipant;
424421

425422
CopyKey copy_key{copy_spec.src_name, copy_spec.dst_name};
426423

424+
// First-writer-wins hint storage: the first shard submission's hints
425+
// become authoritative for this operation.
426+
if (operation_hints_.find(copy_key) == operation_hints_.end()) {
427+
// First shard for this operation — store its hints
428+
operation_hints_[copy_key] = std::move(hints);
429+
operation_fingerprints_[copy_key] = hints_fingerprint;
430+
} else {
431+
// Subsequent shard — verify fingerprint in debug mode
432+
if (setu::commons::Logger::log_level <= setu::commons::LogLevel::kDebug) {
433+
ASSERT_VALID_RUNTIME(
434+
hints_fingerprint == operation_fingerprints_[copy_key],
435+
"SPMD hint mismatch for {} -> {}: shard {} sent fingerprint {} but "
436+
"first submission had {}",
437+
copy_spec.src_name, copy_spec.dst_name, shard_id, hints_fingerprint,
438+
operation_fingerprints_[copy_key]);
439+
}
440+
}
441+
427442
auto result = shard_aggregator_.Submit(
428443
copy_key, shard_id, copy_spec,
429444
AggregationParticipant{node_agent_identity, request_id}, expected_shards,
@@ -465,8 +480,14 @@ void Coordinator::Handler::HandleShardSubmission(
465480
// Store the shared state (will be accessed by HandleExecuteResponse)
466481
copy_operations_.emplace(copy_op_id, state);
467482

468-
// Add to planner queue with copy_op_id and shared state
469-
planner_queue_.push(PlannerTask{copy_op_id, result->payload, state});
483+
// Extract per-operation hints and clean up tracking maps
484+
auto op_hints = std::move(operation_hints_[copy_key]);
485+
operation_hints_.erase(copy_key);
486+
operation_fingerprints_.erase(copy_key);
487+
488+
// Add to planner queue with copy_op_id, shared state, and per-op hints
489+
planner_queue_.push(PlannerTask{copy_op_id, result->payload, state,
490+
HintStore(std::move(op_hints))});
470491

471492
// Send responses to all waiting participants with copy_op_id
472493
for (const auto& participant : result->participants) {
@@ -583,6 +604,16 @@ void Coordinator::Handler::HandleDeregisterShardsRequest(
583604
tensor_names.contains(key.dst_name);
584605
});
585606

607+
// Clean up per-operation hint tracking for cancelled operations
608+
std::erase_if(operation_hints_, [&tensor_names](const auto& entry) {
609+
return tensor_names.contains(entry.first.src_name) ||
610+
tensor_names.contains(entry.first.dst_name);
611+
});
612+
std::erase_if(operation_fingerprints_, [&tensor_names](const auto& entry) {
613+
return tensor_names.contains(entry.first.src_name) ||
614+
tensor_names.contains(entry.first.dst_name);
615+
});
616+
586617
// Send error responses to cancelled participants
587618
for (const auto& participant : cancelled_participants) {
588619
LOG_INFO(
@@ -637,13 +668,11 @@ void Coordinator::Handler::HandleDeregisterShardsRequest(
637668
Coordinator::Executor::Executor(Queue<PlannerTask>& planner_queue,
638669
Queue<OutboxMessage>& outbox_queue,
639670
MetaStore& metastore, Planner& planner,
640-
HintStore& hint_store,
641671
OutboxNotifyFn outbox_notify)
642672
: planner_queue_(planner_queue),
643673
outbox_queue_(outbox_queue),
644674
metastore_(metastore),
645675
planner_(planner),
646-
hint_store_(hint_store),
647676
outbox_notify_(std::move(outbox_notify)) {}
648677

649678
void Coordinator::Executor::PushOutbox(OutboxMessage msg) {
@@ -676,9 +705,8 @@ void Coordinator::Executor::Loop() {
676705

677706
LOG_DEBUG("Executor received task for copy_op_id: {}", task.copy_op_id);
678707

679-
auto hints = hint_store_.Snapshot();
680708
auto t_compile_start = std::chrono::steady_clock::now();
681-
Plan plan = planner_.Compile(task.copy_spec, metastore_, hints);
709+
Plan plan = planner_.Compile(task.copy_spec, metastore_, task.hints);
682710
auto t_compile_end = std::chrono::steady_clock::now();
683711

684712
LOG_DEBUG("Compiled plan:\n{}", plan);

csrc/setu/coordinator/Coordinator.h

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,6 @@ class Coordinator {
106106

107107
void PlanExecuted(CopyOperationId copy_op_id);
108108

109-
void AddHint(setu::planner::hints::CompilerHint hint);
110-
void ClearHints();
111-
112109
void Start();
113110
void Stop();
114111

@@ -149,12 +146,13 @@ class Coordinator {
149146
CoordinatorMessage message;
150147
};
151148

152-
/// @brief Task for the planner containing CopyOperationId, CopySpec, and
153-
/// shared state
149+
/// @brief Task for the planner containing CopyOperationId, CopySpec,
150+
/// shared state, and per-operation hints from the first shard submission.
154151
struct PlannerTask {
155152
CopyOperationId copy_op_id;
156153
CopySpec copy_spec;
157154
CopyOperationStatePtr state; // Shared with Handler's copy_operations_ map
155+
HintStore hints; // Per-operation hints (first-writer-wins)
158156
};
159157

160158
//============================================================================
@@ -233,11 +231,12 @@ class Coordinator {
233231
const setu::commons::messages::DeregisterShardsRequest& request);
234232

235233
/// @brief Unified shard submission logic for both Copy and Pull.
236-
void HandleShardSubmission(const Identity& node_agent_identity,
237-
const RequestId& request_id,
238-
const ShardId& shard_id,
239-
const CopySpec& copy_spec,
240-
std::size_t expected_shards);
234+
void HandleShardSubmission(
235+
const Identity& node_agent_identity, const RequestId& request_id,
236+
const ShardId& shard_id, const CopySpec& copy_spec,
237+
std::size_t expected_shards,
238+
std::vector<setu::planner::hints::CompilerHint> hints,
239+
std::uint64_t hints_fingerprint);
241240

242241
/// Key for tracking copy operations by (src, dst) tensor pair
243242
struct CopyKey {
@@ -262,6 +261,13 @@ class Coordinator {
262261
/// address that
263262
setu::commons::utils::ShardAggregator<CopyKey, CopySpec> shard_aggregator_;
264263

264+
/// Per-operation hint storage: first shard's hints become authoritative
265+
/// (first-writer-wins). Keyed by CopyKey, cleared when aggregation
266+
/// completes.
267+
std::map<CopyKey, std::vector<setu::planner::hints::CompilerHint>>
268+
operation_hints_;
269+
std::map<CopyKey, std::uint64_t> operation_fingerprints_;
270+
265271
/// Maps CopyOperationId to shared CopyOperationState (includes submitters
266272
/// and completion tracking)
267273
std::map<CopyOperationId, CopyOperationStatePtr> copy_operations_;
@@ -285,8 +291,7 @@ class Coordinator {
285291
struct Executor {
286292
Executor(Queue<PlannerTask>& planner_queue,
287293
Queue<OutboxMessage>& outbox_queue, MetaStore& metastore,
288-
Planner& planner, HintStore& hint_store,
289-
OutboxNotifyFn outbox_notify);
294+
Planner& planner, OutboxNotifyFn outbox_notify);
290295

291296
void Start();
292297
void Stop();
@@ -300,7 +305,6 @@ class Coordinator {
300305
Queue<OutboxMessage>& outbox_queue_;
301306
MetaStore& metastore_;
302307
Planner& planner_;
303-
HintStore& hint_store_;
304308
OutboxNotifyFn outbox_notify_;
305309

306310
std::thread thread_;
@@ -313,7 +317,6 @@ class Coordinator {
313317

314318
MetaStore metastore_;
315319
PlannerPtr planner_;
316-
HintStore hint_store_;
317320

318321
// Internal message queues
319322
Queue<InboxMessage> inbox_queue_;

csrc/setu/coordinator/Pybind.cpp

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,37 +21,18 @@
2121
#include "commons/TorchCommon.h"
2222
#include "coordinator/Coordinator.h"
2323
#include "metastore/Pybind.h"
24-
#include "planner/hints/Hint.h"
2524
//==============================================================================
2625
namespace setu::coordinator {
2726
//==============================================================================
2827
using setu::planner::PlannerPtr;
29-
using setu::planner::hints::CompilerHint;
30-
using setu::planner::hints::RoutingHint;
3128
//==============================================================================
3229
void InitCoordinatorPybindClass(py::module_& m) {
3330
py::class_<Coordinator, std::shared_ptr<Coordinator>>(m, "Coordinator")
3431
.def(py::init<std::size_t, PlannerPtr>(), py::arg("port"),
3532
py::arg("planner"),
3633
"Create a Coordinator with specified port and planner")
3734
.def("start", &Coordinator::Start, "Start the Coordinator loops")
38-
.def("stop", &Coordinator::Stop, "Stop the Coordinator loops")
39-
// TODO: Ideally we'd bind AddHint directly:
40-
// .def("add_hint", &Coordinator::AddHint)
41-
// and let pybind11/stl.h auto-cast RoutingHint -> CompilerHint
42-
// (std::variant). However, this fails at compile time —
43-
// pybind11 can't default-construct the type_caster tuple for
44-
// the member function pointer's arguments. Root cause is
45-
// unclear, investigate later as this is not a blocking concern.
46-
// Using a per-type lambda as a workaround; needs a new overload
47-
// for each hint type added to CompilerHint.
48-
.def(
49-
"add_hint",
50-
[](Coordinator& self, const RoutingHint& hint) {
51-
self.AddHint(CompilerHint{hint});
52-
},
53-
py::arg("hint"), "Add a compiler hint (e.g. RoutingHint)")
54-
.def("clear_hints", &Coordinator::ClearHints, "Clear all compiler hints");
35+
.def("stop", &Coordinator::Stop, "Stop the Coordinator loops");
5536
}
5637
//==============================================================================
5738
} // namespace setu::coordinator

0 commit comments

Comments
 (0)