Skip to content

Commit 8380e6e

Browse files
authored
Fix span lifecycle with smart pointers to prevent use-after-free in async RPC callbacks (#3140)
* Fix span lifecycle with smart pointers to prevent use-after-free in async RPC callbacks (#3068) * Refactor bthread span lifecycle management and optimize span API with smart pointer reuse (#3068) --------- Co-authored-by: lhh <lhh>
1 parent 625843e commit 8380e6e

34 files changed

Lines changed: 729 additions & 292 deletions

docs/cn/rpcz.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,14 @@ bthread_attr_t attr = { BTHREAD_STACKTYPE_NORMAL, BTHREAD_INHERIT_SPAN, NULL };
6666
bthread_start_urgent(&tid, &attr, thread_proc, arg);
6767
```
6868

69-
注意:使用这种方式创建子bthread来发送rpc,请确保rpc在server返回response之前完成,否则可能导致使用被释放的Span对象而出core。
69+
### Span生命周期管理
70+
71+
brpc使用智能指针(`std::shared_ptr`/`std::weak_ptr`)管理Span对象的生命周期,并通过自旋锁保护并发访问,解决了以下问题:
72+
73+
1. **Use-after-free防护**:父Span通过`shared_ptr`持有子Span的强引用,TLS中使用`weak_ptr`存储,确保Span对象在被访问时仍然有效。即使server在子bthread完成前返回response,也不会导致访问已释放的Span对象。
74+
75+
2. **线程安全**:使用自旋锁保护`_client_list``_info`的并发修改,支持多个bthread同时创建子span或添加annotation。
76+
77+
3. **自动生命周期管理**:当父Span销毁时,会自动清理所有子Span(通过`_client_list.clear()`),无需手动管理。
78+
79+
使用`BTHREAD_INHERIT_SPAN`创建子bthread时,不再需要担心Span对象的生命周期问题,可以安全地在异步场景中使用。

src/brpc/builtin/rpcz_service.cpp

Lines changed: 63 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -185,16 +185,43 @@ static void PrintElapse(std::ostream& os, int64_t cur_time,
185185

186186
static void PrintAnnotations(
187187
std::ostream& os, int64_t cur_time, int64_t* last_time,
188-
SpanInfoExtractor** extractors, int num_extr) {
188+
SpanInfoExtractor** extractors, int num_extr, const RpczSpan* span) {
189189
int64_t anno_time;
190190
std::string a;
191+
const char* span_type_str = "Span";
192+
if (span) {
193+
switch (span->type()) {
194+
case SPAN_TYPE_SERVER:
195+
span_type_str = "ServerSpan";
196+
break;
197+
case SPAN_TYPE_CLIENT:
198+
span_type_str = "ClientSpan";
199+
break;
200+
case SPAN_TYPE_BTHREAD:
201+
span_type_str = "BthreadSpan";
202+
break;
203+
}
204+
}
205+
191206
// TODO: Going through all extractors is not strictly correct because
192207
// later extractors may have earlier annotations.
193208
for (int i = 0; i < num_extr; ++i) {
194209
while (extractors[i]->PopAnnotation(cur_time, &anno_time, &a)) {
195210
PrintRealTime(os, anno_time);
196211
PrintElapse(os, anno_time, last_time);
197-
os << ' ' << WebEscape(a);
212+
os << ' ';
213+
if (span) {
214+
const char* short_type = "SPAN";
215+
if (span->type() == SPAN_TYPE_SERVER) {
216+
short_type = "Server";
217+
} else if (span->type() == SPAN_TYPE_CLIENT) {
218+
short_type = "Client";
219+
} else if (span->type() == SPAN_TYPE_BTHREAD) {
220+
short_type = "Bthread";
221+
}
222+
os << '[' << short_type << " SPAN#" << Hex(span->span_id()) << "] ";
223+
}
224+
os << WebEscape(a);
198225
if (a.empty() || butil::back_char(a) != '\n') {
199226
os << '\n';
200227
}
@@ -204,12 +231,12 @@ static void PrintAnnotations(
204231

205232
static bool PrintAnnotationsAndRealTimeSpan(
206233
std::ostream& os, int64_t cur_time, int64_t* last_time,
207-
SpanInfoExtractor** extr, int num_extr) {
234+
SpanInfoExtractor** extr, int num_extr, const RpczSpan* span) {
208235
if (cur_time == 0) {
209236
// the field was not set.
210237
return false;
211238
}
212-
PrintAnnotations(os, cur_time, last_time, extr, num_extr);
239+
PrintAnnotations(os, cur_time, last_time, extr, num_extr, span);
213240
PrintRealTime(os, cur_time);
214241
PrintElapse(os, cur_time, last_time);
215242
return true;
@@ -239,9 +266,10 @@ static void PrintClientSpan(
239266
extr[num_extr++] = server_extr;
240267
}
241268
extr[num_extr++] = &client_extr;
242-
// start_send_us is always set for client spans.
243-
CHECK(PrintAnnotationsAndRealTimeSpan(os, span.start_send_real_us(),
244-
last_time, extr, num_extr));
269+
if (!PrintAnnotationsAndRealTimeSpan(os, span.start_send_real_us(),
270+
last_time, extr, num_extr, &span)) {
271+
os << " start_send_real_us:not-set";
272+
}
245273
const Protocol* protocol = FindProtocol(span.protocol());
246274
const char* protocol_name = (protocol ? protocol->name : "Unknown");
247275
const butil::EndPoint remote_side(butil::int2ip(span.remote_ip()), span.remote_port());
@@ -271,12 +299,12 @@ static void PrintClientSpan(
271299
os << std::endl;
272300

273301
if (PrintAnnotationsAndRealTimeSpan(os, span.sent_real_us(),
274-
last_time, extr, num_extr)) {
275-
os << " Requested(" << span.request_size() << ") [1]" << std::endl;
302+
last_time, extr, num_extr, &span)) {
303+
os << " [Client SPAN#" << Hex(span.span_id()) << "] Requested(" << span.request_size() << ") [1]" << std::endl;
276304
}
277305
if (PrintAnnotationsAndRealTimeSpan(os, span.received_real_us(),
278-
last_time, extr, num_extr)) {
279-
os << " Received response(" << span.response_size() << ")";
306+
last_time, extr, num_extr, &span)) {
307+
os << " [Client SPAN#" << Hex(span.span_id()) << "] Received response(" << span.response_size() << ")";
280308
if (span.base_cid() != 0 && span.ending_cid() != 0) {
281309
int64_t ver = span.ending_cid() - span.base_cid();
282310
if (ver >= 1) {
@@ -289,18 +317,18 @@ static void PrintClientSpan(
289317
}
290318

291319
if (PrintAnnotationsAndRealTimeSpan(os, span.start_parse_real_us(),
292-
last_time, extr, num_extr)) {
293-
os << " Processing the response in a new bthread" << std::endl;
320+
last_time, extr, num_extr, &span)) {
321+
os << " [Client SPAN#" << Hex(span.span_id()) << "] Processing the response in a new bthread" << std::endl;
294322
}
295323

296324
if (PrintAnnotationsAndRealTimeSpan(
297325
os, span.start_callback_real_us(),
298-
last_time, extr, num_extr)) {
299-
os << (span.async() ? " Enter user's done" : " Back to user's callsite") << std::endl;
326+
last_time, extr, num_extr, &span)) {
327+
os << " [Client SPAN#" << Hex(span.span_id()) << "] " << (span.async() ? " Enter user's done" : " Back to user's callsite") << std::endl;
300328
}
301329

302330
PrintAnnotations(os, std::numeric_limits<int64_t>::max(),
303-
last_time, extr, num_extr);
331+
last_time, extr, num_extr, &span);
304332
}
305333

306334
static void PrintClientSpan(std::ostream& os,const RpczSpan& span,
@@ -318,7 +346,15 @@ static void PrintBthreadSpan(std::ostream& os, const RpczSpan& span, int64_t* la
318346
extr[num_extr++] = server_extr;
319347
}
320348
extr[num_extr++] = &client_extr;
321-
PrintAnnotations(os, std::numeric_limits<int64_t>::max(), last_time, extr, num_extr);
349+
350+
// Print span id for bthread span context identification
351+
os << " [Bthread SPAN#" << Hex(span.span_id());
352+
if (span.parent_span_id() != 0) {
353+
os << " parent#" << Hex(span.parent_span_id());
354+
}
355+
os << "] ";
356+
357+
PrintAnnotations(os, std::numeric_limits<int64_t>::max(), last_time, extr, num_extr, &span);
322358
}
323359

324360
static void PrintServerSpan(std::ostream& os, const RpczSpan& span,
@@ -348,16 +384,16 @@ static void PrintServerSpan(std::ostream& os, const RpczSpan& span,
348384
os << std::endl;
349385
if (PrintAnnotationsAndRealTimeSpan(
350386
os, span.start_parse_real_us(),
351-
&last_time, extr, ARRAY_SIZE(extr))) {
352-
os << " Processing the request in a new bthread" << std::endl;
387+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
388+
os << " [Server SPAN#" << Hex(span.span_id()) << "] Processing the request in a new bthread" << std::endl;
353389
}
354390

355391
bool entered_user_method = false;
356392
if (PrintAnnotationsAndRealTimeSpan(
357393
os, span.start_callback_real_us(),
358-
&last_time, extr, ARRAY_SIZE(extr))) {
394+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
359395
entered_user_method = true;
360-
os << " Enter " << WebEscape(span.full_method_name()) << std::endl;
396+
os << " [Server SPAN#" << Hex(span.span_id()) << "] Enter " << WebEscape(span.full_method_name()) << std::endl;
361397
}
362398

363399
const int nclient = span.client_spans_size();
@@ -372,22 +408,22 @@ static void PrintServerSpan(std::ostream& os, const RpczSpan& span,
372408

373409
if (PrintAnnotationsAndRealTimeSpan(
374410
os, span.start_send_real_us(),
375-
&last_time, extr, ARRAY_SIZE(extr))) {
411+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
376412
if (entered_user_method) {
377-
os << " Leave " << WebEscape(span.full_method_name()) << std::endl;
413+
os << " [Server SPAN#" << Hex(span.span_id()) << "] Leave " << WebEscape(span.full_method_name()) << std::endl;
378414
} else {
379-
os << " Responding" << std::endl;
415+
os << " [Server SPAN#" << Hex(span.span_id()) << "] Responding" << std::endl;
380416
}
381417
}
382418

383419
if (PrintAnnotationsAndRealTimeSpan(
384420
os, span.sent_real_us(),
385-
&last_time, extr, ARRAY_SIZE(extr))) {
386-
os << " Responded(" << span.response_size() << ')' << std::endl;
421+
&last_time, extr, ARRAY_SIZE(extr), &span)) {
422+
os << " [Server SPAN#" << Hex(span.span_id()) << "] Responded(" << span.response_size() << ')' << std::endl;
387423
}
388424

389425
PrintAnnotations(os, std::numeric_limits<int64_t>::max(),
390-
&last_time, extr, ARRAY_SIZE(extr));
426+
&last_time, extr, ARRAY_SIZE(extr), &span);
391427
}
392428

393429
class RpczSpanFilter : public SpanFilter {

src/brpc/channel.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "brpc/rdma/rdma_helper.h"
3939
#include "brpc/policy/esp_authenticator.h"
4040
#include "brpc/transport_factory.h"
41+
#include "brpc/details/controller_private_accessor.h"
4142

4243
namespace brpc {
4344

@@ -502,7 +503,7 @@ void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
502503
}
503504
cntl->set_used_by_rpc();
504505

505-
if (cntl->_sender == NULL && IsTraceable(Span::tls_parent())) {
506+
if (cntl->_sender == NULL && IsTraceable(Span::tls_parent().get())) {
506507
const int64_t start_send_us = butil::cpuwide_time_us();
507508
std::string method_name;
508509
if (_get_method_name) {
@@ -513,13 +514,16 @@ void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
513514
const static std::string NULL_METHOD_STR = "null-method";
514515
method_name = NULL_METHOD_STR;
515516
}
516-
Span* span = Span::CreateClientSpan(
517+
std::shared_ptr<Span> span = Span::CreateClientSpan(
517518
method_name, start_send_real_us - start_send_us);
518-
span->set_log_id(cntl->log_id());
519-
span->set_base_cid(correlation_id);
520-
span->set_protocol(_options.protocol);
521-
span->set_start_send_us(start_send_us);
522-
cntl->_span = span;
519+
if (span) {
520+
ControllerPrivateAccessor accessor(cntl);
521+
span->set_log_id(cntl->log_id());
522+
span->set_base_cid(correlation_id);
523+
span->set_protocol(_options.protocol);
524+
span->set_start_send_us(start_send_us);
525+
accessor.set_span(span);
526+
}
523527
}
524528
// Override some options if they haven't been set by Controller
525529
if (cntl->timeout_ms() == UNSET_MAGIC_NUM) {
@@ -620,9 +624,7 @@ void Channel::CallMethod(const google::protobuf::MethodDescriptor* method,
620624
// be woken up by callback when RPC finishes (succeeds or still
621625
// fails after retry)
622626
Join(correlation_id);
623-
if (cntl->_span) {
624-
cntl->SubmitSpan();
625-
}
627+
cntl->SubmitSpan();
626628
cntl->OnRPCEnd(butil::gettimeofday_us());
627629
}
628630
}

0 commit comments

Comments
 (0)