Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/setu/coordinator/Coordinator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ Coordinator::Coordinator(std::size_t port, PlannerPtr planner,

auto outbox_notify = [this]() { gateway_->NotifyOutbox(); };

handler_ =
std::make_unique<Handler>(inbox_queue_, outbox_queue_, metastore_,
planner_queue_, outbox_notify, handler_sink);
handler_ = std::make_unique<Handler>(inbox_queue_, outbox_queue_, metastore_,
planner_queue_, outbox_notify,
handler_sink);
executor_ =
std::make_unique<Executor>(planner_queue_, outbox_queue_, metastore_,
*planner_, outbox_notify, executor_sink);
Expand Down
6 changes: 3 additions & 3 deletions csrc/setu/coordinator/Coordinator.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ class Coordinator {
Queue<InboxMessage> inbox_queue_;
Queue<OutboxMessage> outbox_queue_;

/// Queue of PlannerTasks (CopyOperationId + CopySpec) for the Executor to
/// compile and dispatch
Queue<PlannerTask> planner_queue_;
/// Queue of ExecutorTasks for the Executor to process (compile+dispatch or
/// onboarding)
Queue<ExecutorTask> planner_queue_;

std::unique_ptr<Gateway> gateway_;
std::unique_ptr<Handler> handler_;
Expand Down
96 changes: 60 additions & 36 deletions csrc/setu/coordinator/Executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
//==============================================================================
namespace setu::coordinator {
//==============================================================================
using setu::commons::enums::ErrorCode;
using setu::commons::messages::ExecuteRequest;
using setu::commons::messages::OnboardNodeAgentResponse;
using setu::planner::Plan;
//==============================================================================
Executor::Executor(Queue<PlannerTask>& planner_queue,
Executor::Executor(Queue<ExecutorTask>& planner_queue,
Queue<OutboxMessage>& outbox_queue, MetaStore& metastore,
Planner& planner, OutboxNotifyFn outbox_notify,
setu::telemetry::MetricsSinkPtr metrics_sink)
Expand Down Expand Up @@ -62,52 +64,74 @@ void Executor::Loop() {
running_ = true;
while (running_) {
try {
PlannerTask task = planner_queue_.pull();
auto t_after_dequeue = std::chrono::steady_clock::now();
ExecutorTask task = planner_queue_.pull();

std::visit(
[&](auto&& alt) {
using T = std::decay_t<decltype(alt)>;
if constexpr (std::is_same_v<T, PlannerTask>) {
HandlePlannerTask(std::move(alt));
} else if constexpr (std::is_same_v<T, OnboardingTask>) {
HandleOnboardingTask(std::move(alt));
}
},
std::move(task));
} catch (const boost::concurrent::sync_queue_is_closed&) {
return;
}
}
}

LOG_DEBUG("Executor received task for copy_op_id: {}", task.copy_op_id);
void Executor::HandlePlannerTask(PlannerTask task) {
auto t_after_dequeue = std::chrono::steady_clock::now();

auto result = planner_.Compile(task.copy_spec, metastore_, task.hints,
task.copy_op_id);
Plan plan = std::move(result.plan);
LOG_DEBUG("Executor received task for copy_op_id: {}", task.copy_op_id);

// Submit compilation metrics
if (metrics_sink_ && metrics_sink_->IsEnabled()) {
metrics_sink_->Submit(
setu::telemetry::MetricsMessage{std::move(result.metrics)});
}
auto result = planner_.Compile(task.copy_spec, metastore_, task.hints,
task.copy_op_id);
Plan plan = std::move(result.plan);

LOG_DEBUG("Compiled plan:\n{}", plan);
// Submit compilation metrics
if (metrics_sink_ && metrics_sink_->IsEnabled()) {
metrics_sink_->Submit(
setu::telemetry::MetricsMessage{std::move(result.metrics)});
}

// Fragment the plan to into per-node fragments
auto fragments = plan.Fragments();
LOG_DEBUG("Compiled plan:\n{}", plan);

// Send ExecuteRequest to each node agent
for (auto& [node_id, node_plan] : fragments) {
Identity node_identity = boost::uuids::to_string(node_id) + "_dealer";
// Fragment the plan to into per-node fragments
auto fragments = plan.Fragments();

ExecuteRequest execute_request(task.copy_op_id, std::move(node_plan));
// Send ExecuteRequest to each node agent
for (auto& [node_id, node_plan] : fragments) {
Identity node_identity = boost::uuids::to_string(node_id) + "_dealer";

PushOutbox(OutboxMessage{node_identity, execute_request});
}
ExecuteRequest execute_request(task.copy_op_id, std::move(node_plan));

// Set expected responses
// memory order release so Handler thread can pick it up (using memory
// order aqcuire)
task.state->expected_responses.store(fragments.size(),
std::memory_order_release);
PushOutbox(OutboxMessage{node_identity, execute_request});
}

auto t_end = std::chrono::steady_clock::now();
auto to_us = [](auto d) {
return std::chrono::duration_cast<std::chrono::microseconds>(d).count();
};
LOG_INFO("Executor: copy_op_id={}, total={}us", task.copy_op_id,
to_us(t_end - t_after_dequeue));
// Set expected responses
// memory order release so Handler thread can pick it up (using memory
// order aqcuire)
task.state->expected_responses.store(fragments.size(),
std::memory_order_release);

auto t_end = std::chrono::steady_clock::now();
auto to_us = [](auto d) {
return std::chrono::duration_cast<std::chrono::microseconds>(d).count();
};
LOG_INFO("Executor: copy_op_id={}, total={}us", task.copy_op_id,
to_us(t_end - t_after_dequeue));
}

} catch (const boost::concurrent::sync_queue_is_closed&) {
return;
}
}
void Executor::HandleOnboardingTask(OnboardingTask task) {
LOG_INFO("Executor processing OnboardingTask ({} devices)",
task.register_sets.size());
planner_.AddBackendRegisterSets(task.register_sets);

OnboardNodeAgentResponse response(task.request_id, ErrorCode::kSuccess);
PushOutbox(OutboxMessage{task.node_agent_identity, response});
}
//==============================================================================
} // namespace setu::coordinator
Expand Down
7 changes: 5 additions & 2 deletions csrc/setu/coordinator/Executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ using setu::planner::Planner;
/// ExecuteRequests through the outbox queue.
class Executor {
public:
Executor(Queue<PlannerTask>& planner_queue,
Executor(Queue<ExecutorTask>& planner_queue,
Queue<OutboxMessage>& outbox_queue, MetaStore& metastore,
Planner& planner, OutboxNotifyFn outbox_notify,
setu::telemetry::MetricsSinkPtr metrics_sink);
Expand All @@ -50,7 +50,10 @@ class Executor {

void PushOutbox(OutboxMessage msg);

Queue<PlannerTask>& planner_queue_;
void HandlePlannerTask(PlannerTask task);
void HandleOnboardingTask(OnboardingTask task);

Queue<ExecutorTask>& planner_queue_;
Queue<OutboxMessage>& outbox_queue_;
MetaStore& metastore_;
Planner& planner_;
Expand Down
17 changes: 16 additions & 1 deletion csrc/setu/coordinator/Handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using setu::commons::utils::AggregationParticipant;
//==============================================================================
Handler::Handler(Queue<InboxMessage>& inbox_queue,
Queue<OutboxMessage>& outbox_queue, MetaStore& metastore,
Queue<PlannerTask>& planner_queue,
Queue<ExecutorTask>& planner_queue,
OutboxNotifyFn outbox_notify,
setu::telemetry::MetricsSinkPtr metrics_sink)
: inbox_queue_(inbox_queue),
Expand Down Expand Up @@ -88,6 +88,8 @@ void Handler::Loop() {
HandleGetTensorSpecRequest(inbox_msg.node_agent_identity, msg);
} else if constexpr (std::is_same_v<T, DeregisterShardsRequest>) {
HandleDeregisterShardsRequest(inbox_msg.node_agent_identity, msg);
} else if constexpr (std::is_same_v<T, OnboardNodeAgentRequest>) {
HandleOnboardNodeAgentRequest(inbox_msg.node_agent_identity, msg);
} else {
LOG_WARNING("Handler: Unknown message type (index={})",
inbox_msg.request.index());
Expand Down Expand Up @@ -434,6 +436,19 @@ void Handler::HandleDeregisterShardsRequest(
tensor_names.size(), node_agent_identity);
}
}

void Handler::HandleOnboardNodeAgentRequest(
const Identity& node_agent_identity,
const OnboardNodeAgentRequest& request) {
LOG_INFO("Coordinator received OnboardNodeAgentRequest from {} ({} devices)",
node_agent_identity, request.register_sets.size());

// Route to Executor thread — Planner is only accessed from Executor.
// Executor sends the response after processing.
planner_queue_.push(
OnboardingTask{node_agent_identity, request.request_id,
request.register_sets});
}
//==============================================================================
} // namespace setu::coordinator
//==============================================================================
7 changes: 5 additions & 2 deletions csrc/setu/coordinator/Handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ using setu::commons::NodeId;
using setu::commons::messages::DeregisterShardsRequest;
using setu::commons::messages::ExecuteResponse;
using setu::commons::messages::GetTensorSpecRequest;
using setu::commons::messages::OnboardNodeAgentRequest;
using setu::commons::messages::RegisterTensorShardRequest;
using setu::commons::messages::SubmitCopyRequest;
using setu::commons::messages::SubmitPullRequest;
Expand All @@ -50,7 +51,7 @@ using setu::metastore::MetaStore;
class Handler {
public:
Handler(Queue<InboxMessage>& inbox_queue, Queue<OutboxMessage>& outbox_queue,
MetaStore& metastore, Queue<PlannerTask>& planner_queue,
MetaStore& metastore, Queue<ExecutorTask>& planner_queue,
OutboxNotifyFn outbox_notify,
setu::telemetry::MetricsSinkPtr metrics_sink);

Expand All @@ -75,14 +76,16 @@ class Handler {
const GetTensorSpecRequest& request);
void HandleDeregisterShardsRequest(const Identity& node_agent_identity,
const DeregisterShardsRequest& request);
void HandleOnboardNodeAgentRequest(const Identity& node_agent_identity,
const OnboardNodeAgentRequest& request);

/// @brief Unified shard submission logic for both Copy and Pull.
void HandleShardSubmission(DispatchManager::ShardSubmission submission);

Queue<InboxMessage>& inbox_queue_;
Queue<OutboxMessage>& outbox_queue_;
MetaStore& metastore_;
Queue<PlannerTask>& planner_queue_;
Queue<ExecutorTask>& planner_queue_;
OutboxNotifyFn outbox_notify_;
setu::telemetry::MetricsSinkPtr metrics_sink_;

Expand Down
16 changes: 16 additions & 0 deletions csrc/setu/coordinator/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
//==============================================================================
#include "commons/datatypes/CopySpec.h"
#include "messaging/Messages.h"
#include "planner/Participant.h"
#include "planner/RegisterSet.h"
#include "planner/hints/HintStore.h"
//==============================================================================
namespace setu::coordinator {
Expand Down Expand Up @@ -108,5 +110,19 @@ struct PlannerTask {
HintStore hints; // Per-operation hints (first-writer-wins)
};
//==============================================================================
// Onboarding task
//==============================================================================

/// @brief Task to add register sets to the planner backend.
struct OnboardingTask {
Identity node_agent_identity;
RequestId request_id;
std::unordered_map<setu::planner::Participant, setu::planner::RegisterSet>
register_sets;
};

/// @brief Variant of tasks the Executor can process.
using ExecutorTask = std::variant<PlannerTask, OnboardingTask>;
//==============================================================================
} // namespace setu::coordinator
//==============================================================================
6 changes: 4 additions & 2 deletions csrc/setu/messaging/Messages.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include "messaging/GetTensorSelectionResponse.h"
#include "messaging/GetTensorSpecRequest.h"
#include "messaging/GetTensorSpecResponse.h"
#include "messaging/OnboardNodeAgentRequest.h"
#include "messaging/OnboardNodeAgentResponse.h"
#include "messaging/RegisterTensorShardCoordinatorResponse.h"
#include "messaging/RegisterTensorShardNodeAgentResponse.h"
#include "messaging/RegisterTensorShardRequest.h"
Expand All @@ -60,14 +62,14 @@ using ClientRequest =
using NodeAgentRequest =
std::variant<RegisterTensorShardRequest, SubmitCopyRequest,
SubmitPullRequest, ExecuteResponse, GetTensorSpecRequest,
DeregisterShardsRequest>;
DeregisterShardsRequest, OnboardNodeAgentRequest>;

/// @brief All messages from Coordinator to NodeAgent
using CoordinatorMessage =
std::variant<AllocateTensorRequest, CopyOperationFinishedRequest,
ExecuteRequest, RegisterTensorShardCoordinatorResponse,
SubmitCopyResponse, WaitForCopyResponse, GetTensorSpecResponse,
DeregisterShardsResponse>;
DeregisterShardsResponse, OnboardNodeAgentResponse>;

using Request =
std::variant<RegisterTensorShardRequest, SubmitCopyRequest,
Expand Down
41 changes: 41 additions & 0 deletions csrc/setu/messaging/OnboardNodeAgentRequest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//==============================================================================
// Copyright (c) 2025 Vajra Team; Georgia Institute of Technology; Microsoft
// Corporation.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//==============================================================================
#include "messaging/OnboardNodeAgentRequest.h"
//==============================================================================
namespace setu::commons::messages {
//==============================================================================
using setu::commons::utils::BinaryReader;
using setu::commons::utils::BinaryWriter;
//==============================================================================

void OnboardNodeAgentRequest::Serialize(BinaryBuffer& buffer) const {
BinaryWriter writer(buffer);
writer.WriteFields(request_id, register_sets);
}

OnboardNodeAgentRequest OnboardNodeAgentRequest::Deserialize(
const BinaryRange& range) {
BinaryReader reader(range);
auto [request_id_val, register_sets_val] =
reader.ReadFields<RequestId,
std::unordered_map<Participant, RegisterSet>>();
return OnboardNodeAgentRequest(request_id_val, std::move(register_sets_val));
}

//==============================================================================
} // namespace setu::commons::messages
//==============================================================================
Loading