Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Legend:
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ |
| CEIL | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | 🟡 | ✅ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ❌ |
Expand Down
4 changes: 4 additions & 0 deletions docs/ops/CPU.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"backend_name","op_name","op_params","test_mode","supported","error_message","backend_reg_name"
"CPU","ABS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","ABS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
Expand Down Expand Up @@ -61,6 +63,8 @@
"CPU","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","CPU"
"CPU","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
"CPU","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","CPU"
"CPU","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","CPU"
Expand Down
4 changes: 4 additions & 0 deletions docs/ops/SYCL.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"backend_name","op_name","op_params","test_mode","supported","error_message","backend_reg_name"
"SYCL0","ABS","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","ABS","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","SGN","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","SGN","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","NEG","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
Expand Down Expand Up @@ -61,6 +63,8 @@
"SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","ABS","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","ABS","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","SGN","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
"SYCL0","SGN","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","SYCL"
"SYCL0","NEG","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","SYCL"
Expand Down
9 changes: 9 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,7 @@ extern "C" {
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_GELU_ERF,
GGML_UNARY_OP_CEIL,

GGML_UNARY_OP_COUNT,
};
Expand Down Expand Up @@ -1148,6 +1149,14 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_ceil(
struct ggml_context * ctx,
struct ggml_tensor * a);

GGML_API struct ggml_tensor * ggml_ceil_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);

// gated linear unit ops
// A: n columns, r rows,
// result is n / 2 columns, r rows,
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9767,6 +9767,10 @@ void ggml_compute_forward_unary(
{
ggml_compute_forward_exp(params, dst);
} break;
case GGML_UNARY_OP_CEIL:
{
ggml_compute_forward_ceil(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
Expand Down
9 changes: 9 additions & 0 deletions ggml/src/ggml-cpu/unary-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ static inline float op_log(float x) {
return logf(x);
}

static inline float op_ceil(float x) {
return ceilf(x);
}

template <float (*op)(float), typename src0_t, typename dst_t>
static inline void vec_unary_op(int64_t n, dst_t * y, const src0_t * x) {
constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
Expand Down Expand Up @@ -184,3 +188,8 @@ void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor *
void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_log>(params, dst);
}

void ggml_compute_forward_ceil(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_ceil>(params, dst);
}

1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/unary-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_ceil(const struct ggml_compute_params * params, struct ggml_tensor * dst);

#ifdef __cplusplus
}
Expand Down
29 changes: 29 additions & 0 deletions ggml/src/ggml-sycl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ static void unary_op_sgn_kernel(const T * x, T * dst, const int k, const sycl::n
dst[i] = op_sgn(x[i]);
}
}
template<typename T>
static __dpct_inline__ T op_ceil(T x) {
return sycl::ceil(x);
}


template<typename T>
static void unary_op_abs_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
Expand Down Expand Up @@ -304,6 +309,13 @@ static void unary_op_clamp_kernel(const T * x, T * dst, const int k, const sycl:
}
}

template<typename T>
static void unary_op_ceil_kernel(const T * x, T * dst, const int k, const sycl::nd_item<1> &item_ct1) {
SYCL_GLOBAL_ID_LOOP(k, item_ct1) {
dst[i] = op_ceil(x[i]);
}
}

template<typename T>
static void upscale(const T *x, T *dst, const int nb00, const int nb01,
const int nb02, const int nb03, const int ne10, const int ne11,
Expand Down Expand Up @@ -944,6 +956,19 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens
}, min_val, max_val);
}

static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
}

static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32);
Expand Down Expand Up @@ -1168,3 +1193,7 @@ void ggml_sycl_geglu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_geglu_quick(ctx, dst);
}
void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
ggml_sycl_op_ceil(ctx, dst);
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-sycl/element_wise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

void ggml_sycl_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
Expand Down
4 changes: 4 additions & 0 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3636,6 +3636,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
case GGML_UNARY_OP_ELU:
ggml_sycl_elu(ctx, dst);
break;
case GGML_UNARY_OP_CEIL:
ggml_sycl_ceil(ctx, dst);
break;
default:
return false;
}
Expand Down Expand Up @@ -4190,6 +4193,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_CEIL:
#if defined (GGML_SYCL_F16)
return ggml_is_contiguous(op->src[0]) && (op->type == op->src[0]->type);
#else
Expand Down
18 changes: 17 additions & 1 deletion ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1143,9 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSIGMOID",
"EXP",
"GELU_ERF",
"CEIL"
};

static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");


static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
Expand Down Expand Up @@ -5743,6 +5744,21 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
return result;
}

// ggml_ceil

struct ggml_tensor * ggml_ceil(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary(ctx, a, GGML_UNARY_OP_CEIL);
}

struct ggml_tensor * ggml_ceil_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a) {
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_CEIL);
}


// opt_step_adamw

struct ggml_tensor * ggml_opt_step_adamw(
Expand Down
40 changes: 40 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3524,6 +3524,45 @@ struct test_log : public test_case {
}
};

// GGML_OP_CEIL
struct test_ceil : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;

std::string vars() override {
return VARS_TO_STR2(type, ne);
}

test_ceil(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3})
: type(type), ne(ne) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");

ggml_tensor * out = ggml_ceil(ctx, a);
ggml_set_name(out, "out");

return out;
}

void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.0f, 10.0f);
}
}

float grad_eps() override {
return 1.0f;
}

bool grad_precise() override {
return true;
}
};

// GGML_OP_SIN
struct test_sin : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -6328,6 +6367,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_sqr(type));
test_cases.emplace_back(new test_sqrt(type));
test_cases.emplace_back(new test_log(type));
test_cases.emplace_back(new test_ceil(type));
test_cases.emplace_back(new test_sin(type));
test_cases.emplace_back(new test_cos(type));
test_cases.emplace_back(new test_clamp(type));
Expand Down
6 changes: 3 additions & 3 deletions vendor/miniaudio/miniaudio.h
Original file line number Diff line number Diff line change
Expand Up @@ -28227,7 +28227,7 @@ static ma_result ma_device_start__alsa(ma_device* pDevice)
}

if (pDevice->type == ma_device_type_playback || pDevice->type == ma_device_type_duplex) {
/*
/*
When data is written to the device we wait for the device to get ready to receive data with poll(). In my testing
I have observed that poll() can sometimes block forever unless the device is started explicitly with snd_pcm_start()
or some data is written with snd_pcm_writei().
Expand Down Expand Up @@ -34520,7 +34520,7 @@ static ma_result ma_device_init_internal__coreaudio(ma_context* pContext, ma_dev
#endif
}


status = ((ma_AudioUnitSetProperty_proc)pContext->coreaudio.AudioUnitSetProperty)(pData->audioUnit, kAudioUnitProperty_StreamFormat, formatScope, formatElement, &bestFormat, sizeof(bestFormat));
if (status != noErr) {
((ma_AudioComponentInstanceDispose_proc)pContext->coreaudio.AudioComponentInstanceDispose)(pData->audioUnit);
Expand Down Expand Up @@ -38526,7 +38526,7 @@ static ma_result ma_device_reinit__aaudio(ma_device* pDevice, ma_device_type dev
ma_device_stop(pDevice); /* Do a full device stop so we set internal state correctly. */
}
}

result = MA_SUCCESS;
}
done:
Expand Down