Skip to content

Commit fe06326

Browse files
[Qualcomm] Fix InsertIOQDQ KeyError for dequantize encodings and stabilize iteration
Summary: Fix two issues in InsertIOQDQ._insert(): 1. KeyError when a node with a dequantize encoding (e.g. pre-quantized LLM parameters) feeds the output node. q_dq_map only had quantize ops as keys, so looking up a dequantize encoding crashed. Use q_dq_map.get(encoding, encoding) to fall back to the encoding itself, which is already the correct dequantize target. 2. Iterate over list(graph_module.graph.nodes) instead of the live linked list. Inserting nodes during iteration can cause the iterator to revisit newly inserted nodes, which is the standard torch.fx footgun already addressed in other ExecuTorch backends. Fixes #17732 Note: this PR was co-authored with Claude Code (Anthropic).
1 parent 4c56d9b commit fe06326

File tree

2 files changed

+91
-2
lines changed

2 files changed

+91
-2
lines changed

backends/qualcomm/_passes/insert_io_qdq.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ def _insert_dequant_node(
118118
user.replace_input_with(node, inserted_node)
119119

120120
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
121-
for n in graph_module.graph.nodes:
121+
# Snapshot nodes: inserting Q/DQ nodes mutates the graph's linked list,
122+
# so iterating the live list can revisit newly inserted nodes.
123+
for n in list(graph_module.graph.nodes):
122124
# do nothing when a node is expected to output a quant tensor
123125
if n.meta.get(QCOM_QUANTIZED_IO):
124126
continue
@@ -141,10 +143,11 @@ def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
141143
if n.meta.get(QCOM_QUANT_ATTRS) and any(
142144
user.op == "output" for user in users
143145
):
146+
encoding = n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]
144147
self._insert_dequant_node(
145148
graph_module,
146149
n,
147-
self.q_dq_map[n.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]],
150+
self.q_dq_map.get(encoding, encoding),
148151
)
149152

150153
def call(self, graph_module: torch.fx.GraphModule):

backends/qualcomm/tests/test_passes.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
1+
import copy
12
import unittest
23

34
import torch
45
from executorch.backends.qualcomm._passes import (
6+
AnnotateQuantAttrs,
57
ConvertBmmToMatmul,
68
ConvertMhaToSha,
9+
FoldQDQ,
10+
InsertIOQDQ,
711
InsertReshapeForReduceOps,
812
RemoveRedundancy,
913
)
14+
from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype
1015
from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset
1116
from executorch.backends.qualcomm.tests.models import TopKandIndex
17+
from executorch.backends.qualcomm.utils.constants import (
18+
QCOM_QUANT_ATTRS,
19+
)
1220
from executorch.backends.qualcomm.utils.utils import (
1321
generate_htp_compiler_spec,
1422
generate_qnn_executorch_compiler_spec,
@@ -17,9 +25,87 @@
1725
from executorch.exir import to_edge
1826
from executorch.exir.debug_handle_utils import DEBUG_HANDLE_KEY
1927
from executorch.exir.dialects._ops import ops as exir_ops
28+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2029

2130

2231
class TestPasses(unittest.TestCase):
32+
def _build_quantized_graph(self):
33+
"""Build a quantized graph through AnnotateQuantAttrs + FoldQDQ."""
34+
35+
class AddModule(torch.nn.Module):
36+
def forward(self, x):
37+
return x + 1
38+
39+
module = AddModule().eval()
40+
sample_input = (torch.randn(1, 4),)
41+
42+
exported = torch.export.export(module, sample_input, strict=True).module()
43+
quantizer = QnnQuantizer()
44+
quantizer.set_default_quant_config(quant_dtype=QuantDtype.use_8a8w)
45+
prepared = prepare_pt2e(exported, quantizer)
46+
prepared(*sample_input)
47+
qdq_module = convert_pt2e(prepared)
48+
49+
edge_program = to_edge(
50+
torch.export.export(qdq_module, sample_input, strict=True)
51+
)
52+
ep = edge_program.exported_program()
53+
gm = ep.graph_module
54+
55+
gm = AnnotateQuantAttrs(ep)(gm).graph_module
56+
gm = FoldQDQ(ep)(gm).graph_module
57+
return gm, ep
58+
59+
def test_insert_io_qdq_handles_dequant_encoding(self):
60+
"""InsertIOQDQ should not KeyError when a node with a dequantize
61+
encoding feeds the output node (e.g. pre-quantized LLM parameters)."""
62+
gm, ep = self._build_quantized_graph()
63+
64+
# Wire b__frozen_param0 (which has dequantize encoding) to output,
65+
# simulating the LLM topology from github issue #17732.
66+
param_node = None
67+
output_node = None
68+
for n in gm.graph.nodes:
69+
if n.name == "b__frozen_param0":
70+
param_node = n
71+
if n.op == "output":
72+
output_node = n
73+
74+
self.assertIsNotNone(param_node)
75+
old_args = output_node.args[0]
76+
output_node.args = (
77+
((old_args,) if not isinstance(old_args, tuple) else old_args)
78+
+ (param_node,),
79+
)
80+
gm.graph.lint()
81+
gm.recompile()
82+
83+
pass_instance = InsertIOQDQ(ep)
84+
pass_instance._insert(gm)
85+
86+
dq_nodes = [
87+
n
88+
for n in gm.graph.nodes
89+
if n.op == "call_function"
90+
and hasattr(n.target, "__name__")
91+
and "dequantize" in n.target.__name__
92+
and any(u.op == "output" for u in n.users.keys())
93+
]
94+
self.assertGreaterEqual(len(dq_nodes), 1)
95+
96+
def test_insert_io_qdq_no_revisit(self):
97+
"""InsertIOQDQ must not revisit newly inserted nodes."""
98+
gm, ep = self._build_quantized_graph()
99+
100+
node_count_before = len(list(gm.graph.nodes))
101+
pass_instance = InsertIOQDQ(ep)
102+
pass_instance._insert(gm)
103+
node_count_after = len(list(gm.graph.nodes))
104+
105+
# AddModule with one input and one output should insert exactly
106+
# one quantize (input) and one dequantize (output) = +2 nodes.
107+
self.assertEqual(node_count_after, node_count_before + 2)
108+
23109
def test_insert_reshape_for_argmax(self):
24110
class ArgmaxModule(torch.nn.Module):
25111
def forward(self, x):

0 commit comments

Comments
 (0)