Skip to content

Commit eb69bd1

Browse files
committed
Made input name consistent across all surgeries and files
Signed-off-by: Hrishith Thadicherla <hthadicherla@nvidia.com>
1 parent babe66c commit eb69bd1

4 files changed

Lines changed: 23 additions & 27 deletions

File tree

modelopt/onnx/graph_surgery/__init__.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@
4747
... )
4848
>>> # Add cross-attention KV cache outputs to encoder (GenAI compatible)
4949
>>> add_cross_kv_to_encoder(
50-
... encoder_path="encoder_model.onnx",
50+
... model_path="encoder_model.onnx",
5151
... output_path="encoder_with_kv.onnx",
5252
... hf_model_id="openai/whisper-large-v3-turbo",
5353
... )
5454
>>> # Standalone FP16 to BF16 conversion
5555
>>> convert_fp16_to_bf16(
56-
... input_path="model_fp16.onnx",
56+
... model_path="model_fp16.onnx",
5757
... output_path="model_bf16.onnx",
5858
... )
5959
>>>
@@ -69,15 +69,11 @@
6969
from .gqa_replacement import replace_attention_with_gqa
7070
from .utils.dtype_conversion import convert_fp16_to_bf16
7171

72-
# Registry of available graph surgeries.
73-
# Maps surgery name -> (function, input_path_param_name)
74-
# input_path_param_name is the keyword argument name for the input model path
75-
# (different surgeries use different names: model_path, encoder_path, input_path)
7672
_SURGERY_REGISTRY = {
77-
"replace-gqa": (replace_attention_with_gqa, "model_path"),
78-
"add-cross-kv": (add_cross_kv_to_encoder, "encoder_path"),
79-
"convert-bf16": (convert_fp16_to_bf16, "input_path"),
80-
"transpose-dq": (transpose_dequantize_linear_weights, "model_path"),
73+
"replace-gqa": replace_attention_with_gqa,
74+
"add-cross-kv": add_cross_kv_to_encoder,
75+
"convert-bf16": convert_fp16_to_bf16,
76+
"transpose-dq": transpose_dequantize_linear_weights,
8177
}
8278

8379

@@ -88,7 +84,7 @@ def get_available_surgeries() -> list[str]:
8884

8985
def run_graph_surgery(
9086
surgery_name: str,
91-
input_path: str,
87+
model_path: str,
9288
output_path: str,
9389
**kwargs,
9490
):
@@ -103,7 +99,7 @@ def run_graph_surgery(
10399
Args:
104100
surgery_name: Name of the surgery to run (e.g. 'replace-gqa', 'transpose-dq').
105101
Use get_available_surgeries() to see all available options.
106-
input_path: Path to the input ONNX model.
102+
model_path: Path to the input ONNX model.
107103
output_path: Path to save the output ONNX model.
108104
**kwargs: Surgery-specific parameters. Passed directly to the surgery function.
109105
@@ -119,7 +115,7 @@ def run_graph_surgery(
119115
['replace-gqa', 'add-cross-kv', 'convert-bf16', 'transpose-dq']
120116
>>> run_graph_surgery(
121117
... "replace-gqa",
122-
... input_path="model.onnx",
118+
... model_path="model.onnx",
123119
... output_path="model_gqa.onnx",
124120
... hf_model_id="meta-llama/Llama-2-7b-hf",
125121
... )
@@ -128,8 +124,8 @@ def run_graph_surgery(
128124
available = ", ".join(f"'{s}'" for s in _SURGERY_REGISTRY)
129125
raise ValueError(f"Unknown surgery: '{surgery_name}'. Available surgeries: {available}")
130126

131-
func, input_param_name = _SURGERY_REGISTRY[surgery_name]
132-
return func(**{input_param_name: input_path, "output_path": output_path}, **kwargs)
127+
func = _SURGERY_REGISTRY[surgery_name]
128+
return func(model_path=model_path, output_path=output_path, **kwargs)
133129

134130

135131
__all__ = [

modelopt/onnx/graph_surgery/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def main():
271271
from .encoder_cross_kv import add_cross_kv_to_encoder
272272

273273
add_cross_kv_to_encoder(
274-
encoder_path=args.input,
274+
model_path=args.input,
275275
output_path=args.output,
276276
hf_model_id=args.model_id,
277277
hidden_state_output_name=args.hidden_state_name,
@@ -288,7 +288,7 @@ def main():
288288
from .utils.dtype_conversion import convert_fp16_to_bf16
289289

290290
convert_fp16_to_bf16(
291-
input_path=args.input,
291+
model_path=args.input,
292292
output_path=args.output,
293293
external_data=not args.no_external_data,
294294
verbose=not args.quiet,

modelopt/onnx/graph_surgery/encoder_cross_kv.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _add_cross_kv_outputs(
320320

321321

322322
def add_cross_kv_to_encoder(
323-
encoder_path: str,
323+
model_path: str,
324324
output_path: str,
325325
hf_model_id: str,
326326
hidden_state_output_name: str = "last_hidden_state",
@@ -349,7 +349,7 @@ def add_cross_kv_to_encoder(
349349
6. Generates genai_config.json and audio_processor_config.json (optional)
350350
351351
Args:
352-
encoder_path: Path to encoder ONNX model.
352+
model_path: Path to encoder ONNX model.
353353
output_path: Path to save modified encoder.
354354
hf_model_id: HuggingFace model ID for loading cross-attention weights.
355355
hidden_state_output_name: Name of encoder hidden state output.
@@ -369,7 +369,7 @@ def add_cross_kv_to_encoder(
369369
Example:
370370
>>> from modelopt.onnx.graph_surgery import add_cross_kv_to_encoder
371371
>>> model = add_cross_kv_to_encoder(
372-
... encoder_path="encoder_model.onnx",
372+
... model_path="encoder_model.onnx",
373373
... output_path="encoder_model_with_kv.onnx",
374374
... hf_model_id="openai/whisper-large-v3-turbo",
375375
... )
@@ -380,9 +380,9 @@ def add_cross_kv_to_encoder(
380380
)
381381

382382
if verbose:
383-
logger.info(f"Loading encoder model from: {encoder_path}")
383+
logger.info(f"Loading encoder model from: {model_path}")
384384

385-
encoder_model = onnx.load(encoder_path, load_external_data=True)
385+
encoder_model = onnx.load(model_path, load_external_data=True)
386386

387387
# Detect model dtype
388388
onnx_dtype, np_dtype = detect_model_dtype(encoder_model)

modelopt/onnx/graph_surgery/utils/dtype_conversion.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _convert_constant_node_to_bf16(node: onnx.NodeProto) -> bool:
133133

134134

135135
def convert_fp16_to_bf16(
136-
input_path: str,
136+
model_path: str,
137137
output_path: str,
138138
external_data: bool = True,
139139
verbose: bool = True,
@@ -147,7 +147,7 @@ def convert_fp16_to_bf16(
147147
4. All Cast nodes that target FP16 to target BF16
148148
149149
Args:
150-
input_path: Path to input FP16 ONNX model.
150+
model_path: Path to input FP16 ONNX model.
151151
output_path: Path to output BF16 ONNX model.
152152
external_data: Whether to save weights as external data.
153153
verbose: Whether to print progress messages.
@@ -157,16 +157,16 @@ def convert_fp16_to_bf16(
157157
158158
Example:
159159
>>> stats = convert_fp16_to_bf16(
160-
... input_path="model_fp16.onnx",
160+
... model_path="model_fp16.onnx",
161161
... output_path="model_bf16.onnx",
162162
... )
163163
>>> logger.info(f"Converted {stats['initializers_converted']} initializers")
164164
"""
165165
if verbose:
166-
logger.info(f"Loading model from: {input_path}")
166+
logger.info(f"Loading model from: {model_path}")
167167

168168
# Load model with external data
169-
model = onnx.load(input_path, load_external_data=True)
169+
model = onnx.load(model_path, load_external_data=True)
170170
graph = model.graph
171171

172172
# Statistics

0 commit comments

Comments
 (0)