Skip to content

Commit b899d0c

Browse files
committed
feat: add support for opaque pointer
1 parent b1c1789 commit b899d0c

7 files changed

Lines changed: 140 additions & 13 deletions

File tree

pybind11_weaver/entity/entity_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ def _inject_docstring(code: str, cursor: cindex.Cursor, insert_mode: str):
1313
if not cursor.raw_comment:
1414
return code
1515
if insert_mode == "append":
16-
code += f',R"({cursor.raw_comment})"'
16+
code += f',R"_pb11_weaver({cursor.raw_comment})_pb11_weaver"'
1717
if insert_mode == "last_arg":
1818
pos = code.rfind(")")
19-
code = code[:pos] + f',R"({cursor.raw_comment})"' + code[pos:]
19+
code = code[:pos] + f',R"_pb11_weaver({cursor.raw_comment})_pb11_weaver"' + code[pos:]
2020
return code
2121

2222

pybind11_weaver/entity/funktion.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ def update_stmts(self, pybind11_obj_sym: str) -> List[str]:
2626
code = []
2727
targets = [self.cursor] + self.overloads
2828
for t in targets:
29-
fn_pointer_type = fn.get_fn_pointer_type(t)
3029
code.append(
31-
f"{pybind11_obj_sym}.def(\"{self.name}\",static_cast<{fn_pointer_type}>(&{self.qualified_name()}));")
30+
f"{pybind11_obj_sym}.def(\"{self.name}\",{fn.get_fn_value_expr(t)});")
3231
if self.gu.io_config.gen_docstring:
3332
code[-1] = entity_base._inject_docstring(code[-1], t, "last_arg")
3433
return code

pybind11_weaver/entity/klass.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ def is_static(cursor):
3636
return cursor.is_static_method()
3737

3838
# to support overload, we will cast member function to function pointer
39-
pointer = f"&{scope_full_name}::{cursor.spelling}"
40-
method_pointer_type = fn.get_fn_pointer_type(cursor)
41-
casted_pointer = f"static_cast<{method_pointer_type}>({pointer})"
39+
40+
casted_pointer = fn.get_fn_value_expr(cursor)
4241
if is_static(cursor):
4342
self.body.append(
4443
f"obj.def_static(\"{cursor.spelling}\",{casted_pointer});")
@@ -53,6 +52,15 @@ def is_static(cursor):
5352
self.body[-1], cursor, "last_arg")
5453

5554

55+
def _is_bindable_type(type: cindex.Type):
56+
type = type.get_canonical()
57+
if type.kind in [cindex.TypeKind.CONSTANTARRAY, cindex.TypeKind.INCOMPLETEARRAY, cindex.TypeKind.VARIABLEARRAY]:
58+
return False
59+
if fn.warp_type(type, "")[0] is not None:
60+
return False
61+
return True
62+
63+
5664
class ClassEntity(entity_base.Entity):
5765

5866
def __init__(self, gu: gen_unit.GenUnit, cursor: cindex.Cursor):
@@ -126,7 +134,7 @@ def not_operator(cursor):
126134
for cursor in self.cursor.get_children():
127135
if cursor.kind == cindex.CursorKind.FIELD_DECL and \
128136
is_pubic(cursor) and \
129-
cursor.type.kind != cindex.TypeKind.CONSTANTARRAY:
137+
_is_bindable_type(cursor.type):
130138
filed_binder = "def_readwrite"
131139
if cursor.type.is_const_qualified():
132140
filed_binder = "def_readonly"

pybind11_weaver/gen_code.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from .entity import entity_base
66
from . import entity_tree
77
from . import gen_unit
8+
from .utils import fn
89

910
entity_template = """
10-
11+
#ifndef PB11_WEAVER_DISABLE_{entity_struct_name}
1112
template <class Pybind11T> struct {bind_struct_name} : public EntityBase {{
1213
using Pybind11Type = Pybind11T;
1314
@@ -45,6 +46,7 @@
4546
{handle_type} handle;
4647
4748
}};
49+
#endif // PB11_WEAVER_DISABLE_{entity_struct_name}
4850
"""
4951

5052
file_template = """
@@ -67,6 +69,8 @@
6769
* If the returned guard is not called, the guard will call the update function on its destruction.
6870
**/
6971
[[nodiscard]] pybind11_weaver::CallUpdateGuard {decl_fn_name}(pybind11::module & m, const pybind11_weaver::CustomBindingRegistry & registry){{
72+
{create_warped_pointer_bindings}
73+
7074
{create_entity_var_stmts}
7175
7276
auto update_fn = [=](){{
@@ -147,6 +151,7 @@ def gen_code(config_file: str):
147151
pybind11_weaver_header=pybind11_weaver_header,
148152
decl_fn_name=gu.io_config.decl_fn_name,
149153
entity_struct_decls="\n".join(entity_struct_decls),
154+
create_warped_pointer_bindings=gen_wrapped_pointer_code(),
150155
create_entity_var_stmts="\n".join(create_entity_var_stmts),
151156
update_entity_var_stmts="\n".join(update_entity_var_stmts),
152157
)
@@ -156,3 +161,19 @@ def gen_code(config_file: str):
156161
# format file if clang-format found
157162
if shutil.which("clang-format") is not None:
158163
os.system(f"clang-format -i {gu.io_config.output} --style=LLVM")
164+
165+
166+
def gen_wrapped_pointer_code():
167+
wrapped_types = fn.get_wrapped_types()
168+
create_warped_pointer_bindings = []
169+
for type in sorted(wrapped_types):
170+
wrapped_type_binding_code_template = "pybind11_weaver::PointerWrapper<{type}>::FastBind(m,\"{safe_type_name}\");"
171+
safe_type_name = type.replace(" ", "")
172+
safe_type_name = safe_type_name.replace(",", "_")
173+
safe_type_name = safe_type_name.replace("*", "p")
174+
safe_type_name = safe_type_name.replace("(", "6")
175+
safe_type_name = safe_type_name.replace(")", "9")
176+
create_warped_pointer_bindings.append(
177+
wrapped_type_binding_code_template.format(type=type, safe_type_name=safe_type_name))
178+
create_warped_pointer_bindings = "\n".join(create_warped_pointer_bindings)
179+
return create_warped_pointer_bindings

pybind11_weaver/include/pybind11_weaver/pybind11_weaver.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,23 @@
77

88
namespace pybind11_weaver {
99

10+
template <class T> struct PointerWrapper {
11+
static_assert(std::is_pointer<T>::value, "T must be a pointer type");
12+
T ptr;
13+
PointerWrapper(T ptr) : ptr(ptr) {}
14+
operator T() { return ptr; }
15+
static void FastBind(pybind11::module &m, const std::string &name) {
16+
pybind11::class_<PointerWrapper> handle(m, name.c_str());
17+
handle.def("get_ptr", [](PointerWrapper &self) {
18+
return reinterpret_cast<intptr_t>(self.ptr);
19+
});
20+
handle.def("set_ptr", [](PointerWrapper &self, intptr_t ptr) {
21+
self.ptr = reinterpret_cast<T>(ptr);
22+
});
23+
}
24+
};
25+
26+
1027
class CallUpdateGuard {
1128
public:
1229
using Fn = std::function<void(void)>;

pybind11_weaver/utils/fn.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,99 @@
1+
from typing import List
2+
13
from clang import cindex
24

35
from . import scope_list
46

57

6-
def fn_arg_type(cursor: cindex.Cursor):
8+
def fn_arg_type(cursor: cindex.Cursor) -> List[str]:
79
return [param.type.get_canonical().spelling for param in cursor.get_arguments()]
810

911

10-
def fn_ret_type(cursor: cindex.Cursor):
12+
def fn_ret_type(cursor: cindex.Cursor) -> str:
1113
return cursor.result_type.get_canonical().spelling
1214

1315

14-
def get_fn_pointer_type(cursor: cindex.Cursor):
16+
def _get_fn_pointer_type(cursor: cindex.Cursor) -> str:
1517
if cursor.kind == cindex.CursorKind.CXX_METHOD and not cursor.is_static_method():
1618
const_mark = "const" if cursor.is_const_method() else ""
1719
return f"{fn_ret_type(cursor)} ({scope_list.get_full_qualified_name(cursor.semantic_parent)}::*)({','.join(fn_arg_type(cursor))}) {const_mark}"
1820
else:
1921
return f"{fn_ret_type(cursor)} (*)({','.join(fn_arg_type(cursor))})"
22+
23+
24+
def get_fn_pointer(cursor: cindex.Cursor) -> str:
25+
pointer = f"&{scope_list.get_full_qualified_name(cursor)}"
26+
method_pointer_type = _get_fn_pointer_type(cursor)
27+
return f"static_cast<{method_pointer_type}>({pointer})"
28+
29+
30+
__wrapped_db = set()
31+
32+
33+
def get_wrapped_types():
34+
return __wrapped_db
35+
36+
37+
def warp_type(type: cindex.Type, param_name: str):
38+
type = type.get_canonical()
39+
ret = None, param_name
40+
if type.kind == cindex.TypeKind.POINTER:
41+
pointee = type.get_pointee().get_canonical()
42+
# if pointee is a pointer, warp it
43+
if pointee.kind in [cindex.TypeKind.POINTER]:
44+
ret = f"pybind11_weaver::PointerWrapper<{type.spelling}>", param_name
45+
if pointee.kind in [cindex.TypeKind.FUNCTIONPROTO]:
46+
ret = f"std::function<{pointee.spelling}>", f"{param_name}.target<{pointee.spelling}>()"
47+
# if pointee is a incompelete type, warp it
48+
pointee_decl = pointee.get_declaration()
49+
if pointee_decl.kind in [cindex.CursorKind.STRUCT_DECL,
50+
cindex.CursorKind.CLASS_DECL] and not pointee_decl.is_definition():
51+
ret = f"pybind11_weaver::PointerWrapper<{type.spelling}>", param_name
52+
if ret[0] is not None:
53+
__wrapped_db.add(type.spelling)
54+
return ret
55+
56+
57+
__warpper_template = """[]({params}){{
58+
return {ret_expr};
59+
}}"""
60+
61+
62+
def get_fn_wrapper(cursor: cindex.Cursor):
63+
params = []
64+
if cursor.kind != cindex.CursorKind.FUNCTION_DECL:
65+
params.append(f"{cursor.semantic_parent.spelling}& self")
66+
args = []
67+
warp = False
68+
for param in cursor.get_arguments():
69+
param_t = param.type.get_canonical()
70+
param_spelling = param.spelling
71+
if param_spelling == "":
72+
param_spelling = "arg" + str(len(args))
73+
warp_t, arg_use = warp_type(param_t, param_spelling)
74+
if warp_t:
75+
warp = True
76+
params.append(f"{warp_t} {param_spelling}")
77+
else:
78+
params.append(f"{param_t.spelling} {param_spelling}")
79+
args.append(arg_use)
80+
ret_t = cursor.result_type.get_canonical()
81+
ret_expr = f"{cursor.spelling}({','.join(args)})"
82+
if cursor.kind != cindex.CursorKind.FUNCTION_DECL:
83+
ret_expr = f"self.{ret_expr}"
84+
warp_t, _ = warp_type(ret_t, "")
85+
if warp_t:
86+
warp = True
87+
ret_expr = f"{warp_t}({ret_expr})"
88+
if warp:
89+
return __warpper_template.format(params=','.join(params), ret_expr=ret_expr)
90+
else:
91+
return None
92+
93+
94+
def get_fn_value_expr(cursor: cindex.Cursor) -> str:
95+
wrapper = get_fn_wrapper(cursor)
96+
if wrapper:
97+
return wrapper
98+
else:
99+
return get_fn_pointer(cursor)

pybind11_weaver/utils/scope_list.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ def get_full_qualified_scopes(cursor: cindex.Cursor):
88
if cursor.kind == cindex.CursorKind.TRANSLATION_UNIT:
99
return []
1010
values = []
11+
# extern C seems to be a scope with kind of CursorKind.UNEXPOSED_DECL
1112
cursor = cursor.semantic_parent
12-
while cursor is not None and cursor.kind != cindex.CursorKind.TRANSLATION_UNIT:
13+
while cursor is not None and cursor.kind not in [cindex.CursorKind.TRANSLATION_UNIT,
14+
cindex.CursorKind.UNEXPOSED_DECL]:
1315
values.append(cursor.spelling)
1416
cursor = cursor.semantic_parent
1517
values.reverse()

0 commit comments

Comments
 (0)