Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,12 @@ std::vector<at::Tensor> create_output_tensors(c10::intrusive_ptr<TRTEngine> comp

auto dims = core::util::toVec(out_shape);
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
auto options = torch::TensorOptions()
.dtype(type)
.layout(at::kStrided)
.device(at::kCUDA, compiled_engine->device_info.id)
.requires_grad(false);
outputs[pyt_idx] = std::move(at::empty(dims, options).contiguous());
}

return outputs;
Expand Down
28 changes: 15 additions & 13 deletions toolchains/local_torch.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,22 @@ Discovery order:

def _find_python(ctx):
"""Return a path-like object for a Python that has torch importable, or None."""

# Use a Label anchor to get the real workspace root reliably under Bzlmod.
ws = ctx.path(Label("//:MODULE.bazel")).dirname

candidates = []

# virtualenv / uv venv
# Workspace-relative venv first — the project's own .venv always wins.
for rel in [
".venv/bin/python3",
".venv/bin/python",
"venv/bin/python3",
"venv/bin/python",
]:
candidates.append(ctx.path(str(ws) + "/" + rel))

# Active virtualenv / uv venv (only if it's not a different project's venv)
virtual_env = ctx.os.environ.get("VIRTUAL_ENV", "")
if virtual_env:
candidates.append(ctx.path(virtual_env + "/bin/python3"))
Expand All @@ -28,17 +41,6 @@ def _find_python(ctx):
candidates.append(ctx.path(conda_prefix + "/bin/python3"))
candidates.append(ctx.path(conda_prefix + "/bin/python"))

# Common relative-to-workspace venv locations
# ctx.workspace_root is the real workspace root (not the synthetic repo root)
ws = ctx.workspace_root
for rel in [
".venv/bin/python3",
".venv/bin/python",
"venv/bin/python3",
"venv/bin/python",
]:
candidates.append(ws.get_child(rel.replace("/", ws.sep if hasattr(ws, "sep") else "/")))

# System Python last
for name in ["python3", "python"]:
p = ctx.which(name)
Expand Down Expand Up @@ -95,5 +97,5 @@ def _local_torch_impl(ctx):

local_torch = repository_rule(
implementation = _local_torch_impl,
environ = ["TORCH_PATH", "VIRTUAL_ENV", "CONDA_PREFIX"],
environ = ["TORCH_PATH", "VIRTUAL_ENV", "CONDA_PREFIX", "PATH", "HOME"],
)
Loading
Loading