Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions python/cudaq/kernel/kernel_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,25 @@ def __convertStringsToPauli__(self, arg):

return arg

def processCallableArg(self, arg):
"""
Process a callable argument
"""
if not isinstance(arg, PyKernelDecorator):
emitFatalError(
"Callable argument provided is not a cudaq.kernel decorated function."
)
# It may be that the provided input callable kernel
# is not currently in the ModuleOp. Need to add it
# if that is the case, we have to use the AST
# so that it shares self.module's MLIR Context
symbols = SymbolTable(self.module.operation)
if nvqppPrefix + arg.name not in symbols:
tmpBridge = PyASTBridge(self.capturedDataStorage,
existingModule=self.module,
disableEntryPointTag=True)
tmpBridge.visit(globalAstRegistry[arg.name][0])

def __call__(self, *args):
"""
Invoke the CUDA-Q kernel. JIT compilation of the kernel AST to MLIR
Expand Down Expand Up @@ -498,16 +517,7 @@ def __call__(self, *args):
if cc.CallableType.isinstance(mlirType):
# Assume this is a PyKernelDecorator
callableNames.append(arg.name)
# It may be that the provided input callable kernel
# is not currently in the ModuleOp. Need to add it
# if that is the case, we have to use the AST
# so that it shares self.module's MLIR Context
symbols = SymbolTable(self.module.operation)
if nvqppPrefix + arg.name not in symbols:
tmpBridge = PyASTBridge(self.capturedDataStorage,
existingModule=self.module,
disableEntryPointTag=True)
tmpBridge.visit(globalAstRegistry[arg.name][0])
self.processCallableArg(arg)

# Convert `numpy` arrays to lists
if cc.StdvecType.isinstance(mlirType) and hasattr(arg, "tolist"):
Expand Down
10 changes: 6 additions & 4 deletions python/runtime/cudaq/algorithms/py_observe_async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,20 @@ async_observe_result pyObserveAsync(py::object &kernel,
if (py::len(kernelBlockArgs) != args.size())
throw std::runtime_error(
"Invalid number of arguments passed to observe_async.");

// Process any callable args
const auto callableNames = getCallableNames(kernel, args);
auto &platform = cudaq::get_platform();
auto kernelName = kernel.attr("name").cast<std::string>();
auto kernelMod = kernel.attr("module").cast<MlirModule>();
args = simplifiedValidateInputArguments(args);
auto *argData = toOpaqueArgs(args, kernelMod, kernelName);
auto *argData =
toOpaqueArgs(args, kernelMod, kernelName, getCallableArgHandler());

// Launch the asynchronous execution.
py::gil_scoped_release release;
return details::runObservationAsync(
[argData, kernelName, kernelMod]() mutable {
pyAltLaunchKernel(kernelName, kernelMod, *argData, {});
[argData, kernelName, kernelMod, callableNames]() mutable {
pyAltLaunchKernel(kernelName, kernelMod, *argData, callableNames);
delete argData;
},
spin_operator, platform, shots, kernelName, qpu_id);
Expand Down
32 changes: 20 additions & 12 deletions python/runtime/cudaq/algorithms/py_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ static std::vector<py::object> readRunResults(mlir::ModuleOp module,
}

static std::tuple<std::string, MlirModule, OpaqueArguments *,
mlir::func::FuncOp, std::string, mlir::func::FuncOp>
mlir::func::FuncOp, std::string, mlir::func::FuncOp,
std::vector<std::string>>
getKernelLaunchParameters(py::object &kernel, py::args args) {
if (!py::hasattr(kernel, "arguments"))
throw std::runtime_error(
Expand All @@ -52,6 +53,9 @@ getKernelLaunchParameters(py::object &kernel, py::args args) {
if (py::hasattr(kernel, "compile"))
kernel.attr("compile")();

// Process any callable args
const auto callableNames = getCallableNames(kernel, args);

auto origKernName = kernel.attr("name").cast<std::string>();
auto kernelName = origKernName + ".run";
if (!py::hasattr(kernel, "module") || kernel.attr("module").is_none())
Expand All @@ -76,16 +80,19 @@ getKernelLaunchParameters(py::object &kernel, py::args args) {
throw std::runtime_error(
"failed to autogenerate the runnable variant of the kernel.");
}
auto *argData = toOpaqueArgs(args, kernelMod, kernelName);
auto *argData =
toOpaqueArgs(args, kernelMod, kernelName, getCallableArgHandler());
auto funcOp = getKernelFuncOp(kernelMod, kernelName);
return {kernelName, kernelMod, argData, funcOp, origKernName, origKern};
return {kernelName, kernelMod, argData, funcOp,
origKernName, origKern, callableNames};
}

static details::RunResultSpan
pyRunTheKernel(const std::string &name, const std::string &origName,
MlirModule module, mlir::func::FuncOp funcOp,
mlir::func::FuncOp origKernel, OpaqueArguments &runtimeArgs,
quantum_platform &platform, std::size_t shots_count,
const std::vector<std::string> &callableNames,
std::size_t qpu_id = 0) {
auto returnTypes = origKernel.getResultTypes();
if (returnTypes.empty() || returnTypes.size() > 1)
Expand All @@ -101,13 +108,13 @@ pyRunTheKernel(const std::string &name, const std::string &origName,

auto mod = unwrap(module);

auto [rawArgs, size, returnOffset, thunk] =
pyAltLaunchKernelBase(name, module, returnTy, runtimeArgs, {}, 0, false);
auto [rawArgs, size, returnOffset, thunk] = pyAltLaunchKernelBase(
name, module, returnTy, runtimeArgs, callableNames, 0, false);

auto results = details::runTheKernel(
[&]() mutable {
pyLaunchKernel(name, thunk, mod, runtimeArgs, rawArgs, size,
returnOffset, {});
returnOffset, callableNames);
},
platform, name, origName, shots_count, qpu_id);

Expand All @@ -133,7 +140,7 @@ std::vector<py::object> pyRun(py::object &kernel, py::args args,
if (shots_count == 0)
return {};

auto [name, module, argData, func, origName, origKern] =
auto [name, module, argData, func, origName, origKern, callableNames] =
getKernelLaunchParameters(kernel, args);

auto mod = unwrap(module);
Expand All @@ -149,7 +156,7 @@ std::vector<py::object> pyRun(py::object &kernel, py::args args,
}

auto span = pyRunTheKernel(name, origName, module, func, origKern, *argData,
platform, shots_count);
platform, shots_count, callableNames);
delete argData;
auto results = pyReadResults(span, module, func, origKern, shots_count);

Expand Down Expand Up @@ -184,7 +191,7 @@ async_run_result pyRunAsync(py::object &kernel, py::args args,
") exceeds the number of available QPUs (" +
std::to_string(numQPUs) + ")");

auto [name, module, argData, func, origName, origKern] =
auto [name, module, argData, func, origName, origKern, callableNames] =
getKernelLaunchParameters(kernel, args);

auto mod = unwrap(module);
Expand Down Expand Up @@ -219,16 +226,17 @@ async_run_result pyRunAsync(py::object &kernel, py::args args,
QuantumTask wrapped = detail::make_copyable_function(
[sp = std::move(spanPromise), ep = std::move(errorPromise), shots_count,
qpu_id, argData, name, module, func, origKern, origName,
noise_model = std::move(noise_model)]() mutable {
noise_model = std::move(noise_model), callableNames]() mutable {
auto &platform = get_platform();

// Launch the kernel in the appropriate context.
if (noise_model.has_value())
platform.set_noise(&noise_model.value());

try {
auto span = pyRunTheKernel(name, origName, module, func, origKern,
*argData, platform, shots_count, qpu_id);
auto span =
pyRunTheKernel(name, origName, module, func, origKern, *argData,
platform, shots_count, callableNames, qpu_id);
delete argData;
sp.set_value(span);
ep.set_value("");
Expand Down
11 changes: 7 additions & 4 deletions python/runtime/cudaq/algorithms/py_sample_async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ for more information on this programming pattern.)#")
auto &platform = cudaq::get_platform();
if (py::hasattr(kernel, "compile"))
kernel.attr("compile")();
// Process any callable args
const auto callableNames = getCallableNames(kernel, args);
auto kernelName = kernel.attr("name").cast<std::string>();
// Clone the kernel module
auto kernelMod = mlirModuleFromOperation(
Expand Down Expand Up @@ -118,7 +120,7 @@ for more information on this programming pattern.)#")
// Hence, pass it as a unique_ptr for the functor to manage its
// lifetime.
std::unique_ptr<OpaqueArguments> argData(
toOpaqueArgs(args, kernelMod, kernelName));
toOpaqueArgs(args, kernelMod, kernelName, getCallableArgHandler()));

// Should only have C++ going on here, safe to release the GIL
py::gil_scoped_release release;
Expand All @@ -129,9 +131,10 @@ for more information on this programming pattern.)#")
// (2) This lambda might be executed multiple times, e.g, when
// the kernel contains measurement feedback.
cudaq::detail::make_copyable_function(
[argData = std::move(argData), kernelName,
kernelMod]() mutable {
pyAltLaunchKernel(kernelName, kernelMod, *argData, {});
[argData = std::move(argData), kernelName, kernelMod,
callableNames]() mutable {
pyAltLaunchKernel(kernelName, kernelMod, *argData,
callableNames);
}),
platform, kernelName, shots, explicitMeasurements, qpu_id),
std::move(mlirCtx));
Expand Down
69 changes: 48 additions & 21 deletions python/runtime/cudaq/platform/py_alt_launch_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,37 @@ void setDataLayout(MlirModule module) {
}
}

std::function<bool(OpaqueArguments &argData, py::object &arg)>
getCallableArgHandler() {
return [](cudaq::OpaqueArguments &argData, py::object &arg) {
if (py::hasattr(arg, "module")) {
// Just give it some dummy data that will not be used.
// We synthesize away all callables, the block argument
// remains but it is not used, so just give argsCreator
// something, and we'll make sure its cleaned up.
long *ourAllocatedArg = new long();
argData.emplace_back(ourAllocatedArg,
[](void *ptr) { delete static_cast<long *>(ptr); });
return true;
}
return false;
};
}

/// @brief Create a new OpaqueArguments pointer and pack the python arguments
/// in it. Clients must delete the memory.
OpaqueArguments *toOpaqueArgs(py::args &args, MlirModule mod,
const std::string &name) {
OpaqueArguments *
toOpaqueArgs(py::args &args, MlirModule mod, const std::string &name,
const std::optional<
std::function<bool(OpaqueArguments &argData, py::object &arg)>>
&optionalBackupHandler) {
auto kernelFunc = getKernelFuncOp(mod, name);
auto *argData = new cudaq::OpaqueArguments();
args = simplifiedValidateInputArguments(args);
setDataLayout(mod);
cudaq::packArgs(*argData, args, kernelFunc,
[](OpaqueArguments &, py::object &) { return false; });
auto backupHandler = optionalBackupHandler.value_or(
[](OpaqueArguments &, py::object &) { return false; });
cudaq::packArgs(*argData, args, kernelFunc, backupHandler);
return argData;
}

Expand Down Expand Up @@ -998,26 +1019,32 @@ std::string getASM(const std::string &name, MlirModule module,
return str;
}

std::vector<std::string> getCallableNames(py::object &kernel, py::args &args) {
// Handle callable arguments, if any, similar to `PyKernelDecorator.__call__`,
// so that the callable arguments are properly packed for `pyAltLaunchKernel`
// as if it's launched from Python.
std::vector<std::string> callableNames;
for (std::size_t i = 0; i < args.size(); ++i) {
auto arg = args[i];
// If this is a `PyKernelDecorator` callable:
if (py::hasattr(arg, "__call__") && py::hasattr(arg, "module") &&
py::hasattr(arg, "name")) {
if (py::hasattr(arg, "compile"))
arg.attr("compile")();

if (py::hasattr(kernel, "processCallableArg"))
kernel.attr("processCallableArg")(arg);
callableNames.push_back(arg.attr("name").cast<std::string>());
}
}
return callableNames;
}

void bindAltLaunchKernel(py::module &mod,
std::function<std::string()> &&getTL) {
jitCache = std::make_unique<JITExecutionCache>();
getTransportLayer = std::move(getTL);

auto callableArgHandler = [](cudaq::OpaqueArguments &argData,
py::object &arg) {
if (py::hasattr(arg, "module")) {
// Just give it some dummy data that will not be used.
// We synthesize away all callables, the block argument
// remains but it is not used, so just give argsCreator
// something, and we'll make sure its cleaned up.
long *ourAllocatedArg = new long();
argData.emplace_back(ourAllocatedArg,
[](void *ptr) { delete static_cast<long *>(ptr); });
return true;
}
return false;
};

mod.def(
"pyAltLaunchKernel",
[&](const std::string &kernelName, MlirModule module,
Expand All @@ -1026,7 +1053,7 @@ void bindAltLaunchKernel(py::module &mod,

cudaq::OpaqueArguments args;
setDataLayout(module);
cudaq::packArgs(args, runtimeArgs, kernelFunc, callableArgHandler);
cudaq::packArgs(args, runtimeArgs, kernelFunc, getCallableArgHandler());
pyAltLaunchKernel(kernelName, module, args, callable_names);
},
py::arg("kernelName"), py::arg("module"), py::kw_only(),
Expand All @@ -1040,7 +1067,7 @@ void bindAltLaunchKernel(py::module &mod,

cudaq::OpaqueArguments args;
setDataLayout(module);
cudaq::packArgs(args, runtimeArgs, kernelFunc, callableArgHandler);
cudaq::packArgs(args, runtimeArgs, kernelFunc, getCallableArgHandler());
return pyAltLaunchKernelR(kernelName, module, returnType, args,
callable_names);
},
Expand Down
17 changes: 15 additions & 2 deletions python/runtime/cudaq/platform/py_alt_launch_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,23 @@ namespace cudaq {
/// @brief Set current architecture's data layout attribute on a module.
void setDataLayout(MlirModule module);

/// @brief Get the default callable argument handler for packing arguments.
std::function<bool(OpaqueArguments &argData, py::object &arg)>
getCallableArgHandler();

/// @brief Get the names of callable arguments from the given kernel and
/// arguments.
// As we process the arguments, we also perform any extra processing required
// for callable arguments.
std::vector<std::string> getCallableNames(py::object &kernel, py::args &args);

/// @brief Create a new OpaqueArguments pointer and pack the
/// python arguments in it. Clients must delete the memory.
OpaqueArguments *toOpaqueArgs(py::args &args, MlirModule mod,
const std::string &name);
OpaqueArguments *
toOpaqueArgs(py::args &args, MlirModule mod, const std::string &name,
const std::optional<
std::function<bool(OpaqueArguments &argData, py::object &arg)>>
&optionalBackupHandler = std::nullopt);

inline std::size_t byteSize(mlir::Type ty) {
if (isa<mlir::ComplexType>(ty)) {
Expand Down
30 changes: 29 additions & 1 deletion python/tests/kernel/test_observe_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import pytest
import numpy as np
from typing import List
from typing import List, Callable

import cudaq
from cudaq import spin
Expand Down Expand Up @@ -343,3 +343,31 @@ def gqeCirc2(N: int, thetas: list[float], paulis: list[cudaq.pauli_word]):
exp_val2 = cudaq.observe_async(gqeCirc2, obs, numQubits, list(ts),
pauliStings).get().expectation()
print('observe_async exp_val2', exp_val2)


def test_observe_callable():
"""Test that we can observe kernels with callable arguments."""

@cudaq.kernel
def ansatz_callable(angle: float, rotate: Callable[[cudaq.qubit, float],
None]):
q = cudaq.qvector(2)
x(q[0])
rotate(q[1], angle)
x.ctrl(q[1], q[0])

@cudaq.kernel
def ry_rotate(qubit: cudaq.qubit, angle: float):
ry(angle, qubit)

hamiltonian = 5.907 - 2.1433 * spin.x(0) * spin.x(1) - 2.1433 * spin.y(
0) * spin.y(1) + .21829 * spin.z(0) - 6.125 * spin.z(1)

result = cudaq.observe(ansatz_callable, hamiltonian, .59, ry_rotate)
print(result.expectation())
assert np.isclose(result.expectation(), -1.74, atol=1e-2)

result_async = cudaq.observe_async(ansatz_callable, hamiltonian, .59,
ry_rotate).get()
print(result_async.expectation())
assert np.isclose(result_async.expectation(), -1.74, atol=1e-2)
Loading
Loading