Skip to content

Commit 5674c33

Browse files
pchxcopybara-github
authored andcommitted
Replace CLIF SbsWriter with pybind-based gcpp extension
Maintains compatibility with previous version. PiperOrigin-RevId: 696181603
1 parent 719699f commit 5674c33

File tree

6 files changed

+47
-21
lines changed

6 files changed

+47
-21
lines changed

MODULE.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ bazel_dep(name = "googletest", version = "1.15.2")
99
bazel_dep(name = "highway", version = "1.1.0")
1010
bazel_dep(name = "nlohmann_json", version = "3.11.3")
1111
bazel_dep(name = "platforms", version = "0.0.10")
12+
bazel_dep(name = "pybind11_bazel", version = "2.12.0")
1213
bazel_dep(name = "rules_cc", version = "0.0.9")
1314
bazel_dep(name = "rules_license", version = "0.0.7")
1415
bazel_dep(name = "google_benchmark", version = "1.8.5")

compression/python/BUILD.bazel

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
load("//devtools/clif/python:clif_build_rule.bzl", "py_clif_cc")
21
# [internal] load strict.bzl
2+
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
33

44
package(
55
default_applicable_licenses = [
@@ -12,21 +12,22 @@ cc_library(
1212
name = "compression_clif_aux",
1313
srcs = ["compression_clif_aux.cc"],
1414
hdrs = ["compression_clif_aux.h"],
15+
visibility = ["//visibility:private"],
1516
deps = [
16-
"//third_party/absl/types:span",
17+
"@abseil-cpp//absl/types:span",
1718
"//compression:compress",
1819
"//compression:io",
1920
"@highway//:hwy",
2021
"@highway//:thread_pool",
2122
],
2223
)
2324

24-
py_clif_cc(
25+
pybind_extension(
2526
name = "compression",
26-
srcs = ["compression.clif"],
27+
srcs = ["compression_extension.cc"],
2728
deps = [
2829
":compression_clif_aux",
29-
"//third_party/absl/python/numpy:span_clif_lib",
30+
"@abseil-cpp//absl/types:span",
3031
],
3132
)
3233

compression/python/compression.clif

Lines changed: 0 additions & 14 deletions
This file was deleted.

compression/python/compression_clif_aux.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#ifndef GEMMA_ONCE
2121
#define GEMMA_ONCE
2222

23-
#include "third_party/absl/types/span.h"
23+
#include "absl/types/span.h"
2424
#include "compression/io.h"
2525
#include "hwy/base.h"
2626
#include "hwy/contrib/thread_pool/thread_pool.h"

compression/python/compression_clif_aux.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <string>
66
#include <vector>
77

8-
#include "third_party/absl/types/span.h"
8+
#include "absl/types/span.h"
99

1010
namespace gcpp {
1111

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include <pybind11/pybind11.h>
2+
3+
#include <exception>
4+
#include <stdexcept>
5+
#include <string>
6+
7+
#include "absl/types/span.h"
8+
#include "compression/python/compression_clif_aux.h"
9+
#include "pybind11/numpy.h"
10+
#include "pybind11/pybind11.h"
11+
#include "pybind11/stl.h"
12+
13+
using gcpp::SbsWriter;
14+
15+
namespace py = pybind11;
16+
17+
namespace {
18+
template <auto Func>
19+
void wrap_span(SbsWriter& writer, std::string name, py::array_t<float> data) {
20+
if (data.ndim() != 1 || data.strides(0) != sizeof(float)) {
21+
throw std::domain_error("Input array must be 1D and densely packed.");
22+
}
23+
std::invoke(Func, writer, name, absl::MakeSpan(data.data(0), data.size()));
24+
}
25+
} // namespace
26+
27+
PYBIND11_MODULE(compression, m) {
28+
py::class_<SbsWriter>(m, "SbsWriter")
29+
.def(py::init<>())
30+
// NOTE: Individual compression backends may impose constraints on the
31+
// array length, such as a minimum of (say) 32 elements.
32+
.def("insert", wrap_span<&SbsWriter::Insert>)
33+
.def("insert_nuq", wrap_span<&SbsWriter::InsertNUQ>)
34+
.def("insert_bf16", wrap_span<&SbsWriter::InsertBfloat16>)
35+
.def("insert_float", wrap_span<&SbsWriter::InsertFloat>)
36+
.def("add_scales", &SbsWriter::AddScales)
37+
.def("write", &SbsWriter::Write);
38+
}

0 commit comments

Comments
 (0)