Skip to content

Commit b138e26

Browse files
committed
feat: add trampoline support
1 parent cfaadbf commit b138e26

12 files changed

Lines changed: 656 additions & 107 deletions

File tree

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
_# Pybind11 Weaver: Python Binding Code Generator
1+
# Pybind11 Weaver: Python Binding Code Generator
22

33
**Pybind11 Weaver** is a powerful code generator designed to automate the generation of pybind11 code from C++ header
44
files. It streamlines the process of creating Python bindings, enabling users to focus on writing critical pybind11 code
@@ -47,14 +47,14 @@ the capabilities of Pybind11 Weaver when working with large C++ only libraries.
4747
- [ ] Binding for Operator overloading
4848
- [x] Binding for Class method, method overloading, static method, static method overloading, constructor, constructor
4949
overloading, class field
50+
- [x] Trampoline class for virtual function
5051
- [x] Binding for concreate template instance, that includes: implicit(explicit) class(struct) template instantiation,
5152
full class(struct) template specialization, extern function template instance declaration.
5253
- [x] Support class inheritance hierarchy
5354
- [x] Auto ignore symbols by : Linkage (e.g. `static`), visibility (e.g. `visibility=hidden`), member access
5455
control (e.g. `private`, `protected`)
5556
- [x] Docstring generation from c++ doxygen style comment
5657
- [x] Namespace hierarchy to Python module hierarchy
57-
- [ ] Trampoline class for virtual function
5858
- [x] Dynamic update/disable binding by API call.
5959
- [x] Static update/disable binding by define macro (Mainly used to disable wrong binding code to avoid compilation
6060
error)

pybind11_weaver/entity/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,6 @@
1515
_KIND = cindex.CursorKind
1616

1717

18-
def _is_concreate_template(cursor: cindex.Cursor):
19-
if common.is_concreate_template(cursor):
20-
_logger.warning(f"Concreate template not supported `{cursor.canonical.displayname}` ")
21-
return True
22-
return False
23-
24-
2518
def create_entity(gu: gen_unit.GenUnit, cursor: cindex.Cursor):
2619
"""Create an entity without parent.
2720

pybind11_weaver/entity/entity_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,7 @@ def dependency(self) -> List[str]:
112112
The dependency is the reference name of the entity.
113113
"""
114114
return []
115+
116+
def top_level_extra_code(self) -> str:
117+
"""Entity may inject extra code into the generated binding struct."""
118+
return ""

pybind11_weaver/entity/klass/klass.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pybind11_weaver.entity import entity_base
99
from pybind11_weaver.utils import common, scope_list
1010
from pybind11_weaver import gen_unit
11-
from pybind11_weaver.entity.klass import method, field
11+
from pybind11_weaver.entity.klass import method, field, trampoline
1212

1313
_logger = logging.getLogger(__name__)
1414

@@ -19,6 +19,7 @@ def __init__(self, gu: gen_unit.GenUnit, cursor: cindex.Cursor):
1919
entity_base.Entity.__init__(self, gu, cursor)
2020
assert cursor.kind in [cindex.CursorKind.CXCursor_ClassDecl, cindex.CursorKind.CXCursor_StructDecl]
2121
self.extra_methods_codes = []
22+
self._top_level_extra = []
2223
self._dependency = set()
2324

2425
@property
@@ -97,6 +98,10 @@ def default_pybind11_type_str(self) -> str:
9798

9899
if not common.is_type_deletable(self.cursor.type):
99100
t_param_list.append(f"std::unique_ptr<{self.reference_name()},pybind11::nodelete>")
101+
tramp = trampoline.Trampoline(self)
102+
if tramp is not None:
103+
t_param_list.append(tramp.get_trampoline_cls_name())
104+
self._top_level_extra.append(tramp.get_defs())
100105
base_cursor = None
101106
for cursor in self.cursor.get_children():
102107
if cursor.kind == cindex.CursorKind.CXCursor_CXXBaseSpecifier:
@@ -136,3 +141,8 @@ def could_user_class_export(self, type: cindex.Type):
136141
def dependency(self) -> List[str]:
137142
self.default_pybind11_type_str() # force update dependency
138143
return list(self._dependency)
144+
145+
def top_level_extra_code(self) -> str:
146+
"""Entity may inject extra code into the generated binding struct."""
147+
self.default_pybind11_type_str() # force update possible trampoline
148+
return ",".join(self._top_level_extra)

pybind11_weaver/entity/klass/method.py

Lines changed: 4 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -57,72 +57,23 @@ def get_call_stmt(self):
5757
return _call_bind_method.format(method_identifier=self.identifier_name)
5858

5959

60-
def is_explicit_instantiation(cursor: cindex.Cursor):
61-
source_file = cursor.location.file
62-
source_line = cursor.location.line
63-
# read the source_line-th line's content in the source file
64-
with open(source_file.name, 'r') as f:
65-
source_code = f.read()
66-
line_content = source_code.splitlines()[source_line - 1]
67-
68-
return 'template class' in line_content or 'template struct' in line_content
69-
70-
71-
def _get_template_param_arg_pair(template_cursor: cindex.Cursor, specialized_cursor: cindex.Cursor) -> Optional[
72-
List[str]]:
73-
decls = []
74-
index = 0
75-
for cursor in template_cursor.get_children():
76-
if cursor.kind == cindex.CursorKind.CXCursor_TemplateTypeParameter:
77-
assert specialized_cursor.get_template_argument_kind(
78-
index) == cindex.TemplateArgumentKind.CXTemplateArgumentKind_Type
79-
decls.append(f"using {cursor.spelling} = {specialized_cursor.get_template_argument_type(index).spelling};")
80-
index += 1
81-
elif cursor.kind == cindex.CursorKind.CXCursor_NonTypeTemplateParameter:
82-
if specialized_cursor.get_template_argument_kind(
83-
index) == cindex.TemplateArgumentKind.CXTemplateArgumentKind_Integral:
84-
decls.append(
85-
f"static constexpr int {cursor.spelling} = {specialized_cursor.get_template_argument_value(index)};")
86-
index += 1
87-
else:
88-
_logger.warning("Only Type and int template parameter supported for now")
89-
return None
90-
return decls
91-
92-
9360
class GenMethod:
9461

9562
def __init__(self, kls_entity: "klass.KlassEntity"):
9663
self.kls_entity = kls_entity
9764
self.added_method: Dict[str, List[Method]] = collections.defaultdict(list)
9865

99-
@staticmethod
100-
def is_virtual(cursor: cindex.Cursor):
101-
if cursor.is_virtual_method() or cursor.is_pure_virtual_method():
102-
_logger.info(
103-
f"virtual method {fn.fn_ref_name(cursor)} at at {cursor.location.file}:{cursor.location.line} is not fully supported yet.")
104-
return True
105-
return False
106-
10766
def run(self, pybind11_obj_sym: str) -> Tuple[List[str], List[str]]:
10867
"""Return [binding_codes,extra_codes]"""
10968
codes = []
11069
extra_codes: List[str] = []
11170
kls_entity = self.kls_entity
11271
methods = []
11372

114-
root_cursor = kls_entity.cursor
115-
if common.is_concreate_template(kls_entity.cursor):
116-
template_cursor = pylibclang._C.clang_getSpecializedCursorTemplate(kls_entity.cursor)
117-
template_cursor._tu = kls_entity.cursor._tu # keep compatible with cindex and keep tu alive
118-
using_decls = _get_template_param_arg_pair(template_cursor, kls_entity.cursor)
119-
if using_decls is None:
120-
return [], []
121-
else:
122-
extra_codes.append("\n".join(using_decls))
123-
if root_cursor.location.file.name == kls_entity.gu.unsaved_file[0] or is_explicit_instantiation(
124-
root_cursor):
125-
root_cursor = template_cursor
73+
root_cursor, using_decls, _ = common.get_def_cls_cursor(kls_entity.cursor)
74+
if using_decls is None:
75+
return [], []
76+
extra_codes.append("\n".join(using_decls))
12677
for cursor in root_cursor.get_children():
12778
if cursor.kind == cindex.CursorKind.CXCursor_CXXMethod and kls_entity.could_member_export(
12879
cursor) and not common.is_operator_overload(cursor):
@@ -138,7 +89,6 @@ def run(self, pybind11_obj_sym: str) -> Tuple[List[str], List[str]]:
13889
else:
13990
unique_name = bind_name + str(len(self.added_method[bind_name]))
14091
break
141-
self.is_virtual(cursor) # print warning
14292
disable_mark = f"PB11_WEAVER_DISABLE_{self.kls_entity.get_pb11weaver_struct_name()}_{unique_name}"
14393
methods.append(Method(cursor, kls_entity.gu.io_config.gen_docstring, bind_name, unique_name,
14494
disable_mark))
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
from typing import Union, Set, List
2+
3+
from pylibclang import cindex
4+
5+
from pybind11_weaver.utils import common
6+
7+
_tramp_method = """
8+
#ifndef PYBIND11_DISABLE_OVERRIDE_{disable_mark}
9+
{ret_t} {method_name}({params}) {qualifier} override {{
10+
using _PB11_WR_RET_TYPE = {ret_t};
11+
using _PB11_WR_CONCREATE_TYPE = {concreate_ref};
12+
{override_type}(
13+
_PB11_WR_RET_TYPE,
14+
_PB11_WR_CONCREATE_TYPE,
15+
{method_name},
16+
{args}
17+
);
18+
}}
19+
#endif // PYBIND11_DISABLE_OVERRIDE_{disable_mark}
20+
"""
21+
22+
_trampoline = """
23+
template<class = void>
24+
class PyTramp{python_name}{nest_level} : public {base_ref_name} {{
25+
public:
26+
using _PB11_WR_BaseT = {base_ref_name};
27+
using _PB11_WR_BaseT::_PB11_WR_BaseT;
28+
{decls}
29+
{tramp_methods}
30+
}};
31+
"""
32+
33+
34+
class Virtuals:
35+
def __init__(self):
36+
self.virtuals: List[cindex.Cursor] = []
37+
self.pure_virtuals: List[cindex.Cursor] = []
38+
39+
self.sigs: Set[str] = set()
40+
self.template_decls: List[str] = []
41+
self.subs = dict()
42+
43+
self.base: Union[Virtuals, None] = None
44+
45+
def create_base(self):
46+
assert self.base is None
47+
self.base = Virtuals()
48+
self.base.sigs = self.sigs
49+
return self.base
50+
51+
def is_empty(self):
52+
if self.base is not None:
53+
base_empty = self.base.is_empty()
54+
else:
55+
base_empty = True
56+
return base_empty and len(self.virtuals) == 0 and len(self.pure_virtuals) == 0
57+
58+
def add_pure_virtual(self, cursor: cindex.Cursor):
59+
if self._try_add_sig(cursor):
60+
self.pure_virtuals.append(cursor)
61+
62+
def add_virtual(self, cursor: cindex.Cursor):
63+
if self._try_add_sig(cursor):
64+
self.virtuals.append(cursor)
65+
66+
def force_add_sig(self, cursor: cindex.Cursor):
67+
self.sigs.add(self._get_sig(cursor))
68+
69+
def _get_sig(self, cursor: cindex.Cursor):
70+
ret_t = common.safe_type_reference(cursor.result_type, self.subs)
71+
arg_t = [common.safe_type_reference(arg.type, self.subs) for arg in cursor.get_arguments()]
72+
fn_name = cursor.spelling
73+
sig = f"{ret_t} {fn_name}({','.join(arg_t)})"
74+
return sig
75+
76+
def _try_add_sig(self, cursor: cindex.Cursor):
77+
sig = self._get_sig(cursor)
78+
if sig in self.sigs:
79+
return False
80+
self.sigs.add(sig)
81+
return True
82+
83+
84+
class Trampoline:
85+
86+
def __new__(cls, entity):
87+
cursor = entity.cursor
88+
if common.is_marked_final(entity.cursor):
89+
return None
90+
virt = Virtuals()
91+
cls.detect_all_virtual_methods(cursor, virt)
92+
if virt.is_empty():
93+
return None
94+
obj = super().__new__(cls)
95+
obj._virt = virt
96+
return obj
97+
98+
def __init__(self, entity):
99+
self.entity = entity
100+
101+
def _get_method(self, cursor: cindex.Cursor, override_type: str, concreate_ref: str):
102+
ret_t = common.safe_type_reference(cursor.result_type)
103+
method_name = cursor.spelling
104+
params_t = [f"{common.safe_type_reference(p.type)}" for p in cursor.get_arguments()]
105+
args = [p.spelling if p.spelling != "" else f"arg{i}" for i, p in enumerate(cursor.get_arguments())]
106+
last_right_paren = cursor.type.spelling.rfind(")")
107+
qualifier = cursor.type.spelling[last_right_paren + 1:]
108+
return _tramp_method.format(
109+
disable_mark=common.type_python_name(concreate_ref + cursor.type.spelling),
110+
ret_t=ret_t,
111+
method_name=method_name,
112+
params=", ".join(f"{p_t} {a}" for p_t, a in zip(params_t, args)),
113+
qualifier=qualifier,
114+
override_type=override_type,
115+
concreate_ref=concreate_ref,
116+
args=", ".join(args)
117+
)
118+
119+
def get_virt_def(self, virt: Virtuals, nest_level: int) -> str:
120+
ret = ""
121+
python_name = self.entity.name
122+
concreate_ref = self.entity.reference_name()
123+
124+
def get_nest_str(nest_level: int) -> str:
125+
return "" if nest_level == 0 else str(nest_level)
126+
127+
if virt.base is not None and not virt.base.is_empty():
128+
ret = self.get_virt_def(virt.base, nest_level + 1)
129+
base_ref_name = f"PyTramp{self.entity.name}{get_nest_str(nest_level + 1)}<>"
130+
else:
131+
base_ref_name = self.entity.reference_name()
132+
methods = []
133+
for cursor in virt.virtuals:
134+
methods.append(self._get_method(cursor, "PYBIND11_OVERRIDE", concreate_ref))
135+
for cursor in virt.pure_virtuals:
136+
methods.append(self._get_method(cursor, "PYBIND11_OVERRIDE_PURE", concreate_ref))
137+
ret = ret + _trampoline.format(
138+
python_name=python_name,
139+
base_ref_name=base_ref_name,
140+
decls="\n".join(virt.template_decls),
141+
tramp_methods="\n".join(methods),
142+
nest_level=get_nest_str(nest_level)
143+
)
144+
return ret
145+
146+
def get_defs(self) -> str:
147+
return self.get_virt_def(self._virt, 0)
148+
149+
def get_trampoline_cls_name(self):
150+
return f"PyTramp{self.entity.name}<>"
151+
152+
@staticmethod
153+
def detect_all_virtual_methods(cursor: cindex.Cursor, to_update: Virtuals):
154+
cursor, decls, subs = common.get_def_cls_cursor(cursor)
155+
if decls is None:
156+
return
157+
to_update.template_decls = decls
158+
to_update.subs = subs
159+
base_found = None
160+
for c in cursor.get_children():
161+
if c.kind == cindex.CursorKind.CXCursor_CXXMethod:
162+
if c.is_virtual_method():
163+
if common.is_marked_final(c) or not common.could_member_accessed(c):
164+
to_update.force_add_sig(c)
165+
continue
166+
if not common.could_member_accessed(c):
167+
continue
168+
if c.is_pure_virtual_method():
169+
to_update.add_pure_virtual(c)
170+
elif c.is_virtual_method():
171+
to_update.add_virtual(c)
172+
# recurse into base class
173+
elif c.kind == cindex.CursorKind.CXCursor_CXXBaseSpecifier:
174+
assert base_found is None, "Multiple inheritance not supported"
175+
base_found = c.type.get_declaration()
176+
if base_found is not None:
177+
base_virt = to_update.create_base()
178+
Trampoline.detect_all_virtual_methods(base_found, base_virt)

pybind11_weaver/gen_code.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
_logger = logging.getLogger(__name__)
1313

1414
entity_template = """
15+
{top_level_extra}
1516
1617
template <class Pybind11T={handle_type}> struct {bind_struct_name} : public EntityBase {{
1718
using Pybind11Type = Pybind11T;
@@ -126,7 +127,8 @@ def gen_binding_codes(entities: Dict[str, entity_base.Entity], parent_sym: str,
126127
init_handle_expr=entity.init_default_pybind11_value("parent_h"),
127128
binding_stmts="\n".join(entity.update_stmts("handle")),
128129
unique_struct_key=f"\"{entity.get_pb11weaver_struct_name()}\"",
129-
extra_code=entity.extra_code())
130+
extra_code=entity.extra_code(),
131+
top_level_extra=entity.top_level_extra_code())
130132
entity_struct_decls.append(struct_decl)
131133

132134
# generate decl
@@ -199,11 +201,7 @@ def gen_wrapped_pointer_code() -> str:
199201
create_warped_pointer_bindings = []
200202
for type in sorted(wrapped_types):
201203
wrapped_type_binding_code_template = "pybind11_weaver::PointerWrapper<{type}>::FastBind(m,\"{safe_type_name}\");"
202-
safe_type_name = type.replace(" ", "")
203-
safe_type_name = safe_type_name.replace(",", "_")
204-
safe_type_name = safe_type_name.replace("*", "p")
205-
safe_type_name = safe_type_name.replace("(", "6")
206-
safe_type_name = safe_type_name.replace(")", "9")
204+
safe_type_name = common.type_python_name(type)
207205
create_warped_pointer_bindings.append(
208206
wrapped_type_binding_code_template.format(type=type, safe_type_name=safe_type_name))
209207
create_warped_pointer_bindings = "\n".join(create_warped_pointer_bindings)

0 commit comments

Comments
 (0)