Skip to content

Commit 91b6eec

Browse files
hunhoffeclaude
andcommitted
Integrate Phoenix (NPU1) support with correct cross-compilation
## What changed **Phoenix (NPU1/aie2) support:** - `ChanneledUnaryOperator` now auto-detects the target kernel directory (`aie2` vs `aie2p`) at compile time via `get_kernel_dir()`, replacing the hardcoded `kernel_subdir` ClassVar - Adds `needs_lut_ops: ClassVar[bool] = False` to `ChanneledUnaryOperator`; subclasses that require `lut_based_ops.o` on aie2 (gelu, sigmoid, tanh, silu, softmax) set it to `True` - `PeanoCompilationRule` derives the compiler target triple and runtime lib dir from the device at compile time instead of a hardcoded `DEVICE_CONFIGS` dict **Cross-compilation:** - All compilation-path `DefaultNPURuntime.device()` calls replaced with `aie_utils.get_current_device()`, which respects `set_current_device()` - Cross-compile for a specific target with `aie_utils.set_current_device(NPU1())` before compilation; execution paths (`DefaultNPURuntime.load/run`) are unchanged - Artifact names now include the device architecture (e.g. `GELU_..._npu2.xclbin`, `gelu_npu2.o`) so NPU1 and NPU2 artifacts coexist safely in the same build directory without cache collisions - Device detection uses `isinstance(dev, NPU2)` rather than string matching **Cleanup:** - Deleted `DEVICE_CONFIGS`, `get_device_name`, `get_device_type` from `device_utils.py`; replaced with a single `get_kernel_dir()` helper - Removed `AIEDeviceManager` imports from all test files (class was deleted in this branch) - Removed dead `get_device_name`/`get_device_type` imports from `gemm/design.py` and `mha/design.py` - Fixed `kernel_archive` NameError in gelu/sigmoid/tanh design files - Fixed bare `kernel_dir` NameError in `mha/op.py` - Fixed hardcoded `"aie2p"` in `gemm/op.py` and `mha/op.py` (mm.cc) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 43fced8 commit 91b6eec

46 files changed

Lines changed: 144 additions & 211 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

conftest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from iron.common import AIEContext
1414
import aie.utils as aie_utils
15-
from iron.common.aie_device_manager import AIEDeviceManager
1615

1716

1817
@pytest.fixture
@@ -158,7 +157,7 @@ def pytest_configure(config):
158157

159158

160159
def pytest_collection_modifyitems(config, items):
161-
device = AIEDeviceManager().device_str()
160+
device = aie_utils.DefaultNPURuntime.device().resolve().name
162161
for item in items:
163162
marker = item.get_closest_marker("supported_devices")
164163
if marker and device not in marker.args:

iron/common/base.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,20 @@ def name(self) -> str:
128128
for f in dataclasses.fields(self)
129129
if f.repr and getattr(self, f.name) is not None
130130
)
131-
return type(self).__name__ + "_" + "_".join(parts)
132-
params = self._params
133-
if params is not None:
131+
base = type(self).__name__ + "_" + "_".join(parts)
132+
elif self._params is not None:
134133
parts = (
135-
f"{k}{_serialize_param(v)}" for k, v in params.items() if v is not None
134+
f"{k}{_serialize_param(v)}"
135+
for k, v in self._params.items()
136+
if v is not None
136137
)
137-
return type(self).__name__ + "_" + "_".join(parts)
138-
raise NotImplementedError(
139-
f"{type(self).__name__} must be a @dataclass or define a _params property"
140-
)
138+
base = type(self).__name__ + "_" + "_".join(parts)
139+
else:
140+
raise NotImplementedError(
141+
f"{type(self).__name__} must be a @dataclass or define a _params property"
142+
)
143+
dev = aie_utils.get_current_device()
144+
return f"{base}_{dev.resolve().name}"
141145

142146
@abstractmethod
143147
def get_mlir_artifact(self) -> CompilationArtifact:
@@ -151,8 +155,13 @@ def get_artifacts(
151155
self, prefix: str = "", dynamic_obj_fifos: bool = False
152156
) -> tuple[XclbinArtifact, InstsBinArtifact]:
153157
operator_name = prefix + self.name
158+
arch = self.name.rsplit("_", 1)[-1]
154159
mlir_artifact = self.get_mlir_artifact()
155160
kernel_deps = self.get_kernel_artifacts()
161+
for dep in kernel_deps:
162+
if isinstance(dep, KernelObjectArtifact):
163+
p = Path(dep.filename)
164+
dep.filename = str(p.with_stem(f"{p.stem}_{arch}"))
156165
extra_flags = ["--dynamic-objFifos"] if dynamic_obj_fifos else []
157166
xclbin_artifact = XclbinArtifact(
158167
f"{operator_name}.xclbin",

iron/common/compilation/base.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
from typing import Any, Callable
4848
import sys
4949

50+
import aie.utils as aie_utils
51+
from aie.iron.device import NPU2
52+
5053
# Global Functions
5154
# ##########################################################################
5255

@@ -574,10 +577,9 @@ def compile(self, graph):
574577

575578

576579
class PeanoCompilationRule(CompilationRule):
577-
def __init__(self, peano_dir, mlir_aie_dir, device_type=None, *args, **kwargs):
580+
def __init__(self, peano_dir, mlir_aie_dir, *args, **kwargs):
578581
self.peano_dir = peano_dir
579582
self.mlir_aie_dir = mlir_aie_dir
580-
self.device_type = device_type
581583
super().__init__(*args, **kwargs)
582584

583585
def matches(self, artifacts):
@@ -589,15 +591,13 @@ def compile(self, artifacts):
589591
worklist = artifacts.get_worklist(KernelObjectArtifact)
590592
commands = []
591593

592-
if self.device_type not in DEVICE_CONFIGS:
593-
raise ValueError(
594-
f"Unsupported device type: {self.device_type!r} "
595-
f"(supported: {', '.join(DEVICE_CONFIGS)})"
596-
)
597-
config = DEVICE_CONFIGS[self.device_type]
598-
target = config["target"]
594+
dev = aie_utils.get_current_device()
595+
is_npu2 = isinstance(dev, NPU2)
596+
target = "aie2p-none-unknown-elf" if is_npu2 else "aie2-none-unknown-elf"
599597
runtime_lib_include_path = (
600-
Path(self.mlir_aie_dir) / "aie_runtime_lib" / config["runtime_lib_dir"]
598+
Path(self.mlir_aie_dir)
599+
/ "aie_runtime_lib"
600+
/ ("AIE2P" if is_npu2 else "AIE2")
601601
)
602602

603603
for artifact in worklist:

iron/common/device_utils.py

Lines changed: 7 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,12 @@
11
# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
"""Device utility functions for handling NPU device types and configurations."""
4+
import aie.utils as aie_utils
5+
from aie.iron.device import NPU2
56

6-
from aie.iron.device import NPU1, NPU1Col1, NPU1Col2, NPU2
77

8-
9-
def get_device_name(dev):
10-
"""Get device name string for looking up microkernel dimensions.
11-
12-
Returns "npu1" for Phoenix/NPU1 devices, "npu2" for Strix/NPU2 devices.
13-
14-
Args:
15-
dev: Either a string ("npu", "npu1", "npu2"), a device object with .resolve(),
16-
or a device type object.
17-
18-
Returns:
19-
str: "npu1" or "npu2"
20-
"""
21-
if isinstance(dev, str):
22-
if dev in ("npu", "npu1"):
23-
return "npu1"
24-
else:
25-
return "npu2"
26-
elif hasattr(dev, "resolve"):
27-
# Device object from device_manager.device_type
28-
return dev.resolve().name
29-
else:
30-
# Assume it's a device type object - check class name
31-
name = type(dev).__name__
32-
if "NPU1" in name:
33-
return "npu1"
34-
else:
35-
return "npu2"
36-
37-
38-
DEVICE_CONFIGS = {
39-
"npu1": {
40-
"target": "aie2-none-unknown-elf",
41-
"runtime_lib_dir": "AIE2",
42-
"kernel_dir": "aie2",
43-
"max_columns": NPU1().cols,
44-
},
45-
"npu2": {
46-
"target": "aie2p-none-unknown-elf",
47-
"runtime_lib_dir": "AIE2P",
48-
"kernel_dir": "aie2p",
49-
"max_columns": NPU2().cols,
50-
},
51-
}
52-
53-
54-
def get_device_type(dev, n_aie_cols):
55-
"""Resolve device type to appropriate NPU device instance.
56-
57-
Handles both string inputs ("npu", "npu1", "npu2") and device objects.
58-
59-
Args:
60-
dev: Either a string ("npu", "npu1", "npu2"), a device object with .resolve(),
61-
or a device type object.
62-
n_aie_cols: Number of AIE columns to use (1, 2, or 4).
63-
64-
Returns:
65-
NPU device instance (NPU1, NPU1Col1, NPU1Col2, or NPU2)
66-
"""
67-
dev_name = get_device_name(dev)
68-
if dev_name == "npu1":
69-
if n_aie_cols == 1:
70-
return NPU1Col1()
71-
elif n_aie_cols == 2:
72-
return NPU1Col2()
73-
else:
74-
return NPU1()
75-
else:
76-
return NPU2()
8+
def get_kernel_dir(dev=None) -> str:
9+
"""Returns 'aie2p' for NPU2 (Strix), 'aie2' for NPU1 (Phoenix)."""
10+
if dev is None:
11+
dev = aie_utils.get_current_device()
12+
return "aie2p" if isinstance(dev, NPU2) else "aie2"

iron/common/operator_bases.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,19 @@
44
from __future__ import annotations
55

66
from dataclasses import dataclass, field
7+
from pathlib import Path
78
from typing import Any, ClassVar, Dict
89

910
import aie.utils as aie_utils
11+
1012
from .base import MLIROperator, AIERuntimeArgSpec
1113
from .compilation import (
1214
KernelObjectArtifact,
1315
SourceArtifact,
1416
PythonGeneratedMLIRArtifact,
1517
DesignGenerator,
1618
)
19+
from .device_utils import get_kernel_dir
1720
from .utils import get_shim_dma_limit
1821

1922

@@ -24,10 +27,10 @@ class ChanneledUnaryOperator(MLIROperator):
2427
Assumes a single kernel source file and a standard design.py callback
2528
with args [device, size, num_aie_columns, num_channels, tile_size, trace_size].
2629
27-
Subclasses must define three ClassVar attributes:
30+
Subclasses must define ClassVar attributes:
2831
kernel_name: name of the kernel object file (e.g. "gelu" → gelu.o / gelu.cc)
29-
kernel_subdir: subdirectory under aie_kernels/ (e.g. "aie2p", "generic")
3032
callback_fn: design.py callback function name (e.g. "my_gelu")
33+
needs_lut_ops: set True for operators that require lut_based_ops.o on aie2
3134
3235
Customization points:
3336
- For operators with extra parameters (e.g. alpha, trace_size), add
@@ -45,8 +48,8 @@ class ChanneledUnaryOperator(MLIROperator):
4548
context: object = field(default=None, repr=False)
4649

4750
kernel_name: ClassVar[str]
48-
kernel_subdir: ClassVar[str]
4951
callback_fn: ClassVar[str]
52+
needs_lut_ops: ClassVar[bool] = False
5053

5154
def __post_init__(self) -> None:
5255
max_multiple = self.num_aie_columns * self.tile_size
@@ -55,7 +58,7 @@ def __post_init__(self) -> None:
5558
f"size ({self.size}) must be a multiple of "
5659
f"num_aie_columns * tile_size ({max_multiple})"
5760
)
58-
dev = aie_utils.DefaultNPURuntime.device()
61+
dev = aie_utils.get_current_device()
5962
shim_dma_limit = get_shim_dma_limit(dev)
6063
total_shimdma_channels = self.num_aie_columns * self.num_channels
6164
if total_shimdma_channels > shim_dma_limit:
@@ -78,7 +81,7 @@ def _mlir_callback_args(self) -> list[Any]:
7881
override this method.
7982
"""
8083
return [
81-
aie_utils.DefaultNPURuntime.device(),
84+
aie_utils.get_current_device(),
8285
self.size,
8386
self.num_aie_columns,
8487
self.num_channels,
@@ -97,19 +100,37 @@ def get_mlir_artifact(self) -> PythonGeneratedMLIRArtifact:
97100
)
98101

99102
def get_kernel_artifacts(self) -> list[KernelObjectArtifact]:
100-
return [
103+
dev = aie_utils.get_current_device()
104+
kernel_dir = get_kernel_dir(dev)
105+
artifacts = [
101106
KernelObjectArtifact(
102107
f"{self.kernel_name}.o",
103108
dependencies=[
104109
SourceArtifact(
105110
self.context.base_dir
106111
/ "aie_kernels"
107-
/ self.kernel_subdir
112+
/ kernel_dir
108113
/ f"{self.kernel_name}.cc"
109114
)
110115
],
111-
),
116+
)
112117
]
118+
if self.needs_lut_ops and kernel_dir == "aie2":
119+
mlir_aie_dir = Path(aie_utils.config.root_path())
120+
artifacts.append(
121+
KernelObjectArtifact(
122+
"lut_based_ops.o",
123+
dependencies=[
124+
SourceArtifact(
125+
mlir_aie_dir
126+
/ "aie_runtime_lib"
127+
/ "AIE2"
128+
/ "lut_based_ops.cpp"
129+
)
130+
],
131+
)
132+
)
133+
return artifacts
113134

114135

115136
@dataclass
@@ -123,7 +144,7 @@ class BinaryElementwiseOperator(MLIROperator):
123144
parameter — each core uses 2 DMA channels (one per input), so the ShimDMA
124145
limit is enforced as num_aie_columns * 2 <= 16.
125146
126-
Subclasses must define three ClassVar attributes:
147+
Subclasses must define ClassVar attributes:
127148
kernel_name: name of the kernel object file (e.g. "add" → add.o / add.cc)
128149
kernel_subdir: subdirectory under aie_kernels/ (e.g. "generic")
129150
callback_fn: design.py callback function name (e.g. "my_eltwise_add")
@@ -148,7 +169,7 @@ def __post_init__(self) -> None:
148169
f"size ({self.size}) must be a multiple of "
149170
f"num_aie_columns * tile_size ({self.num_aie_columns * self.tile_size})"
150171
)
151-
dev = aie_utils.DefaultNPURuntime.device()
172+
dev = aie_utils.get_current_device()
152173
shim_dma_limit = get_shim_dma_limit(dev)
153174
# Binary operators use 2 ShimDMA channels per column (one per input).
154175
total_shimdma_channels = self.num_aie_columns * 2
@@ -168,7 +189,7 @@ def get_arg_spec(self) -> list[AIERuntimeArgSpec]:
168189

169190
def _mlir_callback_args(self) -> list[Any]:
170191
return [
171-
aie_utils.DefaultNPURuntime.device(),
192+
aie_utils.get_current_device(),
172193
self.size,
173194
self.num_aie_columns,
174195
self.tile_size,

iron/operators/axpy/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_mlir_artifact(self):
3939
self.operator_dir / "design.py",
4040
"my_axpy",
4141
(
42-
aie_utils.DefaultNPURuntime.device(),
42+
aie_utils.get_current_device(),
4343
self.size,
4444
self.num_aie_columns,
4545
self.tile_size,

iron/operators/axpy/test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,15 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import pytest
6+
import aie.utils as aie_utils
67

78
from iron.operators.axpy.op import AXPY
89
from iron.operators.axpy.reference import generate_golden_reference
910
from iron.common.test_utils import run_test
10-
from iron.common.aie_device_manager import AIEDeviceManager
11-
from iron.common.device_utils import DEVICE_CONFIGS
1211

1312

1413
def get_params():
15-
device_type = AIEDeviceManager().device_str()
16-
max_aie_columns = DEVICE_CONFIGS[device_type]["max_columns"]
14+
max_aie_columns = aie_utils.get_current_device().cols
1715
input_lengths = [1024, 2048, 4096, 8192]
1816
scalar_factors = [3.0, 10.0]
1917

iron/operators/dequant/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_mlir_artifact(self):
5050
self.operator_dir / "design.py",
5151
"my_dequant_kernel",
5252
(
53-
aie_utils.DefaultNPURuntime.device(),
53+
aie_utils.get_current_device(),
5454
self.size,
5555
self.num_aie_columns,
5656
self.num_channels,

iron/operators/dequant/test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,15 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import pytest
6+
import aie.utils as aie_utils
67

78
from iron.operators.dequant.op import Dequant
89
from iron.operators.dequant.reference import generate_golden_reference
910
from iron.common.test_utils import run_test
10-
from iron.common.aie_device_manager import AIEDeviceManager
11-
from iron.common.device_utils import DEVICE_CONFIGS
1211

1312

1413
def get_params():
15-
device_type = AIEDeviceManager().device_str()
16-
max_aie_columns = DEVICE_CONFIGS[device_type]["max_columns"]
14+
max_aie_columns = aie_utils.get_current_device().cols
1715

1816
input_lengths = [1024, 2048, 4096, 8192]
1917
group_size = 32

iron/operators/elementwise_add/test.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,15 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import pytest
6+
import aie.utils as aie_utils
67

7-
from iron.common.aie_device_manager import AIEDeviceManager
8-
from iron.common.device_utils import DEVICE_CONFIGS
98
from iron.operators.elementwise_add.op import ElementwiseAdd
109
from iron.operators.elementwise_add.reference import generate_golden_reference
1110
from iron.common.test_utils import run_test
1211

1312

1413
def get_params():
15-
device_type = AIEDeviceManager().device_str()
16-
max_aie_columns = DEVICE_CONFIGS[device_type]["max_columns"]
14+
max_aie_columns = aie_utils.get_current_device().cols
1715
# Combine all lengths
1816
input_lengths = [1024, 2048, 4096, 8192]
1917

0 commit comments

Comments
 (0)