Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ chrono = { version = "0.4", optional = true, features = ["serde"] }
bstr = { version = "1.11.0", default-features = false }
quanta = { version = "0.12", optional = true }
replace_with = { version = "0.1.7" }
polonius-the-crab = "0.4.2"

[dev-dependencies]
clickhouse-derive = { version = "0.2.0", path = "derive" }
Expand Down
23 changes: 23 additions & 0 deletions examples/usage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ async fn fetch(client: &Client) -> Result<()> {
Ok(())
}

#[cfg(feature = "futures03")]
async fn fetch_stream(client: &Client) -> Result<()> {
use futures_util::TryStreamExt;

client
.query("SELECT ?fields FROM some WHERE name = ? AND no BETWEEN ? AND ?")
.bind("foo")
.bind(500)
.bind(504)
.fetch::<MyRowOwned>()?
.try_for_each(|row| {
println!("{row:?}");
futures_util::future::ready(Ok(()))
})
.await?;

Ok(())
}

async fn fetch_all(client: &Client) -> Result<()> {
let vec = client
.query("SELECT ?fields FROM ? WHERE no BETWEEN ? AND ?")
Expand Down Expand Up @@ -117,6 +136,10 @@ async fn main() -> Result<()> {
inserter(&client).await?;
select_count(&client).await?;
fetch(&client).await?;
#[cfg(feature = "futures03")]
{
fetch_stream(&client).await?;
}
fetch_all(&client).await?;
delete(&client).await?;
select_count(&client).await?;
Expand Down
72 changes: 4 additions & 68 deletions src/bytes_ext.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
use bytes::{Bytes, BytesMut};
use std::cell::{Cell, UnsafeCell};
use std::cell::Cell;

#[derive(Default)]
pub(crate) struct BytesExt {
// The only reason we use `UnsafeCell` here is to provide `extend_by_ref` method
// in the sound way. After stabilization of the polonius borrow checker, it
// will be replaced with simple `Bytes`. See `RowCursor::next()` for details.
bytes: UnsafeCell<Bytes>,
bytes: Bytes,

// Points to the real start of the remaining slice.
// `Cell` allows us to mutate this value while keeping references to `bytes`.
Expand Down Expand Up @@ -44,31 +41,12 @@ impl BytesExt {
/// Adds the provided chunk into available bytes.
#[inline(always)]
pub(crate) fn extend(&mut self, chunk: Bytes) {
*self.bytes.get_mut() = merge_bytes(self.slice(), chunk);
self.cursor.set(0);
}

/// Adds the provided chunk into available bytes.
///
/// See `RowCursor::next()` for details on why this method exists.
///
/// # Safety
///
/// The caller MUST ensure that there are no active references from `slice()` calls.
#[inline(always)]
pub(crate) unsafe fn extend_by_ref(&self, chunk: Bytes) {
let new_bytes = merge_bytes(self.slice(), chunk);

// SAFETY: no active references to `bytes` are held at this point (ensured by the caller).
unsafe {
*self.bytes.get() = new_bytes;
}
self.bytes = merge_bytes(self.slice(), chunk);
self.cursor.set(0);
}

fn bytes(&self) -> &Bytes {
// SAFETY: all possible incorrect accesses are ensured by caller's of `extend_by_ref()`.
unsafe { &*self.bytes.get() }
&self.bytes
}
}

Expand Down Expand Up @@ -131,46 +109,4 @@ mod tests_miri {
assert_eq!(bytes.slice(), b"l");
assert_eq!(bytes.remaining(), 1);
}

// Unfortunately, we cannot run miri against async code in order to check
// the unsafe code in `RowCursor::next()`. However, we can at least
// check that the valid usage of `extend_by_ref()` is free of UB.
#[test]
fn extend_by_ref() {
fn next(buffer: &mut BytesExt) -> &[u8] {
loop {
if let Some(slice) = decode(buffer.slice()) {
buffer.set_remaining(buffer.remaining() - 3);
return slice;
}

let more = read_more();

// Compilation error:
/*
buffer.extend(more);
*/

// SAFETY: we're checking it right now in miri =)
unsafe { buffer.extend_by_ref(more) };
}
}

fn decode(buffer: &[u8]) -> Option<&[u8]> {
if buffer.len() > 3 {
Some(&buffer[..3])
} else {
None
}
}

fn read_more() -> Bytes {
Bytes::from_static(b"aaaa")
}

let mut buffer = BytesExt::default();
for _ in 0..10 {
assert_eq!(next(&mut buffer), b"aaa");
}
}
}
135 changes: 97 additions & 38 deletions src/cursors/row.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "futures03")]
use crate::RowOwned;
use crate::row_metadata::RowMetadata;
use crate::{
RowRead,
Expand All @@ -9,7 +11,10 @@ use crate::{
};
use clickhouse_types::error::TypesError;
use clickhouse_types::parse_rbwnat_columns_header;
use polonius_the_crab::prelude::*;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll, ready};

/// A cursor that emits rows deserialized as structures from RowBinary.
#[must_use]
Expand All @@ -20,7 +25,7 @@ pub struct RowCursor<T> {
/// [`None`] until the first call to [`RowCursor::next()`],
/// as [`RowCursor::new`] is not `async`, so it loads lazily.
row_metadata: Option<RowMetadata>,
_marker: PhantomData<T>,
_marker: PhantomData<fn() -> T>,
}

impl<T> RowCursor<T> {
Expand All @@ -36,7 +41,7 @@ impl<T> RowCursor<T> {

#[cold]
#[inline(never)]
async fn read_columns(&mut self) -> Result<()>
fn poll_read_columns(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>>
where
T: RowRead,
{
Expand All @@ -47,32 +52,32 @@ impl<T> RowCursor<T> {
Ok(columns) if !columns.is_empty() => {
self.bytes.set_remaining(slice.len());
self.row_metadata = Some(RowMetadata::new_for_cursor::<T>(columns));
return Ok(());
return Poll::Ready(Ok(()));
}
Ok(_) => {
// This does not panic, as it could be a network issue
// or a malformed response from the server or LB,
// and a simple retry might help in certain cases.
return Err(Error::BadResponse(
return Poll::Ready(Err(Error::BadResponse(
"Expected at least one column in the header".to_string(),
));
)));
}
Err(TypesError::NotEnoughData(_)) => {}
Err(err) => {
return Err(Error::InvalidColumnsHeader(err.into()));
return Poll::Ready(Err(Error::InvalidColumnsHeader(err.into())));
}
}
}
match self.raw.next().await? {
match ready!(self.raw.poll_next(cx))? {
Some(chunk) => self.bytes.extend(chunk),
None if self.row_metadata.is_none() => {
// Similar to the other BadResponse branch above
return Err(Error::BadResponse(
return Poll::Ready(Err(Error::BadResponse(
"Could not read columns header".to_string(),
));
)));
}
// if the result set is empty, there is only the columns header
None => return Ok(()),
None => return Poll::Ready(Ok(())),
}
}
}
Expand All @@ -85,49 +90,52 @@ impl<T> RowCursor<T> {
///
/// This method is cancellation safe.
pub async fn next(&mut self) -> Result<Option<T::Value<'_>>>
where
T: RowRead,
{
Next::new(self).await
}

#[inline]
fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Option<T::Value<'_>>>>
where
T: RowRead,
{
if self.validation && self.row_metadata.is_none() {
self.read_columns().await?;
ready!(self.poll_read_columns(cx))?;
debug_assert!(self.row_metadata.is_some());
}

let mut bytes = &mut self.bytes;

loop {
if self.bytes.remaining() > 0 {
let mut slice = self.bytes.slice();
let result = rowbinary::deserialize_row::<T::Value<'_>>(
&mut slice,
self.row_metadata.as_ref(),
);
polonius!(|bytes| -> Poll<Result<Option<T::Value<'polonius>>>> {
if bytes.remaining() > 0 {
let mut slice = bytes.slice();
let result = rowbinary::deserialize_row::<T::Value<'_>>(
&mut slice,
self.row_metadata.as_ref(),
);

match result {
Ok(value) => {
self.bytes.set_remaining(slice.len());
return Ok(Some(value));
match result {
Ok(value) => {
bytes.set_remaining(slice.len());
polonius_return!(Poll::Ready(Ok(Some(value))))
}
Err(Error::NotEnoughData) => {}
Err(err) => polonius_return!(Poll::Ready(Err(err))),
}
Err(Error::NotEnoughData) => {}
Err(err) => return Err(err),
}
}
});

match self.raw.next().await? {
Some(chunk) => {
// SAFETY: we actually don't have active immutable references at this point.
//
// The borrow checker prior to polonius thinks we still have ones.
// This is a pretty common restriction that can be fixed by using
// the polonius-the-crab crate, which cannot be used in async code.
//
// See https://github.com/rust-lang/rust/issues/51132
unsafe { self.bytes.extend_by_ref(chunk) }
}
None if self.bytes.remaining() > 0 => {
match ready!(self.raw.poll_next(cx))? {
Some(chunk) => bytes.extend(chunk),
None if bytes.remaining() > 0 => {
// If some data is left, we have an incomplete row in the buffer.
// This is usually a schema mismatch on the client side.
return Err(Error::NotEnoughData);
return Poll::Ready(Err(Error::NotEnoughData));
}
None => return Ok(None),
None => return Poll::Ready(Ok(None)),
}
}
}
Expand All @@ -148,3 +156,54 @@ impl<T> RowCursor<T> {
self.raw.decoded_bytes()
}
}

#[cfg(feature = "futures03")]
impl<T> futures_util::stream::Stream for RowCursor<T>
where
T: RowOwned + RowRead,
{
type Item = Result<T>;

fn poll_next(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
Self::poll_next(self.get_mut(), cx).map(Result::transpose)
}
}

struct Next<'a, T> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what returns this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RowCursor::next basically does, except that it immediately awaits it

cursor: Option<&'a mut RowCursor<T>>,
}

impl<'a, T> Next<'a, T> {
fn new(cursor: &'a mut RowCursor<T>) -> Self {
Self {
cursor: Some(cursor),
}
}
}

impl<'a, T> std::future::Future for Next<'a, T>
where
T: RowRead,
{
type Output = Result<Option<T::Value<'a>>>;

#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Temporarily take the cursor out in order for `cursor.poll_next` to return a value with
// the correct lifetime `'a` rather than the unnamed lifetime of `&mut self`.
let mut cursor = self.cursor.take().expect("Future polled after completion");

polonius!(|cursor| -> Poll<Result<Option<T::Value<'polonius>>>> {
match cursor.poll_next(cx) {
Poll::Ready(value) => polonius_return!(Poll::Ready(value)),
Poll::Pending => {}
}
});

self.cursor = Some(cursor);
Poll::Pending
}
}