|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include "sycl.h" |
| 4 | + |
| 5 | +#include <stdint.h> |
| 6 | + |
| 7 | +#ifdef __cplusplus |
| 8 | +extern "C" { |
| 9 | +#endif |
| 10 | + |
| 11 | +// Return codes (negative values indicate errors): |
| 12 | +// 0 : success |
| 13 | +// -1 : internal error / exception caught |
| 14 | +// -2 : invalid argument (null pointer, bad length, etc.) |
| 15 | +// -3 : invalid descriptor state (e.g. uninitialized desc->ptr) or size query failure |
| 16 | +#define ONEMKL_DFT_STATUS_SUCCESS 0 |
| 17 | +#define ONEMKL_DFT_STATUS_ERROR -1 |
| 18 | +#define ONEMKL_DFT_STATUS_INVALID_ARGUMENT -2 |
| 19 | +#define ONEMKL_DFT_STATUS_BAD_STATE -3 |
| 20 | + |
| 21 | +// DFT precision |
| 22 | +typedef enum { |
| 23 | + ONEMKL_DFT_PRECISION_SINGLE = 0, |
| 24 | + ONEMKL_DFT_PRECISION_DOUBLE = 1 |
| 25 | +} onemklDftPrecision; |
| 26 | + |
| 27 | +// DFT domain |
| 28 | +typedef enum { |
| 29 | + ONEMKL_DFT_DOMAIN_REAL = 0, |
| 30 | + ONEMKL_DFT_DOMAIN_COMPLEX = 1 |
| 31 | +} onemklDftDomain; |
| 32 | + |
| 33 | +// Configuration parameters (subset mirrors oneapi::mkl::dft::config_param) |
| 34 | +typedef enum { |
| 35 | + ONEMKL_DFT_PARAM_FORWARD_DOMAIN = 0, |
| 36 | + ONEMKL_DFT_PARAM_DIMENSION, |
| 37 | + ONEMKL_DFT_PARAM_LENGTHS, |
| 38 | + ONEMKL_DFT_PARAM_PRECISION, |
| 39 | + ONEMKL_DFT_PARAM_FORWARD_SCALE, |
| 40 | + ONEMKL_DFT_PARAM_BACKWARD_SCALE, |
| 41 | + ONEMKL_DFT_PARAM_NUMBER_OF_TRANSFORMS, |
| 42 | + ONEMKL_DFT_PARAM_COMPLEX_STORAGE, |
| 43 | + ONEMKL_DFT_PARAM_PLACEMENT, |
| 44 | + ONEMKL_DFT_PARAM_INPUT_STRIDES, |
| 45 | + ONEMKL_DFT_PARAM_OUTPUT_STRIDES, |
| 46 | + ONEMKL_DFT_PARAM_FWD_DISTANCE, |
| 47 | + ONEMKL_DFT_PARAM_BWD_DISTANCE, |
| 48 | + ONEMKL_DFT_PARAM_WORKSPACE, // size query / placement |
| 49 | + ONEMKL_DFT_PARAM_WORKSPACE_ESTIMATE_BYTES, |
| 50 | + ONEMKL_DFT_PARAM_WORKSPACE_BYTES, |
| 51 | + ONEMKL_DFT_PARAM_FWD_STRIDES, |
| 52 | + ONEMKL_DFT_PARAM_BWD_STRIDES, |
| 53 | + ONEMKL_DFT_PARAM_WORKSPACE_PLACEMENT, |
| 54 | + ONEMKL_DFT_PARAM_WORKSPACE_EXTERNAL_BYTES |
| 55 | +} onemklDftConfigParam; |
| 56 | + |
| 57 | +// Configuration values (mirrors oneapi::mkl::dft::config_value) |
| 58 | +typedef enum { |
| 59 | + ONEMKL_DFT_VALUE_COMMITTED = 0, |
| 60 | + ONEMKL_DFT_VALUE_UNCOMMITTED, |
| 61 | + ONEMKL_DFT_VALUE_COMPLEX_COMPLEX, |
| 62 | + ONEMKL_DFT_VALUE_REAL_REAL, |
| 63 | + ONEMKL_DFT_VALUE_INPLACE, |
| 64 | + ONEMKL_DFT_VALUE_NOT_INPLACE, |
| 65 | + ONEMKL_DFT_VALUE_WORKSPACE_AUTOMATIC, // internal |
| 66 | + ONEMKL_DFT_VALUE_ALLOW, |
| 67 | + ONEMKL_DFT_VALUE_AVOID, |
| 68 | + ONEMKL_DFT_VALUE_WORKSPACE_INTERNAL, |
| 69 | + ONEMKL_DFT_VALUE_WORKSPACE_EXTERNAL |
| 70 | +} onemklDftConfigValue; |
| 71 | + |
| 72 | +// Opaque descriptor handle |
| 73 | +struct onemklDftDescriptor_st; |
| 74 | +typedef struct onemklDftDescriptor_st *onemklDftDescriptor_t; |
| 75 | + |
| 76 | +// Creation / destruction |
| 77 | +int onemklDftCreate1D(onemklDftDescriptor_t *desc, |
| 78 | + onemklDftPrecision precision, |
| 79 | + onemklDftDomain domain, |
| 80 | + int64_t length); |
| 81 | + |
| 82 | +int onemklDftCreateND(onemklDftDescriptor_t *desc, |
| 83 | + onemklDftPrecision precision, |
| 84 | + onemklDftDomain domain, |
| 85 | + int64_t dim, |
| 86 | + const int64_t *lengths); |
| 87 | + |
| 88 | +int onemklDftDestroy(onemklDftDescriptor_t desc); |
| 89 | + |
| 90 | +// Commit descriptor to a queue |
| 91 | +int onemklDftCommit(onemklDftDescriptor_t desc, syclQueue_t queue); |
| 92 | + |
| 93 | +// Configuration set |
| 94 | +int onemklDftSetValueInt64(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t value); |
| 95 | +int onemklDftSetValueDouble(onemklDftDescriptor_t desc, onemklDftConfigParam param, double value); |
| 96 | +int onemklDftSetValueInt64Array(onemklDftDescriptor_t desc, onemklDftConfigParam param, const int64_t *values, int64_t n); |
| 97 | +int onemklDftSetValueConfigValue(onemklDftDescriptor_t desc, onemklDftConfigParam param, onemklDftConfigValue value); |
| 98 | + |
| 99 | +// Configuration get |
| 100 | +int onemklDftGetValueInt64(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t *value); |
| 101 | +int onemklDftGetValueDouble(onemklDftDescriptor_t desc, onemklDftConfigParam param, double *value); |
| 102 | +// For array queries pass *n as available length; on return *n has elements written. |
| 103 | +int onemklDftGetValueInt64Array(onemklDftDescriptor_t desc, onemklDftConfigParam param, int64_t *values, int64_t *n); |
| 104 | +int onemklDftGetValueConfigValue(onemklDftDescriptor_t desc, onemklDftConfigParam param, onemklDftConfigValue *value); |
| 105 | + |
| 106 | +// Compute (USM) in-place/out-of-place. Pointers must reference memory |
| 107 | +// appropriate for precision/domain. No size checking is performed. |
| 108 | +int onemklDftComputeForward(onemklDftDescriptor_t desc, void *inout); |
| 109 | +int onemklDftComputeForwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void *out); |
| 110 | +int onemklDftComputeBackward(onemklDftDescriptor_t desc, void *inout); |
| 111 | +int onemklDftComputeBackwardOutOfPlace(onemklDftDescriptor_t desc, void *in, void *out); |
| 112 | + |
| 113 | +// Compute (buffer API) variants. Host pointers are wrapped in temporary 1D buffers. |
| 114 | +int onemklDftComputeForwardBuffer(onemklDftDescriptor_t desc, void *inout); |
| 115 | +int onemklDftComputeForwardOutOfPlaceBuffer(onemklDftDescriptor_t desc, void *in, void *out); |
| 116 | +int onemklDftComputeBackwardBuffer(onemklDftDescriptor_t desc, void *inout); |
| 117 | +int onemklDftComputeBackwardOutOfPlaceBuffer(onemklDftDescriptor_t desc, void *in, void *out); |
| 118 | + |
| 119 | +// Introspection: write out the integral values of selected config_param enums in |
| 120 | +// the same order as our public enum declaration above. Returns number written or |
| 121 | +// a negative error code if n is insufficient or arguments invalid. |
| 122 | +int onemklDftQueryParamIndices(int64_t *out, int64_t n); |
| 123 | + |
| 124 | +#ifdef __cplusplus |
| 125 | +} |
| 126 | +#endif |
0 commit comments