diff --git a/.gitmodules b/.gitmodules index 98c3df68fd..1a19075141 100644 --- a/.gitmodules +++ b/.gitmodules @@ -10,3 +10,7 @@ [submodule "hls4ml/templates/catapult/ac_math"] path = hls4ml/templates/catapult/ac_math url = https://github.com/hlslibs/ac_math.git +[submodule "hls4ml/contrib/Coyote"] + path = hls4ml/contrib/Coyote + url = https://github.com/fpgasystems/Coyote.git + branch = integrations/hls4ml diff --git a/docs/backend/accelerator.rst b/docs/backend/accelerator.rst index 187bccaa2c..64c5d218fa 100644 --- a/docs/backend/accelerator.rst +++ b/docs/backend/accelerator.rst @@ -75,3 +75,70 @@ The ``predict`` method will send the input data to the PL and return the output nn = NeuralNetworkOverlay('hls4ml_nn.bit', X_test.shape, y_test.shape) y_hw, latency, throughput = nn.predict(X_test, profile=True) + + +================= +CoyoteAccelerator +================= + +The **CoyoteAccelerator** backend of ``hls4ml`` leverages the `Coyote shell `_ to easily deploy models on PCIe-attached Alveo FPGAs. +Coyote is an open-source, research shell that facilitates the deployment of applications on FPGAs, as well as the integration of FPGAs into larger computer systems. +Some of its features include: +- Multi-tenancy +- Virtualized memory +- Optimized data movement +- Dynamic reconfiguration +- Automatic work scheduling and memory striping +- Networking for distributed applications + +The list of supported boards is available in the `Coyote documentation. `_ +The current Coyote backend can be used to deploy hls4ml models from both Python and C++. While the focus of the current backend is on the inference, +it can easily be extended to support dynamic reconfiguration of models, as well as distributed inference across multiple FPGAs. + +CoyoteOverlay +================================ + +Similar to the VivadoAccelerator backend, the Coyote backend creates a custom **neural network overlay** that interacts with the FPGA. +This overlay can be used to provide inputs, run inference and retrieve the predictions. Additionally, the overlay provides a utility +functon to load the model bitstream and driver for some clusters. On others, the users need to manually load the bitstream and driver. +For guidance, see the `Coyote documentation. `_. + +C++ binary +================================ + +Additionally, the Coyote backend generates and compiles a C++ program that can be used to run inference on the FPGA. +The binary can be found in ``/build/_cyt_sw/bin/test`` and when launched, it will +run inference using the inputs from ``tb_data``. Similar to the Python overlay, the bitstream and driver must be loaded before running the inference. + +Example +====================== + +Similar to the ``VivadoAccelerator``backend, we first generate a bitstream from a Keras model ``model`` and a config. + +.. code-block:: Python + + import hls4ml + config = hls4ml.utils.config_from_keras_model(model, granularity='name') + hls_model = hls4ml.converters.convert_from_keras_model(model, + hls_config=config, + output_dir='hls4ml_prj_coyote', + backend='CoyoteAccelerator', + board='u55c') + hls4ml.build(bitfile=True) + +After this command completes, the FPGA must be programmed with the bistream. Additionally, the Coyote driver must be loaded. +For some platforms, Coyote provides utility functions to load the bitstream and driver. For others, this can be achieved using +the Vivado hardware manager and Linux commands. More detail can be found in the `Coyote documentation. `_. + +Finally, we can create a ``CoyoteOverlay`` object, which can be used to run inference on the FPGA. Additionally, the overlay provides a utility +functon to load the model bitstream and driver for some clusters. +When running inference, we must provide the input tensor and the shape of the output tensor (to allocate the buffers for the data transfer). +Optionally, batch size can be specified.. +The ``predict`` method will send the input data to the FPGA and return the output data ``y_hw``. + +.. code-block:: Python + + from hls4ml.backends.coyote_accelerator.coyote_accelerator_overlay import CoyoteOverlay + + overlay = CoyoteOverlay('hls4ml_prj_coyote') + y_hw = overlay.predict(x, (1, ), BATCH_SIZE) diff --git a/hls4ml/backends/__init__.py b/hls4ml/backends/__init__.py index 4a48f072cd..2f2870b14a 100644 --- a/hls4ml/backends/__init__.py +++ b/hls4ml/backends/__init__.py @@ -11,6 +11,8 @@ from hls4ml.backends.vitis.vitis_backend import VitisBackend # isort: skip +from hls4ml.backends.coyote_accelerator.coyote_accelerator_backend import CoyoteAcceleratorBackend + register_backend('Vivado', VivadoBackend) register_backend('VivadoAccelerator', VivadoAcceleratorBackend) register_backend('Vitis', VitisBackend) @@ -18,3 +20,4 @@ register_backend('Catapult', CatapultBackend) register_backend('SymbolicExpression', SymbolicExpressionBackend) register_backend('oneAPI', OneAPIBackend) +register_backend('CoyoteAccelerator', CoyoteAcceleratorBackend) diff --git a/hls4ml/backends/coyote_accelerator/__init__.py b/hls4ml/backends/coyote_accelerator/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hls4ml/backends/coyote_accelerator/coyote_accelerator_backend.py b/hls4ml/backends/coyote_accelerator/coyote_accelerator_backend.py new file mode 100644 index 0000000000..f5908f960e --- /dev/null +++ b/hls4ml/backends/coyote_accelerator/coyote_accelerator_backend.py @@ -0,0 +1,150 @@ +import os +import subprocess +from hls4ml.model.flow import get_flow, register_flow +from hls4ml.backends import VitisBackend, VivadoBackend + +class CoyoteAcceleratorBackend(VitisBackend): + """ + The CoyoteAccelerator backend, which deploys hls4ml models on a PCIe-attached Alveo FPGA + Underneath it uses the Coyote shell: https://github.com/fpgasystems/Coyote, + which offers high-performance data movement, networking capabilities, multi-tenancy, + partial reconfiguration etc. This backend has some similarities with the VitisAccelerator + backend, but the underlying platforms are different. The implementation of this backend + remains mostly simple, inheriting most of the functionality from the Vitis backend and + providing the necessary infrastructure to run model inference on Alveo boards. + + Currently, this backend supports batched inference of a single model on hardware. + In the future, it can easily be extended with the following capabilities, leveraging + Coyote's features: + - Distributed inference + - Multiple parallel instances of hls4ml models (same or distinct models) + - Dynamic, run-time reconfiguration of models + + Generic examples of Coyote can be found at the above-mentioned repository, under examples/ + """ + + def __init__(self): + super(VivadoBackend, self).__init__(name='CoyoteAccelerator') + self._register_layer_attributes() + self._register_flows() + + def _register_flows(self): + writer_passes = ['make_stamp', 'coyoteaccelerator:write_hls'] + self._writer_flow = register_flow('write', writer_passes, requires=['vitis:ip'], backend=self.name) + + ip_flow_requirements = get_flow('vitis:ip').requires.copy() + self._default_flow = register_flow('ip', None, requires=ip_flow_requirements, backend=self.name) + + def compile(self, model): + """ + Compiles the hls4ml model for software emulation + + Args: + model (ModelGraph): hls4ml model to synthesize + + Return: + lib_name (str): The name of the compiled library + """ + lib_name = None + ret_val = subprocess.run( + ['./build_lib.sh'], + shell=True, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd=model.config.get_output_dir(), + ) + if ret_val.returncode != 0: + print(ret_val.stdout) + raise Exception(f'Failed to compile project "{model.config.get_project_name()}"') + lib_name = '{}/build/{}-{}.so'.format( + model.config.get_output_dir(), model.config.get_project_name(), model.config.get_config_value('Stamp') + ) + + return lib_name + + def build( + self, + model, + device: str = 'u55c', + reset: bool = False, + csim: bool = True, + synth: bool = True, + cosim: bool = False, + validation: bool = False, + csynth: bool = False, + bitfile: bool = False, + timing_opt: bool = False, + hls_clock_period: float = 4, + hls_clock_uncertainty: float = 27 + ): + """ + Synthesizes the hls4ml model bitstream as part of the Coyote shell + and compiles the host-side software to control the FPGA and run model inference + + Args: + model (ModelGraph): hls4ml model to synthesize + device (str, optional): Target Alveo FPGA card; currently supported u55c, u280 and u250 + reset (bool, optional): Reset HLS project, if a previous one is found + csim (bool, optional): Run C-Simulation of the HLS project + synth (bool, optional): Run HLS synthesis + cosim (bool, optional): Run HLS co-simulation + validation (bool, optional): Validate results between C-Sim and Co-Sim + csynth (bool, optional): Run Coyote synthesis using Vivado, which will synthesize the model in a vFPGA + bitfile (bool, optional): Generate Coyote bitstream + timing_opt (bool, optional): Run additional optimizations when running PnR during bitstream generation + hls_clock_period (float, optional): Clock period to be used for HLS synthesis + hls_clock_uncertainty (float, optional): Clock uncertainty to be used for HLS synthesis + + NOTE: Currently, the hardware will synthesize with a default clock period of 4ns / 250 MHz frequency, + since this is the default frequency of Coyote (since the XDMA core defaults to 250 MHz). Coyote allows + one to specify a different clock period for the model and use a clock-domain crossing (CDC) between the + XDMA region and the model. This option is currently not exposed as part of the hls4ml backend, but advanced + users can easily set in the the CMake configuration of Coyote. + + NOTE: While the hardware will synthesize at 250 MHz, users can optionally pass a different HLS clock period + This is primarily a work-around when HLS synthesize a kernel that doesn't meet timing during PnR. + The "trick" is to run HLS synthesis at a higher clock frequency then (or provide higher uncertainty) + + TODO: Add functionality to parse synthesis reports + """ + curr_dir = os.getcwd() + + # Synthesize hardware + cmake_cmd = ( + f'cmake ../../ ' + f'-DFLOW=hw ' + f'-DFDEV_NAME={device} ' + f'-DBUILD_OPT={int(timing_opt)} ' + f'-DEN_HLS_RESET={int(reset)} ' + f'-DEN_HLS_CSIM={int(csim)} ' + f'-DEN_HLS_SYNTH={int(synth)} ' + f'-DEN_HLS_COSIM={int(cosim)} ' + f'-DEN_HLS_VALIDATION={int(validation)} ' + f'-DHLS_CLOCK_PERIOD={hls_clock_period} ' + f'-DHLS_CLOCK_UNCERTAINTY="{str(hls_clock_uncertainty)}%"' + ) + + if not os.path.exists(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_hw'): + os.mkdir(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_hw') + os.chdir(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_hw') + os.system(cmake_cmd) + + if bitfile: + os.system('make project && make bitgen') + elif csynth: + os.system('make project && make synth') + else: + os.system('make project') + + os.chdir(curr_dir) + + # Compile host software + cmake_cmd = 'cmake ../../ -DFLOW=sw' + if not os.path.exists(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_sw'): + os.mkdir(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_sw') + os.chdir(f'{model.config.get_output_dir()}/build/{model.config.get_project_name()}_cyt_sw') + os.system(cmake_cmd) + os.system('make') + os.chdir(curr_dir) + diff --git a/hls4ml/backends/coyote_accelerator/coyote_accelerator_overlay.py b/hls4ml/backends/coyote_accelerator/coyote_accelerator_overlay.py new file mode 100644 index 0000000000..12b56bf762 --- /dev/null +++ b/hls4ml/backends/coyote_accelerator/coyote_accelerator_overlay.py @@ -0,0 +1,104 @@ +import os +import time +import ctypes +import logging +import numpy as np + +class CoyoteOverlay: + """ + CoyoteOverlay class, similar to NeuralNetworkOverlay for the VivadoAccelerator backend + This class can be used to run model inference on the FPGA with the CoyoteAccelerator backend + """ + def __init__(self, path: str, project_name: str = 'myproject'): + """ + Default constructor + + Args: + path (str): Path to the hls4ml folder, as specified in convert_model(...) + project_name (str, optional): hls4ml model name, if different than myproject + """ + + self.path = path + self.project_name = project_name + + # Set up dynamic C library + self.coyote_lib = ctypes.cdll.LoadLibrary( + f'{self.path}/build/{self.project_name}_cyt_sw/lib/libCoyoteInference.so' + ) + + self.coyote_lib.init_model_inference.argtypes = [ctypes.c_uint, ctypes.c_uint, ctypes.c_uint] + self.coyote_lib.init_model_inference.restype = ctypes.POINTER(ctypes.c_void_p) + + self.coyote_lib.flush.argtypes = [ctypes.POINTER(ctypes.c_void_p)] + self.coyote_lib.predict.argtypes = [ctypes.POINTER(ctypes.c_void_p)] + + self.coyote_lib.get_inference_predictions.argtypes = [ctypes.POINTER(ctypes.c_void_p), ctypes.c_uint] + self.coyote_lib.get_inference_predictions.restype = ctypes.POINTER(ctypes.c_float) + + self.coyote_lib.free_model_inference.argtypes = [ctypes.POINTER(ctypes.c_void_p)] + + def program_hacc_fpga(self): + """ + Utility function for loading the Coyote-hls4ml bitstream and driver + on the ETH Zurich Heteregenous Accelerate Compute Cluster (HACC) + On other clusters, users would need to manually load the bitstream and driver + Gudance on this is specified in Coyote docs. + """ + os.system( + f'cd {self.path}/Coyote/driver && ' + f'make && ' + f'cd ../util && ' + f'bash program_hacc_local.sh ../../build/{self.project_name}_cyt_hw/bitstreams/cyt_top.bit ../driver/build/coyote_driver.ko' + ) + + def predict(self, X: np.array, y_shape: tuple, batch_size: int = 1): + """ + Run model inference + + Args: + X (np.array): Input data + y_shape (tuple): Shape of the output; used for allocating sufficient memory for the output + batch_size (int, optional): Inference batch size + """ + if len(X.shape) == 1: + X = np.array([X]) + if not (isinstance(X.dtype, float) or isinstance(X.dtype, np.float32)): + logging.warning('CoyoteOverlay only supports (for now) floating-point inputs; casting input data to float') + X = X.astype(np.float32) + y = np.empty((len(X), *y_shape)) + np_pointer_nd = np.ctypeslib.ndpointer(dtype=np.float32, ndim=len(X[0].shape), flags='C') + self.coyote_lib.set_inference_data.argtypes = [ctypes.POINTER(ctypes.c_void_p), np_pointer_nd, ctypes.c_uint] + + model = self.coyote_lib.init_model_inference(batch_size, int(np.prod(X[0].shape)), int(np.prod(y_shape))) + + cnt = 0 + avg_latency = 0 + avg_throughput = 0 + total_batches = 0 + for x in X: + self.coyote_lib.set_inference_data(model, x, cnt) + cnt += 1 + if cnt == batch_size: + self.coyote_lib.flush(model) + + ts = time.time_ns() + self.coyote_lib.predict(model) + te = time.time_ns() + + time_taken = te - ts + avg_latency += (time_taken / 1e3) + avg_throughput += (batch_size / (time_taken * 1e-9)) + + for j in range(batch_size): + tmp = self.coyote_lib.get_inference_predictions(model, j) + y[total_batches * batch_size + j] = np.ctypeslib.as_array(tmp, shape=y_shape) + + cnt = 0 + total_batches += 1 + + self.coyote_lib.free_model_inference(model) + print(f'Batch size: {batch_size}; batches processed: {total_batches}') + print(f'Mean latency: {round(avg_latency / total_batches, 3)}us (inference only)') + print(f'Mean throughput: {round(avg_throughput / total_batches, 1)} samples/s (inference only)') + + return y \ No newline at end of file diff --git a/hls4ml/backends/coyote_accelerator/passes/__init__.py b/hls4ml/backends/coyote_accelerator/passes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hls4ml/contrib/Coyote b/hls4ml/contrib/Coyote new file mode 160000 index 0000000000..292ec1521c --- /dev/null +++ b/hls4ml/contrib/Coyote @@ -0,0 +1 @@ +Subproject commit 292ec1521c4a9a1cc9b1335dee6b99deabb38542 diff --git a/hls4ml/templates/coyote_accelerator/CMakeLists.txt b/hls4ml/templates/coyote_accelerator/CMakeLists.txt new file mode 100644 index 0000000000..63bfede764 --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.5) +set(CYT_DIR ${CMAKE_SOURCE_DIR}/Coyote/) +set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} ${CYT_DIR}/cmake) +find_package(CoyoteHW REQUIRED) +find_package(CoyoteSW REQUIRED) + +set(FLOW "hw" CACHE STRING "Synthesize hardware (hw) or host software (sw)") + +if(FLOW STREQUAL "hw") + project(myproject) + set(EN_STRM 1) + set(N_STRM_AXI 1) + set(N_REGIONS 1) + + validation_checks_hw() + load_apps ( + VFPGA_C0_0 "src" + ) + create_hw() +endif() + +if(FLOW STREQUAL "sw") + project( + CoyoteInference + VERSION 1.0.0 + DESCRIPTION "CoyoteInference library" + ) + set(CYT_INCLUDE_PATH ${CYT_DIR}/sw/include) + add_library(CoyoteInference SHARED "${CMAKE_SOURCE_DIR}/src/host_libs.cpp" "${CMAKE_SOURCE_DIR}/src/host_libs.hpp") + target_include_directories(CoyoteInference PUBLIC ${CYT_INCLUDE_PATH}) + target_link_libraries(CoyoteInference PUBLIC Coyote) + target_link_directories(CoyoteInference PUBLIC /usr/local/lib) + + project(myproject) + set(EXEC test) + set(TARGET_DIR "${CMAKE_SOURCE_DIR}/src/") + add_executable(${EXEC} ${TARGET_DIR}/myproject_host.cpp) + target_link_libraries(${EXEC} PUBLIC Coyote) + target_link_libraries(${EXEC} PUBLIC CoyoteInference) + target_link_directories(${EXEC} PUBLIC /usr/local/lib) + target_include_directories(${EXEC} PUBLIC src/hls/model_wrapper/firmware/) + target_include_directories(${EXEC} PUBLIC src/hls/model_wrapper/firmware/ap_types) + +endif() \ No newline at end of file diff --git a/hls4ml/templates/coyote_accelerator/build_lib.sh b/hls4ml/templates/coyote_accelerator/build_lib.sh new file mode 100755 index 0000000000..57ce75e2dc --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/build_lib.sh @@ -0,0 +1,23 @@ +#!/bin/bash +set -e + +CC=g++ +if [[ "$OSTYPE" == "linux-gnu" ]]; then + CFLAGS="-O3 -fPIC -std=c++11 -fno-gnu-unique" +elif [[ "$OSTYPE" == "darwin"* ]]; then + CFLAGS="-O3 -fPIC -std=c++11" +fi + +PROJECT=myproject +LIB_STAMP=mystamp + +BASE_DIR="$(cd "$(dirname "$0")" && pwd)"/src +BUILD_DIR="$(cd "$(dirname "$0")" && pwd)"/build +INC_FLAGS="-Isrc/hls/model_wrapper/firmware/ap_types/ -Isrc/hls/model_wrapper/" +WEIGHTS_DIR="\"${BASE_DIR}/hls/model_wrapper/firmware/weights\"" + +mkdir -p ${BUILD_DIR} +${CC} ${CFLAGS} ${INC_FLAGS} -D WEIGHTS_DIR="${WEIGHTS_DIR}" -c ${BASE_DIR}/hls/model_wrapper/firmware/${PROJECT}.cpp -o ${BUILD_DIR}/${PROJECT}.o +${CC} ${CFLAGS} ${INC_FLAGS} -D WEIGHTS_DIR="${WEIGHTS_DIR}" -c ${BASE_DIR}/${PROJECT}_bridge.cpp -o ${BUILD_DIR}/${PROJECT}_bridge.o +${CC} ${CFLAGS} ${INC_FLAGS} -shared ${BUILD_DIR}/${PROJECT}.o ${BUILD_DIR}/${PROJECT}_bridge.o -o ${BUILD_DIR}/${PROJECT}-${LIB_STAMP}.so +rm -f ${BUILD_DIR}/*.o diff --git a/hls4ml/templates/coyote_accelerator/host_libs.cpp b/hls4ml/templates/coyote_accelerator/host_libs.cpp new file mode 100644 index 0000000000..71f2f5c35c --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/host_libs.cpp @@ -0,0 +1,80 @@ +#include "host_libs.hpp" + +CoyoteInference::CoyoteInference(unsigned int batch_size, unsigned int in_size, unsigned int out_size): + batch_size(batch_size), in_size(in_size), out_size(out_size), + coyote_thread(DEFAULT_VFPGA_ID, getpid()) +{ + for (unsigned int i = 0; i < batch_size; i++) { + // Allocate memory using huge pages (HPF) for input and output tensors + src_mems.emplace_back((float *) coyote_thread.getMem({coyote::CoyoteAllocType::HPF, (uint) (in_size * sizeof(float))})); + dst_mems.emplace_back((float *) coyote_thread.getMem({coyote::CoyoteAllocType::HPF, (uint) (out_size * sizeof(float))})); + if (!src_mems[i] || !dst_mems[i]) { throw std::runtime_error("Could not allocate memory; exiting..."); } + + // Create scatter-gather entry for this input/output pair + coyote::localSg src_sg = { .addr = src_mems[i], .len = (uint) (in_size * sizeof(float))}; + coyote::localSg dst_sg = { .addr = dst_mems[i], .len = (uint) (out_size * sizeof(float))}; + src_sgs.emplace_back(src_sg); + dst_sgs.emplace_back(dst_sg); + } +} + +CoyoteInference::~CoyoteInference() {} + +void CoyoteInference::flush() { + // Reset output tensors to zero + for (unsigned int i = 0; i < batch_size; i++) { + memset(dst_mems[i], 0, out_size); + } + + // Clear completion counters + coyote_thread.clearCompleted(); +} + +void CoyoteInference::predict() { + // Coyote uses the so-called invoke function to run operation in vFPGAs. + // In this case, the operation is LOCAL_TRANSFER, and the flow of data is: + // host memory (input data) => vFPGA (hls4ml model) => host memory (output data) + for (int i = 0 ; i < batch_size; i++) { + coyote_thread.invoke(coyote::CoyoteOper::LOCAL_TRANSFER, src_sgs[i], dst_sgs[i]); + } + + // Poll on completion; each batch increments the counter by one + while (coyote_thread.checkCompleted(coyote::CoyoteOper::LOCAL_TRANSFER) != batch_size) {} +} + +void CoyoteInference::set_data(float *x, unsigned int i) { + // Simply copy from one buffer to the other + for (int j = 0; j < in_size; j++) { + src_mems[i][j] = x[j]; + } +} + +float* CoyoteInference::get_predictions(unsigned int i) { return dst_mems[i]; } + +// C API for the CoyoteInference class; so that it can be used from Python or other languages +// Better option would be to use something like pybind11, but the implementation is simple enough for now. +extern "C" { + CoyoteInference* init_model_inference(unsigned int batch_size, unsigned int in_size, unsigned int out_size) { + return new CoyoteInference(batch_size, in_size, out_size); + } + + void free_model_inference(CoyoteInference* obj) { + delete obj; + } + + void flush(CoyoteInference* obj) { + obj->flush(); + } + + void predict(CoyoteInference* obj) { + obj->predict(); + } + + void set_inference_data(CoyoteInference* obj, float *x, unsigned int i) { + obj->set_data(x, i); + } + + float* get_inference_predictions(CoyoteInference* obj, unsigned int i) { + return obj->get_predictions(i); + } +} diff --git a/hls4ml/templates/coyote_accelerator/host_libs.hpp b/hls4ml/templates/coyote_accelerator/host_libs.hpp new file mode 100644 index 0000000000..571cf2f72f --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/host_libs.hpp @@ -0,0 +1,107 @@ +#ifndef HOST_LIBS_HPP_ +#define HOST_LIBS_HPP_ + +#include +#include "cOps.hpp" +#include "cThread.hpp" + +// Coyote uses so-called vFPGAs: individual applications running in parallel on the FPGA +// Users can deploy multiple vFPGAs on the same hardware, each with its own application +// For now, the CoyoteAccelerator only supports a single vFPGA, though future extensions +// could easily allow multiple parallel instance of hls4ml models +#define DEFAULT_VFPGA_ID 0 + +/** + * @brief Utility class for running inference of an hls4ml model with the Coyote accelerator backend + * + * This class can be used to set up and execute the inference, by allocating memory for the tensors, + * running the inference, and retrieving predictions. It abstracts away all the interaciton with the + * Coyote software library, which in turn abstracts away the interaction with the hardware. + * This class assumes some familiarity with the Coyote software library; examples of its use + * can be found on Github examples: https://github.com/fpgasystems/Coyote/tree/master/examples. + * + * NOTE: This class can be linked into a shared library and called from the Python overlay (CoyoteOverlay) or + * it can be instantiated stand-alone in a C++ code. + * + * NOTE: The functions set_data, predict and get_prediction are separated, simply to be able to obtain granular + * measurements of how long each step takes. One could easily combine them into a single function. + + * NOTE: There is a difference between XRT (VitisAccelerator backend) and Coyote: in XRT it is necessary + * to sync the input data from the host memory to device memory (HBM/DDR) befor running the inference. + * On the other hand, Coyote implements a shared virtual memory model, and the shell will automatically + * fetch data from host memory and feed it to the model kernel, fully bypassing device memory. However, + * we still have a function set_data that esentially copies data from one host-side array (e.g., NumPy) to + * an array that's a member variable of this class. This is not necessary and Coyote could equally work + * with the NumPy array, but it makes it easier to manage multiple batches. Future optimizations could fix + * this, if desired. For more details on Coyote's memory model, refer to the paper: https://arxiv.org/abs/2504.21538 + */ +class CoyoteInference { +public: + /** + * @brief Constructor for CoyoteInference + * @param batch_size Number of samples in a batch + * @param in_size Size of the input tensor (in elements) + * @param out_size Size of the output tensor (in elements) + * + * NOTE: The batch size is not a hardware/synthesis parameter, but rather a runtime parameter + * Coyote supports asynchronous execution of request, so the software can invoke multiple + * inputs, as specified by the batch size, and the hardware handles the scheduling, any back-pressure etc. + */ + CoyoteInference(unsigned int batch_size, unsigned int in_size, unsigned int out_size); + + /// Default destructor + ~CoyoteInference(); + + /** + * @brief Utility function, clears completion counters in Coyote and resets output tensors to zero + */ + void flush(); + + /** + * @brief Runs inference on the input tensors, specified by set_data + */ + void predict(); + + /** + * @brief Set the input data for a specific entry of the batch + * + * @param x Pointer to the input data (array of floats) + * @param i Index of the batch entry to set data for + */ + void set_data(float *x, unsigned int i); + + /** + * @brief Returns the i-th prediction of a batch + * + * @param i Index of the batch entry to get predictions for + * @return Pointer to the output predictions (array of floats) + */ + float* get_predictions(unsigned int i); + +private: + + unsigned int batch_size, in_size, out_size; + + /** + * @brief Coyote thread for inference + * + * Coyote uses so called threads to interfact with th FPGA, which include + * high-level functions for moving data, setting control registers, + * polling on completions etc. + */ + coyote::cThread coyote_thread; + + /** + * @brief Coyote scatter-gather entries + * + * Scatter-gather entries are used to specify the source and destination + * addresses and lengths for data transfers between host memory and the FPGA. + * In this case, they point to the input and output tensors for each batch entry. + */ + std::vector src_sgs, dst_sgs; + + /// Memory pointers for input tensors (one per batch entry) + std::vector src_mems, dst_mems; +}; + +#endif diff --git a/hls4ml/templates/coyote_accelerator/model_wrapper.cpp b/hls4ml/templates/coyote_accelerator/model_wrapper.cpp new file mode 100644 index 0000000000..2cf950ba2a --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/model_wrapper.cpp @@ -0,0 +1,32 @@ +#include "model_wrapper.hpp" + +/** + * @brief A wrapper for an hls4ml model deployed with Coyote. + * + * In Coyote, data is passed through 512-bit AXI streams; the data can originate + * from host or card memory, or the network (from other nodes). The model wrapper + * encapsulates the hls4ml model and converter functions that convert beats from + * 512-bit AXI streams to the model's input format (depends whether io_parallel or io_stream) + * and vice-versa for the output. Important, when running the Coyote accelerator backend and + * moving data from/to the host, it is packed as float32 to the 512-bit AXI stream. That is + * each AXI beat (.tvalid asserted) contains 16 float32 values. The reason for this is two-fold: + * (1) the predict function inherently works with float32 data, and (2) when moving data between + * the host and the accelerator, one must specify the size of the buffer moved. While it's perfectly + * possible to "emulate" ap_fixed on the host and convert the float32 to ap_fixed, it is unclear + * what the exact size/alignment etc. of the buffer will be on the host (e.g, ap_fixed<1> cannot + * possibly be 1 bit in a "convential" OS, so some padding would almost certainly be added; this + * padding will then have to be removed by the model_wrapper, which could be error-prone). + */ +void model_wrapper ( + hls::stream &data_in, + hls::stream &data_out +) { + #pragma HLS INTERFACE ap_ctrl_none port=return + #pragma HLS INTERFACE axis register port=data_in name=data_in + #pragma HLS INTERFACE axis register port=data_out name=data_out + + // hls-fpga-machine-learning insert data + + // hls-fpga-machine-learning insert top-level function + +} diff --git a/hls4ml/templates/coyote_accelerator/model_wrapper.hpp b/hls4ml/templates/coyote_accelerator/model_wrapper.hpp new file mode 100644 index 0000000000..e0813002ce --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/model_wrapper.hpp @@ -0,0 +1,19 @@ +#ifndef MODEL_WRAPPER_HPP_ +#define MODEL_WRAPPER_HPP_ + +#include "hls_stream.h" +#include "ap_axi_sdata.h" + +#define COYOTE_AXI_STREAM_BITS 512 +typedef ap_axiu axi_s; + +#include "firmware/myproject.h" +#include "firmware/nnet_utils/nnet_axi_utils.h" +#include "firmware/nnet_utils/nnet_axi_utils_stream.h" + +void model_wrapper ( + hls::stream &data_in, + hls::stream &data_out +); + +#endif diff --git a/hls4ml/templates/coyote_accelerator/myproject_host.cpp b/hls4ml/templates/coyote_accelerator/myproject_host.cpp new file mode 100644 index 0000000000..e17ada711b --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/myproject_host.cpp @@ -0,0 +1,113 @@ +/** + * @brief myproject_host.cpp + * + * This file is a stand-alone C++ program that can be used to run inference of an hls4ml + * model with Coyote. The alternative way is to use the CoyoteOverlay from Python. + * Both of these rely on the CoyoteInference class from the host_libs.hpp file. + * The format of this script is largely similar to myproject_test.cpp (i.e. it reads the + * inputs and outputs from some files and runs inference), but adapted to run on an FPGA. + */ + +#include +#include +#include +#include +#include + +#include "defines.h" +#include "host_libs.hpp" + +#include + +std::string default_path("../../tb_data/"); + +int main(int argc, char **argv) { + std::string data_path; + unsigned int batch_size; + + boost::program_options::options_description runtime_options("Coyote hls4ml run-time options"); + runtime_options.add_options() + ("batch_size,b", boost::program_options::value(&batch_size)->default_value(1), "Inference batch size") + ("data_path,p", boost::program_options::value(&data_path)->default_value(default_path), "Path to tb_data folder with input/output features for validation"); + boost::program_options::variables_map command_line_arguments; + boost::program_options::store(boost::program_options::parse_command_line(argc, argv, runtime_options), command_line_arguments); + boost::program_options::notify(command_line_arguments); + + // hls-fpga-machine-learning insert I/O size + + CoyoteInference model(batch_size, in_size, out_size); + + std::string iline; + std::string pline; + std::ifstream fin(data_path + "/tb_input_features.dat"); + std::ifstream fpr(data_path + "/csim_results.log"); + + if (fin.is_open() && fpr.is_open()) { + int cnt = 0; + int total_batches = 0; + double avg_latency = 0; + double avg_throughput = 0; + std::vector> labels; + + while (std::getline(fin, iline) && std::getline(fpr, pline)) { + // Read inputs and outputs from tb_data folder + char *current; + std::vector in, pr; + + char *cstr = const_cast(iline.c_str()); + current = strtok(cstr, " "); + while (current != NULL) { + in.push_back(atof(current)); + current = strtok(NULL, " "); + } + cstr = const_cast(pline.c_str()); + current = strtok(cstr, " "); + while (current != NULL) { + pr.push_back(atof(current)); + current = strtok(NULL, " "); + } + + // Set model data for the i-th point in the batch + model.set_data(&in[0], cnt); + labels.push_back(pr); + cnt++; + + // If batch is full, run inference, measuring time + if (cnt == batch_size) { + model.flush(); + + auto begin_time = std::chrono::high_resolution_clock::now(); + model.predict(); + auto end_time = std::chrono::high_resolution_clock::now(); + double time = std::chrono::duration_cast(end_time - begin_time).count(); + avg_latency += (time / 1e3); + avg_throughput += (batch_size / (time * 1e-9)); + + // Functional correctness + for (int i = 0; i < batch_size; i++) { + float *pred = model.get_predictions(i); + for (int j = 0; j < out_size; j++) { + assert(int(10000.0 * labels[i][j]) == int(10000.0 * pred[j])); + } + } + + // Reset for next batch + total_batches++; + labels.clear(); + cnt = 0; + } + + } + + std::cout << "Batches processed: " << total_batches << std::endl; + std::cout << "Average latency: " << avg_latency / (double) total_batches << " us" << std::endl; + std::cout << "Average throughput: " << avg_throughput / (double) total_batches << " inferences/s" << std::endl; + + fin.close(); + fpr.close(); + } else { + std::cout << "Couldn't open input/output file; make sure data_path is set correctly!" << std::endl; + } + + return EXIT_SUCCESS; +} diff --git a/hls4ml/templates/coyote_accelerator/myproject_test.cpp b/hls4ml/templates/coyote_accelerator/myproject_test.cpp new file mode 100644 index 0000000000..a48d6f571f --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/myproject_test.cpp @@ -0,0 +1,98 @@ +/** + * @brief myproject_test.cpp + * + * HLS CSim and CoSim testbench file. Largely similar to Vitis/Vivado backends testbench, + * but adapted to call the model_wrapper rather than the model directly. + */ + + +#include +#include +#include + +#include "hls_stream.h" +#include "ap_axi_sdata.h" + +#include "model_wrapper.hpp" +#include "firmware/myproject.h" +#include "firmware/nnet_utils/nnet_helpers.h" +#include "firmware/nnet_utils/nnet_axi_utils.h" +#include "firmware/nnet_utils/nnet_axi_utils_stream.h" + +#define CHECKPOINT 5000 + +#define COYOTE_AXI_STREAM_BITS 512 +typedef ap_axiu axi_s; + +int main(int argc, char **argv) { + std::ifstream fin("tb_data/tb_input_features.dat"); + std::ifstream fpr("tb_data/tb_output_predictions.dat"); + + #ifdef RTL_SIM + std::string RESULTS_LOG = "tb_data/rtl_cosim_results.log"; + #else + std::string RESULTS_LOG = "tb_data/csim_results.log"; + #endif + std::ofstream fout(RESULTS_LOG); + + std::string iline; + std::string pline; + int e = 0; + + if (fin.is_open() && fpr.is_open()) { + while (std::getline(fin, iline) && std::getline(fpr, pline)) { + if (e % CHECKPOINT == 0) { + std::cout << "Processing input " << e << std::endl; + } + char *cstr = const_cast(iline.c_str()); + char *current; + std::vector in; + current = strtok(cstr, " "); + while (current != NULL) { + in.push_back(atof(current)); + current = strtok(NULL, " "); + } + cstr = const_cast(pline.c_str()); + std::vector pr; + current = strtok(cstr, " "); + while (current != NULL) { + pr.push_back(atof(current)); + current = strtok(NULL, " "); + } + + // hls-fpga-machine-learning insert data + + // hls-fpga-machine-learning insert top-level-function + + if (e % CHECKPOINT == 0) { + std::cout << "Predictions" << std::endl; + // hls-fpga-machine-learning insert predictions + + std::cout << "Quantized predictions" << std::endl; + // hls-fpga-machine-learning insert quantized + } + e++; + + // hls-fpga-machine-learning insert tb-output + } + fin.close(); + fpr.close(); + } else { + std::cout << "INFO: Unable to open input/predictions file, using default input." << std::endl; + const unsigned NUM_TEST_SAMPLES = 5; + for (unsigned i = 0; i < NUM_TEST_SAMPLES; i++) { + // hls-fpga-machine-learning insert zero + + // hls-fpga-machine-learning insert top-level-function + + // hls-fpga-machine-learning insert output + + // hls-fpga-machine-learning insert tb-output + } + } + + fout.close(); + std::cout << "INFO: Saved inference results to file: " << RESULTS_LOG << std::endl; + + return 0; +} diff --git a/hls4ml/templates/coyote_accelerator/nnet_utils/nnet_axi_utils.h b/hls4ml/templates/coyote_accelerator/nnet_utils/nnet_axi_utils.h new file mode 100644 index 0000000000..5a0b6ed152 --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/nnet_utils/nnet_axi_utils.h @@ -0,0 +1,90 @@ +#ifndef NNET_AXI_UTILS_H_ +#define NNET_AXI_UTILS_H_ + +#include "ap_axi_sdata.h" + +namespace nnet { + +// Converts an array of data (fixed-point numbers) into 512-bit AXI stream packets; see model_wrapper.hpp for usage +template +void data_to_axi_stream(array_T data_in[SIZE], hls::stream> &axi_out) { + #pragma HLS INLINE OFF + #pragma HLS PIPELINE + + constexpr const unsigned int ELEMENTS_PER_AXI = AXI_BITS / PRECISION; + constexpr const unsigned int NUM_BEATS = (SIZE + ELEMENTS_PER_AXI - 1) / ELEMENTS_PER_AXI; + + for (unsigned int i = 0; i < NUM_BEATS; i++) { + if (i == NUM_BEATS - 1) { + ap_axiu axi_packet; + unsigned int index = i * ELEMENTS_PER_AXI; + + for (unsigned int j = 0; j < SIZE - index; j++) { + #pragma HLS UNROLL + + axi_T axi_tmp = axi_T(data_in[index + j]); + ap_uint axi_bits = *reinterpret_cast*>(&axi_tmp); + axi_packet.data.range((j + 1) * PRECISION - 1, j * PRECISION) = axi_bits; + } + + axi_packet.last = 1; + axi_out.write(axi_packet); + + } else { + ap_axiu axi_packet; + unsigned int index = i * ELEMENTS_PER_AXI; + + for (unsigned int j = 0; j < ELEMENTS_PER_AXI; j++) { + #pragma HLS UNROLL + + axi_T axi_tmp = axi_T(data_in[index + j]); + ap_uint axi_bits = *reinterpret_cast*>(&axi_tmp); + axi_packet.data.range((j + 1) * PRECISION - 1, j * PRECISION) = axi_bits; + } + + axi_packet.last = 0; + axi_out.write(axi_packet); + } + } +} + +// Unpacks beats of 512-bit AXI beats into an array of data (fixed-point numbers) see model_wrapper.hpp for usage +template +void axi_stream_to_data(hls::stream> &axi_in, array_T data_out[SIZE]) { + #pragma HLS INLINE OFF + #pragma HLS PIPELINE + + constexpr const unsigned int ELEMENTS_PER_AXI = AXI_BITS / PRECISION; + constexpr const unsigned int NUM_BEATS = (SIZE + ELEMENTS_PER_AXI - 1) / ELEMENTS_PER_AXI; + + for (unsigned int i = 0; i < NUM_BEATS; i++) { + if (i == NUM_BEATS - 1) { + unsigned int index = i * ELEMENTS_PER_AXI; + ap_axiu axi_packet = axi_in.read(); + + for (unsigned int j = 0; j < SIZE - index; j++) { + #pragma HLS UNROLL + + ap_uint axi_bits = axi_packet.data.range((j + 1) * PRECISION - 1, j * PRECISION); + axi_T axi_tmp = *reinterpret_cast(&axi_bits); + data_out[index + j] = array_T(axi_tmp); + } + + } else { + unsigned int index = i * ELEMENTS_PER_AXI; + ap_axiu axi_packet = axi_in.read(); + + for (unsigned int j = 0; j < ELEMENTS_PER_AXI; j++) { + #pragma HLS UNROLL + + ap_uint axi_bits = axi_packet.data.range((j + 1) * PRECISION - 1, j * PRECISION); + axi_T axi_tmp = *reinterpret_cast(&axi_bits); + data_out[index + j] = array_T(axi_tmp); + } + } + } +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/coyote_accelerator/nnet_utils/nnet_axi_utils_stream.h b/hls4ml/templates/coyote_accelerator/nnet_utils/nnet_axi_utils_stream.h new file mode 100644 index 0000000000..7b8d16a5ae --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/nnet_utils/nnet_axi_utils_stream.h @@ -0,0 +1,77 @@ +#ifndef NNET_AXI_UTILS_STREAM_H +#define NNET_AXI_UTILS_STREAM_H + +#include "ap_axi_sdata.h" + +namespace nnet { + +// Converts an stream of data (fixed-point numbers) into 512-bit AXI stream packets; see model_wrapper.hpp for usage +template +void data_to_axi_stream(hls::stream &data_in, hls::stream> &axi_out) { + #pragma HLS INLINE OFF + #pragma HLS PIPELINE + + constexpr const unsigned int ELEMENTS_PER_AXI = (SIZE <= (AXI_BITS / PRECISION)) ? SIZE : (AXI_BITS / PRECISION); + constexpr const unsigned int NUM_BEATS = SIZE / ELEMENTS_PER_AXI + (SIZE % ELEMENTS_PER_AXI != 0); + + unsigned int index = 0; + ap_axiu axi_packet; + + for (int i = 0; i < SIZE / array_T::size; i++) { + array_T in_data = data_in.read(); + + for (int j = 0; j < array_T::size; j++) { + #pragma HLS UNROLL + axi_T axi_tmp = axi_T (in_data[j]); + ap_uint axi_bits = *reinterpret_cast*>(&axi_tmp); + axi_packet.data.range((index + 1) * PRECISION - 1, index * PRECISION) = axi_bits; + index++; + if (index == ELEMENTS_PER_AXI) { + axi_packet.last = 0; + axi_out.write(axi_packet); + index = 0; + } + } + } + + if (index != ELEMENTS_PER_AXI && index != 0) { + axi_packet.last = 1; + axi_out.write(axi_packet); + } + +} + +// Unpacks beats of 512-bit AXI beats into an stream of data (fixed-point numbers) see model_wrapper.hpp for usage +template +void axi_stream_to_data(hls::stream> &axi_in, hls::stream &data_out) { + #pragma HLS INLINE OFF + #pragma HLS PIPELINE + + constexpr const unsigned int ELEMENTS_PER_AXI = (SIZE <= (AXI_BITS / PRECISION)) ? SIZE : (AXI_BITS / PRECISION); + constexpr const unsigned int NUM_BEATS = SIZE / ELEMENTS_PER_AXI + (SIZE % ELEMENTS_PER_AXI != 0); + + array_T tmp; + unsigned int index = 0; + ap_axiu axi_packet; + + for (int i = 0; i < NUM_BEATS; i++) { + ap_axiu axi_packet = axi_in.read(); + + for (int j = 0; j < ELEMENTS_PER_AXI; j++) { + #pragma HLS UNROLL + ap_uint axi_bits = axi_packet.data.range((j + 1) * PRECISION - 1, j * PRECISION); + axi_T axi_tmp = *reinterpret_cast(&axi_bits); + tmp[index] = typename array_T::value_type(axi_tmp); + index++; + if (index == array_T::size) { + index = 0; + data_out.write(tmp); + + } + } + } +} + +} + +#endif \ No newline at end of file diff --git a/hls4ml/templates/coyote_accelerator/vfpga_top.svh b/hls4ml/templates/coyote_accelerator/vfpga_top.svh new file mode 100644 index 0000000000..c251e9ad68 --- /dev/null +++ b/hls4ml/templates/coyote_accelerator/vfpga_top.svh @@ -0,0 +1,32 @@ +// Each Coyote project needs a vfpga_top.svh file, which is a simple SystemVerilog header +// that provides the interface from/to the Coyote shell. If not provided, the synthesis +// process of Coyote will fail. In this case, the vfpga_top.svh simply instantiates the model_wrapper + +// Model wrapper; note the suffix _hls_ip which must be added for HLS kernels in Coyote. +// More details can be found in Example 2 of the Coyote repository. +model_wrapper_hls_ip inst_model( + .data_in_TDATA (axis_host_recv[0].tdata), + .data_in_TKEEP (axis_host_recv[0].tkeep), + .data_in_TLAST (axis_host_recv[0].tlast), + .data_in_TSTRB (0), + .data_in_TVALID (axis_host_recv[0].tvalid), + .data_in_TREADY (axis_host_recv[0].tready), + + .data_out_TDATA (axis_host_send[0].tdata), + .data_out_TKEEP (axis_host_send[0].tkeep), + .data_out_TLAST (axis_host_send[0].tlast), + .data_out_TSTRB (), + .data_out_TVALID (axis_host_send[0].tvalid), + .data_out_TREADY (axis_host_send[0].tready), + + .ap_clk (aclk), + .ap_rst_n (aresetn) +); + +// Tie-off unused signals to avoid synthesis problems +always_comb sq_rd.tie_off_m(); +always_comb sq_wr.tie_off_m(); +always_comb cq_rd.tie_off_s(); +always_comb cq_wr.tie_off_s(); +always_comb notify.tie_off_m(); +always_comb axi_ctrl.tie_off_s(); diff --git a/hls4ml/writer/__init__.py b/hls4ml/writer/__init__.py index 8de19fe1d2..f9ab76192d 100644 --- a/hls4ml/writer/__init__.py +++ b/hls4ml/writer/__init__.py @@ -5,6 +5,7 @@ from hls4ml.writer.vitis_writer import VitisWriter from hls4ml.writer.vivado_accelerator_writer import VivadoAcceleratorWriter from hls4ml.writer.vivado_writer import VivadoWriter +from hls4ml.writer.coyote_accelerator_writer import CoyoteAcceleratorWriter from hls4ml.writer.writers import Writer, get_writer, register_writer # noqa: F401 register_writer('Vivado', VivadoWriter) @@ -14,3 +15,4 @@ register_writer('oneAPI', OneAPIWriter) register_writer('Catapult', CatapultWriter) register_writer('SymbolicExpression', SymbolicExpressionWriter) +register_writer('CoyoteAccelerator', CoyoteAcceleratorWriter) diff --git a/hls4ml/writer/coyote_accelerator_writer.py b/hls4ml/writer/coyote_accelerator_writer.py new file mode 100644 index 0000000000..b1a0135ee1 --- /dev/null +++ b/hls4ml/writer/coyote_accelerator_writer.py @@ -0,0 +1,533 @@ +import os +import stat +import glob +import numpy as np +from pathlib import Path +from shutil import copyfile, copytree, move + +from hls4ml.writer.vitis_writer import VitisWriter + +class CoyoteAcceleratorWriter(VitisWriter): + def __init__(self): + super().__init__() + + def write_coyote(self, model): + """ + Copies the Coyote repository to the project folder + + Args: + model (ModelGraph): the hls4ml model + """ + filedir = os.path.dirname(os.path.abspath(__file__)) + srcpath = os.path.join(filedir, '../contrib/Coyote/') + dstpath = f'{model.config.get_output_dir()}/Coyote' + copytree(srcpath, dstpath) + + def restructure_dir(self, model): + """ + Simply moves around some files; these files were generated from the Vitis backend + For a cleaner integration with the rest of the Coyote library, these are + moved to the src/ folder + + Args: + model (ModelGraph): the hls4ml model + """ + srcpath = f'{model.config.get_output_dir()}/{model.config.get_project_name()}_bridge.cpp' + dstpath = f'{model.config.get_output_dir()}/src/{model.config.get_project_name()}_bridge.cpp' + move(srcpath, dstpath) + + srcpath = f'{model.config.get_output_dir()}/firmware' + dstpath = f'{model.config.get_output_dir()}/src/hls/model_wrapper/firmware' + move(srcpath, dstpath) + + def write_project_cpp(self, model): + """ + Write the main architecture source file (myproject.cpp) + Very similar to VivadoWriter, but with a different generation for I/O. + Since the myproject.cpp is no longer the top-level file (but model_wrapper is), + no need to specify interfaces. Additionally, inlining can cause issues here + when integrated with the model_wrapper, so it's disabled. + + Args: + model (ModelGraph): the hls4ml model + """ + + filedir = os.path.dirname(os.path.abspath(__file__)) + + f = open(os.path.join(filedir, '../templates/vivado/firmware/myproject.cpp')) + fout = open(f'{model.config.get_output_dir()}/firmware/{model.config.get_project_name()}.cpp', 'w') + + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram'] + + indent = ' ' + + for line in f.readlines(): + # Add headers to weights and biases + if 'myproject' in line: + newline = line.replace('myproject', model.config.get_project_name()) + + elif '// hls-fpga-machine-learning insert header' in line: + inputs_str = ', '.join([i.definition_cpp(as_reference=True) for i in model_inputs]) + outputs_str = ', '.join([o.definition_cpp(as_reference=True) for o in model_outputs]) + brams_str = ', \n'.join([indent + b.definition_cpp(as_reference=False) for b in model_brams]) + + newline = '' + newline += indent + inputs_str + ',\n' + newline += indent + outputs_str + if len(model_brams) > 0: + newline += ',\n' + brams_str + newline += '\n' + + elif '// hls-fpga-machine-learning insert namespace-start' in line: + newline = '' + + namespace = model.config.get_writer_config().get('Namespace', None) + if namespace is not None: + newline += f'namespace {namespace} {{\n' + + elif '// hls-fpga-machine-learning insert namespace-end' in line: + newline = '' + + namespace = model.config.get_writer_config().get('Namespace', None) + if namespace is not None: + newline += '}\n' + + elif '// hls-fpga-machine-learning insert load weights' in line: + newline = line + if model.config.get_writer_config()['WriteWeightsTxt']: + + newline += '#ifndef __SYNTHESIS__\n' + newline += ' static bool loaded_weights = false;\n' + newline += ' if (!loaded_weights) {\n' + + for layer in model.get_layers(): + for w in layer.get_weights(): + if w.weight_class == 'CompressedWeightVariable': + newline += ( + indent + + ' nnet::load_compressed_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( + w.type.name, w.nonzeros, w.name, w.name + ) + ) + elif w.weight_class == 'ExponentWeightVariable': + newline += ( + indent + + ' nnet::load_exponent_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( + w.type.name, w.data_length, w.name, w.name + ) + ) + else: + newline += indent + ' nnet::load_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format( + w.type.name, w.data_length, w.name, w.name + ) + + newline += ' loaded_weights = true;' + newline += ' }\n' + newline += '#endif' + + # Add input/output type + elif '// hls-fpga-machine-learning insert IO' in line: + newline = line + newline += indent + '#pragma HLS INLINE OFF\n' + + pipeline_style = model.config.pipeline_style + pipeline_ii = model.config.pipeline_ii + pipeline_pragma = indent + f'#pragma HLS {pipeline_style.upper()}' + if pipeline_style == 'pipeline' and pipeline_ii is not None: + pipeline_pragma += f' II={pipeline_ii}\n' + else: + pipeline_pragma += '\n' + newline += pipeline_pragma + + elif '// hls-fpga-machine-learning insert layers' in line: + newline = line + '\n' + for layer in model.get_layers(): + vars = layer.get_variables() + for var in vars: + if var not in model_inputs and var not in model_outputs: + def_cpp = var.definition_cpp() + if def_cpp is not None: + newline += ' ' + def_cpp + ';\n' + if var.pragma: + newline += ' ' + self._make_array_pragma(var) + '\n\n' + for layer in model.get_layers(): + func = layer.get_attr('function_cpp', None) + if func: + if not isinstance(func, (list, set)): + func = [func] + if len(func) == 1: + newline += ' ' + func[0] + ' // ' + layer.name + '\n' + else: + newline += ' // ' + layer.name + '\n' + for line in func: + newline += ' ' + line + '\n' + if model.config.trace_output and layer.get_attr('trace', False): + vars = layer.get_variables() + newline += '#ifndef __SYNTHESIS__\n' + for var in vars: + newline += ' nnet::save_layer_output<{}>({}, "{}", {});\n'.format( + var.type.name, var.name, layer.name, var.size_cpp() + ) + newline += '#endif\n' + newline += '\n' + + # Just copy line + else: + newline = line + + fout.write(newline) + + f.close() + fout.close() + + def write_nnet_utils_overrides(self, model): + """ + Writes the HLS templates, both from Vitis and from Coyote + + Args: + model (ModelGraph): the hls4ml model + """ + filedir = os.path.dirname(os.path.abspath(__file__)) + + # Vitis HLS overwrites, as done in VitisWriter + srcpath = os.path.join(filedir, '../templates/vitis/nnet_utils/') + dstpath = f'{model.config.get_output_dir()}/firmware/nnet_utils/' + headers = [os.path.basename(h) for h in glob.glob(srcpath + '*.h')] + for h in headers: + copyfile(srcpath + h, dstpath + h) + + # Coyote accelerator-specific overvwrites + srcpath = os.path.join(filedir, '../templates/coyote_accelerator/nnet_utils/') + dstpath = f'{model.config.get_output_dir()}/firmware/nnet_utils/' + headers = [os.path.basename(h) for h in glob.glob(srcpath + '*.h')] + for h in headers: + copyfile(srcpath + h, dstpath + h) + + def write_build_script(self, model): + """ + Generate the following build scripts: + - build_lib.sh --- used for software emulation (with gcc) of the model + - CMakeLists.txt --- for synthesizing the hardware with Coyote and the corresponding software library + + Args: + model (ModelGraph): the hls4ml model + """ + filedir = Path(__file__).parent + + # build_lib.sh + build_lib_src = (filedir / '../templates/coyote_accelerator/build_lib.sh').resolve() + build_lib_dst = Path(f'{model.config.get_output_dir()}/build_lib.sh').resolve() + with open(build_lib_src) as src, open(build_lib_dst, 'w') as dst: + for line in src.readlines(): + line = line.replace('myproject', model.config.get_project_name()) + line = line.replace('mystamp', model.config.get_config_value('Stamp')) + dst.write(line) + + build_lib_dst.chmod(build_lib_dst.stat().st_mode | stat.S_IEXEC) + + # CMakeLists.txt + cmake_src = os.path.join(filedir, '../templates/coyote_accelerator/CMakeLists.txt') + cmake_dst = f'{model.config.get_output_dir()}/CMakeLists.txt' + with open(cmake_src) as src, open(cmake_dst, 'w') as dst: + for line in src.readlines(): + line = line.replace('myproject', model.config.get_project_name()) + dst.write(line) + + def write_model_wrapper(self, model): + """ + Generate the model_wrapper and vfpga_top + + model_wrapper encapsulates the hls4ml model kernel as well as AXI-to-data + and data-to-AXI converters. More details on the model_wrapper and these + converters can be found in model_wrapper.hpp. + + vfpga_top.svh is a simple SystemVerilog header that is needed to synthesize + any Coyote project; see vfpga_top.svh and the Coyote examples for more details + + Args: + model (ModelGraph): the hls4ml model + """ + filedir = Path(__file__).parent + + if not os.path.isdir(f'{model.config.get_output_dir()}/src/hls/model_wrapper'): + os.makedirs(f'{model.config.get_output_dir()}/src/hls/model_wrapper') + + # model_wrapper.h + srcpath = (filedir / '../templates/coyote_accelerator/model_wrapper.hpp').resolve() + dstpath = f'{model.config.get_output_dir()}/src/hls/model_wrapper/model_wrapper.hpp' + copyfile(srcpath, dstpath) + + # model_wrapper.cpp + f = open(os.path.join(filedir, '../templates/coyote_accelerator/model_wrapper.cpp')) + fout = open(f'{model.config.get_output_dir()}/src/hls/model_wrapper/model_wrapper.cpp', 'w') + + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + if len(model_inputs) > 1 or len(model_outputs) > 1: + raise RuntimeError('CoyoteAccelerator backend currently only supports one input and one output') + + for line in f.readlines(): + indent = ' ' * (len(line) - len(line.lstrip(' '))) + if 'myproject' in line: + newline = line.replace('myproject', model.config.get_project_name()) + + elif '// hls-fpga-machine-learning insert data' in line: + newline = '' + io_type = model.config.get_config_value('IOType') + + for inp in model_inputs: + newline += indent + inp.definition_cpp() + ';\n' + newline += indent + self._make_array_pragma(inp) + '\n\n' + + for out in model_outputs: + newline += indent + out.definition_cpp() + ';\n' + newline += indent + self._make_array_pragma(out) + '\n\n' + + elif '// hls-fpga-machine-learning insert top-level function' in line: + newline = '' + + for inp in model_inputs: + newline += indent + f'nnet::axi_stream_to_data<{inp.type.name}, float, {inp.size_cpp()}, COYOTE_AXI_STREAM_BITS, 8 * sizeof(float)>(data_in, {inp.name});\n' + + input_vars = ','.join([i.name for i in model_inputs]) + output_vars = ','.join([o.name for o in model_outputs]) + all_vars = ','.join(filter(None, [input_vars, output_vars])) + top_level = indent + f'{model.config.get_project_name()}({all_vars});\n' + newline += top_level + + for out in model_outputs: + newline += indent + f'nnet::data_to_axi_stream<{out.type.name}, float, {out.size_cpp()}, COYOTE_AXI_STREAM_BITS, 8 * sizeof(float)>({out.name}, data_out);\n' + + else: + newline = line + + fout.write(newline) + + f.close() + fout.close() + + # vfpga_top.svh + srcpath = (filedir / '../templates/coyote_accelerator/vfpga_top.svh').resolve() + dstpath = f'{model.config.get_output_dir()}/src/vfpga_top.svh' + copyfile(srcpath, dstpath) + + # init_ip.tcl for any additional IPs that may be needed for the model (e.g., ILA for debugging) --- UNUSED FOR NOW + # srcpath = (filedir / '../templates/coyote_accelerator/init_ip.tcl').resolve() + # dstpath = f'{model.config.get_output_dir()}/src/init_ip.tcl' + + copyfile(srcpath, dstpath) + + def write_host_code(self, model): + """ + Generates the host code, namely myproject_host.cpp and host_libs.hpp + host_libs.hpp implements the "glue" logic which interacts with the Coyote + software library. myproject_host.cpp is a stand-alone program that can be + compiled and used to run model inference on an FPGA, with inputs from tb_data. + + Args: + model (ModelGraph): the hls4ml model + """ + filedir = Path(__file__).parent + + if not os.path.isdir(f'{model.config.get_output_dir()}/src/'): + os.makedirs(f'{model.config.get_output_dir()}/src/') + + # myproject_host.cpp + f = open(os.path.join(filedir, '../templates/coyote_accelerator/myproject_host.cpp')) + fout = open(f'{model.config.get_output_dir()}/src/{model.config.get_project_name()}_host.cpp', 'w') + + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + if len(model_inputs) > 1 or len(model_outputs) > 1: + raise RuntimeError('CoyoteAccelerator backend currently only supports one input and one output') + + for line in f.readlines(): + indent = ' ' * (len(line) - len(line.lstrip(' '))) + + if '// hls-fpga-machine-learning insert I/O size' in line: + newline = '' + for inp in model_inputs: + newline += indent + f'constexpr const unsigned int in_size = {inp.size_cpp()};\n' + for out in model_outputs: + newline += indent + f'constexpr const unsigned int out_size = {out.size_cpp()};\n' + + else: + newline = line + + fout.write(newline) + + f.close() + fout.close() + + # host_libs.hpp + srcpath = os.path.join(filedir, '../templates/coyote_accelerator/host_libs.hpp') + dstpath = f'{model.config.get_output_dir()}/src/host_libs.hpp' + copyfile(srcpath, dstpath) + + # host_libs.cpp + srcpath = os.path.join(filedir, '../templates/coyote_accelerator/host_libs.cpp') + dstpath = f'{model.config.get_output_dir()}/src/host_libs.cpp' + copyfile(srcpath, dstpath) + + def __make_dat_file(self, original_path, project_path): + """ + Convert other input/output data types into a dat file, which is + a text file with the falttened matrix printed out. Note that ' ' is + assumed to be the delimiter. + + TODO: These seemed to be shared between many hls4ml writers; perhaps + these should be moved to some utility class + """ + + # Take in data from current supported data files + if original_path[-3:] == "npy": + data = np.load(original_path) + else: + raise Exception("Unsupported input/output data files.") + + # Faltten data, just keep first dimension + data = data.reshape(data.shape[0], -1) + + def print_data(f): + for i in range(data.shape[0]): + for j in range(data.shape[1]): + f.write(str(data[i][j]) + " ") + f.write("\n") + + # Print out in dat file + with open(project_path, "w") as f: + print_data(f) + + def write_test_bench(self, model): + """ + Generates the HLS testbench; very similar to the testbench in Vivado/Vitis backends + For differences, refer to the myproject_test.cpp file. + + Args: + model (ModelGraph): the hls4ml model + """ + filedir = os.path.dirname(os.path.abspath(__file__)) + + if not os.path.exists(f'{model.config.get_output_dir()}/tb_data/'): + os.mkdir(f'{model.config.get_output_dir()}/tb_data/') + + input_data = model.config.get_config_value('InputData') + output_predictions = model.config.get_config_value('OutputPredictions') + + if input_data: + if input_data[-3:] == 'dat': + copyfile(input_data, f'{model.config.get_output_dir()}/tb_data/tb_input_features.dat') + else: + self.__make_dat_file(input_data, f'{model.config.get_output_dir()}/tb_data/tb_input_features.dat') + + if output_predictions: + if output_predictions[-3:] == 'dat': + copyfile(output_predictions, f'{model.config.get_output_dir()}/tb_data/tb_output_predictions.dat') + else: + self.__make_dat_file( + output_predictions, f'{model.config.get_output_dir()}/tb_data/tb_output_predictions.dat' + ) + + f = open(os.path.join(filedir, '../templates/coyote_accelerator/myproject_test.cpp')) + fout = open(f'{model.config.get_output_dir()}/src/{model.config.get_project_name()}_test.cpp', 'w') + + model_inputs = model.get_input_variables() + model_outputs = model.get_output_variables() + if len(model_inputs) > 1 or len(model_outputs) > 1: + raise RuntimeError('CoyoteAccelerator backend currently only supports one input and one output') + + for line in f.readlines(): + indent = ' ' * (len(line) - len(line.lstrip(' '))) + + if 'myproject' in line: + newline = line.replace('myproject', model.config.get_project_name()) + + elif '// hls-fpga-machine-learning insert data' in line: + newline = line + offset = 0 + for inp in model_inputs: + newline += indent + f'float {inp.name}[{inp.size_cpp()}];\n' + newline += indent + f'nnet::copy_data(in, {inp.name});\n' + newline += indent + 'hls::stream data_in;\n' + newline += indent + f'nnet::data_to_axi_stream({inp.name}, data_in);\n' + offset += inp.size() + for out in model_outputs: + newline += indent + f'float {out.name}[{out.size_cpp()}];\n' + newline += indent + 'hls::stream data_out;\n' + + elif '// hls-fpga-machine-learning insert zero' in line: + newline = line + for inp in model_inputs: + newline += indent + f'float {inp.name}[{inp.size_cpp()}];\n' + newline += indent + f'nnet::fill_zero({inp.name});\n' + newline += indent + 'hls::stream data_in;\n' + newline += indent + f'nnet::data_to_axi_stream({inp.name}, data_in);\n' + + for out in model_outputs: + newline += indent + f'float {out.name}[{out.size_cpp()}];\n' + newline += indent + 'hls::stream data_out;\n' + + elif '// hls-fpga-machine-learning insert top-level-function' in line: + newline = line + newline += indent + 'model_wrapper(data_in, data_out);\n' + newline += indent + f'nnet::axi_stream_to_data(data_out, {out.name});\n' + + elif '// hls-fpga-machine-learning insert predictions' in line: + newline = line + for out in model_outputs: + newline += indent + f'for(int i = 0; i < {out.size_cpp()}; i++) {{\n' + newline += indent + ' std::cout << pr[i] << " ";\n' + newline += indent + '}\n' + newline += indent + 'std::cout << std::endl;\n' + + elif '// hls-fpga-machine-learning insert tb-output' in line: + newline = line + for out in model_outputs: + newline += indent + f'nnet::print_result({out.name}, fout);\n' + + elif ( + '// hls-fpga-machine-learning insert output' in line + or '// hls-fpga-machine-learning insert quantized' in line + ): + newline = line + for out in model_outputs: + newline += indent + f'nnet::print_result({out.name}, std::cout, true);\n' + + else: + newline = line + fout.write(newline) + f.close() + fout.close() + + def write_hls(self, model): + """ + Write the HLS project. Most of the functionality inherited from VitisWriter; + some additional functionality added for Coyote specifically. + + Args: + model (ModelGraph): the hls4ml model + """ + # General hls4ml write proces, inherited from Vitis Writer + self.write_project_dir(model) + self.write_project_cpp(model) + self.write_project_header(model) + self.write_weights(model) + self.write_defines(model) + self.write_parameters(model) + self.write_bridge(model) + self.write_nnet_utils(model) + self.write_nnet_utils_overrides(model) + self.write_generated_code(model) + + # Coyote-specific writes, implemented in this file + self.write_coyote(model) + self.write_model_wrapper(model) + self.write_host_code(model) + self.write_test_bench(model) + self.write_build_script(model) + self.restructure_dir(model) + self.write_yml(model) + + print('Done')