From 31d109a9d42fe4b34d991c9559e64cd3bdf50837 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Wed, 7 May 2025 21:00:27 +0200 Subject: [PATCH 01/54] Draft RowBinaryWNAT/Native header parser --- rowbinary/Cargo.toml | 17 + rowbinary/src/error.rs | 20 + rowbinary/src/leb128.rs | 111 ++++ rowbinary/src/lib.rs | 3 + rowbinary/src/types.rs | 1116 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 1267 insertions(+) create mode 100644 rowbinary/Cargo.toml create mode 100644 rowbinary/src/error.rs create mode 100644 rowbinary/src/leb128.rs create mode 100644 rowbinary/src/lib.rs create mode 100644 rowbinary/src/types.rs diff --git a/rowbinary/Cargo.toml b/rowbinary/Cargo.toml new file mode 100644 index 00000000..6eb40fb0 --- /dev/null +++ b/rowbinary/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "clickhouse-rowbinary" +version = "0.0.1" +description = "RowBinaryWithNamesAndTypes format utils" +authors = ["ClickHouse"] +repository = "https://github.com/ClickHouse/clickhouse-rs" +homepage = "https://clickhouse.com" +edition = "2021" +license = "MIT OR Apache-2.0" +# update `Cargo.toml` and CI if changed +rust-version = "1.73.0" + +[lib] +#proc-macro = true + +[dependencies] +thiserror = "1.0.16" diff --git a/rowbinary/src/error.rs b/rowbinary/src/error.rs new file mode 100644 index 00000000..00cacafa --- /dev/null +++ b/rowbinary/src/error.rs @@ -0,0 +1,20 @@ +#[derive(Debug, thiserror::Error)] +pub enum ColumnsParserError { + #[error("IO error: {0}")] + IoError(#[from] std::io::Error), + + #[error("Expected LF at position {0}")] + ExpectedLF(usize), + + #[error("Invalid integer encoding at position {0}")] + InvalidIntegerEncoding(usize), + + #[error("Incomplete column data at position {0}")] + IncompleteColumnData(usize), + + #[error("Invalid column spec at position {0}: {1}")] + InvalidColumnSpec(usize, String), + + #[error("Type parsing error: {0}")] + TypeParsingError(String), +} diff --git a/rowbinary/src/leb128.rs b/rowbinary/src/leb128.rs new file mode 100644 index 00000000..4046b443 --- /dev/null +++ b/rowbinary/src/leb128.rs @@ -0,0 +1,111 @@ +use std::io::{ErrorKind, Read}; + +use crate::error::ColumnsParserError; + +pub fn decode_leb128(pos: &mut usize, reader: &mut R) -> Result { + let mut result: u64 = 0; + let mut shift: u32 = 0; + let mut buf = [0u8; 1]; + + loop { + reader.read_exact(&mut buf).map_err(|e| { + if e.kind() == ErrorKind::UnexpectedEof { + ColumnsParserError::InvalidIntegerEncoding(*pos) + } else { + ColumnsParserError::IoError(e) + } + })?; + + *pos += 1; + + let byte = buf[0]; + result |= ((byte & 0x7f) as u64) << shift; + + if byte & 0x80 == 0 { + break; + } + + shift += 7; + + if shift > 63 { + return Err(ColumnsParserError::InvalidIntegerEncoding(*pos)); + } + } + + Ok(result) +} + +pub fn encode_leb128(value: u64) -> Vec { + let mut result = Vec::new(); + let mut val = value; + + loop { + let mut byte = (val & 0x7f) as u8; + val >>= 7; + + if val != 0 { + byte |= 0x80; // Set high bit to indicate more bytes follow + } + + result.push(byte); + + if val == 0 { + break; + } + } + + result +} + +mod tests { + #[test] + fn test_decode_leb128() { + let test_cases = vec![ + // (input bytes, expected value) + (vec![0], 0), + (vec![1], 1), + (vec![127], 127), + (vec![128, 1], 128), + (vec![255, 1], 255), + (vec![0x85, 0x91, 0x26], 624773), + (vec![0xE5, 0x8E, 0x26], 624485), + ]; + + for (input, expected) in test_cases { + let mut cursor = std::io::Cursor::new(input.clone()); + let mut pos = 0; + let result = super::decode_leb128(&mut pos, &mut cursor).unwrap(); + assert_eq!(result, expected, "Failed decoding {:?}", input); + } + } + + #[test] + fn test_encode_decode_leb128() { + let test_values = vec![ + 0u64, + 1, + 127, + 128, + 255, + 624773, + 624485, + 300_000, + 10_000_000, + u32::MAX as u64, + (u32::MAX as u64) + 1, + ]; + + for value in test_values { + let encoded = super::encode_leb128(value); + let mut cursor = std::io::Cursor::new(&encoded); + let mut pos = 0; + let decoded = super::decode_leb128(&mut pos, &mut cursor).unwrap(); + + assert_eq!( + decoded, value, + "Failed round trip for {}: encoded as {:?}, decoded as {}", + value, encoded, decoded + ); + } + } +} diff --git a/rowbinary/src/lib.rs b/rowbinary/src/lib.rs new file mode 100644 index 00000000..1a6b89ff --- /dev/null +++ b/rowbinary/src/lib.rs @@ -0,0 +1,3 @@ +mod error; +mod leb128; +mod types; diff --git a/rowbinary/src/types.rs b/rowbinary/src/types.rs new file mode 100644 index 00000000..84b44600 --- /dev/null +++ b/rowbinary/src/types.rs @@ -0,0 +1,1116 @@ +use crate::error::ColumnsParserError; +use std::collections::HashMap; +use std::fmt::Display; + +#[derive(Debug, Clone, PartialEq)] +pub struct ColumnSpec { + name: String, + data_type: DataType, +} + +#[derive(Debug, Clone, PartialEq)] +pub enum DataType { + Bool, + UInt8, + UInt16, + UInt32, + UInt64, + UInt128, + UInt256, + Int8, + Int16, + Int32, + Int64, + Int128, + Int256, + Float32, + Float64, + BFloat16, + String, + UUID, + Date, + Date32, + DateTime(Option), // Optional timezone + DateTime64(DateTimePrecision, Option), // Precision and optional timezone + IPv4, + IPv6, + + Nullable(Box), + Array(Box), + Tuple(Vec), + Map(Box, Box), + LowCardinality(Box), + Decimal(u8, u8, DecimalSize), + Enum(EnumType, HashMap), + AggregateFunction(String, Vec), + FixedString(usize), + + Variant(Vec), + Dynamic, + JSON, + // TODO: Nested, Geo +} + +impl DataType { + pub fn new(name: &str) -> Result { + match name { + "UInt8" => Ok(Self::UInt8), + "UInt16" => Ok(Self::UInt16), + "UInt32" => Ok(Self::UInt32), + "UInt64" => Ok(Self::UInt64), + "UInt128" => Ok(Self::UInt128), + "UInt256" => Ok(Self::UInt256), + "Int8" => Ok(Self::Int8), + "Int16" => Ok(Self::Int16), + "Int32" => Ok(Self::Int32), + "Int64" => Ok(Self::Int64), + "Int128" => Ok(Self::Int128), + "Int256" => Ok(Self::Int256), + "Float32" => Ok(Self::Float32), + "Float64" => Ok(Self::Float64), + "BFloat16" => Ok(Self::BFloat16), + "String" => Ok(Self::String), + "UUID" => Ok(Self::UUID), + "Date" => Ok(Self::Date), + "Date32" => Ok(Self::Date32), + "IPv4" => Ok(Self::IPv4), + "IPv6" => Ok(Self::IPv6), + "Bool" => Ok(Self::Bool), + "Dynamic" => Ok(Self::Dynamic), + "JSON" => Ok(Self::JSON), + + str if str.starts_with("Decimal") => parse_decimal(str), + str if str.starts_with("DateTime64") => parse_datetime64(str), + str if str.starts_with("DateTime") => parse_datetime(str), + + str if str.starts_with("Nullable") => parse_nullable(str), + str if str.starts_with("LowCardinality") => parse_low_cardinality(str), + str if str.starts_with("FixedString") => parse_fixed_string(str), + + str if str.starts_with("Array") => parse_array(str), + str if str.starts_with("Enum") => parse_enum(str), + str if str.starts_with("Map") => parse_map(str), + str if str.starts_with("Tuple") => parse_tuple(str), + str if str.starts_with("Variant") => parse_variant(str), + + // ... + str => Err(ColumnsParserError::TypeParsingError(format!( + "Unknown data type: {}", + str + ))), + } + } +} + +impl Into for DataType { + fn into(self) -> String { + self.to_string() + } +} + +impl Display for DataType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use DataType::*; + let str = match self { + UInt8 => "UInt8".to_string(), + UInt16 => "UInt16".to_string(), + UInt32 => "UInt32".to_string(), + UInt64 => "UInt64".to_string(), + UInt128 => "UInt128".to_string(), + UInt256 => "UInt256".to_string(), + Int8 => "Int8".to_string(), + Int16 => "Int16".to_string(), + Int32 => "Int32".to_string(), + Int64 => "Int64".to_string(), + Int128 => "Int128".to_string(), + Int256 => "Int256".to_string(), + Float32 => "Float32".to_string(), + Float64 => "Float64".to_string(), + BFloat16 => "BFloat16".to_string(), + String => "String".to_string(), + UUID => "UUID".to_string(), + Date => "Date".to_string(), + Date32 => "Date32".to_string(), + DateTime(None) => "DateTime".to_string(), + DateTime(Some(tz)) => format!("DateTime('{}')", tz), + DateTime64(precision, None) => format!("DateTime64({})", precision), + DateTime64(precision, Some(tz)) => format!("DateTime64({}, '{}')", precision, tz), + IPv4 => "IPv4".to_string(), + IPv6 => "IPv6".to_string(), + Bool => "Bool".to_string(), + Nullable(inner) => format!("Nullable({})", inner.to_string()), + Array(inner) => format!("Array({})", inner.to_string()), + Tuple(elements) => { + let elements_str = data_types_to_string(elements); + format!("Tuple({})", elements_str) + } + Map(key, value) => { + format!("Map({}, {})", key.to_string(), value.to_string()) + } + LowCardinality(inner) => { + format!("LowCardinality({})", inner.to_string()) + } + Decimal(precision, scale, _) => { + format!("Decimal({}, {})", precision, scale) + } + Enum(enum_type, values) => { + let mut values_vec = values.iter().collect::>(); + values_vec.sort_by(|(i1, _), (i2, _)| (*i1).cmp(*i2)); + let values_str = values_vec + .iter() + .map(|(index, name)| format!("'{}' = {}", name, index)) + .collect::>() + .join(", "); + format!("{}({})", enum_type, values_str) + } + AggregateFunction(func_name, args) => { + let args_str = data_types_to_string(args); + format!("AggregateFunction({}, {})", func_name, args_str) + } + FixedString(size) => { + format!("FixedString({})", size) + } + Variant(types) => { + let types_str = data_types_to_string(types); + format!("Variant({})", types_str) + } + JSON => "JSON".to_string(), + Dynamic => "Dynamic".to_string(), + }; + write!(f, "{}", str) + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum EnumType { + Enum8, + Enum16, +} + +impl Display for EnumType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EnumType::Enum8 => write!(f, "Enum8"), + EnumType::Enum16 => write!(f, "Enum16"), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum DateTimePrecision { + Precision0, + Precision1, + Precision2, + Precision3, + Precision4, + Precision5, + Precision6, + Precision7, + Precision8, + Precision9, +} + +impl DateTimePrecision { + pub(crate) fn new(char: char) -> Result { + match char { + '0' => Ok(DateTimePrecision::Precision0), + '1' => Ok(DateTimePrecision::Precision1), + '2' => Ok(DateTimePrecision::Precision2), + '3' => Ok(DateTimePrecision::Precision3), + '4' => Ok(DateTimePrecision::Precision4), + '5' => Ok(DateTimePrecision::Precision5), + '6' => Ok(DateTimePrecision::Precision6), + '7' => Ok(DateTimePrecision::Precision7), + '8' => Ok(DateTimePrecision::Precision8), + '9' => Ok(DateTimePrecision::Precision9), + _ => Err(ColumnsParserError::TypeParsingError(format!( + "Invalid DateTime64 precision, expected to be within [0, 9] interval, got {}", + char + ))), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum DecimalSize { + Int32, + Int64, + Int128, + Int256, +} + +impl DecimalSize { + pub(crate) fn new(precision: u8) -> Result { + if precision <= 9 { + Ok(DecimalSize::Int32) + } else if precision <= 18 { + Ok(DecimalSize::Int64) + } else if precision <= 38 { + Ok(DecimalSize::Int128) + } else if precision <= 76 { + Ok(DecimalSize::Int256) + } else { + return Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Decimal precision: {}", + precision + ))); + } + } +} + +impl Display for DateTimePrecision { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + DateTimePrecision::Precision0 => write!(f, "0"), + DateTimePrecision::Precision1 => write!(f, "1"), + DateTimePrecision::Precision2 => write!(f, "2"), + DateTimePrecision::Precision3 => write!(f, "3"), + DateTimePrecision::Precision4 => write!(f, "4"), + DateTimePrecision::Precision5 => write!(f, "5"), + DateTimePrecision::Precision6 => write!(f, "6"), + DateTimePrecision::Precision7 => write!(f, "7"), + DateTimePrecision::Precision8 => write!(f, "8"), + DateTimePrecision::Precision9 => write!(f, "9"), + } + } +} + +fn data_types_to_string(elements: &[DataType]) -> String { + elements + .iter() + .map(|a| a.to_string()) + .collect::>() + .join(", ") +} + +fn parse_fixed_string(input: &str) -> Result { + if input.len() >= 14 { + let size_str = &input[12..input.len() - 1]; + let size = size_str.parse::().map_err(|err| { + ColumnsParserError::TypeParsingError(format!( + "Invalid FixedString size, expected a valid number. Underlying error: {}, input: {}, size_str: {}", + err, input, size_str + )) + })?; + if size == 0 { + return Err(ColumnsParserError::TypeParsingError(format!( + "Invalid FixedString size, expected a positive number, got zero. Input: {}", + input + ))); + } + return Ok(DataType::FixedString(size)); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid FixedString format, expected FixedString(N), got {}", + input + ))) +} + +fn parse_array(input: &str) -> Result { + if input.len() >= 8 { + let inner_type_str = &input[6..input.len() - 1]; + let inner_type = DataType::new(inner_type_str)?; + return Ok(DataType::Array(Box::new(inner_type))); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Array format, expected Array(InnerType), got {}", + input + ))) +} + +fn parse_enum(input: &str) -> Result { + if input.len() >= 9 { + let (enum_type, prefix_len) = if input.starts_with("Enum8") { + (EnumType::Enum8, 6) + } else if input.starts_with("Enum16") { + (EnumType::Enum16, 7) + } else { + return Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Enum type, expected Enum8 or Enum16, got {}", + input + ))); + }; + let enum_values_map_str = &input[prefix_len..input.len() - 1]; + let enum_values_map = parse_enum_values_map(enum_values_map_str)?; + return Ok(DataType::Enum(enum_type, enum_values_map)); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Enum format, expected Enum8('name' = value), got {}", + input + ))) +} + +fn parse_datetime(input: &str) -> Result { + if input == "DateTime" { + return Ok(DataType::DateTime(None)); + } + if input.len() >= 12 { + let timezone = (&input[10..input.len() - 2]).to_string(); + return Ok(DataType::DateTime(Some(timezone))); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid DateTime format, expected DateTime('timezone'), got {}", + input + ))) +} + +fn parse_decimal(input: &str) -> Result { + if input.len() >= 10 { + let precision_and_scale_str = (&input[8..input.len() - 1]).split(", ").collect::>(); + if precision_and_scale_str.len() != 2 { + return Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Decimal format, expected Decimal(P, S), got {}", + input + ))); + } + let parsed = precision_and_scale_str + .iter() + .map(|s| s.parse::()) + .collect::, _>>() + .map_err(|err| { + ColumnsParserError::TypeParsingError(format!( + "Invalid Decimal format, expected Decimal(P, S), got {}. Underlying error: {}", + input, err + )) + })?; + let precision = parsed[0]; + let scale = parsed[1]; + if scale < 1 || precision < 1 { + return Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Decimal format, expected Decimal(P, S) with P > 0 and S > 0, got {}", + input + ))); + } + if precision < scale { + return Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Decimal format, expected Decimal(P, S) with P >= S, got {}", + input + ))); + } + let size = DecimalSize::new(parsed[0])?; + return Ok(DataType::Decimal(precision, scale, size)); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Decimal format, expected Decimal(P), got {}", + input + ))) +} + +fn parse_datetime64(input: &str) -> Result { + if input.len() >= 13 { + let mut chars = (&input[11..input.len() - 1]).chars(); + let precision_char = chars + .next() + .ok_or(ColumnsParserError::TypeParsingError(format!( + "Invalid DateTime64 precision, expected a positive number. Input: {}", + input + )))?; + let precision = DateTimePrecision::new(precision_char)?; + let maybe_tz = match chars.as_str() { + str if str.len() > 2 => Some((&str[3..str.len() - 1]).to_string()), + _ => None, + }; + return Ok(DataType::DateTime64(precision, maybe_tz)); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid DateTime format, expected DateTime('timezone'), got {}", + input + ))) +} + +fn parse_low_cardinality(input: &str) -> Result { + if input.len() >= 16 { + let inner_type_str = &input[15..input.len() - 1]; + let inner_type = DataType::new(inner_type_str)?; + return Ok(DataType::LowCardinality(Box::new(inner_type))); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid LowCardinality format, expected LowCardinality(InnerType), got {}", + input + ))) +} + +fn parse_nullable(input: &str) -> Result { + if input.len() >= 10 { + let inner_type_str = &input[9..input.len() - 1]; + let inner_type = DataType::new(inner_type_str)?; + return Ok(DataType::Nullable(Box::new(inner_type))); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Nullable format, expected Nullable(InnerType), got {}", + input + ))) +} + +fn parse_map(input: &str) -> Result { + if input.len() >= 5 { + let inner_types_str = &input[4..input.len() - 1]; + let inner_types = parse_inner_types(inner_types_str)?; + if inner_types.len() != 2 { + return Err(ColumnsParserError::TypeParsingError(format!( + "Expected two inner elements in a Map from input {}", + input + ))); + } + return Ok(DataType::Map( + Box::new(inner_types[0].clone()), + Box::new(inner_types[1].clone()), + )); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Map format, expected Map(KeyType, ValueType), got {}", + input + ))) +} + +fn parse_tuple(input: &str) -> Result { + if input.len() > 7 { + let inner_types_str = &input[6..input.len() - 1]; + let inner_types = parse_inner_types(inner_types_str)?; + if inner_types.is_empty() { + return Err(ColumnsParserError::TypeParsingError(format!( + "Expected at least one inner element in a Tuple from input {}", + input + ))); + } + return Ok(DataType::Tuple(inner_types)); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Tuple format, expected Tuple(Type1, Type2, ...), got {}", + input + ))) +} + +fn parse_variant(input: &str) -> Result { + if input.len() >= 9 { + let inner_types_str = &input[8..input.len() - 1]; + let inner_types = parse_inner_types(inner_types_str)?; + return Ok(DataType::Variant(inner_types)); + } + Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Variant format, expected Variant(Type1, Type2, ...), got {}", + input + ))) +} + +/// Considers the element type parsed once we reach a comma outside of parens AND after an unescaped tick. +/// The most complicated cases are values names in the self-defined Enum types: +/// ``` +/// let input1 = "Tuple(Enum8('f\'()' = 1))`"; // the result is `f\'()` +/// let input2 = "Tuple(Enum8('(' = 1))"; // the result is `(` +/// ``` +fn parse_inner_types(input: &str) -> Result, ColumnsParserError> { + let mut inner_types: Vec = Vec::new(); + + let input_bytes = input.as_bytes(); + + let mut open_parens = 0; + let mut quote_open = false; + let mut char_escaped = false; + let mut last_element_index = 0; + + let mut i = 0; + while i < input_bytes.len() { + if char_escaped { + char_escaped = false; + } else if input_bytes[i] == b'\\' { + char_escaped = true; + } else if input_bytes[i] == b'\'' { + quote_open = !quote_open; // unescaped quote + } else { + if !quote_open { + if input_bytes[i] == b'(' { + open_parens += 1; + } else if input_bytes[i] == b')' { + open_parens -= 1; + } else if input_bytes[i] == b',' { + if open_parens == 0 { + let data_type_str = + String::from_utf8(input_bytes[last_element_index..i].to_vec()) + .map_err(|_| { + ColumnsParserError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the inner data type: {}", + &input[last_element_index..] + )) + })?; + let data_type = DataType::new(&data_type_str)?; + inner_types.push(data_type); + // Skip ', ' (comma and space) + if i + 2 <= input_bytes.len() && input_bytes[i + 1] == b' ' { + i += 2; + } else { + i += 1; + } + last_element_index = i; + continue; // Skip the normal increment at the end of the loop + } + } + } + } + i += 1; + } + + // Push the remaining part of the type if it seems to be valid (at least all parentheses are closed) + if open_parens == 0 && last_element_index < input_bytes.len() { + let data_type_str = + String::from_utf8(input_bytes[last_element_index..].to_vec()).map_err(|_| { + ColumnsParserError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the inner data type: {}", + &input[last_element_index..] + )) + })?; + let data_type = DataType::new(&data_type_str)?; + inner_types.push(data_type); + } + + Ok(inner_types) +} + +fn parse_enum_values_map(input: &str) -> Result, ColumnsParserError> { + let mut names: Vec = Vec::new(); + let mut indices: Vec = Vec::new(); + let mut parsing_name = true; // false when parsing the index + let mut char_escaped = false; // we should ignore escaped ticks + let mut start_index = 1; // Skip the first ' + + let mut i = 1; + let input_bytes = input.as_bytes(); + while i < input_bytes.len() { + if parsing_name { + if char_escaped { + char_escaped = false; + } else { + if input_bytes[i] == b'\\' { + char_escaped = true; + } else if input_bytes[i] == b'\'' { + // non-escaped closing tick - push the name + let name_bytes = &input_bytes[start_index..i]; + let name = String::from_utf8(name_bytes.to_vec()).map_err(|_| { + ColumnsParserError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the enum name: {}", + &input[start_index..i] + )) + })?; + names.push(name); + + // Skip ` = ` and the first digit, as it will always have at least one + if i + 4 >= input_bytes.len() { + return Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Enum format - expected ` = ` after name, input: {}", + input, + ))); + } + i += 4; + start_index = i; + parsing_name = false; + } + } + } + // Parsing the index, skipping next iterations until the first non-digit one + else if input_bytes[i] < b'0' || input_bytes[i] > b'9' { + let index = String::from_utf8(input_bytes[start_index..i].to_vec()) + .map_err(|_| { + ColumnsParserError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the enum index: {}", + &input[start_index..i] + )) + })? + .parse::() + .map_err(|_| { + ColumnsParserError::TypeParsingError(format!( + "Invalid Enum index, expected a valid number. Input: {}", + input + )) + })?; + indices.push(index); + + // the char at this index should be comma + // Skip `, '`, but not the first char - ClickHouse allows something like Enum8('foo' = 0, '' = 42) + if i + 2 >= input_bytes.len() { + break; // At the end of the enum, no more entries + } + i += 2; + start_index = i + 1; + parsing_name = true; + char_escaped = false; + } + + i += 1; + } + + let index = String::from_utf8(input_bytes[start_index..i].to_vec()) + .map_err(|_| { + ColumnsParserError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the enum index: {}", + &input[start_index..i] + )) + })? + .parse::() + .map_err(|_| { + ColumnsParserError::TypeParsingError(format!( + "Invalid Enum index, expected a valid number. Input: {}", + input + )) + })?; + indices.push(index); + + if names.len() != indices.len() { + return Err(ColumnsParserError::TypeParsingError(format!( + "Invalid Enum format - expected the same number of names and indices, got names: {}, indices: {}", + names.join(", "), + indices.iter().map(|index| index.to_string()).collect::>().join(", "), + ))); + } + + Ok(indices + .into_iter() + .zip(names) + .collect::>()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_data_type_new_simple() { + assert_eq!(DataType::new("UInt8").unwrap(), DataType::UInt8); + assert_eq!(DataType::new("UInt16").unwrap(), DataType::UInt16); + assert_eq!(DataType::new("UInt32").unwrap(), DataType::UInt32); + assert_eq!(DataType::new("UInt64").unwrap(), DataType::UInt64); + assert_eq!(DataType::new("UInt128").unwrap(), DataType::UInt128); + assert_eq!(DataType::new("UInt256").unwrap(), DataType::UInt256); + assert_eq!(DataType::new("Int8").unwrap(), DataType::Int8); + assert_eq!(DataType::new("Int16").unwrap(), DataType::Int16); + assert_eq!(DataType::new("Int32").unwrap(), DataType::Int32); + assert_eq!(DataType::new("Int64").unwrap(), DataType::Int64); + assert_eq!(DataType::new("Int128").unwrap(), DataType::Int128); + assert_eq!(DataType::new("Int256").unwrap(), DataType::Int256); + assert_eq!(DataType::new("Float32").unwrap(), DataType::Float32); + assert_eq!(DataType::new("Float64").unwrap(), DataType::Float64); + assert_eq!(DataType::new("BFloat16").unwrap(), DataType::BFloat16); + assert_eq!(DataType::new("String").unwrap(), DataType::String); + assert_eq!(DataType::new("UUID").unwrap(), DataType::UUID); + assert_eq!(DataType::new("Date").unwrap(), DataType::Date); + assert_eq!(DataType::new("Date32").unwrap(), DataType::Date32); + assert_eq!(DataType::new("IPv4").unwrap(), DataType::IPv4); + assert_eq!(DataType::new("IPv6").unwrap(), DataType::IPv6); + assert_eq!(DataType::new("Bool").unwrap(), DataType::Bool); + assert_eq!(DataType::new("Dynamic").unwrap(), DataType::Dynamic); + assert_eq!(DataType::new("JSON").unwrap(), DataType::JSON); + assert!(DataType::new("SomeUnknownType").is_err(),); + } + + #[test] + fn test_data_type_new_fixed_string() { + assert_eq!( + DataType::new("FixedString(1)").unwrap(), + DataType::FixedString(1) + ); + assert_eq!( + DataType::new("FixedString(16)").unwrap(), + DataType::FixedString(16) + ); + assert_eq!( + DataType::new("FixedString(255)").unwrap(), + DataType::FixedString(255) + ); + assert_eq!( + DataType::new("FixedString(65535)").unwrap(), + DataType::FixedString(65_535) + ); + assert!(DataType::new("FixedString()").is_err()); + assert!(DataType::new("FixedString(0)").is_err()); + assert!(DataType::new("FixedString(-1)").is_err()); + assert!(DataType::new("FixedString(abc)").is_err()); + } + + #[test] + fn test_data_type_new_array() { + assert_eq!( + DataType::new("Array(UInt8)").unwrap(), + DataType::Array(Box::new(DataType::UInt8)) + ); + assert_eq!( + DataType::new("Array(String)").unwrap(), + DataType::Array(Box::new(DataType::String)) + ); + assert_eq!( + DataType::new("Array(FixedString(16))").unwrap(), + DataType::Array(Box::new(DataType::FixedString(16))) + ); + assert_eq!( + DataType::new("Array(Nullable(Int32))").unwrap(), + DataType::Array(Box::new(DataType::Nullable(Box::new(DataType::Int32)))) + ); + assert!(DataType::new("Array()").is_err()); + assert!(DataType::new("Array(abc)").is_err()); + } + + #[test] + fn test_data_type_new_decimal() { + assert_eq!( + DataType::new("Decimal(7, 2)").unwrap(), + DataType::Decimal(7, 2, DecimalSize::Int32) + ); + assert_eq!( + DataType::new("Decimal(12, 4)").unwrap(), + DataType::Decimal(12, 4, DecimalSize::Int64) + ); + assert_eq!( + DataType::new("Decimal(27, 6)").unwrap(), + DataType::Decimal(27, 6, DecimalSize::Int128) + ); + assert_eq!( + DataType::new("Decimal(42, 8)").unwrap(), + DataType::Decimal(42, 8, DecimalSize::Int256) + ); + assert!(DataType::new("Decimal").is_err()); + assert!(DataType::new("Decimal(").is_err()); + assert!(DataType::new("Decimal()").is_err()); + assert!(DataType::new("Decimal(1)").is_err()); + assert!(DataType::new("Decimal(1,)").is_err()); + assert!(DataType::new("Decimal(1, )").is_err()); + assert!(DataType::new("Decimal(0, 0)").is_err()); // Precision must be > 0 + assert!(DataType::new("Decimal(x, 0)").is_err()); // Non-numeric precision + assert!(DataType::new("Decimal(', ')").is_err()); + assert!(DataType::new("Decimal(77, 1)").is_err()); // Max precision is 76 + assert!(DataType::new("Decimal(1, 2)").is_err()); // Scale must be less than precision + assert!(DataType::new("Decimal(1, x)").is_err()); // Non-numeric scale + assert!(DataType::new("Decimal(42, ,)").is_err()); + assert!(DataType::new("Decimal(42, ')").is_err()); + assert!(DataType::new("Decimal(foobar)").is_err()); + } + + #[test] + fn test_data_type_new_datetime() { + assert_eq!(DataType::new("DateTime").unwrap(), DataType::DateTime(None)); + assert_eq!( + DataType::new("DateTime('UTC')").unwrap(), + DataType::DateTime(Some("UTC".to_string())) + ); + assert_eq!( + DataType::new("DateTime('America/New_York')").unwrap(), + DataType::DateTime(Some("America/New_York".to_string())) + ); + assert!(DataType::new("DateTime()").is_err()); + } + + #[test] + fn test_data_type_new_datetime64() { + assert_eq!( + DataType::new("DateTime64(0)").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision0, None) + ); + assert_eq!( + DataType::new("DateTime64(1)").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision1, None) + ); + assert_eq!( + DataType::new("DateTime64(2)").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision2, None) + ); + assert_eq!( + DataType::new("DateTime64(3)").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision3, None) + ); + assert_eq!( + DataType::new("DateTime64(4)").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision4, None) + ); + assert_eq!( + DataType::new("DateTime64(5)").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision5, None) + ); + assert_eq!( + DataType::new("DateTime64(6)").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision6, None) + ); + assert_eq!( + DataType::new("DateTime64(7)").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision7, None) + ); + assert_eq!( + DataType::new("DateTime64(8)").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision8, None) + ); + assert_eq!( + DataType::new("DateTime64(9)").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision9, None) + ); + assert_eq!( + DataType::new("DateTime64(0, 'UTC')").unwrap(), + DataType::DateTime64(DateTimePrecision::Precision0, Some("UTC".to_string())) + ); + assert_eq!( + DataType::new("DateTime64(3, 'America/New_York')").unwrap(), + DataType::DateTime64( + DateTimePrecision::Precision3, + Some("America/New_York".to_string()) + ) + ); + assert_eq!( + DataType::new("DateTime64(6, 'America/New_York')").unwrap(), + DataType::DateTime64( + DateTimePrecision::Precision6, + Some("America/New_York".to_string()) + ) + ); + assert_eq!( + DataType::new("DateTime64(9, 'Europe/Amsterdam')").unwrap(), + DataType::DateTime64( + DateTimePrecision::Precision9, + Some("Europe/Amsterdam".to_string()) + ) + ); + assert!(DataType::new("DateTime64()").is_err()); + } + + #[test] + fn test_data_type_new_low_cardinality() { + assert_eq!( + DataType::new("LowCardinality(UInt8)").unwrap(), + DataType::LowCardinality(Box::new(DataType::UInt8)) + ); + assert_eq!( + DataType::new("LowCardinality(String)").unwrap(), + DataType::LowCardinality(Box::new(DataType::String)) + ); + assert_eq!( + DataType::new("LowCardinality(Array(Int32))").unwrap(), + DataType::LowCardinality(Box::new(DataType::Array(Box::new(DataType::Int32)))) + ); + assert!(DataType::new("LowCardinality()").is_err()); + } + + #[test] + fn test_data_type_new_nullable() { + assert_eq!( + DataType::new("Nullable(UInt8)").unwrap(), + DataType::Nullable(Box::new(DataType::UInt8)) + ); + assert_eq!( + DataType::new("Nullable(String)").unwrap(), + DataType::Nullable(Box::new(DataType::String)) + ); + assert!(DataType::new("Nullable()").is_err()); + } + + #[test] + fn test_data_type_new_map() { + assert_eq!( + DataType::new("Map(UInt8, String)").unwrap(), + DataType::Map(Box::new(DataType::UInt8), Box::new(DataType::String)) + ); + assert_eq!( + DataType::new("Map(String, Int32)").unwrap(), + DataType::Map(Box::new(DataType::String), Box::new(DataType::Int32)) + ); + assert_eq!( + DataType::new("Map(String, Map(Int32, Array(Nullable(String))))").unwrap(), + DataType::Map( + Box::new(DataType::String), + Box::new(DataType::Map( + Box::new(DataType::Int32), + Box::new(DataType::Array(Box::new(DataType::Nullable(Box::new( + DataType::String + ))))) + )) + ) + ); + assert!(DataType::new("Map()").is_err()); + } + + #[test] + fn test_data_type_new_variant() { + assert_eq!( + DataType::new("Variant(UInt8, String)").unwrap(), + DataType::Variant(vec![DataType::UInt8, DataType::String]) + ); + assert_eq!( + DataType::new("Variant(String, Int32)").unwrap(), + DataType::Variant(vec![DataType::String, DataType::Int32]) + ); + assert_eq!( + DataType::new("Variant(Int32, Array(Nullable(String)), Map(Int32, String))").unwrap(), + DataType::Variant(vec![ + DataType::Int32, + DataType::Array(Box::new(DataType::Nullable(Box::new(DataType::String)))), + DataType::Map(Box::new(DataType::Int32), Box::new(DataType::String)) + ]) + ); + assert!(DataType::new("Variant").is_err()); + } + + #[test] + fn test_data_type_new_tuple() { + assert_eq!( + DataType::new("Tuple(UInt8, String)").unwrap(), + DataType::Tuple(vec![DataType::UInt8, DataType::String]) + ); + assert_eq!( + DataType::new("Tuple(String, Int32)").unwrap(), + DataType::Tuple(vec![DataType::String, DataType::Int32]) + ); + assert_eq!( + DataType::new( + "Tuple(Int32, Array(Nullable(String)), Map(Int32, Tuple(String, Array(UInt8))))" + ) + .unwrap(), + DataType::Tuple(vec![ + DataType::Int32, + DataType::Array(Box::new(DataType::Nullable(Box::new(DataType::String)))), + DataType::Map( + Box::new(DataType::Int32), + Box::new(DataType::Tuple(vec![ + DataType::String, + DataType::Array(Box::new(DataType::UInt8)) + ])) + ) + ]) + ); + assert!(DataType::new("Tuple").is_err()); + } + + #[test] + fn test_data_type_new_enum() { + assert_eq!( + DataType::new("Enum8('A' = -42)").unwrap(), + DataType::Enum(EnumType::Enum8, HashMap::from([(-42, "A".to_string())])) + ); + assert_eq!( + DataType::new("Enum16('A' = -144)").unwrap(), + DataType::Enum(EnumType::Enum16, HashMap::from([(-144, "A".to_string())])) + ); + + assert_eq!( + DataType::new("Enum8('A' = 1, 'B' = 2)").unwrap(), + DataType::Enum( + EnumType::Enum8, + HashMap::from([(1, "A".to_string()), (2, "B".to_string())]) + ) + ); + assert_eq!( + DataType::new("Enum16('A' = 1, 'B' = 2)").unwrap(), + DataType::Enum( + EnumType::Enum16, + HashMap::from([(1, "A".to_string()), (2, "B".to_string())]) + ) + ); + assert_eq!( + DataType::new("Enum8('f\\'' = 1, 'x =' = 2, 'b\\'\\'' = 3, '\\'c=4=' = 42, '4' = 100)") + .unwrap(), + DataType::Enum( + EnumType::Enum8, + HashMap::from([ + (1, "f\\'".to_string()), + (2, "x =".to_string()), + (3, "b\\'\\'".to_string()), + (42, "\\'c=4=".to_string()), + (100, "4".to_string()) + ]) + ) + ); + assert_eq!( + DataType::new("Enum8('foo' = 0, '' = 42)").unwrap(), + DataType::Enum( + EnumType::Enum8, + HashMap::from([(0, "foo".to_string()), (42, "".to_string())]) + ) + ); + + assert!(DataType::new("Enum()").is_err()); + assert!(DataType::new("Enum8()").is_err()); + assert!(DataType::new("Enum16()").is_err()); + } + + #[test] + fn test_data_type_to_string_simple() { + // Simple types + assert_eq!(DataType::UInt8.to_string(), "UInt8"); + assert_eq!(DataType::UInt16.to_string(), "UInt16"); + assert_eq!(DataType::UInt32.to_string(), "UInt32"); + assert_eq!(DataType::UInt64.to_string(), "UInt64"); + assert_eq!(DataType::UInt128.to_string(), "UInt128"); + assert_eq!(DataType::UInt256.to_string(), "UInt256"); + assert_eq!(DataType::Int8.to_string(), "Int8"); + assert_eq!(DataType::Int16.to_string(), "Int16"); + assert_eq!(DataType::Int32.to_string(), "Int32"); + assert_eq!(DataType::Int64.to_string(), "Int64"); + assert_eq!(DataType::Int128.to_string(), "Int128"); + assert_eq!(DataType::Int256.to_string(), "Int256"); + assert_eq!(DataType::Float32.to_string(), "Float32"); + assert_eq!(DataType::Float64.to_string(), "Float64"); + assert_eq!(DataType::BFloat16.to_string(), "BFloat16"); + assert_eq!(DataType::UUID.to_string(), "UUID"); + assert_eq!(DataType::Date.to_string(), "Date"); + assert_eq!(DataType::Date32.to_string(), "Date32"); + assert_eq!(DataType::IPv4.to_string(), "IPv4"); + assert_eq!(DataType::IPv6.to_string(), "IPv6"); + assert_eq!(DataType::Bool.to_string(), "Bool"); + assert_eq!(DataType::Dynamic.to_string(), "Dynamic"); + assert_eq!(DataType::JSON.to_string(), "JSON"); + assert_eq!(DataType::String.to_string(), "String"); + } + + #[test] + fn test_data_types_to_string_complex() { + assert_eq!(DataType::DateTime(None).to_string(), "DateTime"); + assert_eq!( + DataType::DateTime(Some("UTC".to_string())).to_string(), + "DateTime('UTC')" + ); + assert_eq!( + DataType::DateTime(Some("America/New_York".to_string())).to_string(), + "DateTime('America/New_York')" + ); + + assert_eq!( + DataType::Nullable(Box::new(DataType::UInt64)).to_string(), + "Nullable(UInt64)" + ); + assert_eq!( + DataType::Array(Box::new(DataType::String)).to_string(), + "Array(String)" + ); + assert_eq!( + DataType::Array(Box::new(DataType::Nullable(Box::new(DataType::String)))).to_string(), + "Array(Nullable(String))" + ); + assert_eq!( + DataType::Tuple(vec![DataType::String, DataType::UInt32, DataType::Float64]) + .to_string(), + "Tuple(String, UInt32, Float64)" + ); + assert_eq!( + DataType::Map(Box::new(DataType::String), Box::new(DataType::UInt32)).to_string(), + "Map(String, UInt32)" + ); + assert_eq!( + DataType::Decimal(10, 2, DecimalSize::Int32).to_string(), + "Decimal(10, 2)" + ); + assert_eq!( + DataType::Enum( + EnumType::Enum8, + HashMap::from([(1, "A".to_string()), (2, "B".to_string())]), + ) + .to_string(), + "Enum8('A' = 1, 'B' = 2)" + ); + assert_eq!( + DataType::AggregateFunction("sum".to_string(), vec![DataType::UInt64]).to_string(), + "AggregateFunction(sum, UInt64)" + ); + assert_eq!(DataType::FixedString(16).to_string(), "FixedString(16)"); + assert_eq!( + DataType::Variant(vec![DataType::UInt8, DataType::Bool]).to_string(), + "Variant(UInt8, Bool)" + ); + assert_eq!( + DataType::DateTime64(DateTimePrecision::Precision3, Some("UTC".to_string())) + .to_string(), + "DateTime64(3, 'UTC')" + ); + } +} From 3a66d7ab64a6163e06ba7319a821d048e5626a3b Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 8 May 2025 22:20:59 +0200 Subject: [PATCH 02/54] Add RBWNAT header parser --- Cargo.toml | 1 + rowbinary/Cargo.toml | 1 + rowbinary/src/decoders.rs | 20 +++++++ rowbinary/src/error.rs | 19 ++---- rowbinary/src/header.rs | 31 ++++++++++ rowbinary/src/leb128.rs | 62 ++++++++------------ rowbinary/src/lib.rs | 7 ++- rowbinary/src/types.rs | 120 ++++++++++++++++++++------------------ tests/it/main.rs | 1 + tests/it/rbwnat.rs | 94 +++++++++++++++++++++++++++++ 10 files changed, 243 insertions(+), 113 deletions(-) create mode 100644 rowbinary/src/decoders.rs create mode 100644 rowbinary/src/header.rs create mode 100644 tests/it/rbwnat.rs diff --git a/Cargo.toml b/Cargo.toml index 931f77fe..ce1c84c1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -131,6 +131,7 @@ quanta = { version = "0.12", optional = true } replace_with = { version = "0.1.7" } [dev-dependencies] +clickhouse-rowbinary = { version = "*", path = "./rowbinary" } criterion = "0.5.0" serde = { version = "1.0.106", features = ["derive"] } tokio = { version = "1.0.1", features = ["full", "test-util"] } diff --git a/rowbinary/Cargo.toml b/rowbinary/Cargo.toml index 6eb40fb0..59bebd3b 100644 --- a/rowbinary/Cargo.toml +++ b/rowbinary/Cargo.toml @@ -15,3 +15,4 @@ rust-version = "1.73.0" [dependencies] thiserror = "1.0.16" +bytes = "1.10.1" diff --git a/rowbinary/src/decoders.rs b/rowbinary/src/decoders.rs new file mode 100644 index 00000000..745c2263 --- /dev/null +++ b/rowbinary/src/decoders.rs @@ -0,0 +1,20 @@ +use crate::error::ParserError; +use crate::leb128::decode_leb128; +use bytes::{Buf, Bytes}; + +pub(crate) fn decode_string(buffer: &mut Bytes) -> Result { + let length = decode_leb128(buffer)? as usize; + if length == 0 { + return Ok("".to_string()); + } + if buffer.remaining() < length { + return Err(ParserError::NotEnoughData(format!( + "decoding string, {} bytes remaining, {} bytes required, pos {}", + buffer.remaining(), + length, + buffer.remaining(), + ))); + } + let result = String::from_utf8_lossy(&buffer.copy_to_bytes(length)).to_string(); + Ok(result) +} diff --git a/rowbinary/src/error.rs b/rowbinary/src/error.rs index 00cacafa..eb10af4f 100644 --- a/rowbinary/src/error.rs +++ b/rowbinary/src/error.rs @@ -1,19 +1,10 @@ #[derive(Debug, thiserror::Error)] -pub enum ColumnsParserError { - #[error("IO error: {0}")] - IoError(#[from] std::io::Error), +pub enum ParserError { + #[error("Not enough data: {0}")] + NotEnoughData(String), - #[error("Expected LF at position {0}")] - ExpectedLF(usize), - - #[error("Invalid integer encoding at position {0}")] - InvalidIntegerEncoding(usize), - - #[error("Incomplete column data at position {0}")] - IncompleteColumnData(usize), - - #[error("Invalid column spec at position {0}: {1}")] - InvalidColumnSpec(usize, String), + #[error("Header parsing error: {0}")] + HeaderParsingError(String), #[error("Type parsing error: {0}")] TypeParsingError(String), diff --git a/rowbinary/src/header.rs b/rowbinary/src/header.rs new file mode 100644 index 00000000..d0a9744e --- /dev/null +++ b/rowbinary/src/header.rs @@ -0,0 +1,31 @@ +use crate::decoders::decode_string; +use crate::error::ParserError; +use crate::leb128::decode_leb128; +use crate::types::{Column, DataType}; +use bytes::Bytes; + +pub fn parse_names_and_types_header(bytes: &mut Bytes) -> Result, ParserError> { + let num_columns = decode_leb128(bytes)?; + if num_columns == 0 { + return Err(ParserError::HeaderParsingError( + "Expected at least one column in the header".to_string(), + )); + } + let mut columns_names: Vec = Vec::with_capacity(num_columns as usize); + for _ in 0..num_columns { + let column_name = decode_string(bytes)?; + columns_names.push(column_name); + } + let mut column_data_types: Vec = Vec::with_capacity(num_columns as usize); + for _ in 0..num_columns { + let column_type = decode_string(bytes)?; + let data_type = DataType::new(&column_type)?; + column_data_types.push(data_type); + } + let columns = columns_names + .into_iter() + .zip(column_data_types) + .map(|(name, data_type)| Column { name, data_type }) + .collect(); + Ok(columns) +} diff --git a/rowbinary/src/leb128.rs b/rowbinary/src/leb128.rs index 4046b443..cdc33e7d 100644 --- a/rowbinary/src/leb128.rs +++ b/rowbinary/src/leb128.rs @@ -1,63 +1,49 @@ -use std::io::{ErrorKind, Read}; - -use crate::error::ColumnsParserError; - -pub fn decode_leb128(pos: &mut usize, reader: &mut R) -> Result { - let mut result: u64 = 0; - let mut shift: u32 = 0; - let mut buf = [0u8; 1]; +use crate::error::ParserError; +use crate::error::ParserError::NotEnoughData; +use bytes::{Buf, Bytes}; +pub fn decode_leb128(buffer: &mut Bytes) -> Result { + let mut value = 0u64; + let mut shift = 0; loop { - reader.read_exact(&mut buf).map_err(|e| { - if e.kind() == ErrorKind::UnexpectedEof { - ColumnsParserError::InvalidIntegerEncoding(*pos) - } else { - ColumnsParserError::IoError(e) - } - })?; - - *pos += 1; - - let byte = buf[0]; - result |= ((byte & 0x7f) as u64) << shift; - + if buffer.remaining() < 1 { + return Err(NotEnoughData( + "decoding LEB128, 0 bytes remaining".to_string(), + )); + } + let byte = buffer.get_u8(); + value |= (byte as u64 & 0x7f) << shift; if byte & 0x80 == 0 { break; } - shift += 7; - - if shift > 63 { - return Err(ColumnsParserError::InvalidIntegerEncoding(*pos)); + if shift > 57 { + return Err(NotEnoughData("decoding LEB128, invalid shift".to_string())); } } - - Ok(result) + Ok(value) } pub fn encode_leb128(value: u64) -> Vec { let mut result = Vec::new(); let mut val = value; - loop { let mut byte = (val & 0x7f) as u8; val >>= 7; - if val != 0 { - byte |= 0x80; // Set high bit to indicate more bytes follow + byte |= 0x80; } - result.push(byte); - if val == 0 { break; } } - result } mod tests { + use bytes::Bytes; + #[test] fn test_decode_leb128() { let test_cases = vec![ @@ -72,9 +58,8 @@ mod tests { ]; for (input, expected) in test_cases { - let mut cursor = std::io::Cursor::new(input.clone()); - let mut pos = 0; - let result = super::decode_leb128(&mut pos, &mut cursor).unwrap(); + let mut input_bytes = Bytes::from(input.clone()); + let result = super::decode_leb128(&mut input_bytes).unwrap(); assert_eq!(result, expected, "Failed decoding {:?}", input); } } @@ -97,9 +82,8 @@ mod tests { for value in test_values { let encoded = super::encode_leb128(value); - let mut cursor = std::io::Cursor::new(&encoded); - let mut pos = 0; - let decoded = super::decode_leb128(&mut pos, &mut cursor).unwrap(); + let mut bytes = Bytes::from(encoded.clone()); + let decoded = super::decode_leb128(&mut bytes).unwrap(); assert_eq!( decoded, value, diff --git a/rowbinary/src/lib.rs b/rowbinary/src/lib.rs index 1a6b89ff..b43660bf 100644 --- a/rowbinary/src/lib.rs +++ b/rowbinary/src/lib.rs @@ -1,3 +1,6 @@ -mod error; +pub mod error; +pub mod header; +pub mod types; + +mod decoders; mod leb128; -mod types; diff --git a/rowbinary/src/types.rs b/rowbinary/src/types.rs index 84b44600..8b42944a 100644 --- a/rowbinary/src/types.rs +++ b/rowbinary/src/types.rs @@ -1,11 +1,17 @@ -use crate::error::ColumnsParserError; +use crate::error::ParserError; use std::collections::HashMap; use std::fmt::Display; #[derive(Debug, Clone, PartialEq)] -pub struct ColumnSpec { - name: String, - data_type: DataType, +pub struct Column { + pub name: String, + pub data_type: DataType, +} + +impl Column { + pub fn new(name: String, data_type: DataType) -> Self { + Self { name, data_type } + } } #[derive(Debug, Clone, PartialEq)] @@ -52,7 +58,7 @@ pub enum DataType { } impl DataType { - pub fn new(name: &str) -> Result { + pub fn new(name: &str) -> Result { match name { "UInt8" => Ok(Self::UInt8), "UInt16" => Ok(Self::UInt16), @@ -94,7 +100,7 @@ impl DataType { str if str.starts_with("Variant") => parse_variant(str), // ... - str => Err(ColumnsParserError::TypeParsingError(format!( + str => Err(ParserError::TypeParsingError(format!( "Unknown data type: {}", str ))), @@ -211,7 +217,7 @@ pub enum DateTimePrecision { } impl DateTimePrecision { - pub(crate) fn new(char: char) -> Result { + pub(crate) fn new(char: char) -> Result { match char { '0' => Ok(DateTimePrecision::Precision0), '1' => Ok(DateTimePrecision::Precision1), @@ -223,7 +229,7 @@ impl DateTimePrecision { '7' => Ok(DateTimePrecision::Precision7), '8' => Ok(DateTimePrecision::Precision8), '9' => Ok(DateTimePrecision::Precision9), - _ => Err(ColumnsParserError::TypeParsingError(format!( + _ => Err(ParserError::TypeParsingError(format!( "Invalid DateTime64 precision, expected to be within [0, 9] interval, got {}", char ))), @@ -240,7 +246,7 @@ pub enum DecimalSize { } impl DecimalSize { - pub(crate) fn new(precision: u8) -> Result { + pub(crate) fn new(precision: u8) -> Result { if precision <= 9 { Ok(DecimalSize::Int32) } else if precision <= 18 { @@ -250,7 +256,7 @@ impl DecimalSize { } else if precision <= 76 { Ok(DecimalSize::Int256) } else { - return Err(ColumnsParserError::TypeParsingError(format!( + return Err(ParserError::TypeParsingError(format!( "Invalid Decimal precision: {}", precision ))); @@ -283,49 +289,49 @@ fn data_types_to_string(elements: &[DataType]) -> String { .join(", ") } -fn parse_fixed_string(input: &str) -> Result { +fn parse_fixed_string(input: &str) -> Result { if input.len() >= 14 { let size_str = &input[12..input.len() - 1]; let size = size_str.parse::().map_err(|err| { - ColumnsParserError::TypeParsingError(format!( + ParserError::TypeParsingError(format!( "Invalid FixedString size, expected a valid number. Underlying error: {}, input: {}, size_str: {}", err, input, size_str )) })?; if size == 0 { - return Err(ColumnsParserError::TypeParsingError(format!( + return Err(ParserError::TypeParsingError(format!( "Invalid FixedString size, expected a positive number, got zero. Input: {}", input ))); } return Ok(DataType::FixedString(size)); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid FixedString format, expected FixedString(N), got {}", input ))) } -fn parse_array(input: &str) -> Result { +fn parse_array(input: &str) -> Result { if input.len() >= 8 { let inner_type_str = &input[6..input.len() - 1]; let inner_type = DataType::new(inner_type_str)?; return Ok(DataType::Array(Box::new(inner_type))); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid Array format, expected Array(InnerType), got {}", input ))) } -fn parse_enum(input: &str) -> Result { +fn parse_enum(input: &str) -> Result { if input.len() >= 9 { let (enum_type, prefix_len) = if input.starts_with("Enum8") { (EnumType::Enum8, 6) } else if input.starts_with("Enum16") { (EnumType::Enum16, 7) } else { - return Err(ColumnsParserError::TypeParsingError(format!( + return Err(ParserError::TypeParsingError(format!( "Invalid Enum type, expected Enum8 or Enum16, got {}", input ))); @@ -334,13 +340,13 @@ fn parse_enum(input: &str) -> Result { let enum_values_map = parse_enum_values_map(enum_values_map_str)?; return Ok(DataType::Enum(enum_type, enum_values_map)); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid Enum format, expected Enum8('name' = value), got {}", input ))) } -fn parse_datetime(input: &str) -> Result { +fn parse_datetime(input: &str) -> Result { if input == "DateTime" { return Ok(DataType::DateTime(None)); } @@ -348,17 +354,17 @@ fn parse_datetime(input: &str) -> Result { let timezone = (&input[10..input.len() - 2]).to_string(); return Ok(DataType::DateTime(Some(timezone))); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid DateTime format, expected DateTime('timezone'), got {}", input ))) } -fn parse_decimal(input: &str) -> Result { +fn parse_decimal(input: &str) -> Result { if input.len() >= 10 { let precision_and_scale_str = (&input[8..input.len() - 1]).split(", ").collect::>(); if precision_and_scale_str.len() != 2 { - return Err(ColumnsParserError::TypeParsingError(format!( + return Err(ParserError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P, S), got {}", input ))); @@ -368,7 +374,7 @@ fn parse_decimal(input: &str) -> Result { .map(|s| s.parse::()) .collect::, _>>() .map_err(|err| { - ColumnsParserError::TypeParsingError(format!( + ParserError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P, S), got {}. Underlying error: {}", input, err )) @@ -376,13 +382,13 @@ fn parse_decimal(input: &str) -> Result { let precision = parsed[0]; let scale = parsed[1]; if scale < 1 || precision < 1 { - return Err(ColumnsParserError::TypeParsingError(format!( + return Err(ParserError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P, S) with P > 0 and S > 0, got {}", input ))); } if precision < scale { - return Err(ColumnsParserError::TypeParsingError(format!( + return Err(ParserError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P, S) with P >= S, got {}", input ))); @@ -390,21 +396,19 @@ fn parse_decimal(input: &str) -> Result { let size = DecimalSize::new(parsed[0])?; return Ok(DataType::Decimal(precision, scale, size)); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P), got {}", input ))) } -fn parse_datetime64(input: &str) -> Result { +fn parse_datetime64(input: &str) -> Result { if input.len() >= 13 { let mut chars = (&input[11..input.len() - 1]).chars(); - let precision_char = chars - .next() - .ok_or(ColumnsParserError::TypeParsingError(format!( - "Invalid DateTime64 precision, expected a positive number. Input: {}", - input - )))?; + let precision_char = chars.next().ok_or(ParserError::TypeParsingError(format!( + "Invalid DateTime64 precision, expected a positive number. Input: {}", + input + )))?; let precision = DateTimePrecision::new(precision_char)?; let maybe_tz = match chars.as_str() { str if str.len() > 2 => Some((&str[3..str.len() - 1]).to_string()), @@ -412,42 +416,42 @@ fn parse_datetime64(input: &str) -> Result { }; return Ok(DataType::DateTime64(precision, maybe_tz)); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid DateTime format, expected DateTime('timezone'), got {}", input ))) } -fn parse_low_cardinality(input: &str) -> Result { +fn parse_low_cardinality(input: &str) -> Result { if input.len() >= 16 { let inner_type_str = &input[15..input.len() - 1]; let inner_type = DataType::new(inner_type_str)?; return Ok(DataType::LowCardinality(Box::new(inner_type))); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid LowCardinality format, expected LowCardinality(InnerType), got {}", input ))) } -fn parse_nullable(input: &str) -> Result { +fn parse_nullable(input: &str) -> Result { if input.len() >= 10 { let inner_type_str = &input[9..input.len() - 1]; let inner_type = DataType::new(inner_type_str)?; return Ok(DataType::Nullable(Box::new(inner_type))); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid Nullable format, expected Nullable(InnerType), got {}", input ))) } -fn parse_map(input: &str) -> Result { +fn parse_map(input: &str) -> Result { if input.len() >= 5 { let inner_types_str = &input[4..input.len() - 1]; let inner_types = parse_inner_types(inner_types_str)?; if inner_types.len() != 2 { - return Err(ColumnsParserError::TypeParsingError(format!( + return Err(ParserError::TypeParsingError(format!( "Expected two inner elements in a Map from input {}", input ))); @@ -457,37 +461,37 @@ fn parse_map(input: &str) -> Result { Box::new(inner_types[1].clone()), )); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid Map format, expected Map(KeyType, ValueType), got {}", input ))) } -fn parse_tuple(input: &str) -> Result { +fn parse_tuple(input: &str) -> Result { if input.len() > 7 { let inner_types_str = &input[6..input.len() - 1]; let inner_types = parse_inner_types(inner_types_str)?; if inner_types.is_empty() { - return Err(ColumnsParserError::TypeParsingError(format!( + return Err(ParserError::TypeParsingError(format!( "Expected at least one inner element in a Tuple from input {}", input ))); } return Ok(DataType::Tuple(inner_types)); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid Tuple format, expected Tuple(Type1, Type2, ...), got {}", input ))) } -fn parse_variant(input: &str) -> Result { +fn parse_variant(input: &str) -> Result { if input.len() >= 9 { let inner_types_str = &input[8..input.len() - 1]; let inner_types = parse_inner_types(inner_types_str)?; return Ok(DataType::Variant(inner_types)); } - Err(ColumnsParserError::TypeParsingError(format!( + Err(ParserError::TypeParsingError(format!( "Invalid Variant format, expected Variant(Type1, Type2, ...), got {}", input ))) @@ -499,7 +503,7 @@ fn parse_variant(input: &str) -> Result { /// let input1 = "Tuple(Enum8('f\'()' = 1))`"; // the result is `f\'()` /// let input2 = "Tuple(Enum8('(' = 1))"; // the result is `(` /// ``` -fn parse_inner_types(input: &str) -> Result, ColumnsParserError> { +fn parse_inner_types(input: &str) -> Result, ParserError> { let mut inner_types: Vec = Vec::new(); let input_bytes = input.as_bytes(); @@ -528,7 +532,7 @@ fn parse_inner_types(input: &str) -> Result, ColumnsParserError> { let data_type_str = String::from_utf8(input_bytes[last_element_index..i].to_vec()) .map_err(|_| { - ColumnsParserError::TypeParsingError(format!( + ParserError::TypeParsingError(format!( "Invalid UTF-8 sequence in input for the inner data type: {}", &input[last_element_index..] )) @@ -554,7 +558,7 @@ fn parse_inner_types(input: &str) -> Result, ColumnsParserError> { if open_parens == 0 && last_element_index < input_bytes.len() { let data_type_str = String::from_utf8(input_bytes[last_element_index..].to_vec()).map_err(|_| { - ColumnsParserError::TypeParsingError(format!( + ParserError::TypeParsingError(format!( "Invalid UTF-8 sequence in input for the inner data type: {}", &input[last_element_index..] )) @@ -566,7 +570,7 @@ fn parse_inner_types(input: &str) -> Result, ColumnsParserError> { Ok(inner_types) } -fn parse_enum_values_map(input: &str) -> Result, ColumnsParserError> { +fn parse_enum_values_map(input: &str) -> Result, ParserError> { let mut names: Vec = Vec::new(); let mut indices: Vec = Vec::new(); let mut parsing_name = true; // false when parsing the index @@ -586,7 +590,7 @@ fn parse_enum_values_map(input: &str) -> Result, ColumnsPar // non-escaped closing tick - push the name let name_bytes = &input_bytes[start_index..i]; let name = String::from_utf8(name_bytes.to_vec()).map_err(|_| { - ColumnsParserError::TypeParsingError(format!( + ParserError::TypeParsingError(format!( "Invalid UTF-8 sequence in input for the enum name: {}", &input[start_index..i] )) @@ -595,7 +599,7 @@ fn parse_enum_values_map(input: &str) -> Result, ColumnsPar // Skip ` = ` and the first digit, as it will always have at least one if i + 4 >= input_bytes.len() { - return Err(ColumnsParserError::TypeParsingError(format!( + return Err(ParserError::TypeParsingError(format!( "Invalid Enum format - expected ` = ` after name, input: {}", input, ))); @@ -610,14 +614,14 @@ fn parse_enum_values_map(input: &str) -> Result, ColumnsPar else if input_bytes[i] < b'0' || input_bytes[i] > b'9' { let index = String::from_utf8(input_bytes[start_index..i].to_vec()) .map_err(|_| { - ColumnsParserError::TypeParsingError(format!( + ParserError::TypeParsingError(format!( "Invalid UTF-8 sequence in input for the enum index: {}", &input[start_index..i] )) })? .parse::() .map_err(|_| { - ColumnsParserError::TypeParsingError(format!( + ParserError::TypeParsingError(format!( "Invalid Enum index, expected a valid number. Input: {}", input )) @@ -640,14 +644,14 @@ fn parse_enum_values_map(input: &str) -> Result, ColumnsPar let index = String::from_utf8(input_bytes[start_index..i].to_vec()) .map_err(|_| { - ColumnsParserError::TypeParsingError(format!( + ParserError::TypeParsingError(format!( "Invalid UTF-8 sequence in input for the enum index: {}", &input[start_index..i] )) })? .parse::() .map_err(|_| { - ColumnsParserError::TypeParsingError(format!( + ParserError::TypeParsingError(format!( "Invalid Enum index, expected a valid number. Input: {}", input )) @@ -655,7 +659,7 @@ fn parse_enum_values_map(input: &str) -> Result, ColumnsPar indices.push(index); if names.len() != indices.len() { - return Err(ColumnsParserError::TypeParsingError(format!( + return Err(ParserError::TypeParsingError(format!( "Invalid Enum format - expected the same number of names and indices, got names: {}, indices: {}", names.join(", "), indices.iter().map(|index| index.to_string()).collect::>().join(", "), diff --git a/tests/it/main.rs b/tests/it/main.rs index 5e0385db..b868e988 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -65,6 +65,7 @@ mod ip; mod mock; mod nested; mod query; +mod rbwnat; mod time; mod user_agent; mod uuid; diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs new file mode 100644 index 00000000..7c132abf --- /dev/null +++ b/tests/it/rbwnat.rs @@ -0,0 +1,94 @@ +use clickhouse_rowbinary::header::parse_names_and_types_header; +use clickhouse_rowbinary::types::{Column, DataType}; + +#[tokio::test] +async fn test_header_parsing() { + let client = prepare_database!(); + client + .query( + " + CREATE OR REPLACE TABLE visits + ( + CounterID UInt32, + StartDate Date, + Sign Int8, + IsNew UInt8, + VisitID UInt64, + UserID UInt64, + Goals Nested + ( + ID UInt32, + Serial UInt32, + EventTime DateTime, + Price Int64, + OrderID String, + CurrencyID UInt32 + ) + ) ENGINE = MergeTree ORDER BY () + ", + ) + .execute() + .await + .unwrap(); + + let mut cursor = client + .query("SELECT * FROM visits LIMIT 0") + .fetch_bytes("RowBinaryWithNamesAndTypes") + .unwrap(); + + let mut data = cursor.collect().await.unwrap(); + let result = parse_names_and_types_header(&mut data).unwrap(); + assert_eq!( + result, + vec![ + Column { + name: "CounterID".to_string(), + data_type: DataType::UInt32 + }, + Column { + name: "StartDate".to_string(), + data_type: DataType::Date + }, + Column { + name: "Sign".to_string(), + data_type: DataType::Int8 + }, + Column { + name: "IsNew".to_string(), + data_type: DataType::UInt8 + }, + Column { + name: "VisitID".to_string(), + data_type: DataType::UInt64 + }, + Column { + name: "UserID".to_string(), + data_type: DataType::UInt64 + }, + Column { + name: "Goals.ID".to_string(), + data_type: DataType::Array(Box::new(DataType::UInt32)) + }, + Column { + name: "Goals.Serial".to_string(), + data_type: DataType::Array(Box::new(DataType::UInt32)) + }, + Column { + name: "Goals.EventTime".to_string(), + data_type: DataType::Array(Box::new(DataType::DateTime(None))) + }, + Column { + name: "Goals.Price".to_string(), + data_type: DataType::Array(Box::new(DataType::Int64)) + }, + Column { + name: "Goals.OrderID".to_string(), + data_type: DataType::Array(Box::new(DataType::String)) + }, + Column { + name: "Goals.CurrencyID".to_string(), + data_type: DataType::Array(Box::new(DataType::UInt32)) + } + ] + ); +} From cf72759eff6c53a6177a7172734d2d35deaf465c Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Wed, 14 May 2025 01:01:17 +0200 Subject: [PATCH 03/54] RBWNAT deserializer WIP --- Cargo.toml | 2 +- rowbinary/src/decoders.rs | 7 +- rowbinary/src/header.rs | 31 -- rowbinary/src/leb128.rs | 12 +- rowbinary/src/lib.rs | 35 +- rowbinary/src/types.rs | 610 ++++++++++++++++++------------- src/cursors/row.rs | 48 ++- src/error.rs | 27 +- src/lib.rs | 13 + src/output_format.rs | 12 + src/query.rs | 12 +- src/rowbinary/de.rs | 61 +--- src/rowbinary/de_rbwnat.rs | 712 +++++++++++++++++++++++++++++++++++++ src/rowbinary/mod.rs | 3 + src/rowbinary/utils.rs | 41 +++ tests/it/main.rs | 2 +- tests/it/rbwnat.rs | 94 ----- tests/it/rbwnat_smoke.rs | 334 +++++++++++++++++ 18 files changed, 1588 insertions(+), 468 deletions(-) delete mode 100644 rowbinary/src/header.rs create mode 100644 src/output_format.rs create mode 100644 src/rowbinary/de_rbwnat.rs create mode 100644 src/rowbinary/utils.rs delete mode 100644 tests/it/rbwnat.rs create mode 100644 tests/it/rbwnat_smoke.rs diff --git a/Cargo.toml b/Cargo.toml index ce1c84c1..a29171a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -98,7 +98,7 @@ rustls-tls-native-roots = [ [dependencies] clickhouse-derive = { version = "0.2.0", path = "derive" } - +clickhouse-rowbinary = { version = "*", path = "rowbinary" } thiserror = "1.0.16" serde = "1.0.106" bytes = "1.5.0" diff --git a/rowbinary/src/decoders.rs b/rowbinary/src/decoders.rs index 745c2263..dcd2f9ea 100644 --- a/rowbinary/src/decoders.rs +++ b/rowbinary/src/decoders.rs @@ -1,18 +1,17 @@ use crate::error::ParserError; use crate::leb128::decode_leb128; -use bytes::{Buf, Bytes}; +use bytes::Buf; -pub(crate) fn decode_string(buffer: &mut Bytes) -> Result { +pub(crate) fn decode_string(buffer: &mut &[u8]) -> Result { let length = decode_leb128(buffer)? as usize; if length == 0 { return Ok("".to_string()); } if buffer.remaining() < length { return Err(ParserError::NotEnoughData(format!( - "decoding string, {} bytes remaining, {} bytes required, pos {}", + "decoding string, {} bytes remaining, {} bytes required", buffer.remaining(), length, - buffer.remaining(), ))); } let result = String::from_utf8_lossy(&buffer.copy_to_bytes(length)).to_string(); diff --git a/rowbinary/src/header.rs b/rowbinary/src/header.rs deleted file mode 100644 index d0a9744e..00000000 --- a/rowbinary/src/header.rs +++ /dev/null @@ -1,31 +0,0 @@ -use crate::decoders::decode_string; -use crate::error::ParserError; -use crate::leb128::decode_leb128; -use crate::types::{Column, DataType}; -use bytes::Bytes; - -pub fn parse_names_and_types_header(bytes: &mut Bytes) -> Result, ParserError> { - let num_columns = decode_leb128(bytes)?; - if num_columns == 0 { - return Err(ParserError::HeaderParsingError( - "Expected at least one column in the header".to_string(), - )); - } - let mut columns_names: Vec = Vec::with_capacity(num_columns as usize); - for _ in 0..num_columns { - let column_name = decode_string(bytes)?; - columns_names.push(column_name); - } - let mut column_data_types: Vec = Vec::with_capacity(num_columns as usize); - for _ in 0..num_columns { - let column_type = decode_string(bytes)?; - let data_type = DataType::new(&column_type)?; - column_data_types.push(data_type); - } - let columns = columns_names - .into_iter() - .zip(column_data_types) - .map(|(name, data_type)| Column { name, data_type }) - .collect(); - Ok(columns) -} diff --git a/rowbinary/src/leb128.rs b/rowbinary/src/leb128.rs index cdc33e7d..dd03148f 100644 --- a/rowbinary/src/leb128.rs +++ b/rowbinary/src/leb128.rs @@ -1,8 +1,8 @@ use crate::error::ParserError; use crate::error::ParserError::NotEnoughData; -use bytes::{Buf, Bytes}; +use bytes::Buf; -pub fn decode_leb128(buffer: &mut Bytes) -> Result { +pub fn decode_leb128(buffer: &mut &[u8]) -> Result { let mut value = 0u64; let mut shift = 0; loop { @@ -42,8 +42,6 @@ pub fn encode_leb128(value: u64) -> Vec { } mod tests { - use bytes::Bytes; - #[test] fn test_decode_leb128() { let test_cases = vec![ @@ -58,8 +56,7 @@ mod tests { ]; for (input, expected) in test_cases { - let mut input_bytes = Bytes::from(input.clone()); - let result = super::decode_leb128(&mut input_bytes).unwrap(); + let result = super::decode_leb128(&mut input.as_slice()).unwrap(); assert_eq!(result, expected, "Failed decoding {:?}", input); } } @@ -82,8 +79,7 @@ mod tests { for value in test_values { let encoded = super::encode_leb128(value); - let mut bytes = Bytes::from(encoded.clone()); - let decoded = super::decode_leb128(&mut bytes).unwrap(); + let decoded = super::decode_leb128(&mut encoded.as_slice()).unwrap(); assert_eq!( decoded, value, diff --git a/rowbinary/src/lib.rs b/rowbinary/src/lib.rs index b43660bf..d0b79740 100644 --- a/rowbinary/src/lib.rs +++ b/rowbinary/src/lib.rs @@ -1,6 +1,35 @@ +use crate::decoders::decode_string; +use crate::error::ParserError; +use crate::leb128::decode_leb128; +use crate::types::{Column, DataTypeNode}; + +pub mod decoders; pub mod error; -pub mod header; +pub mod leb128; pub mod types; -mod decoders; -mod leb128; +pub fn parse_columns_header(bytes: &mut &[u8]) -> Result, ParserError> { + let num_columns = decode_leb128(bytes)?; + if num_columns == 0 { + return Err(ParserError::HeaderParsingError( + "Expected at least one column in the header".to_string(), + )); + } + let mut columns_names: Vec = Vec::with_capacity(num_columns as usize); + for _ in 0..num_columns { + let column_name = decode_string(bytes)?; + columns_names.push(column_name); + } + let mut column_data_types: Vec = Vec::with_capacity(num_columns as usize); + for _ in 0..num_columns { + let column_type = decode_string(bytes)?; + let data_type = DataTypeNode::new(&column_type)?; + column_data_types.push(data_type); + } + let columns = columns_names + .into_iter() + .zip(column_data_types) + .map(|(name, data_type)| Column { name, data_type }) + .collect(); + Ok(columns) +} diff --git a/rowbinary/src/types.rs b/rowbinary/src/types.rs index 8b42944a..0a477905 100644 --- a/rowbinary/src/types.rs +++ b/rowbinary/src/types.rs @@ -1,21 +1,28 @@ use crate::error::ParserError; use std::collections::HashMap; -use std::fmt::Display; +use std::fmt::{Display, Formatter}; #[derive(Debug, Clone, PartialEq)] pub struct Column { pub name: String, - pub data_type: DataType, + pub data_type: DataTypeNode, } impl Column { - pub fn new(name: String, data_type: DataType) -> Self { + pub fn new(name: String, data_type: DataTypeNode) -> Self { Self { name, data_type } } } +impl Display for Column { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.name, self.data_type) + } +} + #[derive(Debug, Clone, PartialEq)] -pub enum DataType { +#[non_exhaustive] +pub enum DataTypeNode { Bool, UInt8, UInt16, @@ -41,23 +48,39 @@ pub enum DataType { IPv4, IPv6, - Nullable(Box), - Array(Box), - Tuple(Vec), - Map(Box, Box), - LowCardinality(Box), + Nullable(Box), + Array(Box), + Tuple(Vec), + Map(Box, Box), + LowCardinality(Box), Decimal(u8, u8, DecimalSize), Enum(EnumType, HashMap), - AggregateFunction(String, Vec), + AggregateFunction(String, Vec), FixedString(usize), - Variant(Vec), + Variant(Vec), Dynamic, JSON, - // TODO: Nested, Geo + // TODO: Geo +} + +macro_rules! data_type_is { + ($method:ident, $pattern:pat) => { + #[inline] + pub fn $method(&self) -> Result<(), ParserError> { + match self { + $pattern => Ok(()), + _ => Err(ParserError::TypeParsingError(format!( + "Expected {}, got {}", + stringify!($pattern), + self + ))), + } + } + }; } -impl DataType { +impl DataTypeNode { pub fn new(name: &str) -> Result { match name { "UInt8" => Ok(Self::UInt8), @@ -106,17 +129,54 @@ impl DataType { ))), } } + + data_type_is!(is_bool, DataTypeNode::Bool); + data_type_is!(is_uint8, DataTypeNode::UInt8); + data_type_is!(is_uint16, DataTypeNode::UInt16); + data_type_is!(is_uint32, DataTypeNode::UInt32); + data_type_is!(is_uint64, DataTypeNode::UInt64); + data_type_is!(is_uint128, DataTypeNode::UInt128); + data_type_is!(is_uint256, DataTypeNode::UInt256); + data_type_is!(is_int8, DataTypeNode::Int8); + data_type_is!(is_int16, DataTypeNode::Int16); + data_type_is!(is_int32, DataTypeNode::Int32); + data_type_is!(is_int64, DataTypeNode::Int64); + data_type_is!(is_int128, DataTypeNode::Int128); + data_type_is!(is_int256, DataTypeNode::Int256); + data_type_is!(is_float32, DataTypeNode::Float32); + data_type_is!(is_float64, DataTypeNode::Float64); + data_type_is!(is_bfloat16, DataTypeNode::BFloat16); + data_type_is!(is_string, DataTypeNode::String); + data_type_is!(is_uuid, DataTypeNode::UUID); + data_type_is!(is_date, DataTypeNode::Date); + data_type_is!(is_date32, DataTypeNode::Date32); + data_type_is!(is_datetime, DataTypeNode::DateTime(_)); + data_type_is!(is_datetime64, DataTypeNode::DateTime64(_, _)); + data_type_is!(is_ipv4, DataTypeNode::IPv4); + data_type_is!(is_ipv6, DataTypeNode::IPv6); + data_type_is!(is_nullable, DataTypeNode::Nullable(_)); + data_type_is!(is_array, DataTypeNode::Array(_)); + data_type_is!(is_tuple, DataTypeNode::Tuple(_)); + data_type_is!(is_map, DataTypeNode::Map(_, _)); + data_type_is!(is_low_cardinality, DataTypeNode::LowCardinality(_)); + data_type_is!(is_decimal, DataTypeNode::Decimal(_, _, _)); + data_type_is!(is_enum, DataTypeNode::Enum(_, _)); + data_type_is!(is_aggregate_function, DataTypeNode::AggregateFunction(_, _)); + data_type_is!(is_fixed_string, DataTypeNode::FixedString(_)); + data_type_is!(is_variant, DataTypeNode::Variant(_)); + data_type_is!(is_dynamic, DataTypeNode::Dynamic); + data_type_is!(is_json, DataTypeNode::JSON); } -impl Into for DataType { +impl Into for DataTypeNode { fn into(self) -> String { self.to_string() } } -impl Display for DataType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - use DataType::*; +impl Display for DataTypeNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use DataTypeNode::*; let str = match self { UInt8 => "UInt8".to_string(), UInt16 => "UInt16".to_string(), @@ -281,7 +341,7 @@ impl Display for DateTimePrecision { } } -fn data_types_to_string(elements: &[DataType]) -> String { +fn data_types_to_string(elements: &[DataTypeNode]) -> String { elements .iter() .map(|a| a.to_string()) @@ -289,7 +349,7 @@ fn data_types_to_string(elements: &[DataType]) -> String { .join(", ") } -fn parse_fixed_string(input: &str) -> Result { +fn parse_fixed_string(input: &str) -> Result { if input.len() >= 14 { let size_str = &input[12..input.len() - 1]; let size = size_str.parse::().map_err(|err| { @@ -304,7 +364,7 @@ fn parse_fixed_string(input: &str) -> Result { input ))); } - return Ok(DataType::FixedString(size)); + return Ok(DataTypeNode::FixedString(size)); } Err(ParserError::TypeParsingError(format!( "Invalid FixedString format, expected FixedString(N), got {}", @@ -312,11 +372,11 @@ fn parse_fixed_string(input: &str) -> Result { ))) } -fn parse_array(input: &str) -> Result { +fn parse_array(input: &str) -> Result { if input.len() >= 8 { let inner_type_str = &input[6..input.len() - 1]; - let inner_type = DataType::new(inner_type_str)?; - return Ok(DataType::Array(Box::new(inner_type))); + let inner_type = DataTypeNode::new(inner_type_str)?; + return Ok(DataTypeNode::Array(Box::new(inner_type))); } Err(ParserError::TypeParsingError(format!( "Invalid Array format, expected Array(InnerType), got {}", @@ -324,7 +384,7 @@ fn parse_array(input: &str) -> Result { ))) } -fn parse_enum(input: &str) -> Result { +fn parse_enum(input: &str) -> Result { if input.len() >= 9 { let (enum_type, prefix_len) = if input.starts_with("Enum8") { (EnumType::Enum8, 6) @@ -338,7 +398,7 @@ fn parse_enum(input: &str) -> Result { }; let enum_values_map_str = &input[prefix_len..input.len() - 1]; let enum_values_map = parse_enum_values_map(enum_values_map_str)?; - return Ok(DataType::Enum(enum_type, enum_values_map)); + return Ok(DataTypeNode::Enum(enum_type, enum_values_map)); } Err(ParserError::TypeParsingError(format!( "Invalid Enum format, expected Enum8('name' = value), got {}", @@ -346,13 +406,13 @@ fn parse_enum(input: &str) -> Result { ))) } -fn parse_datetime(input: &str) -> Result { +fn parse_datetime(input: &str) -> Result { if input == "DateTime" { - return Ok(DataType::DateTime(None)); + return Ok(DataTypeNode::DateTime(None)); } if input.len() >= 12 { let timezone = (&input[10..input.len() - 2]).to_string(); - return Ok(DataType::DateTime(Some(timezone))); + return Ok(DataTypeNode::DateTime(Some(timezone))); } Err(ParserError::TypeParsingError(format!( "Invalid DateTime format, expected DateTime('timezone'), got {}", @@ -360,7 +420,7 @@ fn parse_datetime(input: &str) -> Result { ))) } -fn parse_decimal(input: &str) -> Result { +fn parse_decimal(input: &str) -> Result { if input.len() >= 10 { let precision_and_scale_str = (&input[8..input.len() - 1]).split(", ").collect::>(); if precision_and_scale_str.len() != 2 { @@ -394,7 +454,7 @@ fn parse_decimal(input: &str) -> Result { ))); } let size = DecimalSize::new(parsed[0])?; - return Ok(DataType::Decimal(precision, scale, size)); + return Ok(DataTypeNode::Decimal(precision, scale, size)); } Err(ParserError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P), got {}", @@ -402,7 +462,7 @@ fn parse_decimal(input: &str) -> Result { ))) } -fn parse_datetime64(input: &str) -> Result { +fn parse_datetime64(input: &str) -> Result { if input.len() >= 13 { let mut chars = (&input[11..input.len() - 1]).chars(); let precision_char = chars.next().ok_or(ParserError::TypeParsingError(format!( @@ -414,7 +474,7 @@ fn parse_datetime64(input: &str) -> Result { str if str.len() > 2 => Some((&str[3..str.len() - 1]).to_string()), _ => None, }; - return Ok(DataType::DateTime64(precision, maybe_tz)); + return Ok(DataTypeNode::DateTime64(precision, maybe_tz)); } Err(ParserError::TypeParsingError(format!( "Invalid DateTime format, expected DateTime('timezone'), got {}", @@ -422,11 +482,11 @@ fn parse_datetime64(input: &str) -> Result { ))) } -fn parse_low_cardinality(input: &str) -> Result { +fn parse_low_cardinality(input: &str) -> Result { if input.len() >= 16 { let inner_type_str = &input[15..input.len() - 1]; - let inner_type = DataType::new(inner_type_str)?; - return Ok(DataType::LowCardinality(Box::new(inner_type))); + let inner_type = DataTypeNode::new(inner_type_str)?; + return Ok(DataTypeNode::LowCardinality(Box::new(inner_type))); } Err(ParserError::TypeParsingError(format!( "Invalid LowCardinality format, expected LowCardinality(InnerType), got {}", @@ -434,11 +494,11 @@ fn parse_low_cardinality(input: &str) -> Result { ))) } -fn parse_nullable(input: &str) -> Result { +fn parse_nullable(input: &str) -> Result { if input.len() >= 10 { let inner_type_str = &input[9..input.len() - 1]; - let inner_type = DataType::new(inner_type_str)?; - return Ok(DataType::Nullable(Box::new(inner_type))); + let inner_type = DataTypeNode::new(inner_type_str)?; + return Ok(DataTypeNode::Nullable(Box::new(inner_type))); } Err(ParserError::TypeParsingError(format!( "Invalid Nullable format, expected Nullable(InnerType), got {}", @@ -446,7 +506,7 @@ fn parse_nullable(input: &str) -> Result { ))) } -fn parse_map(input: &str) -> Result { +fn parse_map(input: &str) -> Result { if input.len() >= 5 { let inner_types_str = &input[4..input.len() - 1]; let inner_types = parse_inner_types(inner_types_str)?; @@ -456,7 +516,7 @@ fn parse_map(input: &str) -> Result { input ))); } - return Ok(DataType::Map( + return Ok(DataTypeNode::Map( Box::new(inner_types[0].clone()), Box::new(inner_types[1].clone()), )); @@ -467,7 +527,7 @@ fn parse_map(input: &str) -> Result { ))) } -fn parse_tuple(input: &str) -> Result { +fn parse_tuple(input: &str) -> Result { if input.len() > 7 { let inner_types_str = &input[6..input.len() - 1]; let inner_types = parse_inner_types(inner_types_str)?; @@ -477,7 +537,7 @@ fn parse_tuple(input: &str) -> Result { input ))); } - return Ok(DataType::Tuple(inner_types)); + return Ok(DataTypeNode::Tuple(inner_types)); } Err(ParserError::TypeParsingError(format!( "Invalid Tuple format, expected Tuple(Type1, Type2, ...), got {}", @@ -485,11 +545,11 @@ fn parse_tuple(input: &str) -> Result { ))) } -fn parse_variant(input: &str) -> Result { +fn parse_variant(input: &str) -> Result { if input.len() >= 9 { let inner_types_str = &input[8..input.len() - 1]; let inner_types = parse_inner_types(inner_types_str)?; - return Ok(DataType::Variant(inner_types)); + return Ok(DataTypeNode::Variant(inner_types)); } Err(ParserError::TypeParsingError(format!( "Invalid Variant format, expected Variant(Type1, Type2, ...), got {}", @@ -503,8 +563,8 @@ fn parse_variant(input: &str) -> Result { /// let input1 = "Tuple(Enum8('f\'()' = 1))`"; // the result is `f\'()` /// let input2 = "Tuple(Enum8('(' = 1))"; // the result is `(` /// ``` -fn parse_inner_types(input: &str) -> Result, ParserError> { - let mut inner_types: Vec = Vec::new(); +fn parse_inner_types(input: &str) -> Result, ParserError> { + let mut inner_types: Vec = Vec::new(); let input_bytes = input.as_bytes(); @@ -537,7 +597,7 @@ fn parse_inner_types(input: &str) -> Result, ParserError> { &input[last_element_index..] )) })?; - let data_type = DataType::new(&data_type_str)?; + let data_type = DataTypeNode::new(&data_type_str)?; inner_types.push(data_type); // Skip ', ' (comma and space) if i + 2 <= input_bytes.len() && input_bytes[i + 1] == b' ' { @@ -563,7 +623,7 @@ fn parse_inner_types(input: &str) -> Result, ParserError> { &input[last_element_index..] )) })?; - let data_type = DataType::new(&data_type_str)?; + let data_type = DataTypeNode::new(&data_type_str)?; inner_types.push(data_type); } @@ -678,333 +738,359 @@ mod tests { #[test] fn test_data_type_new_simple() { - assert_eq!(DataType::new("UInt8").unwrap(), DataType::UInt8); - assert_eq!(DataType::new("UInt16").unwrap(), DataType::UInt16); - assert_eq!(DataType::new("UInt32").unwrap(), DataType::UInt32); - assert_eq!(DataType::new("UInt64").unwrap(), DataType::UInt64); - assert_eq!(DataType::new("UInt128").unwrap(), DataType::UInt128); - assert_eq!(DataType::new("UInt256").unwrap(), DataType::UInt256); - assert_eq!(DataType::new("Int8").unwrap(), DataType::Int8); - assert_eq!(DataType::new("Int16").unwrap(), DataType::Int16); - assert_eq!(DataType::new("Int32").unwrap(), DataType::Int32); - assert_eq!(DataType::new("Int64").unwrap(), DataType::Int64); - assert_eq!(DataType::new("Int128").unwrap(), DataType::Int128); - assert_eq!(DataType::new("Int256").unwrap(), DataType::Int256); - assert_eq!(DataType::new("Float32").unwrap(), DataType::Float32); - assert_eq!(DataType::new("Float64").unwrap(), DataType::Float64); - assert_eq!(DataType::new("BFloat16").unwrap(), DataType::BFloat16); - assert_eq!(DataType::new("String").unwrap(), DataType::String); - assert_eq!(DataType::new("UUID").unwrap(), DataType::UUID); - assert_eq!(DataType::new("Date").unwrap(), DataType::Date); - assert_eq!(DataType::new("Date32").unwrap(), DataType::Date32); - assert_eq!(DataType::new("IPv4").unwrap(), DataType::IPv4); - assert_eq!(DataType::new("IPv6").unwrap(), DataType::IPv6); - assert_eq!(DataType::new("Bool").unwrap(), DataType::Bool); - assert_eq!(DataType::new("Dynamic").unwrap(), DataType::Dynamic); - assert_eq!(DataType::new("JSON").unwrap(), DataType::JSON); + assert_eq!(DataTypeNode::new("UInt8").unwrap(), DataTypeNode::UInt8); + assert_eq!(DataTypeNode::new("UInt16").unwrap(), DataTypeNode::UInt16); + assert_eq!(DataTypeNode::new("UInt32").unwrap(), DataTypeNode::UInt32); + assert_eq!(DataTypeNode::new("UInt64").unwrap(), DataTypeNode::UInt64); + assert_eq!(DataTypeNode::new("UInt128").unwrap(), DataTypeNode::UInt128); + assert_eq!(DataTypeNode::new("UInt256").unwrap(), DataTypeNode::UInt256); + assert_eq!(DataTypeNode::new("Int8").unwrap(), DataTypeNode::Int8); + assert_eq!(DataTypeNode::new("Int16").unwrap(), DataTypeNode::Int16); + assert_eq!(DataTypeNode::new("Int32").unwrap(), DataTypeNode::Int32); + assert_eq!(DataTypeNode::new("Int64").unwrap(), DataTypeNode::Int64); + assert_eq!(DataTypeNode::new("Int128").unwrap(), DataTypeNode::Int128); + assert_eq!(DataTypeNode::new("Int256").unwrap(), DataTypeNode::Int256); + assert_eq!(DataTypeNode::new("Float32").unwrap(), DataTypeNode::Float32); + assert_eq!(DataTypeNode::new("Float64").unwrap(), DataTypeNode::Float64); + assert_eq!( + DataTypeNode::new("BFloat16").unwrap(), + DataTypeNode::BFloat16 + ); + assert_eq!(DataTypeNode::new("String").unwrap(), DataTypeNode::String); + assert_eq!(DataTypeNode::new("UUID").unwrap(), DataTypeNode::UUID); + assert_eq!(DataTypeNode::new("Date").unwrap(), DataTypeNode::Date); + assert_eq!(DataTypeNode::new("Date32").unwrap(), DataTypeNode::Date32); + assert_eq!(DataTypeNode::new("IPv4").unwrap(), DataTypeNode::IPv4); + assert_eq!(DataTypeNode::new("IPv6").unwrap(), DataTypeNode::IPv6); + assert_eq!(DataTypeNode::new("Bool").unwrap(), DataTypeNode::Bool); + assert_eq!(DataTypeNode::new("Dynamic").unwrap(), DataTypeNode::Dynamic); + assert_eq!(DataTypeNode::new("JSON").unwrap(), DataTypeNode::JSON); assert!(DataType::new("SomeUnknownType").is_err(),); } #[test] fn test_data_type_new_fixed_string() { assert_eq!( - DataType::new("FixedString(1)").unwrap(), - DataType::FixedString(1) + DataTypeNode::new("FixedString(1)").unwrap(), + DataTypeNode::FixedString(1) ); assert_eq!( - DataType::new("FixedString(16)").unwrap(), - DataType::FixedString(16) + DataTypeNode::new("FixedString(16)").unwrap(), + DataTypeNode::FixedString(16) ); assert_eq!( - DataType::new("FixedString(255)").unwrap(), - DataType::FixedString(255) + DataTypeNode::new("FixedString(255)").unwrap(), + DataTypeNode::FixedString(255) ); assert_eq!( - DataType::new("FixedString(65535)").unwrap(), - DataType::FixedString(65_535) + DataTypeNode::new("FixedString(65535)").unwrap(), + DataTypeNode::FixedString(65_535) ); - assert!(DataType::new("FixedString()").is_err()); - assert!(DataType::new("FixedString(0)").is_err()); - assert!(DataType::new("FixedString(-1)").is_err()); - assert!(DataType::new("FixedString(abc)").is_err()); + assert!(DataTypeNode::new("FixedString()").is_err()); + assert!(DataTypeNode::new("FixedString(0)").is_err()); + assert!(DataTypeNode::new("FixedString(-1)").is_err()); + assert!(DataTypeNode::new("FixedString(abc)").is_err()); } #[test] fn test_data_type_new_array() { assert_eq!( - DataType::new("Array(UInt8)").unwrap(), - DataType::Array(Box::new(DataType::UInt8)) + DataTypeNode::new("Array(UInt8)").unwrap(), + DataTypeNode::Array(Box::new(DataTypeNode::UInt8)) ); assert_eq!( - DataType::new("Array(String)").unwrap(), - DataType::Array(Box::new(DataType::String)) + DataTypeNode::new("Array(String)").unwrap(), + DataTypeNode::Array(Box::new(DataTypeNode::String)) ); assert_eq!( - DataType::new("Array(FixedString(16))").unwrap(), - DataType::Array(Box::new(DataType::FixedString(16))) + DataTypeNode::new("Array(FixedString(16))").unwrap(), + DataTypeNode::Array(Box::new(DataTypeNode::FixedString(16))) ); assert_eq!( - DataType::new("Array(Nullable(Int32))").unwrap(), - DataType::Array(Box::new(DataType::Nullable(Box::new(DataType::Int32)))) + DataTypeNode::new("Array(Nullable(Int32))").unwrap(), + DataTypeNode::Array(Box::new(DataTypeNode::Nullable(Box::new( + DataTypeNode::Int32 + )))) ); - assert!(DataType::new("Array()").is_err()); - assert!(DataType::new("Array(abc)").is_err()); + assert!(DataTypeNode::new("Array()").is_err()); + assert!(DataTypeNode::new("Array(abc)").is_err()); } #[test] fn test_data_type_new_decimal() { assert_eq!( - DataType::new("Decimal(7, 2)").unwrap(), - DataType::Decimal(7, 2, DecimalSize::Int32) + DataTypeNode::new("Decimal(7, 2)").unwrap(), + DataTypeNode::Decimal(7, 2, DecimalSize::Int32) ); assert_eq!( - DataType::new("Decimal(12, 4)").unwrap(), - DataType::Decimal(12, 4, DecimalSize::Int64) + DataTypeNode::new("Decimal(12, 4)").unwrap(), + DataTypeNode::Decimal(12, 4, DecimalSize::Int64) ); assert_eq!( - DataType::new("Decimal(27, 6)").unwrap(), - DataType::Decimal(27, 6, DecimalSize::Int128) + DataTypeNode::new("Decimal(27, 6)").unwrap(), + DataTypeNode::Decimal(27, 6, DecimalSize::Int128) ); assert_eq!( - DataType::new("Decimal(42, 8)").unwrap(), - DataType::Decimal(42, 8, DecimalSize::Int256) + DataTypeNode::new("Decimal(42, 8)").unwrap(), + DataTypeNode::Decimal(42, 8, DecimalSize::Int256) ); - assert!(DataType::new("Decimal").is_err()); - assert!(DataType::new("Decimal(").is_err()); - assert!(DataType::new("Decimal()").is_err()); - assert!(DataType::new("Decimal(1)").is_err()); - assert!(DataType::new("Decimal(1,)").is_err()); - assert!(DataType::new("Decimal(1, )").is_err()); - assert!(DataType::new("Decimal(0, 0)").is_err()); // Precision must be > 0 - assert!(DataType::new("Decimal(x, 0)").is_err()); // Non-numeric precision - assert!(DataType::new("Decimal(', ')").is_err()); - assert!(DataType::new("Decimal(77, 1)").is_err()); // Max precision is 76 - assert!(DataType::new("Decimal(1, 2)").is_err()); // Scale must be less than precision - assert!(DataType::new("Decimal(1, x)").is_err()); // Non-numeric scale - assert!(DataType::new("Decimal(42, ,)").is_err()); - assert!(DataType::new("Decimal(42, ')").is_err()); - assert!(DataType::new("Decimal(foobar)").is_err()); + assert!(DataTypeNode::new("Decimal").is_err()); + assert!(DataTypeNode::new("Decimal(").is_err()); + assert!(DataTypeNode::new("Decimal()").is_err()); + assert!(DataTypeNode::new("Decimal(1)").is_err()); + assert!(DataTypeNode::new("Decimal(1,)").is_err()); + assert!(DataTypeNode::new("Decimal(1, )").is_err()); + assert!(DataTypeNode::new("Decimal(0, 0)").is_err()); // Precision must be > 0 + assert!(DataTypeNode::new("Decimal(x, 0)").is_err()); // Non-numeric precision + assert!(DataTypeNode::new("Decimal(', ')").is_err()); + assert!(DataTypeNode::new("Decimal(77, 1)").is_err()); // Max precision is 76 + assert!(DataTypeNode::new("Decimal(1, 2)").is_err()); // Scale must be less than precision + assert!(DataTypeNode::new("Decimal(1, x)").is_err()); // Non-numeric scale + assert!(DataTypeNode::new("Decimal(42, ,)").is_err()); + assert!(DataTypeNode::new("Decimal(42, ')").is_err()); + assert!(DataTypeNode::new("Decimal(foobar)").is_err()); } #[test] fn test_data_type_new_datetime() { - assert_eq!(DataType::new("DateTime").unwrap(), DataType::DateTime(None)); assert_eq!( - DataType::new("DateTime('UTC')").unwrap(), - DataType::DateTime(Some("UTC".to_string())) + DataTypeNode::new("DateTime").unwrap(), + DataTypeNode::DateTime(None) ); assert_eq!( - DataType::new("DateTime('America/New_York')").unwrap(), - DataType::DateTime(Some("America/New_York".to_string())) + DataTypeNode::new("DateTime('UTC')").unwrap(), + DataTypeNode::DateTime(Some("UTC".to_string())) ); - assert!(DataType::new("DateTime()").is_err()); + assert_eq!( + DataTypeNode::new("DateTime('America/New_York')").unwrap(), + DataTypeNode::DateTime(Some("America/New_York".to_string())) + ); + assert!(DataTypeNode::new("DateTime()").is_err()); } #[test] fn test_data_type_new_datetime64() { assert_eq!( - DataType::new("DateTime64(0)").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision0, None) + DataTypeNode::new("DateTime64(0)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision0, None) ); assert_eq!( - DataType::new("DateTime64(1)").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision1, None) + DataTypeNode::new("DateTime64(1)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision1, None) ); assert_eq!( - DataType::new("DateTime64(2)").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision2, None) + DataTypeNode::new("DateTime64(2)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision2, None) ); assert_eq!( - DataType::new("DateTime64(3)").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision3, None) + DataTypeNode::new("DateTime64(3)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision3, None) ); assert_eq!( - DataType::new("DateTime64(4)").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision4, None) + DataTypeNode::new("DateTime64(4)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision4, None) ); assert_eq!( - DataType::new("DateTime64(5)").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision5, None) + DataTypeNode::new("DateTime64(5)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision5, None) ); assert_eq!( - DataType::new("DateTime64(6)").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision6, None) + DataTypeNode::new("DateTime64(6)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision6, None) ); assert_eq!( - DataType::new("DateTime64(7)").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision7, None) + DataTypeNode::new("DateTime64(7)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision7, None) ); assert_eq!( - DataType::new("DateTime64(8)").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision8, None) + DataTypeNode::new("DateTime64(8)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision8, None) ); assert_eq!( - DataType::new("DateTime64(9)").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision9, None) + DataTypeNode::new("DateTime64(9)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision9, None) ); assert_eq!( - DataType::new("DateTime64(0, 'UTC')").unwrap(), - DataType::DateTime64(DateTimePrecision::Precision0, Some("UTC".to_string())) + DataTypeNode::new("DateTime64(0, 'UTC')").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision0, Some("UTC".to_string())) ); assert_eq!( - DataType::new("DateTime64(3, 'America/New_York')").unwrap(), - DataType::DateTime64( + DataTypeNode::new("DateTime64(3, 'America/New_York')").unwrap(), + DataTypeNode::DateTime64( DateTimePrecision::Precision3, Some("America/New_York".to_string()) ) ); assert_eq!( - DataType::new("DateTime64(6, 'America/New_York')").unwrap(), - DataType::DateTime64( + DataTypeNode::new("DateTime64(6, 'America/New_York')").unwrap(), + DataTypeNode::DateTime64( DateTimePrecision::Precision6, Some("America/New_York".to_string()) ) ); assert_eq!( - DataType::new("DateTime64(9, 'Europe/Amsterdam')").unwrap(), - DataType::DateTime64( + DataTypeNode::new("DateTime64(9, 'Europe/Amsterdam')").unwrap(), + DataTypeNode::DateTime64( DateTimePrecision::Precision9, Some("Europe/Amsterdam".to_string()) ) ); - assert!(DataType::new("DateTime64()").is_err()); + assert!(DataTypeNode::new("DateTime64()").is_err()); } #[test] fn test_data_type_new_low_cardinality() { assert_eq!( - DataType::new("LowCardinality(UInt8)").unwrap(), - DataType::LowCardinality(Box::new(DataType::UInt8)) + DataTypeNode::new("LowCardinality(UInt8)").unwrap(), + DataTypeNode::LowCardinality(Box::new(DataTypeNode::UInt8)) ); assert_eq!( - DataType::new("LowCardinality(String)").unwrap(), - DataType::LowCardinality(Box::new(DataType::String)) + DataTypeNode::new("LowCardinality(String)").unwrap(), + DataTypeNode::LowCardinality(Box::new(DataTypeNode::String)) ); assert_eq!( - DataType::new("LowCardinality(Array(Int32))").unwrap(), - DataType::LowCardinality(Box::new(DataType::Array(Box::new(DataType::Int32)))) + DataTypeNode::new("LowCardinality(Array(Int32))").unwrap(), + DataTypeNode::LowCardinality(Box::new(DataTypeNode::Array(Box::new( + DataTypeNode::Int32 + )))) ); - assert!(DataType::new("LowCardinality()").is_err()); + assert!(DataTypeNode::new("LowCardinality()").is_err()); } #[test] fn test_data_type_new_nullable() { assert_eq!( - DataType::new("Nullable(UInt8)").unwrap(), - DataType::Nullable(Box::new(DataType::UInt8)) + DataTypeNode::new("Nullable(UInt8)").unwrap(), + DataTypeNode::Nullable(Box::new(DataTypeNode::UInt8)) ); assert_eq!( - DataType::new("Nullable(String)").unwrap(), - DataType::Nullable(Box::new(DataType::String)) + DataTypeNode::new("Nullable(String)").unwrap(), + DataTypeNode::Nullable(Box::new(DataTypeNode::String)) ); - assert!(DataType::new("Nullable()").is_err()); + assert!(DataTypeNode::new("Nullable()").is_err()); } #[test] fn test_data_type_new_map() { assert_eq!( - DataType::new("Map(UInt8, String)").unwrap(), - DataType::Map(Box::new(DataType::UInt8), Box::new(DataType::String)) + DataTypeNode::new("Map(UInt8, String)").unwrap(), + DataTypeNode::Map( + Box::new(DataTypeNode::UInt8), + Box::new(DataTypeNode::String) + ) ); assert_eq!( - DataType::new("Map(String, Int32)").unwrap(), - DataType::Map(Box::new(DataType::String), Box::new(DataType::Int32)) + DataTypeNode::new("Map(String, Int32)").unwrap(), + DataTypeNode::Map( + Box::new(DataTypeNode::String), + Box::new(DataTypeNode::Int32) + ) ); assert_eq!( - DataType::new("Map(String, Map(Int32, Array(Nullable(String))))").unwrap(), - DataType::Map( - Box::new(DataType::String), - Box::new(DataType::Map( - Box::new(DataType::Int32), - Box::new(DataType::Array(Box::new(DataType::Nullable(Box::new( - DataType::String - ))))) + DataTypeNode::new("Map(String, Map(Int32, Array(Nullable(String))))").unwrap(), + DataTypeNode::Map( + Box::new(DataTypeNode::String), + Box::new(DataTypeNode::Map( + Box::new(DataTypeNode::Int32), + Box::new(DataTypeNode::Array(Box::new(DataTypeNode::Nullable( + Box::new(DataTypeNode::String) + )))) )) ) ); - assert!(DataType::new("Map()").is_err()); + assert!(DataTypeNode::new("Map()").is_err()); } #[test] fn test_data_type_new_variant() { assert_eq!( - DataType::new("Variant(UInt8, String)").unwrap(), - DataType::Variant(vec![DataType::UInt8, DataType::String]) + DataTypeNode::new("Variant(UInt8, String)").unwrap(), + DataTypeNode::Variant(vec![DataTypeNode::UInt8, DataTypeNode::String]) ); assert_eq!( - DataType::new("Variant(String, Int32)").unwrap(), - DataType::Variant(vec![DataType::String, DataType::Int32]) + DataTypeNode::new("Variant(String, Int32)").unwrap(), + DataTypeNode::Variant(vec![DataTypeNode::String, DataTypeNode::Int32]) ); assert_eq!( - DataType::new("Variant(Int32, Array(Nullable(String)), Map(Int32, String))").unwrap(), - DataType::Variant(vec![ - DataType::Int32, - DataType::Array(Box::new(DataType::Nullable(Box::new(DataType::String)))), - DataType::Map(Box::new(DataType::Int32), Box::new(DataType::String)) + DataTypeNode::new("Variant(Int32, Array(Nullable(String)), Map(Int32, String))") + .unwrap(), + DataTypeNode::Variant(vec![ + DataTypeNode::Int32, + DataTypeNode::Array(Box::new(DataTypeNode::Nullable(Box::new( + DataTypeNode::String + )))), + DataTypeNode::Map( + Box::new(DataTypeNode::Int32), + Box::new(DataTypeNode::String) + ) ]) ); - assert!(DataType::new("Variant").is_err()); + assert!(DataTypeNode::new("Variant").is_err()); } #[test] fn test_data_type_new_tuple() { assert_eq!( - DataType::new("Tuple(UInt8, String)").unwrap(), - DataType::Tuple(vec![DataType::UInt8, DataType::String]) + DataTypeNode::new("Tuple(UInt8, String)").unwrap(), + DataTypeNode::Tuple(vec![DataTypeNode::UInt8, DataTypeNode::String]) ); assert_eq!( - DataType::new("Tuple(String, Int32)").unwrap(), - DataType::Tuple(vec![DataType::String, DataType::Int32]) + DataTypeNode::new("Tuple(String, Int32)").unwrap(), + DataTypeNode::Tuple(vec![DataTypeNode::String, DataTypeNode::Int32]) ); assert_eq!( - DataType::new( + DataTypeNode::new( "Tuple(Int32, Array(Nullable(String)), Map(Int32, Tuple(String, Array(UInt8))))" ) .unwrap(), - DataType::Tuple(vec![ - DataType::Int32, - DataType::Array(Box::new(DataType::Nullable(Box::new(DataType::String)))), - DataType::Map( - Box::new(DataType::Int32), - Box::new(DataType::Tuple(vec![ - DataType::String, - DataType::Array(Box::new(DataType::UInt8)) + DataTypeNode::Tuple(vec![ + DataTypeNode::Int32, + DataTypeNode::Array(Box::new(DataTypeNode::Nullable(Box::new( + DataTypeNode::String + )))), + DataTypeNode::Map( + Box::new(DataTypeNode::Int32), + Box::new(DataTypeNode::Tuple(vec![ + DataTypeNode::String, + DataTypeNode::Array(Box::new(DataTypeNode::UInt8)) ])) ) ]) ); - assert!(DataType::new("Tuple").is_err()); + assert!(DataTypeNode::new("Tuple").is_err()); } #[test] fn test_data_type_new_enum() { assert_eq!( - DataType::new("Enum8('A' = -42)").unwrap(), - DataType::Enum(EnumType::Enum8, HashMap::from([(-42, "A".to_string())])) + DataTypeNode::new("Enum8('A' = -42)").unwrap(), + DataTypeNode::Enum(EnumType::Enum8, HashMap::from([(-42, "A".to_string())])) ); assert_eq!( - DataType::new("Enum16('A' = -144)").unwrap(), - DataType::Enum(EnumType::Enum16, HashMap::from([(-144, "A".to_string())])) + DataTypeNode::new("Enum16('A' = -144)").unwrap(), + DataTypeNode::Enum(EnumType::Enum16, HashMap::from([(-144, "A".to_string())])) ); assert_eq!( - DataType::new("Enum8('A' = 1, 'B' = 2)").unwrap(), - DataType::Enum( + DataTypeNode::new("Enum8('A' = 1, 'B' = 2)").unwrap(), + DataTypeNode::Enum( EnumType::Enum8, HashMap::from([(1, "A".to_string()), (2, "B".to_string())]) ) ); assert_eq!( - DataType::new("Enum16('A' = 1, 'B' = 2)").unwrap(), - DataType::Enum( + DataTypeNode::new("Enum16('A' = 1, 'B' = 2)").unwrap(), + DataTypeNode::Enum( EnumType::Enum16, HashMap::from([(1, "A".to_string()), (2, "B".to_string())]) ) ); assert_eq!( - DataType::new("Enum8('f\\'' = 1, 'x =' = 2, 'b\\'\\'' = 3, '\\'c=4=' = 42, '4' = 100)") - .unwrap(), - DataType::Enum( + DataTypeNode::new( + "Enum8('f\\'' = 1, 'x =' = 2, 'b\\'\\'' = 3, '\\'c=4=' = 42, '4' = 100)" + ) + .unwrap(), + DataTypeNode::Enum( EnumType::Enum8, HashMap::from([ (1, "f\\'".to_string()), @@ -1016,86 +1102,97 @@ mod tests { ) ); assert_eq!( - DataType::new("Enum8('foo' = 0, '' = 42)").unwrap(), - DataType::Enum( + DataTypeNode::new("Enum8('foo' = 0, '' = 42)").unwrap(), + DataTypeNode::Enum( EnumType::Enum8, HashMap::from([(0, "foo".to_string()), (42, "".to_string())]) ) ); - assert!(DataType::new("Enum()").is_err()); - assert!(DataType::new("Enum8()").is_err()); - assert!(DataType::new("Enum16()").is_err()); + assert!(DataTypeNode::new("Enum()").is_err()); + assert!(DataTypeNode::new("Enum8()").is_err()); + assert!(DataTypeNode::new("Enum16()").is_err()); } #[test] fn test_data_type_to_string_simple() { // Simple types - assert_eq!(DataType::UInt8.to_string(), "UInt8"); - assert_eq!(DataType::UInt16.to_string(), "UInt16"); - assert_eq!(DataType::UInt32.to_string(), "UInt32"); - assert_eq!(DataType::UInt64.to_string(), "UInt64"); - assert_eq!(DataType::UInt128.to_string(), "UInt128"); - assert_eq!(DataType::UInt256.to_string(), "UInt256"); - assert_eq!(DataType::Int8.to_string(), "Int8"); - assert_eq!(DataType::Int16.to_string(), "Int16"); - assert_eq!(DataType::Int32.to_string(), "Int32"); - assert_eq!(DataType::Int64.to_string(), "Int64"); - assert_eq!(DataType::Int128.to_string(), "Int128"); - assert_eq!(DataType::Int256.to_string(), "Int256"); - assert_eq!(DataType::Float32.to_string(), "Float32"); - assert_eq!(DataType::Float64.to_string(), "Float64"); - assert_eq!(DataType::BFloat16.to_string(), "BFloat16"); - assert_eq!(DataType::UUID.to_string(), "UUID"); - assert_eq!(DataType::Date.to_string(), "Date"); - assert_eq!(DataType::Date32.to_string(), "Date32"); - assert_eq!(DataType::IPv4.to_string(), "IPv4"); - assert_eq!(DataType::IPv6.to_string(), "IPv6"); - assert_eq!(DataType::Bool.to_string(), "Bool"); - assert_eq!(DataType::Dynamic.to_string(), "Dynamic"); - assert_eq!(DataType::JSON.to_string(), "JSON"); - assert_eq!(DataType::String.to_string(), "String"); + assert_eq!(DataTypeNode::UInt8.to_string(), "UInt8"); + assert_eq!(DataTypeNode::UInt16.to_string(), "UInt16"); + assert_eq!(DataTypeNode::UInt32.to_string(), "UInt32"); + assert_eq!(DataTypeNode::UInt64.to_string(), "UInt64"); + assert_eq!(DataTypeNode::UInt128.to_string(), "UInt128"); + assert_eq!(DataTypeNode::UInt256.to_string(), "UInt256"); + assert_eq!(DataTypeNode::Int8.to_string(), "Int8"); + assert_eq!(DataTypeNode::Int16.to_string(), "Int16"); + assert_eq!(DataTypeNode::Int32.to_string(), "Int32"); + assert_eq!(DataTypeNode::Int64.to_string(), "Int64"); + assert_eq!(DataTypeNode::Int128.to_string(), "Int128"); + assert_eq!(DataTypeNode::Int256.to_string(), "Int256"); + assert_eq!(DataTypeNode::Float32.to_string(), "Float32"); + assert_eq!(DataTypeNode::Float64.to_string(), "Float64"); + assert_eq!(DataTypeNode::BFloat16.to_string(), "BFloat16"); + assert_eq!(DataTypeNode::UUID.to_string(), "UUID"); + assert_eq!(DataTypeNode::Date.to_string(), "Date"); + assert_eq!(DataTypeNode::Date32.to_string(), "Date32"); + assert_eq!(DataTypeNode::IPv4.to_string(), "IPv4"); + assert_eq!(DataTypeNode::IPv6.to_string(), "IPv6"); + assert_eq!(DataTypeNode::Bool.to_string(), "Bool"); + assert_eq!(DataTypeNode::Dynamic.to_string(), "Dynamic"); + assert_eq!(DataTypeNode::JSON.to_string(), "JSON"); + assert_eq!(DataTypeNode::String.to_string(), "String"); } #[test] fn test_data_types_to_string_complex() { - assert_eq!(DataType::DateTime(None).to_string(), "DateTime"); + assert_eq!(DataTypeNode::DateTime(None).to_string(), "DateTime"); assert_eq!( - DataType::DateTime(Some("UTC".to_string())).to_string(), + DataTypeNode::DateTime(Some("UTC".to_string())).to_string(), "DateTime('UTC')" ); assert_eq!( - DataType::DateTime(Some("America/New_York".to_string())).to_string(), + DataTypeNode::DateTime(Some("America/New_York".to_string())).to_string(), "DateTime('America/New_York')" ); assert_eq!( - DataType::Nullable(Box::new(DataType::UInt64)).to_string(), + DataTypeNode::Nullable(Box::new(DataTypeNode::UInt64)).to_string(), "Nullable(UInt64)" ); assert_eq!( - DataType::Array(Box::new(DataType::String)).to_string(), + DataTypeNode::Array(Box::new(DataTypeNode::String)).to_string(), "Array(String)" ); assert_eq!( - DataType::Array(Box::new(DataType::Nullable(Box::new(DataType::String)))).to_string(), + DataTypeNode::Array(Box::new(DataTypeNode::Nullable(Box::new( + DataTypeNode::String + )))) + .to_string(), "Array(Nullable(String))" ); assert_eq!( - DataType::Tuple(vec![DataType::String, DataType::UInt32, DataType::Float64]) - .to_string(), + DataTypeNode::Tuple(vec![ + DataTypeNode::String, + DataTypeNode::UInt32, + DataTypeNode::Float64 + ]) + .to_string(), "Tuple(String, UInt32, Float64)" ); assert_eq!( - DataType::Map(Box::new(DataType::String), Box::new(DataType::UInt32)).to_string(), + DataTypeNode::Map( + Box::new(DataTypeNode::String), + Box::new(DataTypeNode::UInt32) + ) + .to_string(), "Map(String, UInt32)" ); assert_eq!( - DataType::Decimal(10, 2, DecimalSize::Int32).to_string(), + DataTypeNode::Decimal(10, 2, DecimalSize::Int32).to_string(), "Decimal(10, 2)" ); assert_eq!( - DataType::Enum( + DataTypeNode::Enum( EnumType::Enum8, HashMap::from([(1, "A".to_string()), (2, "B".to_string())]), ) @@ -1103,16 +1200,17 @@ mod tests { "Enum8('A' = 1, 'B' = 2)" ); assert_eq!( - DataType::AggregateFunction("sum".to_string(), vec![DataType::UInt64]).to_string(), + DataTypeNode::AggregateFunction("sum".to_string(), vec![DataTypeNode::UInt64]) + .to_string(), "AggregateFunction(sum, UInt64)" ); - assert_eq!(DataType::FixedString(16).to_string(), "FixedString(16)"); + assert_eq!(DataTypeNode::FixedString(16).to_string(), "FixedString(16)"); assert_eq!( - DataType::Variant(vec![DataType::UInt8, DataType::Bool]).to_string(), + DataTypeNode::Variant(vec![DataTypeNode::UInt8, DataTypeNode::Bool]).to_string(), "Variant(UInt8, Bool)" ); assert_eq!( - DataType::DateTime64(DateTimePrecision::Precision3, Some("UTC".to_string())) + DataTypeNode::DateTime64(DateTimePrecision::Precision3, Some("UTC".to_string())) .to_string(), "DateTime64(3, 'UTC')" ); diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 6f17cfcc..20ec0f5e 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -1,3 +1,4 @@ +use crate::output_format::OutputFormat; use crate::{ bytes_ext::BytesExt, cursors::RawCursor, @@ -5,6 +6,8 @@ use crate::{ response::Response, rowbinary, }; +use clickhouse_rowbinary::parse_columns_header; +use clickhouse_rowbinary::types::Column; use serde::Deserialize; use std::marker::PhantomData; @@ -13,15 +16,19 @@ use std::marker::PhantomData; pub struct RowCursor { raw: RawCursor, bytes: BytesExt, + format: OutputFormat, + columns: Option>, _marker: PhantomData, } impl RowCursor { - pub(crate) fn new(response: Response) -> Self { + pub(crate) fn new(response: Response, format: OutputFormat) -> Self { Self { + _marker: PhantomData, raw: RawCursor::new(response), bytes: BytesExt::default(), - _marker: PhantomData, + columns: None, + format, } } @@ -37,15 +44,36 @@ impl RowCursor { T: Deserialize<'b>, { loop { - let mut slice = super::workaround_51132(self.bytes.slice()); - - match rowbinary::deserialize_from(&mut slice) { - Ok(value) => { - self.bytes.set_remaining(slice.len()); - return Ok(Some(value)); + if self.bytes.remaining() > 0 { + let mut slice = super::workaround_51132(self.bytes.slice()); + match self.format { + OutputFormat::RowBinary => match rowbinary::deserialize_from(&mut slice) { + Ok(value) => { + self.bytes.set_remaining(slice.len()); + return Ok(Some(value)); + } + Err(Error::NotEnoughData) => {} + Err(err) => return Err(err), + }, + OutputFormat::RowBinaryWithNamesAndTypes => match self.columns.as_ref() { + // FIXME: move this branch to new? + None => { + let columns = parse_columns_header(&mut slice)?; + self.bytes.set_remaining(slice.len()); + self.columns = Some(columns); + } + Some(columns) => { + match rowbinary::deserialize_from_rbwnat(&mut slice, columns) { + Ok(value) => { + self.bytes.set_remaining(slice.len()); + return Ok(Some(value)); + } + Err(Error::NotEnoughData) => {} + Err(err) => return Err(err), + } + } + }, } - Err(Error::NotEnoughData) => {} - Err(err) => return Err(err), } match self.raw.next().await? { diff --git a/src/error.rs b/src/error.rs index f4bde3c4..852f3d5b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,8 +1,8 @@ //! Contains [`Error`] and corresponding [`Result`]. -use std::{error::Error as StdError, fmt, io, result, str::Utf8Error}; - +use clickhouse_rowbinary::types::Column; use serde::{de, ser}; +use std::{error::Error as StdError, fmt, io, result, str::Utf8Error}; /// A result with a specified [`Error`] type. pub type Result = result::Result; @@ -44,12 +44,35 @@ pub enum Error { TimedOut, #[error("unsupported: {0}")] Unsupported(String), + #[error("error while parsing data from the response: {0}")] + ParserError(BoxedError), + #[error("struct mismatches the database definition; field {field_name} has unexpected type {unexpected_type}; allowed types for {field_name}: {allowed_types}; database columns: {all_columns:?}")] + DataTypeMismatch { + field_name: String, + allowed_types: String, + unexpected_type: String, + all_columns: String, + }, + #[error( + "too many struct fields: trying to read more columns than expected {0}. All columns: {1:?}" + )] + TooManyStructFields(usize, Vec), + #[error("deserialization error: {0}")] + DeserializationError(#[source] BoxedError), + #[error("deserialize is called for more fields than a struct has")] + DeserializeCallAfterEndOfStruct, #[error("{0}")] Other(BoxedError), } assert_impl_all!(Error: StdError, Send, Sync); +impl From for Error { + fn from(err: clickhouse_rowbinary::error::ParserError) -> Self { + Self::ParserError(Box::new(err)) + } +} + impl From for Error { fn from(error: hyper::Error) -> Self { Self::Network(Box::new(error)) diff --git a/src/lib.rs b/src/lib.rs index 7d02cdca..2eee5f78 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,12 +9,14 @@ use self::{error::Result, http_client::HttpClient}; use std::{collections::HashMap, fmt::Display, sync::Arc}; pub use self::{compression::Compression, row::Row}; +use crate::output_format::OutputFormat; pub use clickhouse_derive::Row; pub mod error; pub mod insert; #[cfg(feature = "inserter")] pub mod inserter; +pub mod output_format; pub mod query; pub mod serde; pub mod sql; @@ -48,6 +50,7 @@ pub struct Client { options: HashMap, headers: HashMap, products_info: Vec, + fetch_format: OutputFormat, } #[derive(Clone)] @@ -83,6 +86,7 @@ impl Client { options: HashMap::new(), headers: HashMap::new(), products_info: Vec::default(), + fetch_format: OutputFormat::default(), } } @@ -222,6 +226,15 @@ impl Client { self } + pub fn with_fetch_format(mut self, format: OutputFormat) -> Self { + self.fetch_format = format; + self + } + + pub fn get_fetch_format(&self) -> OutputFormat { + self.fetch_format.clone() + } + /// Starts a new INSERT statement. /// /// # Panics diff --git a/src/output_format.rs b/src/output_format.rs new file mode 100644 index 00000000..68dbd10f --- /dev/null +++ b/src/output_format.rs @@ -0,0 +1,12 @@ +#[non_exhaustive] +#[derive(Clone)] +pub enum OutputFormat { + RowBinary, + RowBinaryWithNamesAndTypes, +} + +impl Default for OutputFormat { + fn default() -> Self { + Self::RowBinary + } +} diff --git a/src/query.rs b/src/query.rs index 7a76dbd5..af8743c7 100644 --- a/src/query.rs +++ b/src/query.rs @@ -16,6 +16,7 @@ use crate::{ const MAX_QUERY_LEN_TO_USE_GET: usize = 8192; pub use crate::cursors::{BytesCursor, RowCursor}; +use crate::output_format::OutputFormat; #[must_use] #[derive(Clone)] @@ -43,7 +44,7 @@ impl Query { /// [`Identifier`], will be appropriately escaped. /// /// All possible errors will be returned as [`Error::InvalidParams`] - /// during query execution (`execute()`, `fetch()` etc). + /// during query execution (`execute()`, `fetch()`, etc.). /// /// WARNING: This means that the query must not have any extra `?`, even if /// they are in a string literal! Use `??` to have plain `?` in query. @@ -83,11 +84,16 @@ impl Query { /// # Ok(()) } /// ``` pub fn fetch(mut self) -> Result> { + let fetch_format = self.client.get_fetch_format(); + self.sql.bind_fields::(); - self.sql.set_output_format("RowBinary"); + self.sql.set_output_format(match fetch_format { + OutputFormat::RowBinary => "RowBinary", + OutputFormat::RowBinaryWithNamesAndTypes => "RowBinaryWithNamesAndTypes", + }); let response = self.do_execute(true)?; - Ok(RowCursor::new(response)) + Ok(RowCursor::new(response, fetch_format)) } /// Executes the query and returns just a single row. diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index c7c41392..4a5a2beb 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -1,6 +1,7 @@ use std::{convert::TryFrom, mem, str}; use crate::error::{Error, Result}; +use crate::rowbinary::utils::{ensure_size, get_unsigned_leb128}; use bytes::Buf; use serde::{ de::{DeserializeSeed, Deserializer, EnumAccess, SeqAccess, VariantAccess, Visitor}, @@ -20,38 +21,29 @@ pub(crate) fn deserialize_from<'data, T: Deserialize<'data>>(input: &mut &'data /// A deserializer for the RowBinary format. /// /// See https://clickhouse.com/docs/en/interfaces/formats#rowbinary for details. -struct RowBinaryDeserializer<'cursor, 'data> { - input: &'cursor mut &'data [u8], +pub(crate) struct RowBinaryDeserializer<'cursor, 'data> { + pub(crate) input: &'cursor mut &'data [u8], } impl<'data> RowBinaryDeserializer<'_, 'data> { - fn read_vec(&mut self, size: usize) -> Result> { + pub(crate) fn read_vec(&mut self, size: usize) -> Result> { Ok(self.read_slice(size)?.to_vec()) } - fn read_slice(&mut self, size: usize) -> Result<&'data [u8]> { + pub(crate) fn read_slice(&mut self, size: usize) -> Result<&'data [u8]> { ensure_size(&mut self.input, size)?; let slice = &self.input[..size]; self.input.advance(size); Ok(slice) } - fn read_size(&mut self) -> Result { + pub(crate) fn read_size(&mut self) -> Result { let size = get_unsigned_leb128(&mut self.input)?; // TODO: what about another error? usize::try_from(size).map_err(|_| Error::NotEnoughData) } } -#[inline] -fn ensure_size(buffer: impl Buf, size: usize) -> Result<()> { - if buffer.remaining() < size { - Err(Error::NotEnoughData) - } else { - Ok(()) - } -} - macro_rules! impl_num { ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident) => { #[inline] @@ -67,27 +59,16 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { type Error = Error; impl_num!(i8, deserialize_i8, visit_i8, get_i8); - impl_num!(i16, deserialize_i16, visit_i16, get_i16_le); - impl_num!(i32, deserialize_i32, visit_i32, get_i32_le); - impl_num!(i64, deserialize_i64, visit_i64, get_i64_le); - impl_num!(i128, deserialize_i128, visit_i128, get_i128_le); - impl_num!(u8, deserialize_u8, visit_u8, get_u8); - impl_num!(u16, deserialize_u16, visit_u16, get_u16_le); - impl_num!(u32, deserialize_u32, visit_u32, get_u32_le); - impl_num!(u64, deserialize_u64, visit_u64, get_u64_le); - impl_num!(u128, deserialize_u128, visit_u128, get_u128_le); - impl_num!(f32, deserialize_f32, visit_f32, get_f32_le); - impl_num!(f64, deserialize_f64, visit_f64, get_f64_le); #[inline] @@ -318,33 +299,3 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { false } } - -fn get_unsigned_leb128(mut buffer: impl Buf) -> Result { - let mut value = 0u64; - let mut shift = 0; - - loop { - ensure_size(&mut buffer, 1)?; - - let byte = buffer.get_u8(); - value |= (byte as u64 & 0x7f) << shift; - - if byte & 0x80 == 0 { - break; - } - - shift += 7; - if shift > 57 { - // TODO: what about another error? - return Err(Error::NotEnoughData); - } - } - - Ok(value) -} - -#[test] -fn it_deserializes_unsigned_leb128() { - let buf = &[0xe5, 0x8e, 0x26][..]; - assert_eq!(get_unsigned_leb128(buf).unwrap(), 624_485); -} diff --git a/src/rowbinary/de_rbwnat.rs b/src/rowbinary/de_rbwnat.rs new file mode 100644 index 00000000..7a5bed63 --- /dev/null +++ b/src/rowbinary/de_rbwnat.rs @@ -0,0 +1,712 @@ +use crate::error::{Error, Result}; +use crate::rowbinary::de::RowBinaryDeserializer; +use clickhouse_rowbinary::types::{Column, DataTypeNode}; +use serde::de::{DeserializeSeed, SeqAccess, Visitor}; +use serde::{Deserialize, Deserializer}; +use std::fmt::Display; +use std::ops::Deref; +use std::rc::Rc; + +pub(crate) fn deserialize_from_rbwnat<'data, 'cursor, T: Deserialize<'data>>( + input: &mut &'data [u8], + columns: &'cursor [Column], +) -> Result { + println!("[RBWNAT] deserializing with names and types: {:?}", columns); + let mut deserializer = RowBinaryWithNamesAndTypesDeserializer::new(input, columns)?; + T::deserialize(&mut deserializer) +} + +/// Serde method that delegated the value deserialization to [`Deserializer::deserialize_any`]. +#[derive(Clone, Debug, PartialEq)] +enum DelegatedFrom { + Bool, + I8, + I16, + I32, + I64, + I128, + U8, + U16, + U32, + U64, + U128, + F32, + F64, + Char, + Str, + String, + Bytes, + ByteBuf, + Option, + Unit, + UnitStruct, + NewtypeStruct, + Seq, + Tuple, + TupleStruct, + Map, + Struct, + Enum, + Identifier, + IgnoredAny, +} + +impl Default for DelegatedFrom { + fn default() -> Self { + DelegatedFrom::Struct + } +} + +impl Display for DelegatedFrom { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let type_name = match self { + DelegatedFrom::Bool => "bool", + DelegatedFrom::I8 => "i8", + DelegatedFrom::I16 => "i16", + DelegatedFrom::I32 => "i32", + DelegatedFrom::I64 => "i64", + DelegatedFrom::I128 => "i128", + DelegatedFrom::U8 => "u8", + DelegatedFrom::U16 => "u16", + DelegatedFrom::U32 => "u32", + DelegatedFrom::U64 => "u64", + DelegatedFrom::U128 => "u128", + DelegatedFrom::F32 => "f32", + DelegatedFrom::F64 => "f64", + DelegatedFrom::Char => "char", + DelegatedFrom::Str => "&str", + DelegatedFrom::String => "String", + DelegatedFrom::Bytes => "&[u8]", + DelegatedFrom::ByteBuf => "Vec", + DelegatedFrom::Option => "Option", + DelegatedFrom::Unit => "()", + DelegatedFrom::UnitStruct => "unit struct", + DelegatedFrom::NewtypeStruct => "newtype struct", + DelegatedFrom::Seq => "Vec", + DelegatedFrom::Tuple => "tuple", + DelegatedFrom::TupleStruct => "tuple struct", + DelegatedFrom::Map => "map", + DelegatedFrom::Struct => "struct", + DelegatedFrom::Enum => "enum", + DelegatedFrom::Identifier => "identifier", + DelegatedFrom::IgnoredAny => "ignored any", + }; + write!(f, "{}", type_name) + } +} + +#[derive(Clone, Debug)] +enum DeserializerState<'cursor> { + /// At this point, we are either processing a "simple" column (e.g., `UInt32`, `String`, etc.), + /// or starting to process a more complex one (e.g., `Array(T)`, `Map(K, V)`, etc.). + TopLevelColumn(&'cursor Column), + /// Processing a column with a complex type (e.g., `Array(T)`), and we've got what `T` is. + /// We can use this to verify the inner type definition in the struct. + InnerDataType { + column: &'cursor Column, + prev_state: Rc>, + inner_data_type: &'cursor DataTypeNode, + }, + /// Verifying struct fields usually does not make sense more than once. + VerifiedInnerType { + inner_data_type: &'cursor DataTypeNode, + prev_state: Rc>, + }, + /// We are done with all columns and should not try to deserialize anything else. + EndOfStruct, +} + +struct RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { + row_binary: RowBinaryDeserializer<'cursor, 'data>, + state: DeserializerState<'cursor>, + columns: &'cursor [Column], + current_column_idx: usize, + // main usage is to check if the struct field definition is compatible with the expected one + last_delegated_from: DelegatedFrom, + // every deserialization begins from a struct with some name + struct_name: Option<&'static str>, +} + +impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { + #[inline] + fn new( + input: &'cursor mut &'data [u8], + columns: &'cursor [Column], + ) -> Result> { + if columns.is_empty() { + // unlikely - should be validated by the columns parser already + panic!("Zero columns definitions in the response"); + } + Ok(RowBinaryWithNamesAndTypesDeserializer { + row_binary: crate::rowbinary::de::RowBinaryDeserializer { input }, + state: DeserializerState::TopLevelColumn(&columns[0]), + last_delegated_from: DelegatedFrom::default(), + current_column_idx: 0, + struct_name: None, + columns, + }) + } + + #[inline] + fn set_last_delegated_from(&mut self, from: DelegatedFrom) { + if self.last_delegated_from != from { + self.last_delegated_from = from; + } + } + + #[inline] + fn set_struct_name(&mut self, name: &'static str) { + self.struct_name = Some(name); + } + + #[inline] + fn advance_state(&mut self) -> Result<()> { + match self.state { + DeserializerState::TopLevelColumn { .. } => { + self.current_column_idx += 1; + if self.current_column_idx >= self.columns.len() { + self.state = DeserializerState::EndOfStruct; + } else { + let current_col = self.get_current_column()?; + self.state = DeserializerState::TopLevelColumn(current_col); + } + } + DeserializerState::InnerDataType { + inner_data_type, .. + } => { + self.state = DeserializerState::VerifiedInnerType { + prev_state: Rc::new(self.state.clone()), + inner_data_type, + }; + } + DeserializerState::EndOfStruct => { + panic!("trying to advance the current column index after full deserialization"); + } + // skipping this when processing inner data types with more than one nesting level + _ => {} + } + Ok(()) + } + + #[inline] + fn set_inner_data_type_state( + &self, + inner_data_type: &'cursor DataTypeNode, + ) -> DeserializerState<'cursor> { + match self.state { + DeserializerState::TopLevelColumn(column, ..) + | DeserializerState::InnerDataType { column, .. } => DeserializerState::InnerDataType { + prev_state: Rc::new(self.state.clone()), + inner_data_type, + column, + }, + _ => { + panic!("to_inner called on invalid state"); + } + } + } + + #[inline] + fn set_previous_state(&mut self) { + match &self.state { + DeserializerState::InnerDataType { prev_state, .. } + | DeserializerState::VerifiedInnerType { prev_state, .. } => { + self.state = prev_state.deref().clone() + } + _ => panic!("to_prev_state called on invalid state"), + } + } + + #[inline] + fn get_current_data_type(&self) -> Result<&'cursor DataTypeNode> { + match self.state { + DeserializerState::TopLevelColumn(col, ..) => Ok(&col.data_type), + DeserializerState::InnerDataType { + inner_data_type, .. + } => Ok(inner_data_type), + DeserializerState::VerifiedInnerType { + inner_data_type, .. + } => Ok(inner_data_type), + DeserializerState::EndOfStruct => Err(Error::DeserializeCallAfterEndOfStruct), + } + } + + #[inline] + fn get_current_column(&mut self) -> Result<&'cursor Column> { + if self.current_column_idx >= self.columns.len() { + return Err(Error::TooManyStructFields( + self.current_column_idx, + Vec::from(self.columns), + )); + } + let col = &self.columns[self.current_column_idx]; + Ok(col) + } +} + +impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer<'_, 'data> { + type Error = Error; + + #[inline] + fn deserialize_any(self, visitor: V) -> Result + where + V: Visitor<'data>, + { + macro_rules! rbwnat_de_simple_with_type_check { + ($delegate:ident, $compatible:expr) => {{ + if !$compatible.contains(&self.last_delegated_from) { + let column = self.get_current_column()?; + let field_name = match self.struct_name { + Some(struct_name) => format!("{}.{}", struct_name, column.name), + None => column.name.to_string(), + }; + let allowed_types = $compatible.map(|x| x.to_string()).join(", "); + let all_columns = self + .columns + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "); + let unexpected_type = self.last_delegated_from.to_string(); + return Err(Error::DataTypeMismatch { + field_name, + allowed_types, + unexpected_type, + all_columns, + }); + } + self.row_binary.$delegate(visitor) + }}; + } + + let data_type = self.get_current_data_type()?; + let result = match data_type { + DataTypeNode::Bool => rbwnat_de_simple_with_type_check!( + deserialize_bool, + [DelegatedFrom::Bool, DelegatedFrom::U8, DelegatedFrom::I8] + ), + DataTypeNode::UInt8 => { + rbwnat_de_simple_with_type_check!(deserialize_u8, [DelegatedFrom::U8]) + } + DataTypeNode::Int8 => { + rbwnat_de_simple_with_type_check!(deserialize_i8, [DelegatedFrom::I8]) + } + DataTypeNode::Int16 => { + rbwnat_de_simple_with_type_check!(deserialize_i16, [DelegatedFrom::I16]) + } + DataTypeNode::Int32 => { + rbwnat_de_simple_with_type_check!(deserialize_i32, [DelegatedFrom::I32]) + } + DataTypeNode::Int64 => { + rbwnat_de_simple_with_type_check!(deserialize_i64, [DelegatedFrom::I64]) + } + DataTypeNode::Int128 => { + rbwnat_de_simple_with_type_check!(deserialize_i128, [DelegatedFrom::I128]) + } + DataTypeNode::UInt16 => { + rbwnat_de_simple_with_type_check!(deserialize_u16, [DelegatedFrom::U16]) + } + DataTypeNode::UInt32 => { + rbwnat_de_simple_with_type_check!(deserialize_u32, [DelegatedFrom::U32]) + } + DataTypeNode::UInt64 => { + rbwnat_de_simple_with_type_check!(deserialize_u64, [DelegatedFrom::U64]) + } + DataTypeNode::UInt128 => { + rbwnat_de_simple_with_type_check!(deserialize_u128, [DelegatedFrom::U128]) + } + DataTypeNode::Float32 => { + rbwnat_de_simple_with_type_check!(deserialize_f32, [DelegatedFrom::F32]) + } + DataTypeNode::Float64 => { + rbwnat_de_simple_with_type_check!(deserialize_f64, [DelegatedFrom::F64]) + } + DataTypeNode::String => { + rbwnat_de_simple_with_type_check!( + deserialize_str, + [DelegatedFrom::Str, DelegatedFrom::String] + ) + } + DataTypeNode::Array(inner_type) => { + let len = self.row_binary.read_size()?; + self.set_inner_data_type_state(inner_type); + + struct AnyArrayAccess<'de, 'cursor, 'data> { + deserializer: &'de mut RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data>, + remaining: usize, + } + + impl<'data> SeqAccess<'data> for AnyArrayAccess<'_, '_, 'data> { + type Error = Error; + + fn next_element_seed( + &mut self, + seed: T, + ) -> Result, Self::Error> + where + T: DeserializeSeed<'data>, + { + if self.remaining == 0 { + return Ok(None); + } + + self.remaining -= 1; + seed.deserialize(&mut *self.deserializer).map(Some) + } + + fn size_hint(&self) -> Option { + Some(self.remaining) + } + } + + let result = visitor.visit_seq(AnyArrayAccess { + deserializer: self, + remaining: len, + }); + // if we are processing `Array(String)`, the state has `String` as expected type + // revert it back to `Array(String)` + self.set_previous_state(); + result + } + _ => panic!("unsupported type for deserialize_any: {:?}", self.columns), + }; + result + .map_err(|e| Error::DeserializationError(Box::new(e))) + .and_then(|value| { + self.advance_state()?; + Ok(value) + }) + } + + #[inline] + fn deserialize_bool(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Bool); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_i8(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::I8); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_i16(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::I16); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_i32(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::I32); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_i64(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::I64); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_i128(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::I128); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_u8(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::U8); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_u16(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::U16); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_u32(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::U32); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_u64(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::U64); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_u128(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::U128); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_f32(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::F32); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_f64(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::F64); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_char(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Char); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_str(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Str); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_string(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::String); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_bytes(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Bytes); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_byte_buf(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::ByteBuf); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_option(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Option); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_unit(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Unit); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_unit_struct( + self, + _name: &'static str, + visitor: V, + ) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::UnitStruct); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::NewtypeStruct); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_seq(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Seq); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_tuple( + self, + _len: usize, + visitor: V, + ) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Tuple); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::TupleStruct); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_map(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Map); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_struct( + self, + name: &'static str, + fields: &'static [&'static str], + visitor: V, + ) -> std::result::Result + where + V: Visitor<'data>, + { + struct StructAccess<'de, 'cursor, 'data> { + deserializer: &'de mut RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data>, + len: usize, + } + + impl<'data> SeqAccess<'data> for StructAccess<'_, '_, 'data> { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'data>, + { + if self.len > 0 { + self.len -= 1; + let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; + Ok(Some(value)) + } else { + Ok(None) + } + } + + fn size_hint(&self) -> Option { + Some(self.len) + } + } + + self.set_struct_name(name); + self.set_last_delegated_from(DelegatedFrom::Struct); + visitor.visit_seq(StructAccess { + deserializer: self, + len: fields.len(), + }) + } + + #[inline] + fn deserialize_enum( + self, + _name: &'static str, + _variants: &'static [&'static str], + visitor: V, + ) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Enum); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_identifier(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::Identifier); + self.deserialize_any(visitor) + } + + #[inline] + fn deserialize_ignored_any(self, visitor: V) -> std::result::Result + where + V: Visitor<'data>, + { + self.set_last_delegated_from(DelegatedFrom::IgnoredAny); + self.deserialize_any(visitor) + } + + #[inline] + fn is_human_readable(&self) -> bool { + false + } +} diff --git a/src/rowbinary/mod.rs b/src/rowbinary/mod.rs index dbdb672e..0bba1973 100644 --- a/src/rowbinary/mod.rs +++ b/src/rowbinary/mod.rs @@ -1,7 +1,10 @@ pub(crate) use de::deserialize_from; +pub(crate) use de_rbwnat::deserialize_from_rbwnat; pub(crate) use ser::serialize_into; mod de; +mod de_rbwnat; mod ser; #[cfg(test)] mod tests; +mod utils; diff --git a/src/rowbinary/utils.rs b/src/rowbinary/utils.rs new file mode 100644 index 00000000..fc2db7e9 --- /dev/null +++ b/src/rowbinary/utils.rs @@ -0,0 +1,41 @@ +use crate::error::Error; +use bytes::Buf; + +#[inline] +pub(crate) fn ensure_size(buffer: impl Buf, size: usize) -> crate::error::Result<()> { + if buffer.remaining() < size { + Err(Error::NotEnoughData) + } else { + Ok(()) + } +} + +pub(crate) fn get_unsigned_leb128(mut buffer: impl Buf) -> crate::error::Result { + let mut value = 0u64; + let mut shift = 0; + + loop { + ensure_size(&mut buffer, 1)?; + + let byte = buffer.get_u8(); + value |= (byte as u64 & 0x7f) << shift; + + if byte & 0x80 == 0 { + break; + } + + shift += 7; + if shift > 57 { + // TODO: what about another error? + return Err(Error::NotEnoughData); + } + } + + Ok(value) +} + +#[test] +fn it_deserializes_unsigned_leb128() { + let buf = &[0xe5, 0x8e, 0x26][..]; + assert_eq!(get_unsigned_leb128(buf).unwrap(), 624_485); +} diff --git a/tests/it/main.rs b/tests/it/main.rs index b868e988..93ebbff2 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -65,7 +65,7 @@ mod ip; mod mock; mod nested; mod query; -mod rbwnat; +mod rbwnat_smoke; mod time; mod user_agent; mod uuid; diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs deleted file mode 100644 index 7c132abf..00000000 --- a/tests/it/rbwnat.rs +++ /dev/null @@ -1,94 +0,0 @@ -use clickhouse_rowbinary::header::parse_names_and_types_header; -use clickhouse_rowbinary::types::{Column, DataType}; - -#[tokio::test] -async fn test_header_parsing() { - let client = prepare_database!(); - client - .query( - " - CREATE OR REPLACE TABLE visits - ( - CounterID UInt32, - StartDate Date, - Sign Int8, - IsNew UInt8, - VisitID UInt64, - UserID UInt64, - Goals Nested - ( - ID UInt32, - Serial UInt32, - EventTime DateTime, - Price Int64, - OrderID String, - CurrencyID UInt32 - ) - ) ENGINE = MergeTree ORDER BY () - ", - ) - .execute() - .await - .unwrap(); - - let mut cursor = client - .query("SELECT * FROM visits LIMIT 0") - .fetch_bytes("RowBinaryWithNamesAndTypes") - .unwrap(); - - let mut data = cursor.collect().await.unwrap(); - let result = parse_names_and_types_header(&mut data).unwrap(); - assert_eq!( - result, - vec![ - Column { - name: "CounterID".to_string(), - data_type: DataType::UInt32 - }, - Column { - name: "StartDate".to_string(), - data_type: DataType::Date - }, - Column { - name: "Sign".to_string(), - data_type: DataType::Int8 - }, - Column { - name: "IsNew".to_string(), - data_type: DataType::UInt8 - }, - Column { - name: "VisitID".to_string(), - data_type: DataType::UInt64 - }, - Column { - name: "UserID".to_string(), - data_type: DataType::UInt64 - }, - Column { - name: "Goals.ID".to_string(), - data_type: DataType::Array(Box::new(DataType::UInt32)) - }, - Column { - name: "Goals.Serial".to_string(), - data_type: DataType::Array(Box::new(DataType::UInt32)) - }, - Column { - name: "Goals.EventTime".to_string(), - data_type: DataType::Array(Box::new(DataType::DateTime(None))) - }, - Column { - name: "Goals.Price".to_string(), - data_type: DataType::Array(Box::new(DataType::Int64)) - }, - Column { - name: "Goals.OrderID".to_string(), - data_type: DataType::Array(Box::new(DataType::String)) - }, - Column { - name: "Goals.CurrencyID".to_string(), - data_type: DataType::Array(Box::new(DataType::UInt32)) - } - ] - ); -} diff --git a/tests/it/rbwnat_smoke.rs b/tests/it/rbwnat_smoke.rs new file mode 100644 index 00000000..4061963b --- /dev/null +++ b/tests/it/rbwnat_smoke.rs @@ -0,0 +1,334 @@ +use clickhouse::error::Error; +use clickhouse::output_format::OutputFormat; +use clickhouse_derive::Row; +use clickhouse_rowbinary::parse_columns_header; +use clickhouse_rowbinary::types::{Column, DataTypeNode}; +use serde::{Deserialize, Serialize}; +use time::OffsetDateTime; + +#[tokio::test] +async fn test_header_parsing() { + let client = prepare_database!(); + client + .query( + " + CREATE OR REPLACE TABLE visits + ( + CounterID UInt32, + StartDate Date, + Sign Int8, + IsNew UInt8, + VisitID UInt64, + UserID UInt64, + Goals Nested + ( + ID UInt32, + Serial UInt32, + EventTime DateTime, + Price Int64, + OrderID String, + CurrencyID UInt32 + ) + ) ENGINE = MergeTree ORDER BY () + ", + ) + .execute() + .await + .unwrap(); + + let mut cursor = client + .query("SELECT * FROM visits LIMIT 0") + .fetch_bytes("RowBinaryWithNamesAndTypes") + .unwrap(); + + let data = cursor.collect().await.unwrap(); + let result = parse_columns_header(&mut &data[..]).unwrap(); + assert_eq!( + result, + vec![ + Column { + name: "CounterID".to_string(), + data_type: DataTypeNode::UInt32 + }, + Column { + name: "StartDate".to_string(), + data_type: DataTypeNode::Date + }, + Column { + name: "Sign".to_string(), + data_type: DataTypeNode::Int8 + }, + Column { + name: "IsNew".to_string(), + data_type: DataTypeNode::UInt8 + }, + Column { + name: "VisitID".to_string(), + data_type: DataTypeNode::UInt64 + }, + Column { + name: "UserID".to_string(), + data_type: DataTypeNode::UInt64 + }, + Column { + name: "Goals.ID".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)) + }, + Column { + name: "Goals.Serial".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)) + }, + Column { + name: "Goals.EventTime".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::DateTime(None))) + }, + Column { + name: "Goals.Price".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::Int64)) + }, + Column { + name: "Goals.OrderID".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::String)) + }, + Column { + name: "Goals.CurrencyID".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)) + } + ] + ); +} + +#[tokio::test] +async fn test_basic_types_deserialization() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + uint8_val: u8, + uint16_val: u16, + uint32_val: u32, + uint64_val: u64, + uint128_val: u128, + int8_val: i8, + int16_val: i16, + int32_val: i32, + int64_val: i64, + int128_val: i128, + float32_val: f32, + float64_val: f64, + string_val: String, + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let result = client + .query( + " + SELECT + 255 :: UInt8 AS uint8_val, + 65535 :: UInt16 AS uint16_val, + 4294967295 :: UInt32 AS uint32_val, + 18446744073709551615 :: UInt64 AS uint64_val, + 340282366920938463463374607431768211455 :: UInt128 AS uint128_val, + -128 :: Int8 AS int8_val, + -32768 :: Int16 AS int16_val, + -2147483648 :: Int32 AS int32_val, + -9223372036854775808 :: Int64 AS int64_val, + -170141183460469231731687303715884105728 :: Int128 AS int128_val, + 42.0 :: Float32 AS float32_val, + 144.0 :: Float64 AS float64_val, + 'test' :: String AS string_val + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + uint8_val: 255, + uint16_val: 65535, + uint32_val: 4294967295, + uint64_val: 18446744073709551615, + uint128_val: 340282366920938463463374607431768211455, + int8_val: -128, + int16_val: -32768, + int32_val: -2147483648, + int64_val: -9223372036854775808, + int128_val: -170141183460469231731687303715884105728, + float32_val: 42.0, + float64_val: 144.0, + string_val: "test".to_string(), + } + ); +} + +#[tokio::test] +async fn test_array_deserialization() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + id: u16, + one_dim_array: Vec, + two_dim_array: Vec>, + three_dim_array: Vec>>, + description: String, + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let result = client + .query( + " + SELECT + 42 :: UInt16 AS id, + [1, 2] :: Array(UInt32) AS one_dim_array, + [[1, 2], [3, 4]] :: Array(Array(Int64)) AS two_dim_array, + [[[1.1, 2.2], [3.3, 4.4]], [], [[5.5, 6.6], [7.7, 8.8]]] :: Array(Array(Array(Float64))) AS three_dim_array, + 'foobar' :: String AS description + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + id: 42, + one_dim_array: vec![1, 2], + two_dim_array: vec![vec![1, 2], vec![3, 4]], + three_dim_array: vec![ + vec![vec![1.1, 2.2], vec![3.3, 4.4]], + vec![], + vec![vec![5.5, 6.6], vec![7.7, 8.8]] + ], + description: "foobar".to_string(), + } + ); +} + +#[tokio::test] +async fn test_default_types_validation_nullable() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + n: Option, + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let result = client + .query("SELECT true AS b, 144 :: Int32 AS n2") + .fetch_one::() + .await; + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::DataTypeMismatch { .. } + )); + + // FIXME: lack of derive PartialEq for Error prevents proper assertion + // assert_eq!(result, Error::DataTypeMismatch { + // column_name: "n".to_string(), + // expected_type: "Nullable".to_string(), + // actual_type: "Bool".to_string(), + // columns: vec![...], + // }); +} + +#[tokio::test] +#[cfg(feature = "time")] +async fn test_default_types_validation_custom_serde() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + #[serde(with = "clickhouse::serde::time::datetime64::millis")] + n1: OffsetDateTime, // underlying is still Int64; should not compose it from two (U)Int32 + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let result = client + .query("SELECT 42 :: UInt32 AS n1, 144 :: Int32 AS n2") + .fetch_one::() + .await; + + assert!(result.is_err()); + println!("{:?}", result); + assert!(matches!( + result.unwrap_err(), + Error::DataTypeMismatch { .. } + )); + + // FIXME: lack of derive PartialEq for Error prevents proper assertion + // assert_eq!(result, Error::DataTypeMismatch { + // column_name: "n1".to_string(), + // expected_type: "Int64".to_string(), + // actual_type: "Int32".to_string(), + // columns: vec![...], + // }); +} + +#[tokio::test] +async fn test_too_many_struct_fields() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + b: u32, + c: u32, + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let result = client + .query("SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS b") + .fetch_one::() + .await; + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + Error::DeserializeCallAfterEndOfStruct { .. } + )); +} + +#[tokio::test] +async fn test_serde_skip_deserializing() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + #[serde(skip_deserializing)] + b: u32, + c: u32, + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let result = client + .query("SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS c") + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + a: 42, + b: 0, // default value + c: 144, + } + ); +} + +// FIXME: RBWNAT should allow for tracking the order of fields in the struct and in the database! +#[tokio::test] +async fn test_different_struct_field_order() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + c: String, + a: u32, + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let result = client + .query("SELECT 42 :: UInt32 AS a, 'foo' :: String AS c") + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + a: 42, + c: "foo".to_string(), + } + ); +} From 5a6029582829a07ba8e864458fdf34ad81d89bbb Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Wed, 14 May 2025 19:04:51 +0200 Subject: [PATCH 04/54] RBWNAT deserializer - more types WIP --- Cargo.toml | 2 +- rowbinary/src/decoders.rs | 2 + rowbinary/src/types.rs | 14 +- src/cursors/row.rs | 49 +++---- src/error.rs | 2 - src/output_format.rs | 11 ++ src/rowbinary/de_rbwnat.rs | 257 ++++++++++++++++++++----------------- tests/it/rbwnat_smoke.rs | 165 ++++++++++++++++++++++-- 8 files changed, 345 insertions(+), 157 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a29171a6..90d0f39d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -140,6 +140,6 @@ serde_bytes = "0.11.4" serde_json = "1" serde_repr = "0.1.7" uuid = { version = "1", features = ["v4", "serde"] } -time = { version = "0.3.17", features = ["macros", "rand"] } +time = { version = "0.3.17", features = ["macros", "rand", "parsing"] } fixnum = { version = "0.9.2", features = ["serde", "i32", "i64", "i128"] } rand = { version = "0.8.5", features = ["small_rng"] } diff --git a/rowbinary/src/decoders.rs b/rowbinary/src/decoders.rs index dcd2f9ea..61f2f974 100644 --- a/rowbinary/src/decoders.rs +++ b/rowbinary/src/decoders.rs @@ -2,7 +2,9 @@ use crate::error::ParserError; use crate::leb128::decode_leb128; use bytes::Buf; +#[inline] pub(crate) fn decode_string(buffer: &mut &[u8]) -> Result { + // println!("[decode_string] buffer: {:?}", buffer); let length = decode_leb128(buffer)? as usize; if length == 0 { return Ok("".to_string()); diff --git a/rowbinary/src/types.rs b/rowbinary/src/types.rs index 0a477905..bcbdbc34 100644 --- a/rowbinary/src/types.rs +++ b/rowbinary/src/types.rs @@ -24,39 +24,47 @@ impl Display for Column { #[non_exhaustive] pub enum DataTypeNode { Bool, + UInt8, UInt16, UInt32, UInt64, UInt128, UInt256, + Int8, Int16, Int32, Int64, Int128, Int256, + Float32, Float64, BFloat16, + Decimal(u8, u8, DecimalSize), // Scale, Precision, 32 | 64 | 128 | 256 + String, + FixedString(usize), UUID, + Date, Date32, DateTime(Option), // Optional timezone DateTime64(DateTimePrecision, Option), // Precision and optional timezone + IPv4, IPv6, Nullable(Box), + LowCardinality(Box), + Array(Box), Tuple(Vec), Map(Box, Box), - LowCardinality(Box), - Decimal(u8, u8, DecimalSize), Enum(EnumType, HashMap), + AggregateFunction(String, Vec), - FixedString(usize), Variant(Vec), Dynamic, diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 20ec0f5e..0e8e8218 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -44,37 +44,38 @@ impl RowCursor { T: Deserialize<'b>, { loop { - if self.bytes.remaining() > 0 { - let mut slice = super::workaround_51132(self.bytes.slice()); - match self.format { - OutputFormat::RowBinary => match rowbinary::deserialize_from(&mut slice) { - Ok(value) => { - self.bytes.set_remaining(slice.len()); - return Ok(Some(value)); - } - Err(Error::NotEnoughData) => {} - Err(err) => return Err(err), - }, - OutputFormat::RowBinaryWithNamesAndTypes => match self.columns.as_ref() { - // FIXME: move this branch to new? - None => { + let mut slice = super::workaround_51132(self.bytes.slice()); + match self.format { + OutputFormat::RowBinary => match rowbinary::deserialize_from(&mut slice) { + Ok(value) => { + self.bytes.set_remaining(slice.len()); + return Ok(Some(value)); + } + Err(Error::NotEnoughData) => {} + Err(err) => return Err(err), + }, + OutputFormat::RowBinaryWithNamesAndTypes => match self.columns.as_ref() { + // FIXME: move this branch to new? + None => { + if slice.len() > 0 { let columns = parse_columns_header(&mut slice)?; self.bytes.set_remaining(slice.len()); self.columns = Some(columns); } - Some(columns) => { - match rowbinary::deserialize_from_rbwnat(&mut slice, columns) { - Ok(value) => { - self.bytes.set_remaining(slice.len()); - return Ok(Some(value)); - } - Err(Error::NotEnoughData) => {} - Err(err) => return Err(err), + } + Some(columns) => { + match rowbinary::deserialize_from_rbwnat(&mut slice, columns) { + Ok(value) => { + self.bytes.set_remaining(slice.len()); + return Ok(Some(value)); } + Err(Error::NotEnoughData) => {} + Err(err) => return Err(err), } - }, - } + } + }, } + // } match self.raw.next().await? { Some(chunk) => self.bytes.extend(chunk), diff --git a/src/error.rs b/src/error.rs index 852f3d5b..acd5927e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -57,8 +57,6 @@ pub enum Error { "too many struct fields: trying to read more columns than expected {0}. All columns: {1:?}" )] TooManyStructFields(usize, Vec), - #[error("deserialization error: {0}")] - DeserializationError(#[source] BoxedError), #[error("deserialize is called for more fields than a struct has")] DeserializeCallAfterEndOfStruct, #[error("{0}")] diff --git a/src/output_format.rs b/src/output_format.rs index 68dbd10f..4f2121ca 100644 --- a/src/output_format.rs +++ b/src/output_format.rs @@ -10,3 +10,14 @@ impl Default for OutputFormat { Self::RowBinary } } + +impl std::fmt::Display for OutputFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OutputFormat::RowBinary => write!(f, "RowBinary"), + OutputFormat::RowBinaryWithNamesAndTypes => { + write!(f, "RowBinaryWithNamesAndTypes") + } + } + } +} diff --git a/src/rowbinary/de_rbwnat.rs b/src/rowbinary/de_rbwnat.rs index 7a5bed63..9d92f2cd 100644 --- a/src/rowbinary/de_rbwnat.rs +++ b/src/rowbinary/de_rbwnat.rs @@ -11,9 +11,11 @@ pub(crate) fn deserialize_from_rbwnat<'data, 'cursor, T: Deserialize<'data>>( input: &mut &'data [u8], columns: &'cursor [Column], ) -> Result { - println!("[RBWNAT] deserializing with names and types: {:?}", columns); + // println!("[RBWNAT] deserializing with names and types: {:?}, input size: {}", columns, input.len()); let mut deserializer = RowBinaryWithNamesAndTypesDeserializer::new(input, columns)?; - T::deserialize(&mut deserializer) + let value = T::deserialize(&mut deserializer); + // println!("Remaining input size: {}", input.len()); + value } /// Serde method that delegated the value deserialization to [`Deserializer::deserialize_any`]. @@ -107,16 +109,11 @@ enum DeserializerState<'cursor> { prev_state: Rc>, inner_data_type: &'cursor DataTypeNode, }, - /// Verifying struct fields usually does not make sense more than once. - VerifiedInnerType { - inner_data_type: &'cursor DataTypeNode, - prev_state: Rc>, - }, /// We are done with all columns and should not try to deserialize anything else. EndOfStruct, } -struct RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { +pub(crate) struct RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { row_binary: RowBinaryDeserializer<'cursor, 'data>, state: DeserializerState<'cursor>, columns: &'cursor [Column], @@ -125,6 +122,7 @@ struct RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { last_delegated_from: DelegatedFrom, // every deserialization begins from a struct with some name struct_name: Option<&'static str>, + struct_fields: Option<&'static [&'static str]>, } impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { @@ -138,11 +136,12 @@ impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { panic!("Zero columns definitions in the response"); } Ok(RowBinaryWithNamesAndTypesDeserializer { - row_binary: crate::rowbinary::de::RowBinaryDeserializer { input }, + row_binary: RowBinaryDeserializer { input }, state: DeserializerState::TopLevelColumn(&columns[0]), last_delegated_from: DelegatedFrom::default(), current_column_idx: 0, struct_name: None, + struct_fields: None, columns, }) } @@ -156,12 +155,23 @@ impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { #[inline] fn set_struct_name(&mut self, name: &'static str) { - self.struct_name = Some(name); + // TODO: nested structs support? + if self.struct_name.is_none() { + self.struct_name = Some(name); + } + } + + #[inline] + fn set_struct_fields(&mut self, fields: &'static [&'static str]) { + // TODO: nested structs support? + if self.struct_fields.is_none() { + self.struct_fields = Some(fields); + } } #[inline] fn advance_state(&mut self) -> Result<()> { - match self.state { + match &self.state { DeserializerState::TopLevelColumn { .. } => { self.current_column_idx += 1; if self.current_column_idx >= self.columns.len() { @@ -171,14 +181,6 @@ impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { self.state = DeserializerState::TopLevelColumn(current_col); } } - DeserializerState::InnerDataType { - inner_data_type, .. - } => { - self.state = DeserializerState::VerifiedInnerType { - prev_state: Rc::new(self.state.clone()), - inner_data_type, - }; - } DeserializerState::EndOfStruct => { panic!("trying to advance the current column index after full deserialization"); } @@ -189,17 +191,16 @@ impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { } #[inline] - fn set_inner_data_type_state( - &self, - inner_data_type: &'cursor DataTypeNode, - ) -> DeserializerState<'cursor> { + fn set_inner_data_type_state(&mut self, inner_data_type: &'cursor DataTypeNode) { match self.state { DeserializerState::TopLevelColumn(column, ..) - | DeserializerState::InnerDataType { column, .. } => DeserializerState::InnerDataType { - prev_state: Rc::new(self.state.clone()), - inner_data_type, - column, - }, + | DeserializerState::InnerDataType { column, .. } => { + self.state = DeserializerState::InnerDataType { + prev_state: Rc::new(self.state.clone()), + inner_data_type, + column, + } + } _ => { panic!("to_inner called on invalid state"); } @@ -209,8 +210,7 @@ impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { #[inline] fn set_previous_state(&mut self) { match &self.state { - DeserializerState::InnerDataType { prev_state, .. } - | DeserializerState::VerifiedInnerType { prev_state, .. } => { + DeserializerState::InnerDataType { prev_state, .. } => { self.state = prev_state.deref().clone() } _ => panic!("to_prev_state called on invalid state"), @@ -224,9 +224,6 @@ impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { DeserializerState::InnerDataType { inner_data_type, .. } => Ok(inner_data_type), - DeserializerState::VerifiedInnerType { - inner_data_type, .. - } => Ok(inner_data_type), DeserializerState::EndOfStruct => Err(Error::DeserializeCallAfterEndOfStruct), } } @@ -242,6 +239,44 @@ impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { let col = &self.columns[self.current_column_idx]; Ok(col) } + + #[inline] + fn check_data_type_is_allowed(&mut self, allowed: &[DelegatedFrom]) -> Result<()> { + if !allowed.contains(&self.last_delegated_from) { + let column = self.get_current_column()?; + let field_name = match self.struct_name { + Some(struct_name) => format!("{}.{}", struct_name, column.name), + None => column.name.to_string(), + }; + let allowed_types = allowed + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "); + let all_columns = self + .columns + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "); + let unexpected_type = self.last_delegated_from.to_string(); + Err(Error::DataTypeMismatch { + field_name, + allowed_types, + unexpected_type, + all_columns, + }) + } else { + Ok(()) + } + } +} + +macro_rules! rbwnat_deserialize_any { + ($self:ident, $delegated_from:expr, $visitor:ident) => {{ + $self.set_last_delegated_from($delegated_from); + $self.deserialize_any($visitor) + }}; } impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer<'_, 'data> { @@ -254,31 +289,12 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< { macro_rules! rbwnat_de_simple_with_type_check { ($delegate:ident, $compatible:expr) => {{ - if !$compatible.contains(&self.last_delegated_from) { - let column = self.get_current_column()?; - let field_name = match self.struct_name { - Some(struct_name) => format!("{}.{}", struct_name, column.name), - None => column.name.to_string(), - }; - let allowed_types = $compatible.map(|x| x.to_string()).join(", "); - let all_columns = self - .columns - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "); - let unexpected_type = self.last_delegated_from.to_string(); - return Err(Error::DataTypeMismatch { - field_name, - allowed_types, - unexpected_type, - all_columns, - }); - } + self.check_data_type_is_allowed(&$compatible)?; self.row_binary.$delegate(visitor) }}; } + println!("{} state: {:?}", self.last_delegated_from, self.state); let data_type = self.get_current_data_type()?; let result = match data_type { DataTypeNode::Bool => rbwnat_de_simple_with_type_check!( @@ -327,7 +343,36 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< [DelegatedFrom::Str, DelegatedFrom::String] ) } + DataTypeNode::FixedString(len) => match self.last_delegated_from { + DelegatedFrom::Bytes => visitor.visit_bytes(self.row_binary.read_slice(*len)?), + DelegatedFrom::ByteBuf => visitor.visit_byte_buf(self.row_binary.read_vec(*len)?), + _ => unreachable!(), + }, + DataTypeNode::UUID => { + rbwnat_de_simple_with_type_check!( + deserialize_str, + [DelegatedFrom::Str, DelegatedFrom::String] + ) + } + DataTypeNode::Date => { + rbwnat_de_simple_with_type_check!(deserialize_u16, [DelegatedFrom::U16]) + } + DataTypeNode::Date32 => { + rbwnat_de_simple_with_type_check!(deserialize_i32, [DelegatedFrom::I32]) + } + DataTypeNode::DateTime { .. } => { + rbwnat_de_simple_with_type_check!(deserialize_u32, [DelegatedFrom::U32]) + } + DataTypeNode::DateTime64 { .. } => { + rbwnat_de_simple_with_type_check!(deserialize_i64, [DelegatedFrom::I64]) + } + DataTypeNode::IPv4 => { + rbwnat_de_simple_with_type_check!(deserialize_u32, [DelegatedFrom::U32]) + } + DataTypeNode::IPv6 => self.row_binary.deserialize_tuple(16, visitor), + DataTypeNode::Array(inner_type) => { + self.check_data_type_is_allowed(&[DelegatedFrom::Seq])?; let len = self.row_binary.read_size()?; self.set_inner_data_type_state(inner_type); @@ -368,14 +413,17 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< self.set_previous_state(); result } + // DataTypeNode::Nullable(inner_type) => { + // self.check_data_type_is_allowed(&[DelegatedFrom::Option])?; + // self.set_inner_data_type_state(inner_type); + // self.row_binary.deserialize_option(visitor) + // }, _ => panic!("unsupported type for deserialize_any: {:?}", self.columns), }; - result - .map_err(|e| Error::DeserializationError(Box::new(e))) - .and_then(|value| { - self.advance_state()?; - Ok(value) - }) + result.and_then(|value| { + self.advance_state()?; + Ok(value) + }) } #[inline] @@ -383,8 +431,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Bool); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Bool, visitor) } #[inline] @@ -392,8 +439,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::I8); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::I8, visitor) } #[inline] @@ -401,8 +447,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::I16); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::I16, visitor) } #[inline] @@ -410,8 +455,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::I32); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::I32, visitor) } #[inline] @@ -419,8 +463,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::I64); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::I64, visitor) } #[inline] @@ -428,8 +471,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::I128); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::I128, visitor) } #[inline] @@ -437,8 +479,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::U8); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::U8, visitor) } #[inline] @@ -446,8 +487,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::U16); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::U16, visitor) } #[inline] @@ -455,8 +495,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::U32); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::U32, visitor) } #[inline] @@ -464,8 +503,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::U64); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::U64, visitor) } #[inline] @@ -473,8 +511,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::U128); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::U128, visitor) } #[inline] @@ -482,8 +519,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::F32); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::F32, visitor) } #[inline] @@ -491,8 +527,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::F64); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::F64, visitor) } #[inline] @@ -500,8 +535,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Char); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Char, visitor) } #[inline] @@ -509,8 +543,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Str); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Str, visitor) } #[inline] @@ -518,8 +551,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::String); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::String, visitor) } #[inline] @@ -527,8 +559,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Bytes); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Bytes, visitor) } #[inline] @@ -536,8 +567,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::ByteBuf); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::ByteBuf, visitor) } #[inline] @@ -545,8 +575,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Option); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Option, visitor) } #[inline] @@ -554,8 +583,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Unit); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Unit, visitor) } #[inline] @@ -567,8 +595,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::UnitStruct); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::UnitStruct, visitor) } #[inline] @@ -580,8 +607,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::NewtypeStruct); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::NewtypeStruct, visitor) } #[inline] @@ -589,8 +615,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Seq); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Seq, visitor) } #[inline] @@ -602,8 +627,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Tuple); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Tuple, visitor) } #[inline] @@ -616,8 +640,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::TupleStruct); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::TupleStruct, visitor) } #[inline] @@ -625,8 +648,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Map); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Map, visitor) } #[inline] @@ -666,7 +688,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< } self.set_struct_name(name); - self.set_last_delegated_from(DelegatedFrom::Struct); + self.set_struct_fields(fields); visitor.visit_seq(StructAccess { deserializer: self, len: fields.len(), @@ -683,8 +705,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Enum); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Enum, visitor) } #[inline] @@ -692,8 +713,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::Identifier); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::Identifier, visitor) } #[inline] @@ -701,8 +721,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - self.set_last_delegated_from(DelegatedFrom::IgnoredAny); - self.deserialize_any(visitor) + rbwnat_deserialize_any!(self, DelegatedFrom::IgnoredAny, visitor) } #[inline] diff --git a/tests/it/rbwnat_smoke.rs b/tests/it/rbwnat_smoke.rs index 4061963b..128b2e13 100644 --- a/tests/it/rbwnat_smoke.rs +++ b/tests/it/rbwnat_smoke.rs @@ -4,6 +4,9 @@ use clickhouse_derive::Row; use clickhouse_rowbinary::parse_columns_header; use clickhouse_rowbinary::types::{Column, DataTypeNode}; use serde::{Deserialize, Serialize}; +use std::str::FromStr; +use time::format_description::well_known::Iso8601; +use time::Month::{February, January}; use time::OffsetDateTime; #[tokio::test] @@ -160,6 +163,59 @@ async fn test_basic_types_deserialization() { ); } +#[tokio::test] +async fn test_several_simple_rows() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + num: u64, + str: String, + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let result = client + .query("SELECT number AS num, toString(number) AS str FROM system.numbers LIMIT 3") + .fetch_all::() + .await; + + assert_eq!( + result.unwrap(), + vec![ + Data { + num: 0, + str: "0".to_string(), + }, + Data { + num: 1, + str: "1".to_string(), + }, + Data { + num: 2, + str: "2".to_string(), + }, + ] + ); +} + +#[tokio::test] +async fn test_many_numbers() { + #[derive(Row, Deserialize)] + struct Data { + no: u64, + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let mut cursor = client + .query("SELECT number FROM system.numbers_mt LIMIT 2000") + .fetch::() + .unwrap(); + + let mut sum = 0; + while let Some(row) = cursor.next().await.unwrap() { + sum += row.no; + } + assert_eq!(sum, (0..2000).sum::()); +} + #[tokio::test] async fn test_array_deserialization() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] @@ -176,11 +232,11 @@ async fn test_array_deserialization() { .query( " SELECT - 42 :: UInt16 AS id, - [1, 2] :: Array(UInt32) AS one_dim_array, - [[1, 2], [3, 4]] :: Array(Array(Int64)) AS two_dim_array, + 42 :: UInt16 AS id, + [1, 2] :: Array(UInt32) AS one_dim_array, + [[1, 2], [3, 4]] :: Array(Array(Int64)) AS two_dim_array, [[[1.1, 2.2], [3.3, 4.4]], [], [[5.5, 6.6], [7.7, 8.8]]] :: Array(Array(Array(Float64))) AS three_dim_array, - 'foobar' :: String AS description + 'foobar' :: String AS description ", ) .fetch_one::() @@ -309,26 +365,119 @@ async fn test_serde_skip_deserializing() { ); } +#[tokio::test] +#[cfg(feature = "time")] +async fn test_date_time_types() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + #[serde(with = "clickhouse::serde::time::date")] + date: time::Date, + #[serde(with = "clickhouse::serde::time::date32")] + date32: time::Date, + #[serde(with = "clickhouse::serde::time::datetime")] + date_time: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime64::secs")] + date_time64_0: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime64::millis")] + date_time64_3: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime64::micros")] + date_time64_6: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime64::nanos")] + date_time64_9: OffsetDateTime, + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let result = client + .query( + " + SELECT + '2023-01-01' :: Date AS date, + '2023-02-02' :: Date32 AS date32, + '2023-01-03 12:00:00' :: DateTime AS date_time, + '2023-01-04 13:00:00' :: DateTime64(0) AS date_time64_0, + '2023-01-05 14:00:00.123' :: DateTime64(3) AS date_time64_3, + '2023-01-06 15:00:00.123456' :: DateTime64(6) AS date_time64_6, + '2023-01-07 16:00:00.123456789' :: DateTime64(9) AS date_time64_9 + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + date: time::Date::from_calendar_date(2023, January, 1).unwrap(), + date32: time::Date::from_calendar_date(2023, February, 2).unwrap(), + date_time: OffsetDateTime::parse("2023-01-03T12:00:00Z", &Iso8601::DEFAULT).unwrap(), + date_time64_0: OffsetDateTime::parse("2023-01-04T13:00:00Z", &Iso8601::DEFAULT) + .unwrap(), + date_time64_3: OffsetDateTime::parse("2023-01-05T14:00:00.123Z", &Iso8601::DEFAULT) + .unwrap(), + date_time64_6: OffsetDateTime::parse("2023-01-06T15:00:00.123456Z", &Iso8601::DEFAULT) + .unwrap(), + date_time64_9: OffsetDateTime::parse( + "2023-01-07T16:00:00.123456789Z", + &Iso8601::DEFAULT + ) + .unwrap(), + } + ); +} + +#[tokio::test] +async fn test_ipv4_ipv6() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + id: u16, + #[serde(with = "clickhouse::serde::ipv4")] + ipv4: std::net::Ipv4Addr, + ipv6: std::net::Ipv6Addr, + } + + let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let result = client + .query( + " + SELECT + 42 :: UInt16 AS id, + '192.168.0.1' :: IPv4 AS ipv4, + '2001:db8:3333:4444:5555:6666:7777:8888' :: IPv6 AS ipv6 + ", + ) + .fetch_all::() + .await; + + assert_eq!( + result.unwrap(), + vec![Data { + id: 42, + ipv4: std::net::Ipv4Addr::new(192, 168, 0, 1), + ipv6: std::net::Ipv6Addr::from_str("2001:db8:3333:4444:5555:6666:7777:8888").unwrap(), + }] + ) +} + // FIXME: RBWNAT should allow for tracking the order of fields in the struct and in the database! #[tokio::test] +#[ignore] async fn test_different_struct_field_order() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { c: String, - a: u32, + a: String, } let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); let result = client - .query("SELECT 42 :: UInt32 AS a, 'foo' :: String AS c") + .query("SELECT 'foo' AS a, 'bar' :: String AS c") .fetch_one::() .await; assert_eq!( result.unwrap(), Data { - a: 42, - c: "foo".to_string(), + a: "foo".to_string(), + c: "bar".to_string(), } ); } From b338d88348306fa5e9b7908ce5fc30276827268b Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Sun, 18 May 2025 18:51:36 +0200 Subject: [PATCH 05/54] RBWNAT deserializer - validation WIP --- rowbinary/Cargo.toml | 2 +- rowbinary/src/lib.rs | 4 +- rowbinary/src/types.rs | 247 ++++++++++++++++++++++++-- src/cursors/row.rs | 74 ++++---- src/error.rs | 2 + src/lib.rs | 16 +- src/output_format.rs | 23 --- src/query.rs | 10 +- src/rowbinary/de.rs | 354 +++++++++++++++++++++++++++++++++---- src/rowbinary/de_rbwnat.rs | 204 +++++++-------------- src/rowbinary/mod.rs | 83 ++++++++- src/validation_mode.rs | 23 +++ tests/it/rbwnat_smoke.rs | 104 +++++++---- 13 files changed, 858 insertions(+), 288 deletions(-) delete mode 100644 src/output_format.rs create mode 100644 src/validation_mode.rs diff --git a/rowbinary/Cargo.toml b/rowbinary/Cargo.toml index 59bebd3b..b1dd76c1 100644 --- a/rowbinary/Cargo.toml +++ b/rowbinary/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "clickhouse-rowbinary" version = "0.0.1" -description = "RowBinaryWithNamesAndTypes format utils" +description = "Native and RowBinary(WithNamesAndTypes) format utils" authors = ["ClickHouse"] repository = "https://github.com/ClickHouse/clickhouse-rs" homepage = "https://clickhouse.com" diff --git a/rowbinary/src/lib.rs b/rowbinary/src/lib.rs index d0b79740..9e793f48 100644 --- a/rowbinary/src/lib.rs +++ b/rowbinary/src/lib.rs @@ -8,7 +8,7 @@ pub mod error; pub mod leb128; pub mod types; -pub fn parse_columns_header(bytes: &mut &[u8]) -> Result, ParserError> { +pub fn parse_rbwnat_columns_header(bytes: &mut &[u8]) -> Result, ParserError> { let num_columns = decode_leb128(bytes)?; if num_columns == 0 { return Err(ParserError::HeaderParsingError( @@ -29,7 +29,7 @@ pub fn parse_columns_header(bytes: &mut &[u8]) -> Result, ParserErro let columns = columns_names .into_iter() .zip(column_data_types) - .map(|(name, data_type)| Column { name, data_type }) + .map(|(name, data_type)| Column::new(name, data_type)) .collect(); Ok(columns) } diff --git a/rowbinary/src/types.rs b/rowbinary/src/types.rs index bcbdbc34..241c5174 100644 --- a/rowbinary/src/types.rs +++ b/rowbinary/src/types.rs @@ -6,11 +6,17 @@ use std::fmt::{Display, Formatter}; pub struct Column { pub name: String, pub data_type: DataTypeNode, + pub type_hints: Vec, } impl Column { pub fn new(name: String, data_type: DataTypeNode) -> Self { - Self { name, data_type } + let type_hints = data_type.get_type_hints(); + Self { + name, + data_type, + type_hints, + } } } @@ -72,6 +78,125 @@ pub enum DataTypeNode { // TODO: Geo } +// TODO - should be the same top-levels as DataTypeNode; +// gen from DataTypeNode via macro maybe? +#[derive(Debug, Clone, PartialEq)] +#[non_exhaustive] +pub enum DataTypeHint { + Bool, + + UInt8, + UInt16, + UInt32, + UInt64, + UInt128, + UInt256, + + Int8, + Int16, + Int32, + Int64, + Int128, + Int256, + + Float32, + Float64, + BFloat16, + Decimal(DecimalSize), + + String, + FixedString(usize), + UUID, + + Date, + Date32, + DateTime, + DateTime64, + + IPv4, + IPv6, + + Nullable, + LowCardinality, + + Array, + Tuple, + Map, + Enum, + + AggregateFunction, + + Variant, + Dynamic, + JSON, + // TODO: Geo +} + +impl Display for DataTypeHint { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DataTypeHint::Bool => write!(f, "Bool"), + DataTypeHint::UInt8 => write!(f, "UInt8"), + DataTypeHint::UInt16 => write!(f, "UInt16"), + DataTypeHint::UInt32 => write!(f, "UInt32"), + DataTypeHint::UInt64 => write!(f, "UInt64"), + DataTypeHint::UInt128 => write!(f, "UInt128"), + DataTypeHint::UInt256 => write!(f, "UInt256"), + DataTypeHint::Int8 => write!(f, "Int8"), + DataTypeHint::Int16 => write!(f, "Int16"), + DataTypeHint::Int32 => write!(f, "Int32"), + DataTypeHint::Int64 => write!(f, "Int64"), + DataTypeHint::Int128 => write!(f, "Int128"), + DataTypeHint::Int256 => write!(f, "Int256"), + DataTypeHint::Float32 => write!(f, "Float32"), + DataTypeHint::Float64 => write!(f, "Float64"), + DataTypeHint::BFloat16 => write!(f, "BFloat16"), + DataTypeHint::Decimal(size) => write!(f, "Decimal{}", size), + DataTypeHint::String => write!(f, "String"), + DataTypeHint::FixedString(size) => write!(f, "FixedString({})", size), + DataTypeHint::UUID => write!(f, "UUID"), + DataTypeHint::Date => write!(f, "Date"), + DataTypeHint::Date32 => write!(f, "Date32"), + DataTypeHint::DateTime => write!(f, "DateTime"), + DataTypeHint::DateTime64 => write!(f, "DateTime64"), + DataTypeHint::IPv4 => write!(f, "IPv4"), + DataTypeHint::IPv6 => write!(f, "IPv6"), + DataTypeHint::Nullable => write!(f, "Nullable"), + DataTypeHint::LowCardinality => write!(f, "LowCardinality"), + DataTypeHint::Array => { + write!(f, "Array") + } + DataTypeHint::Tuple => { + write!(f, "Tuple") + } + DataTypeHint::Map => { + write!(f, "Map") + } + DataTypeHint::Enum => { + write!(f, "Enum") + } + DataTypeHint::AggregateFunction => { + write!(f, "AggregateFunction") + } + DataTypeHint::Variant => { + write!(f, "Variant") + } + DataTypeHint::Dynamic => { + write!(f, "Dynamic") + } + DataTypeHint::JSON => { + write!(f, "JSON") + } + } + } +} + +impl Into for DataTypeHint { + fn into(self) -> String { + self.to_string() + } +} + macro_rules! data_type_is { ($method:ident, $pattern:pat) => { #[inline] @@ -138,6 +263,91 @@ impl DataTypeNode { } } + pub fn get_type_hints_internal(&self, hints: &mut Vec) { + match self { + DataTypeNode::Bool => hints.push(DataTypeHint::Bool), + DataTypeNode::UInt8 => hints.push(DataTypeHint::UInt8), + DataTypeNode::UInt16 => hints.push(DataTypeHint::UInt16), + DataTypeNode::UInt32 => hints.push(DataTypeHint::UInt32), + DataTypeNode::UInt64 => hints.push(DataTypeHint::UInt64), + DataTypeNode::UInt128 => hints.push(DataTypeHint::UInt128), + DataTypeNode::UInt256 => hints.push(DataTypeHint::UInt256), + DataTypeNode::Int8 => hints.push(DataTypeHint::Int8), + DataTypeNode::Int16 => hints.push(DataTypeHint::Int16), + DataTypeNode::Int32 => hints.push(DataTypeHint::Int32), + DataTypeNode::Int64 => hints.push(DataTypeHint::Int64), + DataTypeNode::Int128 => hints.push(DataTypeHint::Int128), + DataTypeNode::Int256 => hints.push(DataTypeHint::Int256), + DataTypeNode::Float32 => hints.push(DataTypeHint::Float32), + DataTypeNode::Float64 => hints.push(DataTypeHint::Float64), + DataTypeNode::BFloat16 => hints.push(DataTypeHint::BFloat16), + DataTypeNode::Decimal(_, _, size) => { + hints.push(DataTypeHint::Decimal(size.clone())); + } + DataTypeNode::String => hints.push(DataTypeHint::String), + DataTypeNode::FixedString(size) => hints.push(DataTypeHint::FixedString(*size)), + DataTypeNode::UUID => hints.push(DataTypeHint::UUID), + DataTypeNode::Date => hints.push(DataTypeHint::Date), + DataTypeNode::Date32 => hints.push(DataTypeHint::Date32), + DataTypeNode::DateTime(_) => hints.push(DataTypeHint::DateTime), + DataTypeNode::DateTime64(_, _) => hints.push(DataTypeHint::DateTime64), + DataTypeNode::IPv4 => hints.push(DataTypeHint::IPv4), + DataTypeNode::IPv6 => hints.push(DataTypeHint::IPv6), + DataTypeNode::Nullable(inner) => { + hints.push(DataTypeHint::Nullable); + inner.get_type_hints_internal(hints); + } + DataTypeNode::LowCardinality(inner) => { + hints.push(DataTypeHint::LowCardinality); + inner.get_type_hints_internal(hints); + } + DataTypeNode::Array(inner) => { + hints.push(DataTypeHint::Array); + inner.get_type_hints_internal(hints); + } + DataTypeNode::Tuple(elements) => { + hints.push(DataTypeHint::Tuple); + for element in elements { + element.get_type_hints_internal(hints); + } + } + DataTypeNode::Map(key, value) => { + hints.push(DataTypeHint::Map); + key.get_type_hints_internal(hints); + value.get_type_hints_internal(hints); + } + DataTypeNode::Enum(_, _) => hints.push(DataTypeHint::Enum), + DataTypeNode::AggregateFunction(_, args) => { + hints.push(DataTypeHint::AggregateFunction); + for arg in args { + arg.get_type_hints_internal(hints); + } + } + DataTypeNode::Variant(types) => { + hints.push(DataTypeHint::Variant); + for ty in types { + ty.get_type_hints_internal(hints); + } + } + DataTypeNode::Dynamic => hints.push(DataTypeHint::Dynamic), + DataTypeNode::JSON => hints.push(DataTypeHint::JSON), + } + } + + pub fn get_type_hints(&self) -> Vec { + let capacity = match self { + DataTypeNode::Tuple(elements) | DataTypeNode::Variant(elements) => elements.len() + 1, + DataTypeNode::Map(_, _) => 3, + DataTypeNode::Nullable(_) + | DataTypeNode::LowCardinality(_) + | DataTypeNode::Array(_) => 2, + _ => 1, + }; + let mut vec = Vec::with_capacity(capacity); + self.get_type_hints_internal(&mut vec); + vec + } + data_type_is!(is_bool, DataTypeNode::Bool); data_type_is!(is_uint8, DataTypeNode::UInt8); data_type_is!(is_uint16, DataTypeNode::UInt16); @@ -201,6 +411,9 @@ impl Display for DataTypeNode { Float32 => "Float32".to_string(), Float64 => "Float64".to_string(), BFloat16 => "BFloat16".to_string(), + Decimal(precision, scale, _) => { + format!("Decimal({}, {})", precision, scale) + } String => "String".to_string(), UUID => "UUID".to_string(), Date => "Date".to_string(), @@ -224,9 +437,6 @@ impl Display for DataTypeNode { LowCardinality(inner) => { format!("LowCardinality({})", inner.to_string()) } - Decimal(precision, scale, _) => { - format!("Decimal({}, {})", precision, scale) - } Enum(enum_type, values) => { let mut values_vec = values.iter().collect::>(); values_vec.sort_by(|(i1, _), (i2, _)| (*i1).cmp(*i2)); @@ -313,6 +523,17 @@ pub enum DecimalSize { Int256, } +impl Display for DecimalSize { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DecimalSize::Int32 => write!(f, "32"), + DecimalSize::Int64 => write!(f, "64"), + DecimalSize::Int128 => write!(f, "128"), + DecimalSize::Int256 => write!(f, "256"), + } + } +} + impl DecimalSize { pub(crate) fn new(precision: u8) -> Result { if precision <= 9 { @@ -361,11 +582,11 @@ fn parse_fixed_string(input: &str) -> Result { if input.len() >= 14 { let size_str = &input[12..input.len() - 1]; let size = size_str.parse::().map_err(|err| { - ParserError::TypeParsingError(format!( - "Invalid FixedString size, expected a valid number. Underlying error: {}, input: {}, size_str: {}", - err, input, size_str - )) - })?; + ParserError::TypeParsingError(format!( + "Invalid FixedString size, expected a valid number. Underlying error: {}, input: {}, size_str: {}", + err, input, size_str + )) + })?; if size == 0 { return Err(ParserError::TypeParsingError(format!( "Invalid FixedString size, expected a positive number, got zero. Input: {}", @@ -728,10 +949,10 @@ fn parse_enum_values_map(input: &str) -> Result, ParserErro if names.len() != indices.len() { return Err(ParserError::TypeParsingError(format!( - "Invalid Enum format - expected the same number of names and indices, got names: {}, indices: {}", - names.join(", "), - indices.iter().map(|index| index.to_string()).collect::>().join(", "), - ))); + "Invalid Enum format - expected the same number of names and indices, got names: {}, indices: {}", + names.join(", "), + indices.iter().map(|index| index.to_string()).collect::>().join(", "), + ))); } Ok(indices diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 0e8e8218..d53b4ed3 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -1,4 +1,4 @@ -use crate::output_format::OutputFormat; +use crate::validation_mode::StructValidationMode; use crate::{ bytes_ext::BytesExt, cursors::RawCursor, @@ -6,7 +6,7 @@ use crate::{ response::Response, rowbinary, }; -use clickhouse_rowbinary::parse_columns_header; +use clickhouse_rowbinary::parse_rbwnat_columns_header; use clickhouse_rowbinary::types::Column; use serde::Deserialize; use std::marker::PhantomData; @@ -16,19 +16,21 @@ use std::marker::PhantomData; pub struct RowCursor { raw: RawCursor, bytes: BytesExt, - format: OutputFormat, + validation_mode: StructValidationMode, columns: Option>, + rows_emitted: u64, _marker: PhantomData, } impl RowCursor { - pub(crate) fn new(response: Response, format: OutputFormat) -> Self { + pub(crate) fn new(response: Response, format: StructValidationMode) -> Self { Self { _marker: PhantomData, raw: RawCursor::new(response), bytes: BytesExt::default(), columns: None, - format, + rows_emitted: 0, + validation_mode: format, } } @@ -43,39 +45,42 @@ impl RowCursor { where T: Deserialize<'b>, { + let should_validate = match self.validation_mode { + StructValidationMode::Disabled => false, + StructValidationMode::EachRow => true, + StructValidationMode::FirstRow => self.rows_emitted == 0, + }; + loop { - let mut slice = super::workaround_51132(self.bytes.slice()); - match self.format { - OutputFormat::RowBinary => match rowbinary::deserialize_from(&mut slice) { + if self.bytes.remaining() > 0 { + let mut slice = super::workaround_51132(self.bytes.slice()); + let deserialize_result = if should_validate { + match &self.columns { + None => { + let columns = parse_rbwnat_columns_header(&mut slice)?; + self.bytes.set_remaining(slice.len()); + self.columns = Some(columns); + let columns = self.columns.as_ref().unwrap(); + rowbinary::deserialize_from_and_validate(&mut slice, columns) + } + Some(columns) => { + rowbinary::deserialize_from_and_validate(&mut slice, &columns) + } + } + } else { + rowbinary::deserialize_from(&mut slice) + }; + + match deserialize_result { Ok(value) => { self.bytes.set_remaining(slice.len()); + self.rows_emitted += 1; return Ok(Some(value)); } Err(Error::NotEnoughData) => {} Err(err) => return Err(err), - }, - OutputFormat::RowBinaryWithNamesAndTypes => match self.columns.as_ref() { - // FIXME: move this branch to new? - None => { - if slice.len() > 0 { - let columns = parse_columns_header(&mut slice)?; - self.bytes.set_remaining(slice.len()); - self.columns = Some(columns); - } - } - Some(columns) => { - match rowbinary::deserialize_from_rbwnat(&mut slice, columns) { - Ok(value) => { - self.bytes.set_remaining(slice.len()); - return Ok(Some(value)); - } - Err(Error::NotEnoughData) => {} - Err(err) => return Err(err), - } - } - }, + } } - // } match self.raw.next().await? { Some(chunk) => self.bytes.extend(chunk), @@ -99,10 +104,15 @@ impl RowCursor { self.raw.received_bytes() } - /// Returns the total size in bytes decompressed since the cursor was - /// created. + /// Returns the total size in bytes decompressed since the cursor was created. #[inline] pub fn decoded_bytes(&self) -> u64 { self.raw.decoded_bytes() } + + /// Returns the number of rows emitted via [`Self::next`] since the cursor was created. + #[inline] + pub fn rows_emitted(&self) -> u64 { + self.rows_emitted + } } diff --git a/src/error.rs b/src/error.rs index acd5927e..5d191f60 100644 --- a/src/error.rs +++ b/src/error.rs @@ -53,6 +53,8 @@ pub enum Error { unexpected_type: String, all_columns: String, }, + #[error("invalid column data type: {0}")] + InvalidColumnDataType(String), #[error( "too many struct fields: trying to read more columns than expected {0}. All columns: {1:?}" )] diff --git a/src/lib.rs b/src/lib.rs index 2eee5f78..a6863f5d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,19 +9,19 @@ use self::{error::Result, http_client::HttpClient}; use std::{collections::HashMap, fmt::Display, sync::Arc}; pub use self::{compression::Compression, row::Row}; -use crate::output_format::OutputFormat; +use crate::validation_mode::StructValidationMode; pub use clickhouse_derive::Row; pub mod error; pub mod insert; #[cfg(feature = "inserter")] pub mod inserter; -pub mod output_format; pub mod query; pub mod serde; pub mod sql; #[cfg(feature = "test-util")] pub mod test; +pub mod validation_mode; #[cfg(feature = "watch")] pub mod watch; @@ -50,7 +50,7 @@ pub struct Client { options: HashMap, headers: HashMap, products_info: Vec, - fetch_format: OutputFormat, + struct_validation_mode: StructValidationMode, } #[derive(Clone)] @@ -86,7 +86,7 @@ impl Client { options: HashMap::new(), headers: HashMap::new(), products_info: Vec::default(), - fetch_format: OutputFormat::default(), + struct_validation_mode: StructValidationMode::default(), } } @@ -226,15 +226,11 @@ impl Client { self } - pub fn with_fetch_format(mut self, format: OutputFormat) -> Self { - self.fetch_format = format; + pub fn with_struct_validation_mode(mut self, mode: StructValidationMode) -> Self { + self.struct_validation_mode = mode; self } - pub fn get_fetch_format(&self) -> OutputFormat { - self.fetch_format.clone() - } - /// Starts a new INSERT statement. /// /// # Panics diff --git a/src/output_format.rs b/src/output_format.rs deleted file mode 100644 index 4f2121ca..00000000 --- a/src/output_format.rs +++ /dev/null @@ -1,23 +0,0 @@ -#[non_exhaustive] -#[derive(Clone)] -pub enum OutputFormat { - RowBinary, - RowBinaryWithNamesAndTypes, -} - -impl Default for OutputFormat { - fn default() -> Self { - Self::RowBinary - } -} - -impl std::fmt::Display for OutputFormat { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - OutputFormat::RowBinary => write!(f, "RowBinary"), - OutputFormat::RowBinaryWithNamesAndTypes => { - write!(f, "RowBinaryWithNamesAndTypes") - } - } - } -} diff --git a/src/query.rs b/src/query.rs index af8743c7..dfdff889 100644 --- a/src/query.rs +++ b/src/query.rs @@ -16,7 +16,7 @@ use crate::{ const MAX_QUERY_LEN_TO_USE_GET: usize = 8192; pub use crate::cursors::{BytesCursor, RowCursor}; -use crate::output_format::OutputFormat; +use crate::validation_mode::StructValidationMode; #[must_use] #[derive(Clone)] @@ -84,12 +84,14 @@ impl Query { /// # Ok(()) } /// ``` pub fn fetch(mut self) -> Result> { - let fetch_format = self.client.get_fetch_format(); + let fetch_format = self.client.struct_validation_mode.clone(); self.sql.bind_fields::(); self.sql.set_output_format(match fetch_format { - OutputFormat::RowBinary => "RowBinary", - OutputFormat::RowBinaryWithNamesAndTypes => "RowBinaryWithNamesAndTypes", + StructValidationMode::FirstRow | StructValidationMode::EachRow => { + "RowBinaryWithNamesAndTypes" + } + StructValidationMode::Disabled => "RowBinary", }); let response = self.do_execute(true)?; diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 4a5a2beb..8015636b 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -1,12 +1,13 @@ -use std::{convert::TryFrom, mem, str}; - use crate::error::{Error, Result}; use crate::rowbinary::utils::{ensure_size, get_unsigned_leb128}; +use crate::rowbinary::SerdeType; use bytes::Buf; +use clickhouse_rowbinary::types::{Column, DataTypeHint}; use serde::{ de::{DeserializeSeed, Deserializer, EnumAccess, SeqAccess, VariantAccess, Visitor}, Deserialize, }; +use std::{convert::TryFrom, mem, str}; /// Deserializes a value from `input` with a row encoded in `RowBinary`. /// @@ -14,18 +15,171 @@ use serde::{ /// performant generated code than `(&[u8]) -> Result<(T, usize)>` and even /// `(&[u8], &mut Option) -> Result`. pub(crate) fn deserialize_from<'data, T: Deserialize<'data>>(input: &mut &'data [u8]) -> Result { - let mut deserializer = RowBinaryDeserializer { input }; + let mut deserializer = RowBinaryDeserializer { + input, + columns_validator: (), + }; T::deserialize(&mut deserializer) } +/// Similar to [`deserialize_from`], but expects a slice of [`Column`] objects +/// parsed from the beginning of `RowBinaryWithNamesAndTypes` data stream. +/// After the header, the rows format is the same as `RowBinary`. +pub(crate) fn deserialize_from_and_validate<'data, 'cursor, T: Deserialize<'data>>( + input: &mut &'data [u8], + columns: &'cursor [Column], +) -> Result { + let mut deserializer = RowBinaryDeserializer { + input, + columns_validator: ColumnsValidator { + columns, + col_idx: 0, + type_hint_idx: 0, + nesting_level: 0, + }, + }; + T::deserialize(&mut deserializer) +} + +struct ColumnsValidator<'cursor> { + columns: &'cursor [Column], + col_idx: usize, + type_hint_idx: usize, + nesting_level: usize, +} + +impl<'cursor> ColumnsValidator<'cursor> { + #[inline] + fn advance(&mut self) { + self.col_idx += 1; + self.type_hint_idx = 0; + } +} + +pub(crate) trait ValidateDataType { + fn validate( + &mut self, + serde_type: &'static SerdeType, + allowed: &'static [DataTypeHint], + has_inner_type: bool, + ) -> Result<()>; + fn skip_next(&mut self) -> (); + fn increase_nesting(&mut self) -> (); + fn decrease_nesting(&mut self) -> (); +} + +impl ValidateDataType for () { + #[inline] + fn validate( + &mut self, + _serde_type: &'static SerdeType, + _allowed: &'static [DataTypeHint], + _has_inner_type: bool, + ) -> Result<()> { + Ok(()) + } + #[inline] + fn skip_next(&mut self) -> () {} + #[inline] + fn increase_nesting(&mut self) -> () {} + #[inline] + fn decrease_nesting(&mut self) -> () {} +} + +impl<'cursor> ValidateDataType for ColumnsValidator<'cursor> { + #[inline] + fn validate( + &mut self, + serde_type: &'static SerdeType, + allowed: &'static [DataTypeHint], + has_inner_type: bool, + ) -> Result<()> { + println!( + "Validating column {}, type hint {}, serde type {}, allowed {:?}, nesting level {}", + self.col_idx, self.type_hint_idx, serde_type, allowed, self.nesting_level + ); + if self.col_idx >= self.columns.len() { + return Err(Error::TooManyStructFields( + self.columns.len(), + self.columns.into(), + )); + } + if has_inner_type { + self.nesting_level += 1; + println!("Increased nesting level to {}", self.nesting_level); + } + let current_column = &self.columns[self.col_idx]; + if self.type_hint_idx >= current_column.type_hints.len() { + // if self.nesting_level == 0 { + // println!("Advancing #1"); + // self.advance(); + // } + println!( + "Skipping check for column {}, type hint {}, nesting level {}", + current_column.name, self.type_hint_idx, self.nesting_level + ); + return Ok(()); + } + let db_type_hint = ¤t_column.type_hints[self.type_hint_idx]; + if allowed.contains(db_type_hint) { + // self.type_hint_idx += 1; + if self.nesting_level == 0 { + println!("Advancing #2"); + self.advance(); + } else { + self.type_hint_idx += 1; + } + println!( + "Checked column {}, type hint {}, nesting level {}", + current_column.name, self.type_hint_idx, self.nesting_level + ); + Ok(()) + } else { + Err(Error::InvalidColumnDataType( + format!("column {} {} requires deserialization from {} as {}, but serde call allowed only for {:?}", + current_column.name, current_column.data_type, db_type_hint, serde_type, allowed))) + } + } + + #[inline] + fn skip_next(&mut self) { + self.type_hint_idx += 1; + } + + #[inline] + fn increase_nesting(&mut self) { + self.nesting_level += 1; + } + + #[inline] + fn decrease_nesting(&mut self) { + println!("Decreasing nesting level from {}", self.nesting_level); + if self.nesting_level == 1 { + self.advance(); + } else if self.nesting_level > 0 { + self.nesting_level -= 1; + } else { + panic!("decrease_nesting called when nesting level is already 0, current column index {}, type hint index {}, all columns: {:?}", + self.col_idx, self.type_hint_idx, self.columns); + } + } +} + /// A deserializer for the RowBinary format. /// /// See https://clickhouse.com/docs/en/interfaces/formats#rowbinary for details. -pub(crate) struct RowBinaryDeserializer<'cursor, 'data> { +pub(crate) struct RowBinaryDeserializer<'cursor, 'data, Columns = ()> +where + Columns: ValidateDataType, +{ + pub(crate) columns_validator: Columns, pub(crate) input: &'cursor mut &'data [u8], } -impl<'data> RowBinaryDeserializer<'_, 'data> { +impl<'data, Columns> RowBinaryDeserializer<'_, 'data, Columns> +where + Columns: ValidateDataType, +{ pub(crate) fn read_vec(&mut self, size: usize) -> Result> { Ok(self.read_slice(size)?.to_vec()) } @@ -45,9 +199,11 @@ impl<'data> RowBinaryDeserializer<'_, 'data> { } macro_rules! impl_num { - ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident) => { + ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr, $type_hints:expr) => { #[inline] fn $deser_method>(self, visitor: V) -> Result { + self.columns_validator + .validate($serde_type, $type_hints, false)?; ensure_size(&mut self.input, mem::size_of::<$ty>())?; let value = self.input.$reader_method(); visitor.$visitor_method(value) @@ -55,21 +211,110 @@ macro_rules! impl_num { }; } -impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { +impl<'data, Columns> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data, Columns> +where + Columns: ValidateDataType, +{ type Error = Error; - impl_num!(i8, deserialize_i8, visit_i8, get_i8); - impl_num!(i16, deserialize_i16, visit_i16, get_i16_le); - impl_num!(i32, deserialize_i32, visit_i32, get_i32_le); - impl_num!(i64, deserialize_i64, visit_i64, get_i64_le); - impl_num!(i128, deserialize_i128, visit_i128, get_i128_le); - impl_num!(u8, deserialize_u8, visit_u8, get_u8); - impl_num!(u16, deserialize_u16, visit_u16, get_u16_le); - impl_num!(u32, deserialize_u32, visit_u32, get_u32_le); - impl_num!(u64, deserialize_u64, visit_u64, get_u64_le); - impl_num!(u128, deserialize_u128, visit_u128, get_u128_le); - impl_num!(f32, deserialize_f32, visit_f32, get_f32_le); - impl_num!(f64, deserialize_f64, visit_f64, get_f64_le); + impl_num!( + i8, + deserialize_i8, + visit_i8, + get_i8, + &SerdeType::I8, + // TODO: shall we allow deserialization from boolean? + &[DataTypeHint::Int8, DataTypeHint::Bool] + ); + impl_num!( + i16, + deserialize_i16, + visit_i16, + get_i16_le, + &SerdeType::I16, + &[DataTypeHint::Int16] + ); + impl_num!( + i32, + deserialize_i32, + visit_i32, + get_i32_le, + &SerdeType::I32, + &[DataTypeHint::Int32] + ); + impl_num!( + i64, + deserialize_i64, + visit_i64, + get_i64_le, + &SerdeType::I64, + &[DataTypeHint::Int64] + ); + impl_num!( + i128, + deserialize_i128, + visit_i128, + get_i128_le, + &SerdeType::I128, + &[DataTypeHint::Int128] + ); + impl_num!( + u8, + deserialize_u8, + visit_u8, + get_u8, + &SerdeType::U8, + // TODO: shall we allow deserialization from boolean? + &[DataTypeHint::Bool, DataTypeHint::UInt8] + ); + impl_num!( + u16, + deserialize_u16, + visit_u16, + get_u16_le, + &SerdeType::U16, + &[DataTypeHint::UInt16] + ); + impl_num!( + u32, + deserialize_u32, + visit_u32, + get_u32_le, + &SerdeType::U32, + &[DataTypeHint::UInt32] + ); + impl_num!( + u64, + deserialize_u64, + visit_u64, + get_u64_le, + &SerdeType::U64, + &[DataTypeHint::UInt64] + ); + impl_num!( + u128, + deserialize_u128, + visit_u128, + get_u128_le, + &SerdeType::U128, + &[DataTypeHint::UInt128] + ); + impl_num!( + f32, + deserialize_f32, + visit_f32, + get_f32_le, + &SerdeType::F32, + &[DataTypeHint::Float32] + ); + impl_num!( + f64, + deserialize_f64, + visit_f64, + get_f64_le, + &SerdeType::F64, + &[DataTypeHint::Float64] + ); #[inline] fn deserialize_any>(self, _: V) -> Result { @@ -89,6 +334,12 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { #[inline] fn deserialize_bool>(self, visitor: V) -> Result { + self.columns_validator.validate( + &SerdeType::Bool, + // TODO: shall we allow deserialization from integers? + &[DataTypeHint::Bool, DataTypeHint::Int8, DataTypeHint::UInt8], + false, + )?; ensure_size(&mut self.input, 1)?; match self.input.get_u8() { 0 => visitor.visit_bool(false), @@ -99,6 +350,9 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { #[inline] fn deserialize_str>(self, visitor: V) -> Result { + // TODO - which types to allow? + self.columns_validator + .validate(&SerdeType::String, &[DataTypeHint::String], false)?; let size = self.read_size()?; let slice = self.read_slice(size)?; let str = str::from_utf8(slice).map_err(Error::from)?; @@ -107,6 +361,9 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { #[inline] fn deserialize_string>(self, visitor: V) -> Result { + // TODO - which types to allow? + self.columns_validator + .validate(&SerdeType::String, &[DataTypeHint::String], false)?; let size = self.read_size()?; let vec = self.read_vec(size)?; let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; @@ -115,6 +372,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { #[inline] fn deserialize_bytes>(self, visitor: V) -> Result { + // TODO - which types to allow? let size = self.read_size()?; let slice = self.read_slice(size)?; visitor.visit_borrowed_bytes(slice) @@ -122,12 +380,14 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { #[inline] fn deserialize_byte_buf>(self, visitor: V) -> Result { + // TODO - which types to allow? let size = self.read_size()?; visitor.visit_byte_buf(self.read_vec(size)?) } #[inline] fn deserialize_identifier>(self, visitor: V) -> Result { + // TODO - which types to allow? self.deserialize_u8(visitor) } @@ -138,13 +398,22 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { _variants: &'static [&'static str], visitor: V, ) -> Result { - struct Access<'de, 'cursor, 'data> { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data>, + struct Access<'de, 'cursor, 'data, Columns> + where + Columns: ValidateDataType, + { + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Columns>, } - struct VariantDeserializer<'de, 'cursor, 'data> { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data>, + struct VariantDeserializer<'de, 'cursor, 'data, Columns> + where + Columns: ValidateDataType, + { + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Columns>, } - impl<'data> VariantAccess<'data> for VariantDeserializer<'_, '_, 'data> { + impl<'data, Columns> VariantAccess<'data> for VariantDeserializer<'_, '_, 'data, Columns> + where + Columns: ValidateDataType, + { type Error = Error; fn unit_variant(self) -> Result<()> { @@ -177,9 +446,12 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { } } - impl<'de, 'cursor, 'data> EnumAccess<'data> for Access<'de, 'cursor, 'data> { + impl<'de, 'cursor, 'data, Columns> EnumAccess<'data> for Access<'de, 'cursor, 'data, Columns> + where + Columns: ValidateDataType, + { type Error = Error; - type Variant = VariantDeserializer<'de, 'cursor, 'data>; + type Variant = VariantDeserializer<'de, 'cursor, 'data, Columns>; fn variant_seed(self, seed: T) -> Result<(T::Value, Self::Variant), Self::Error> where @@ -192,17 +464,25 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { Ok((value, deserializer)) } } + self.columns_validator + .validate(&SerdeType::Enum, &[DataTypeHint::Enum], false)?; visitor.visit_enum(Access { deserializer: self }) } #[inline] fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { - struct Access<'de, 'cursor, 'data> { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data>, + struct Access<'de, 'cursor, 'data, Columns> + where + Columns: ValidateDataType, + { + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Columns>, len: usize, } - impl<'data> SeqAccess<'data> for Access<'_, '_, 'data> { + impl<'data, Columns> SeqAccess<'data> for Access<'_, '_, 'data, Columns> + where + Columns: ValidateDataType, + { type Error = Error; fn next_element_seed(&mut self, seed: T) -> Result> @@ -231,19 +511,28 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { #[inline] fn deserialize_option>(self, visitor: V) -> Result { + self.columns_validator + .validate(&SerdeType::Option, &[DataTypeHint::Nullable], true)?; ensure_size(&mut self.input, 1)?; - match self.input.get_u8() { 0 => visitor.visit_some(&mut *self), - 1 => visitor.visit_none(), + 1 => { + self.columns_validator.skip_next(); + self.columns_validator.decrease_nesting(); + visitor.visit_none() + } v => Err(Error::InvalidTagEncoding(v as usize)), } } #[inline] fn deserialize_seq>(self, visitor: V) -> Result { + self.columns_validator + .validate(&SerdeType::Seq, &[DataTypeHint::Array], true)?; let len = self.read_size()?; - self.deserialize_tuple(len, visitor) + let result = self.deserialize_tuple(len, visitor); + self.columns_validator.decrease_nesting(); + result } #[inline] @@ -267,6 +556,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { _name: &str, visitor: V, ) -> Result { + // TODO - skip validation? visitor.visit_newtype_struct(self) } diff --git a/src/rowbinary/de_rbwnat.rs b/src/rowbinary/de_rbwnat.rs index 9d92f2cd..7894d8e9 100644 --- a/src/rowbinary/de_rbwnat.rs +++ b/src/rowbinary/de_rbwnat.rs @@ -1,102 +1,23 @@ use crate::error::{Error, Result}; use crate::rowbinary::de::RowBinaryDeserializer; +use crate::rowbinary::SerdeType; use clickhouse_rowbinary::types::{Column, DataTypeNode}; use serde::de::{DeserializeSeed, SeqAccess, Visitor}; use serde::{Deserialize, Deserializer}; -use std::fmt::Display; use std::ops::Deref; use std::rc::Rc; -pub(crate) fn deserialize_from_rbwnat<'data, 'cursor, T: Deserialize<'data>>( +pub(crate) fn _deserialize_from_rbwnat<'data, 'cursor, T: Deserialize<'data>>( input: &mut &'data [u8], columns: &'cursor [Column], ) -> Result { // println!("[RBWNAT] deserializing with names and types: {:?}, input size: {}", columns, input.len()); - let mut deserializer = RowBinaryWithNamesAndTypesDeserializer::new(input, columns)?; + let mut deserializer = RowBinaryWithNamesAndTypesDeserializer::_new(input, columns)?; let value = T::deserialize(&mut deserializer); // println!("Remaining input size: {}", input.len()); value } -/// Serde method that delegated the value deserialization to [`Deserializer::deserialize_any`]. -#[derive(Clone, Debug, PartialEq)] -enum DelegatedFrom { - Bool, - I8, - I16, - I32, - I64, - I128, - U8, - U16, - U32, - U64, - U128, - F32, - F64, - Char, - Str, - String, - Bytes, - ByteBuf, - Option, - Unit, - UnitStruct, - NewtypeStruct, - Seq, - Tuple, - TupleStruct, - Map, - Struct, - Enum, - Identifier, - IgnoredAny, -} - -impl Default for DelegatedFrom { - fn default() -> Self { - DelegatedFrom::Struct - } -} - -impl Display for DelegatedFrom { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let type_name = match self { - DelegatedFrom::Bool => "bool", - DelegatedFrom::I8 => "i8", - DelegatedFrom::I16 => "i16", - DelegatedFrom::I32 => "i32", - DelegatedFrom::I64 => "i64", - DelegatedFrom::I128 => "i128", - DelegatedFrom::U8 => "u8", - DelegatedFrom::U16 => "u16", - DelegatedFrom::U32 => "u32", - DelegatedFrom::U64 => "u64", - DelegatedFrom::U128 => "u128", - DelegatedFrom::F32 => "f32", - DelegatedFrom::F64 => "f64", - DelegatedFrom::Char => "char", - DelegatedFrom::Str => "&str", - DelegatedFrom::String => "String", - DelegatedFrom::Bytes => "&[u8]", - DelegatedFrom::ByteBuf => "Vec", - DelegatedFrom::Option => "Option", - DelegatedFrom::Unit => "()", - DelegatedFrom::UnitStruct => "unit struct", - DelegatedFrom::NewtypeStruct => "newtype struct", - DelegatedFrom::Seq => "Vec", - DelegatedFrom::Tuple => "tuple", - DelegatedFrom::TupleStruct => "tuple struct", - DelegatedFrom::Map => "map", - DelegatedFrom::Struct => "struct", - DelegatedFrom::Enum => "enum", - DelegatedFrom::Identifier => "identifier", - DelegatedFrom::IgnoredAny => "ignored any", - }; - write!(f, "{}", type_name) - } -} - #[derive(Clone, Debug)] enum DeserializerState<'cursor> { /// At this point, we are either processing a "simple" column (e.g., `UInt32`, `String`, etc.), @@ -119,7 +40,7 @@ pub(crate) struct RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { columns: &'cursor [Column], current_column_idx: usize, // main usage is to check if the struct field definition is compatible with the expected one - last_delegated_from: DelegatedFrom, + last_delegated_from: SerdeType, // every deserialization begins from a struct with some name struct_name: Option<&'static str>, struct_fields: Option<&'static [&'static str]>, @@ -127,7 +48,7 @@ pub(crate) struct RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { #[inline] - fn new( + fn _new( input: &'cursor mut &'data [u8], columns: &'cursor [Column], ) -> Result> { @@ -136,9 +57,12 @@ impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { panic!("Zero columns definitions in the response"); } Ok(RowBinaryWithNamesAndTypesDeserializer { - row_binary: RowBinaryDeserializer { input }, + row_binary: RowBinaryDeserializer { + input, + columns_validator: (), + }, state: DeserializerState::TopLevelColumn(&columns[0]), - last_delegated_from: DelegatedFrom::default(), + last_delegated_from: SerdeType::default(), current_column_idx: 0, struct_name: None, struct_fields: None, @@ -147,7 +71,7 @@ impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { } #[inline] - fn set_last_delegated_from(&mut self, from: DelegatedFrom) { + fn set_last_delegated_from(&mut self, from: SerdeType) { if self.last_delegated_from != from { self.last_delegated_from = from; } @@ -241,7 +165,7 @@ impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { } #[inline] - fn check_data_type_is_allowed(&mut self, allowed: &[DelegatedFrom]) -> Result<()> { + fn check_data_type_is_allowed(&mut self, allowed: &[SerdeType]) -> Result<()> { if !allowed.contains(&self.last_delegated_from) { let column = self.get_current_column()?; let field_name = match self.struct_name { @@ -299,80 +223,80 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< let result = match data_type { DataTypeNode::Bool => rbwnat_de_simple_with_type_check!( deserialize_bool, - [DelegatedFrom::Bool, DelegatedFrom::U8, DelegatedFrom::I8] + [SerdeType::Bool, SerdeType::U8, SerdeType::I8] ), DataTypeNode::UInt8 => { - rbwnat_de_simple_with_type_check!(deserialize_u8, [DelegatedFrom::U8]) + rbwnat_de_simple_with_type_check!(deserialize_u8, [SerdeType::U8]) } DataTypeNode::Int8 => { - rbwnat_de_simple_with_type_check!(deserialize_i8, [DelegatedFrom::I8]) + rbwnat_de_simple_with_type_check!(deserialize_i8, [SerdeType::I8]) } DataTypeNode::Int16 => { - rbwnat_de_simple_with_type_check!(deserialize_i16, [DelegatedFrom::I16]) + rbwnat_de_simple_with_type_check!(deserialize_i16, [SerdeType::I16]) } DataTypeNode::Int32 => { - rbwnat_de_simple_with_type_check!(deserialize_i32, [DelegatedFrom::I32]) + rbwnat_de_simple_with_type_check!(deserialize_i32, [SerdeType::I32]) } DataTypeNode::Int64 => { - rbwnat_de_simple_with_type_check!(deserialize_i64, [DelegatedFrom::I64]) + rbwnat_de_simple_with_type_check!(deserialize_i64, [SerdeType::I64]) } DataTypeNode::Int128 => { - rbwnat_de_simple_with_type_check!(deserialize_i128, [DelegatedFrom::I128]) + rbwnat_de_simple_with_type_check!(deserialize_i128, [SerdeType::I128]) } DataTypeNode::UInt16 => { - rbwnat_de_simple_with_type_check!(deserialize_u16, [DelegatedFrom::U16]) + rbwnat_de_simple_with_type_check!(deserialize_u16, [SerdeType::U16]) } DataTypeNode::UInt32 => { - rbwnat_de_simple_with_type_check!(deserialize_u32, [DelegatedFrom::U32]) + rbwnat_de_simple_with_type_check!(deserialize_u32, [SerdeType::U32]) } DataTypeNode::UInt64 => { - rbwnat_de_simple_with_type_check!(deserialize_u64, [DelegatedFrom::U64]) + rbwnat_de_simple_with_type_check!(deserialize_u64, [SerdeType::U64]) } DataTypeNode::UInt128 => { - rbwnat_de_simple_with_type_check!(deserialize_u128, [DelegatedFrom::U128]) + rbwnat_de_simple_with_type_check!(deserialize_u128, [SerdeType::U128]) } DataTypeNode::Float32 => { - rbwnat_de_simple_with_type_check!(deserialize_f32, [DelegatedFrom::F32]) + rbwnat_de_simple_with_type_check!(deserialize_f32, [SerdeType::F32]) } DataTypeNode::Float64 => { - rbwnat_de_simple_with_type_check!(deserialize_f64, [DelegatedFrom::F64]) + rbwnat_de_simple_with_type_check!(deserialize_f64, [SerdeType::F64]) } DataTypeNode::String => { rbwnat_de_simple_with_type_check!( deserialize_str, - [DelegatedFrom::Str, DelegatedFrom::String] + [SerdeType::Str, SerdeType::String] ) } DataTypeNode::FixedString(len) => match self.last_delegated_from { - DelegatedFrom::Bytes => visitor.visit_bytes(self.row_binary.read_slice(*len)?), - DelegatedFrom::ByteBuf => visitor.visit_byte_buf(self.row_binary.read_vec(*len)?), + SerdeType::Bytes => visitor.visit_bytes(self.row_binary.read_slice(*len)?), + SerdeType::ByteBuf => visitor.visit_byte_buf(self.row_binary.read_vec(*len)?), _ => unreachable!(), }, DataTypeNode::UUID => { rbwnat_de_simple_with_type_check!( deserialize_str, - [DelegatedFrom::Str, DelegatedFrom::String] + [SerdeType::Str, SerdeType::String] ) } DataTypeNode::Date => { - rbwnat_de_simple_with_type_check!(deserialize_u16, [DelegatedFrom::U16]) + rbwnat_de_simple_with_type_check!(deserialize_u16, [SerdeType::U16]) } DataTypeNode::Date32 => { - rbwnat_de_simple_with_type_check!(deserialize_i32, [DelegatedFrom::I32]) + rbwnat_de_simple_with_type_check!(deserialize_i32, [SerdeType::I32]) } DataTypeNode::DateTime { .. } => { - rbwnat_de_simple_with_type_check!(deserialize_u32, [DelegatedFrom::U32]) + rbwnat_de_simple_with_type_check!(deserialize_u32, [SerdeType::U32]) } DataTypeNode::DateTime64 { .. } => { - rbwnat_de_simple_with_type_check!(deserialize_i64, [DelegatedFrom::I64]) + rbwnat_de_simple_with_type_check!(deserialize_i64, [SerdeType::I64]) } DataTypeNode::IPv4 => { - rbwnat_de_simple_with_type_check!(deserialize_u32, [DelegatedFrom::U32]) + rbwnat_de_simple_with_type_check!(deserialize_u32, [SerdeType::U32]) } DataTypeNode::IPv6 => self.row_binary.deserialize_tuple(16, visitor), DataTypeNode::Array(inner_type) => { - self.check_data_type_is_allowed(&[DelegatedFrom::Seq])?; + self.check_data_type_is_allowed(&[SerdeType::Seq])?; let len = self.row_binary.read_size()?; self.set_inner_data_type_state(inner_type); @@ -431,7 +355,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Bool, visitor) + rbwnat_deserialize_any!(self, SerdeType::Bool, visitor) } #[inline] @@ -439,7 +363,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::I8, visitor) + rbwnat_deserialize_any!(self, SerdeType::I8, visitor) } #[inline] @@ -447,7 +371,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::I16, visitor) + rbwnat_deserialize_any!(self, SerdeType::I16, visitor) } #[inline] @@ -455,7 +379,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::I32, visitor) + rbwnat_deserialize_any!(self, SerdeType::I32, visitor) } #[inline] @@ -463,7 +387,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::I64, visitor) + rbwnat_deserialize_any!(self, SerdeType::I64, visitor) } #[inline] @@ -471,7 +395,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::I128, visitor) + rbwnat_deserialize_any!(self, SerdeType::I128, visitor) } #[inline] @@ -479,7 +403,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::U8, visitor) + rbwnat_deserialize_any!(self, SerdeType::U8, visitor) } #[inline] @@ -487,7 +411,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::U16, visitor) + rbwnat_deserialize_any!(self, SerdeType::U16, visitor) } #[inline] @@ -495,7 +419,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::U32, visitor) + rbwnat_deserialize_any!(self, SerdeType::U32, visitor) } #[inline] @@ -503,7 +427,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::U64, visitor) + rbwnat_deserialize_any!(self, SerdeType::U64, visitor) } #[inline] @@ -511,7 +435,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::U128, visitor) + rbwnat_deserialize_any!(self, SerdeType::U128, visitor) } #[inline] @@ -519,7 +443,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::F32, visitor) + rbwnat_deserialize_any!(self, SerdeType::F32, visitor) } #[inline] @@ -527,7 +451,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::F64, visitor) + rbwnat_deserialize_any!(self, SerdeType::F64, visitor) } #[inline] @@ -535,7 +459,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Char, visitor) + rbwnat_deserialize_any!(self, SerdeType::Char, visitor) } #[inline] @@ -543,7 +467,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Str, visitor) + rbwnat_deserialize_any!(self, SerdeType::Str, visitor) } #[inline] @@ -551,7 +475,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::String, visitor) + rbwnat_deserialize_any!(self, SerdeType::String, visitor) } #[inline] @@ -559,7 +483,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Bytes, visitor) + rbwnat_deserialize_any!(self, SerdeType::Bytes, visitor) } #[inline] @@ -567,7 +491,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::ByteBuf, visitor) + rbwnat_deserialize_any!(self, SerdeType::ByteBuf, visitor) } #[inline] @@ -575,7 +499,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Option, visitor) + rbwnat_deserialize_any!(self, SerdeType::Option, visitor) } #[inline] @@ -583,7 +507,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Unit, visitor) + rbwnat_deserialize_any!(self, SerdeType::Unit, visitor) } #[inline] @@ -595,7 +519,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::UnitStruct, visitor) + rbwnat_deserialize_any!(self, SerdeType::UnitStruct, visitor) } #[inline] @@ -607,7 +531,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::NewtypeStruct, visitor) + rbwnat_deserialize_any!(self, SerdeType::NewtypeStruct, visitor) } #[inline] @@ -615,7 +539,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Seq, visitor) + rbwnat_deserialize_any!(self, SerdeType::Seq, visitor) } #[inline] @@ -627,7 +551,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Tuple, visitor) + rbwnat_deserialize_any!(self, SerdeType::Tuple, visitor) } #[inline] @@ -640,7 +564,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::TupleStruct, visitor) + rbwnat_deserialize_any!(self, SerdeType::TupleStruct, visitor) } #[inline] @@ -648,7 +572,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Map, visitor) + rbwnat_deserialize_any!(self, SerdeType::Map, visitor) } #[inline] @@ -705,7 +629,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Enum, visitor) + rbwnat_deserialize_any!(self, SerdeType::Enum, visitor) } #[inline] @@ -713,7 +637,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::Identifier, visitor) + rbwnat_deserialize_any!(self, SerdeType::Identifier, visitor) } #[inline] @@ -721,7 +645,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer< where V: Visitor<'data>, { - rbwnat_deserialize_any!(self, DelegatedFrom::IgnoredAny, visitor) + rbwnat_deserialize_any!(self, SerdeType::IgnoredAny, visitor) } #[inline] diff --git a/src/rowbinary/mod.rs b/src/rowbinary/mod.rs index 0bba1973..67540808 100644 --- a/src/rowbinary/mod.rs +++ b/src/rowbinary/mod.rs @@ -1,6 +1,7 @@ pub(crate) use de::deserialize_from; -pub(crate) use de_rbwnat::deserialize_from_rbwnat; +pub(crate) use de::deserialize_from_and_validate; pub(crate) use ser::serialize_into; +use std::fmt::Display; mod de; mod de_rbwnat; @@ -8,3 +9,83 @@ mod ser; #[cfg(test)] mod tests; mod utils; + +/// Which Serde data type (De)serializer used for the given type. +/// Displays into Rust types for convenience in errors reporting. +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum SerdeType { + Bool, + I8, + I16, + I32, + I64, + I128, + U8, + U16, + U32, + U64, + U128, + F32, + F64, + Char, + Str, + String, + Bytes, + ByteBuf, + Option, + Unit, + UnitStruct, + NewtypeStruct, + Seq, + Tuple, + TupleStruct, + Map, + Struct, + Enum, + Identifier, + IgnoredAny, +} + +impl Default for SerdeType { + fn default() -> Self { + SerdeType::Struct + } +} + +impl Display for SerdeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let type_name = match self { + SerdeType::Bool => "bool", + SerdeType::I8 => "i8", + SerdeType::I16 => "i16", + SerdeType::I32 => "i32", + SerdeType::I64 => "i64", + SerdeType::I128 => "i128", + SerdeType::U8 => "u8", + SerdeType::U16 => "u16", + SerdeType::U32 => "u32", + SerdeType::U64 => "u64", + SerdeType::U128 => "u128", + SerdeType::F32 => "f32", + SerdeType::F64 => "f64", + SerdeType::Char => "char", + SerdeType::Str => "&str", + SerdeType::String => "String", + SerdeType::Bytes => "&[u8]", + SerdeType::ByteBuf => "Vec", + SerdeType::Option => "Option", + SerdeType::Unit => "()", + SerdeType::UnitStruct => "unit struct", + SerdeType::NewtypeStruct => "newtype struct", + SerdeType::Seq => "Vec", + SerdeType::Tuple => "tuple", + SerdeType::TupleStruct => "tuple struct", + SerdeType::Map => "map", + SerdeType::Struct => "struct", + SerdeType::Enum => "enum", + SerdeType::Identifier => "identifier", + SerdeType::IgnoredAny => "ignored any", + }; + write!(f, "{}", type_name) + } +} diff --git a/src/validation_mode.rs b/src/validation_mode.rs new file mode 100644 index 00000000..a161b057 --- /dev/null +++ b/src/validation_mode.rs @@ -0,0 +1,23 @@ +#[non_exhaustive] +#[derive(Clone)] +pub enum StructValidationMode { + FirstRow, + EachRow, + Disabled, +} + +impl Default for StructValidationMode { + fn default() -> Self { + Self::FirstRow + } +} + +impl std::fmt::Display for StructValidationMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::FirstRow => write!(f, "FirstRow"), + Self::EachRow => write!(f, "EachRow"), + Self::Disabled => write!(f, "Disabled"), + } + } +} diff --git a/tests/it/rbwnat_smoke.rs b/tests/it/rbwnat_smoke.rs index 128b2e13..ace27da8 100644 --- a/tests/it/rbwnat_smoke.rs +++ b/tests/it/rbwnat_smoke.rs @@ -1,8 +1,8 @@ use clickhouse::error::Error; -use clickhouse::output_format::OutputFormat; +use clickhouse::validation_mode::StructValidationMode; use clickhouse_derive::Row; -use clickhouse_rowbinary::parse_columns_header; -use clickhouse_rowbinary::types::{Column, DataTypeNode}; +use clickhouse_rowbinary::parse_rbwnat_columns_header; +use clickhouse_rowbinary::types::{Column, DataTypeHint, DataTypeNode}; use serde::{Deserialize, Serialize}; use std::str::FromStr; use time::format_description::well_known::Iso8601; @@ -45,57 +45,69 @@ async fn test_header_parsing() { .unwrap(); let data = cursor.collect().await.unwrap(); - let result = parse_columns_header(&mut &data[..]).unwrap(); + let result = parse_rbwnat_columns_header(&mut &data[..]).unwrap(); assert_eq!( result, vec![ Column { name: "CounterID".to_string(), - data_type: DataTypeNode::UInt32 + data_type: DataTypeNode::UInt32, + type_hints: vec![DataTypeHint::UInt32] }, Column { name: "StartDate".to_string(), - data_type: DataTypeNode::Date + data_type: DataTypeNode::Date, + type_hints: vec![DataTypeHint::Date] }, Column { name: "Sign".to_string(), - data_type: DataTypeNode::Int8 + data_type: DataTypeNode::Int8, + type_hints: vec![DataTypeHint::Int8] }, Column { name: "IsNew".to_string(), - data_type: DataTypeNode::UInt8 + data_type: DataTypeNode::UInt8, + type_hints: vec![DataTypeHint::UInt8] }, Column { name: "VisitID".to_string(), - data_type: DataTypeNode::UInt64 + data_type: DataTypeNode::UInt64, + type_hints: vec![DataTypeHint::UInt64] }, Column { name: "UserID".to_string(), - data_type: DataTypeNode::UInt64 + data_type: DataTypeNode::UInt64, + type_hints: vec![DataTypeHint::UInt64] }, Column { name: "Goals.ID".to_string(), - data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)) + data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)), + type_hints: vec![DataTypeHint::Array, DataTypeHint::UInt32] }, Column { name: "Goals.Serial".to_string(), - data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)) + data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)), + type_hints: vec![DataTypeHint::Array, DataTypeHint::UInt32] }, Column { name: "Goals.EventTime".to_string(), - data_type: DataTypeNode::Array(Box::new(DataTypeNode::DateTime(None))) + data_type: DataTypeNode::Array(Box::new(DataTypeNode::DateTime(None))), + type_hints: vec![DataTypeHint::Array, DataTypeHint::DateTime] }, Column { name: "Goals.Price".to_string(), - data_type: DataTypeNode::Array(Box::new(DataTypeNode::Int64)) + data_type: DataTypeNode::Array(Box::new(DataTypeNode::Int64)), + type_hints: vec![DataTypeHint::Array, DataTypeHint::Int64] }, Column { name: "Goals.OrderID".to_string(), - data_type: DataTypeNode::Array(Box::new(DataTypeNode::String)) + data_type: DataTypeNode::Array(Box::new(DataTypeNode::String)), + type_hints: vec![DataTypeHint::Array, DataTypeHint::String] }, Column { name: "Goals.CurrencyID".to_string(), - data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)) + data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)), + type_hints: vec![DataTypeHint::Array, DataTypeHint::UInt32] } ] ); @@ -120,7 +132,7 @@ async fn test_basic_types_deserialization() { string_val: String, } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query( " @@ -171,7 +183,7 @@ async fn test_several_simple_rows() { str: String, } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query("SELECT number AS num, toString(number) AS str FROM system.numbers LIMIT 3") .fetch_all::() @@ -203,7 +215,7 @@ async fn test_many_numbers() { no: u64, } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let mut cursor = client .query("SELECT number FROM system.numbers_mt LIMIT 2000") .fetch::() @@ -227,7 +239,7 @@ async fn test_array_deserialization() { description: String, } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query( " @@ -258,6 +270,39 @@ async fn test_array_deserialization() { ); } +#[tokio::test] +async fn test_multi_dimensional_array_deserialization() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + three_dim_array: Vec>>, + id: u16, + } + + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let result = client + .query( + " + SELECT + [[[1.1, 2.2], [3.3, 4.4]], [], [[5.5, 6.6], [7.7, 8.8]]] :: Array(Array(Array(Float64))) AS three_dim_array, + 42 :: UInt16 AS id + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + id: 42, + three_dim_array: vec![ + vec![vec![1.1, 2.2], vec![3.3, 4.4]], + vec![], + vec![vec![5.5, 6.6], vec![7.7, 8.8]] + ], + } + ); +} + #[tokio::test] async fn test_default_types_validation_nullable() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] @@ -265,7 +310,7 @@ async fn test_default_types_validation_nullable() { n: Option, } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query("SELECT true AS b, 144 :: Int32 AS n2") .fetch_one::() @@ -274,7 +319,7 @@ async fn test_default_types_validation_nullable() { assert!(result.is_err()); assert!(matches!( result.unwrap_err(), - Error::DataTypeMismatch { .. } + Error::InvalidColumnDataType { .. } )); // FIXME: lack of derive PartialEq for Error prevents proper assertion @@ -295,17 +340,16 @@ async fn test_default_types_validation_custom_serde() { n1: OffsetDateTime, // underlying is still Int64; should not compose it from two (U)Int32 } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query("SELECT 42 :: UInt32 AS n1, 144 :: Int32 AS n2") .fetch_one::() .await; assert!(result.is_err()); - println!("{:?}", result); assert!(matches!( result.unwrap_err(), - Error::DataTypeMismatch { .. } + Error::InvalidColumnDataType { .. } )); // FIXME: lack of derive PartialEq for Error prevents proper assertion @@ -326,7 +370,7 @@ async fn test_too_many_struct_fields() { c: u32, } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query("SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS b") .fetch_one::() @@ -349,7 +393,7 @@ async fn test_serde_skip_deserializing() { c: u32, } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query("SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS c") .fetch_one::() @@ -386,7 +430,7 @@ async fn test_date_time_types() { date_time64_9: OffsetDateTime, } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query( " @@ -434,7 +478,7 @@ async fn test_ipv4_ipv6() { ipv6: std::net::Ipv6Addr, } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query( " @@ -467,7 +511,7 @@ async fn test_different_struct_field_order() { a: String, } - let client = prepare_database!().with_fetch_format(OutputFormat::RowBinaryWithNamesAndTypes); + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query("SELECT 'foo' AS a, 'bar' :: String AS c") .fetch_one::() From 8ae362993f5976bad3238c43c55156f2e41450c7 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Tue, 20 May 2025 00:44:06 +0200 Subject: [PATCH 06/54] RBWNAT deserializer - validation WIP --- rowbinary/src/{types.rs => data_types.rs} | 8 +- rowbinary/src/decoders.rs | 1 - rowbinary/src/error.rs | 1 + rowbinary/src/leb128.rs | 1 + rowbinary/src/lib.rs | 4 +- src/cursors/row.rs | 3 +- src/error.rs | 23 +- src/rowbinary/de.rs | 429 +++++++------- src/rowbinary/de_rbwnat.rs | 655 ---------------------- src/rowbinary/mod.rs | 84 +-- src/rowbinary/validation.rs | 380 +++++++++++++ tests/it/main.rs | 2 +- tests/it/{rbwnat_smoke.rs => rbwnat.rs} | 178 ++++-- 13 files changed, 781 insertions(+), 988 deletions(-) rename rowbinary/src/{types.rs => data_types.rs} (99%) delete mode 100644 src/rowbinary/de_rbwnat.rs create mode 100644 src/rowbinary/validation.rs rename tests/it/{rbwnat_smoke.rs => rbwnat.rs} (78%) diff --git a/rowbinary/src/types.rs b/rowbinary/src/data_types.rs similarity index 99% rename from rowbinary/src/types.rs rename to rowbinary/src/data_types.rs index 241c5174..8fa204f0 100644 --- a/rowbinary/src/types.rs +++ b/rowbinary/src/data_types.rs @@ -6,17 +6,11 @@ use std::fmt::{Display, Formatter}; pub struct Column { pub name: String, pub data_type: DataTypeNode, - pub type_hints: Vec, } impl Column { pub fn new(name: String, data_type: DataTypeNode) -> Self { - let type_hints = data_type.get_type_hints(); - Self { - name, - data_type, - type_hints, - } + Self { name, data_type } } } diff --git a/rowbinary/src/decoders.rs b/rowbinary/src/decoders.rs index 61f2f974..02de935f 100644 --- a/rowbinary/src/decoders.rs +++ b/rowbinary/src/decoders.rs @@ -4,7 +4,6 @@ use bytes::Buf; #[inline] pub(crate) fn decode_string(buffer: &mut &[u8]) -> Result { - // println!("[decode_string] buffer: {:?}", buffer); let length = decode_leb128(buffer)? as usize; if length == 0 { return Ok("".to_string()); diff --git a/rowbinary/src/error.rs b/rowbinary/src/error.rs index eb10af4f..1ca0215d 100644 --- a/rowbinary/src/error.rs +++ b/rowbinary/src/error.rs @@ -1,3 +1,4 @@ +// FIXME: better errors #[derive(Debug, thiserror::Error)] pub enum ParserError { #[error("Not enough data: {0}")] diff --git a/rowbinary/src/leb128.rs b/rowbinary/src/leb128.rs index dd03148f..93ec92b2 100644 --- a/rowbinary/src/leb128.rs +++ b/rowbinary/src/leb128.rs @@ -24,6 +24,7 @@ pub fn decode_leb128(buffer: &mut &[u8]) -> Result { Ok(value) } +// FIXME: do not use Vec pub fn encode_leb128(value: u64) -> Vec { let mut result = Vec::new(); let mut val = value; diff --git a/rowbinary/src/lib.rs b/rowbinary/src/lib.rs index 9e793f48..3ec51c7d 100644 --- a/rowbinary/src/lib.rs +++ b/rowbinary/src/lib.rs @@ -1,12 +1,12 @@ +use crate::data_types::{Column, DataTypeNode}; use crate::decoders::decode_string; use crate::error::ParserError; use crate::leb128::decode_leb128; -use crate::types::{Column, DataTypeNode}; +pub mod data_types; pub mod decoders; pub mod error; pub mod leb128; -pub mod types; pub fn parse_rbwnat_columns_header(bytes: &mut &[u8]) -> Result, ParserError> { let num_columns = decode_leb128(bytes)?; diff --git a/src/cursors/row.rs b/src/cursors/row.rs index d53b4ed3..e0da7cb9 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -6,8 +6,8 @@ use crate::{ response::Response, rowbinary, }; +use clickhouse_rowbinary::data_types::Column; use clickhouse_rowbinary::parse_rbwnat_columns_header; -use clickhouse_rowbinary::types::Column; use serde::Deserialize; use std::marker::PhantomData; @@ -56,6 +56,7 @@ impl RowCursor { let mut slice = super::workaround_51132(self.bytes.slice()); let deserialize_result = if should_validate { match &self.columns { + // TODO: can it be moved to `new` instead? None => { let columns = parse_rbwnat_columns_header(&mut slice)?; self.bytes.set_remaining(slice.len()); diff --git a/src/error.rs b/src/error.rs index 5d191f60..cc1ab550 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,7 +1,9 @@ //! Contains [`Error`] and corresponding [`Result`]. -use clickhouse_rowbinary::types::Column; +use crate::rowbinary::SerdeType; +use clickhouse_rowbinary::data_types::{DataTypeHint, DataTypeNode}; use serde::{de, ser}; +use std::fmt::Display; use std::{error::Error as StdError, fmt, io, result, str::Utf8Error}; /// A result with a specified [`Error`] type. @@ -53,18 +55,21 @@ pub enum Error { unexpected_type: String, all_columns: String, }, - #[error("invalid column data type: {0}")] - InvalidColumnDataType(String), - #[error( - "too many struct fields: trying to read more columns than expected {0}. All columns: {1:?}" - )] - TooManyStructFields(usize, Vec), - #[error("deserialize is called for more fields than a struct has")] - DeserializeCallAfterEndOfStruct, + #[error("deserializing field: {0}; serde type: {1} expected to be deserialized as: {}", join_seq(.2))] + InvalidColumnDataType(DataTypeNode, &'static SerdeType, &'static [DataTypeHint]), + #[error("too many struct fields: trying to read more columns than expected")] + TooManyStructFields, #[error("{0}")] Other(BoxedError), } +fn join_seq(vec: &[T]) -> String { + vec.iter() + .map(|x| x.to_string()) + .collect::>() + .join(", ") +} + assert_impl_all!(Error: StdError, Send, Sync); impl From for Error { diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 8015636b..033676d9 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -1,8 +1,10 @@ use crate::error::{Error, Result}; use crate::rowbinary::utils::{ensure_size, get_unsigned_leb128}; -use crate::rowbinary::SerdeType; +use crate::rowbinary::validation::SerdeType; +use crate::rowbinary::validation::{DataTypeValidator, ValidateDataType}; use bytes::Buf; -use clickhouse_rowbinary::types::{Column, DataTypeHint}; +use clickhouse_rowbinary::data_types::{Column, DataTypeHint}; +use serde::de::MapAccess; use serde::{ de::{DeserializeSeed, Deserializer, EnumAccess, SeqAccess, VariantAccess, Visitor}, Deserialize, @@ -15,9 +17,11 @@ use std::{convert::TryFrom, mem, str}; /// performant generated code than `(&[u8]) -> Result<(T, usize)>` and even /// `(&[u8], &mut Option) -> Result`. pub(crate) fn deserialize_from<'data, T: Deserialize<'data>>(input: &mut &'data [u8]) -> Result { + println!("deserialize_from call"); + let mut deserializer = RowBinaryDeserializer { input, - columns_validator: (), + validator: (), }; T::deserialize(&mut deserializer) } @@ -31,154 +35,27 @@ pub(crate) fn deserialize_from_and_validate<'data, 'cursor, T: Deserialize<'data ) -> Result { let mut deserializer = RowBinaryDeserializer { input, - columns_validator: ColumnsValidator { - columns, - col_idx: 0, - type_hint_idx: 0, - nesting_level: 0, - }, + validator: DataTypeValidator::new(columns), }; - T::deserialize(&mut deserializer) -} - -struct ColumnsValidator<'cursor> { - columns: &'cursor [Column], - col_idx: usize, - type_hint_idx: usize, - nesting_level: usize, -} - -impl<'cursor> ColumnsValidator<'cursor> { - #[inline] - fn advance(&mut self) { - self.col_idx += 1; - self.type_hint_idx = 0; - } -} - -pub(crate) trait ValidateDataType { - fn validate( - &mut self, - serde_type: &'static SerdeType, - allowed: &'static [DataTypeHint], - has_inner_type: bool, - ) -> Result<()>; - fn skip_next(&mut self) -> (); - fn increase_nesting(&mut self) -> (); - fn decrease_nesting(&mut self) -> (); -} - -impl ValidateDataType for () { - #[inline] - fn validate( - &mut self, - _serde_type: &'static SerdeType, - _allowed: &'static [DataTypeHint], - _has_inner_type: bool, - ) -> Result<()> { - Ok(()) - } - #[inline] - fn skip_next(&mut self) -> () {} - #[inline] - fn increase_nesting(&mut self) -> () {} - #[inline] - fn decrease_nesting(&mut self) -> () {} -} - -impl<'cursor> ValidateDataType for ColumnsValidator<'cursor> { - #[inline] - fn validate( - &mut self, - serde_type: &'static SerdeType, - allowed: &'static [DataTypeHint], - has_inner_type: bool, - ) -> Result<()> { - println!( - "Validating column {}, type hint {}, serde type {}, allowed {:?}, nesting level {}", - self.col_idx, self.type_hint_idx, serde_type, allowed, self.nesting_level - ); - if self.col_idx >= self.columns.len() { - return Err(Error::TooManyStructFields( - self.columns.len(), - self.columns.into(), - )); - } - if has_inner_type { - self.nesting_level += 1; - println!("Increased nesting level to {}", self.nesting_level); - } - let current_column = &self.columns[self.col_idx]; - if self.type_hint_idx >= current_column.type_hints.len() { - // if self.nesting_level == 0 { - // println!("Advancing #1"); - // self.advance(); - // } - println!( - "Skipping check for column {}, type hint {}, nesting level {}", - current_column.name, self.type_hint_idx, self.nesting_level - ); - return Ok(()); - } - let db_type_hint = ¤t_column.type_hints[self.type_hint_idx]; - if allowed.contains(db_type_hint) { - // self.type_hint_idx += 1; - if self.nesting_level == 0 { - println!("Advancing #2"); - self.advance(); - } else { - self.type_hint_idx += 1; - } - println!( - "Checked column {}, type hint {}, nesting level {}", - current_column.name, self.type_hint_idx, self.nesting_level - ); - Ok(()) - } else { - Err(Error::InvalidColumnDataType( - format!("column {} {} requires deserialization from {} as {}, but serde call allowed only for {:?}", - current_column.name, current_column.data_type, db_type_hint, serde_type, allowed))) - } - } - - #[inline] - fn skip_next(&mut self) { - self.type_hint_idx += 1; - } - - #[inline] - fn increase_nesting(&mut self) { - self.nesting_level += 1; - } - - #[inline] - fn decrease_nesting(&mut self) { - println!("Decreasing nesting level from {}", self.nesting_level); - if self.nesting_level == 1 { - self.advance(); - } else if self.nesting_level > 0 { - self.nesting_level -= 1; - } else { - panic!("decrease_nesting called when nesting level is already 0, current column index {}, type hint index {}, all columns: {:?}", - self.col_idx, self.type_hint_idx, self.columns); - } - } + T::deserialize(&mut deserializer).inspect_err(|e| { + println!("deserialize_from_and_validate call failed: {:?}", e); + }) } -/// A deserializer for the RowBinary format. +/// A deserializer for the RowBinary(WithNamesAndTypes) format. /// /// See https://clickhouse.com/docs/en/interfaces/formats#rowbinary for details. -pub(crate) struct RowBinaryDeserializer<'cursor, 'data, Columns = ()> +pub(crate) struct RowBinaryDeserializer<'cursor, 'data, Validator = ()> where - Columns: ValidateDataType, + Validator: ValidateDataType, { - pub(crate) columns_validator: Columns, + pub(crate) validator: Validator, pub(crate) input: &'cursor mut &'data [u8], } -impl<'data, Columns> RowBinaryDeserializer<'_, 'data, Columns> +impl<'data, Validator> RowBinaryDeserializer<'_, 'data, Validator> where - Columns: ValidateDataType, + Validator: ValidateDataType, { pub(crate) fn read_vec(&mut self, size: usize) -> Result> { Ok(self.read_slice(size)?.to_vec()) @@ -202,8 +79,7 @@ macro_rules! impl_num { ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr, $type_hints:expr) => { #[inline] fn $deser_method>(self, visitor: V) -> Result { - self.columns_validator - .validate($serde_type, $type_hints, false)?; + self.validator.validate($serde_type, $type_hints)?; ensure_size(&mut self.input, mem::size_of::<$ty>())?; let value = self.input.$reader_method(); visitor.$visitor_method(value) @@ -211,9 +87,9 @@ macro_rules! impl_num { }; } -impl<'data, Columns> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data, Columns> +impl<'data, Validator> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data, Validator> where - Columns: ValidateDataType, + Validator: ValidateDataType, { type Error = Error; @@ -318,27 +194,34 @@ where #[inline] fn deserialize_any>(self, _: V) -> Result { + println!("deserialize_any call"); + Err(Error::DeserializeAnyNotSupported) } #[inline] fn deserialize_unit>(self, visitor: V) -> Result { + println!("deserialize_unit call"); + // TODO: revise this. visitor.visit_unit() } #[inline] fn deserialize_char>(self, _: V) -> Result { + println!("deserialize_char call"); + panic!("character types are unsupported: `char`"); } #[inline] fn deserialize_bool>(self, visitor: V) -> Result { - self.columns_validator.validate( + println!("deserialize_bool call"); + + self.validator.validate( &SerdeType::Bool, // TODO: shall we allow deserialization from integers? &[DataTypeHint::Bool, DataTypeHint::Int8, DataTypeHint::UInt8], - false, )?; ensure_size(&mut self.input, 1)?; match self.input.get_u8() { @@ -350,9 +233,11 @@ where #[inline] fn deserialize_str>(self, visitor: V) -> Result { + println!("deserialize_str call"); + // TODO - which types to allow? - self.columns_validator - .validate(&SerdeType::String, &[DataTypeHint::String], false)?; + self.validator + .validate(&SerdeType::String, &[DataTypeHint::String])?; let size = self.read_size()?; let slice = self.read_slice(size)?; let str = str::from_utf8(slice).map_err(Error::from)?; @@ -361,9 +246,11 @@ where #[inline] fn deserialize_string>(self, visitor: V) -> Result { + println!("deserialize_string call"); + // TODO - which types to allow? - self.columns_validator - .validate(&SerdeType::String, &[DataTypeHint::String], false)?; + self.validator + .validate(&SerdeType::String, &[DataTypeHint::String])?; let size = self.read_size()?; let vec = self.read_vec(size)?; let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; @@ -372,6 +259,8 @@ where #[inline] fn deserialize_bytes>(self, visitor: V) -> Result { + println!("deserialize_bytes call"); + // TODO - which types to allow? let size = self.read_size()?; let slice = self.read_slice(size)?; @@ -380,6 +269,8 @@ where #[inline] fn deserialize_byte_buf>(self, visitor: V) -> Result { + println!("deserialize_byte_buf call"); + // TODO - which types to allow? let size = self.read_size()?; visitor.visit_byte_buf(self.read_vec(size)?) @@ -387,6 +278,8 @@ where #[inline] fn deserialize_identifier>(self, visitor: V) -> Result { + println!("deserialize_identifier call"); + // TODO - which types to allow? self.deserialize_u8(visitor) } @@ -398,21 +291,25 @@ where _variants: &'static [&'static str], visitor: V, ) -> Result { - struct Access<'de, 'cursor, 'data, Columns> + println!("deserialize_enum call"); + + struct RowBinaryEnumAccess<'de, 'cursor, 'data, Validator> where - Columns: ValidateDataType, + Validator: ValidateDataType, { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Columns>, + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, } - struct VariantDeserializer<'de, 'cursor, 'data, Columns> + + struct VariantDeserializer<'de, 'cursor, 'data, Validator> where - Columns: ValidateDataType, + Validator: ValidateDataType, { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Columns>, + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, } - impl<'data, Columns> VariantAccess<'data> for VariantDeserializer<'_, '_, 'data, Columns> + + impl<'data, Validator> VariantAccess<'data> for VariantDeserializer<'_, '_, 'data, Validator> where - Columns: ValidateDataType, + Validator: ValidateDataType, { type Error = Error; @@ -446,12 +343,13 @@ where } } - impl<'de, 'cursor, 'data, Columns> EnumAccess<'data> for Access<'de, 'cursor, 'data, Columns> + impl<'de, 'cursor, 'data, Validator> EnumAccess<'data> + for RowBinaryEnumAccess<'de, 'cursor, 'data, Validator> where - Columns: ValidateDataType, + Validator: ValidateDataType, { type Error = Error; - type Variant = VariantDeserializer<'de, 'cursor, 'data, Columns>; + type Variant = VariantDeserializer<'de, 'cursor, 'data, Validator>; fn variant_seed(self, seed: T) -> Result<(T::Value, Self::Variant), Self::Error> where @@ -464,24 +362,28 @@ where Ok((value, deserializer)) } } - self.columns_validator - .validate(&SerdeType::Enum, &[DataTypeHint::Enum], false)?; - visitor.visit_enum(Access { deserializer: self }) + + // FIXME + self.validator + .validate(&SerdeType::Enum, &[DataTypeHint::Enum])?; + visitor.visit_enum(RowBinaryEnumAccess { deserializer: self }) } #[inline] fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { - struct Access<'de, 'cursor, 'data, Columns> + println!("deserialize_tuple call, len {}", len); + + struct RowBinaryTupleAccess<'de, 'cursor, 'data, Validator> where - Columns: ValidateDataType, + Validator: ValidateDataType, { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Columns>, + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, len: usize, } - impl<'data, Columns> SeqAccess<'data> for Access<'_, '_, 'data, Columns> + impl<'data, Validator> SeqAccess<'data> for RowBinaryTupleAccess<'_, '_, 'data, Validator> where - Columns: ValidateDataType, + Validator: ValidateDataType, { type Error = Error; @@ -503,59 +405,202 @@ where } } - visitor.visit_seq(Access { - deserializer: self, + let len = self.read_size()?; + let inner_data_type_validator = self + .validator + .validate(&SerdeType::Seq, &[DataTypeHint::Array, DataTypeHint::IPv6])?; + visitor.visit_seq(RowBinaryTupleAccess { + deserializer: &mut RowBinaryDeserializer { + input: self.input, + validator: inner_data_type_validator, + }, len, }) } #[inline] fn deserialize_option>(self, visitor: V) -> Result { - self.columns_validator - .validate(&SerdeType::Option, &[DataTypeHint::Nullable], true)?; + println!("deserialize_option call"); + ensure_size(&mut self.input, 1)?; + let inner_data_type_validator = self + .validator + .validate(&SerdeType::Option, &[DataTypeHint::Nullable])?; match self.input.get_u8() { - 0 => visitor.visit_some(&mut *self), - 1 => { - self.columns_validator.skip_next(); - self.columns_validator.decrease_nesting(); - visitor.visit_none() - } + 0 => visitor.visit_some(&mut RowBinaryDeserializer { + input: self.input, + validator: inner_data_type_validator, + }), + 1 => visitor.visit_none(), v => Err(Error::InvalidTagEncoding(v as usize)), } } #[inline] fn deserialize_seq>(self, visitor: V) -> Result { - self.columns_validator - .validate(&SerdeType::Seq, &[DataTypeHint::Array], true)?; + println!("deserialize_seq call"); + + struct RowBinarySeqAccess<'de, 'cursor, 'data, Validator> + where + Validator: ValidateDataType, + { + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + len: usize, + } + + impl<'data, Validator> SeqAccess<'data> for RowBinarySeqAccess<'_, '_, 'data, Validator> + where + Validator: ValidateDataType, + { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'data>, + { + if self.len > 0 { + self.len -= 1; + let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; + Ok(Some(value)) + } else { + Ok(None) + } + } + + fn size_hint(&self) -> Option { + Some(self.len) + } + } + let len = self.read_size()?; - let result = self.deserialize_tuple(len, visitor); - self.columns_validator.decrease_nesting(); - result + let inner_data_type_validator = self + .validator + .validate(&SerdeType::Seq, &[DataTypeHint::Array])?; + visitor.visit_seq(RowBinarySeqAccess { + deserializer: &mut RowBinaryDeserializer { + input: self.input, + validator: inner_data_type_validator, + }, + len, + }) } #[inline] - fn deserialize_map>(self, _visitor: V) -> Result { - panic!("maps are unsupported, use `Vec<(A, B)>` instead"); + fn deserialize_map>(self, visitor: V) -> Result { + println!("deserialize_map call"); + + struct RowBinaryMapAccess<'de, 'cursor, 'data, Validator> + where + Validator: ValidateDataType, + { + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + entries_visited: usize, + len: usize, + } + + impl<'data, Validator> MapAccess<'data> for RowBinaryMapAccess<'_, '_, 'data, Validator> + where + Validator: ValidateDataType, + { + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'data>, + { + if self.entries_visited >= self.len { + return Ok(None); + } + self.entries_visited += 1; + seed.deserialize(&mut *self.deserializer).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'data>, + { + seed.deserialize(&mut *self.deserializer) + } + + fn size_hint(&self) -> Option { + Some(self.len) + } + } + + let len = self.read_size()?; + let inner_data_type_validator = self + .validator + .validate(&SerdeType::Map, &[DataTypeHint::Map])?; + visitor.visit_map(RowBinaryMapAccess { + deserializer: &mut RowBinaryDeserializer { + input: self.input, + validator: inner_data_type_validator, + }, + entries_visited: 0, + len, + }) } #[inline] fn deserialize_struct>( self, - _name: &str, + name: &str, fields: &'static [&'static str], visitor: V, ) -> Result { - self.deserialize_tuple(fields.len(), visitor) + println!("deserialize_struct call - {}", name); + + // FIXME use &'_ str, fix lifetimes + self.validator.set_struct_name(name.to_string()); + + // TODO: it should also support using HashMap to deserialize + // currently just copy-pasted to prevent former `deserialize_tuple` delegation + struct RowBinaryStructAccess<'de, 'cursor, 'data, Validator> + where + Validator: ValidateDataType, + { + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + len: usize, + } + + impl<'data, Validator> SeqAccess<'data> for RowBinaryStructAccess<'_, '_, 'data, Validator> + where + Validator: ValidateDataType, + { + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'data>, + { + if self.len > 0 { + self.len -= 1; + let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; + Ok(Some(value)) + } else { + Ok(None) + } + } + + fn size_hint(&self) -> Option { + Some(self.len) + } + } + + visitor.visit_seq(RowBinaryStructAccess { + deserializer: self, + len: fields.len(), + }) } #[inline] fn deserialize_newtype_struct>( self, - _name: &str, + name: &str, visitor: V, ) -> Result { + println!("deserialize_newtype_struct call - {}", name); + // TODO - skip validation? visitor.visit_newtype_struct(self) } @@ -566,6 +611,8 @@ where name: &'static str, _visitor: V, ) -> Result { + println!("deserialize_unit_struct call"); + panic!("unit types are unsupported: `{name}`"); } @@ -576,11 +623,15 @@ where _len: usize, _visitor: V, ) -> Result { + println!("deserialize_tuple_struct call"); + panic!("tuple struct types are unsupported: `{name}`"); } #[inline] fn deserialize_ignored_any>(self, _visitor: V) -> Result { + println!("deserialize_ignored_any call"); + panic!("ignored types are unsupported"); } diff --git a/src/rowbinary/de_rbwnat.rs b/src/rowbinary/de_rbwnat.rs deleted file mode 100644 index 7894d8e9..00000000 --- a/src/rowbinary/de_rbwnat.rs +++ /dev/null @@ -1,655 +0,0 @@ -use crate::error::{Error, Result}; -use crate::rowbinary::de::RowBinaryDeserializer; -use crate::rowbinary::SerdeType; -use clickhouse_rowbinary::types::{Column, DataTypeNode}; -use serde::de::{DeserializeSeed, SeqAccess, Visitor}; -use serde::{Deserialize, Deserializer}; -use std::ops::Deref; -use std::rc::Rc; - -pub(crate) fn _deserialize_from_rbwnat<'data, 'cursor, T: Deserialize<'data>>( - input: &mut &'data [u8], - columns: &'cursor [Column], -) -> Result { - // println!("[RBWNAT] deserializing with names and types: {:?}, input size: {}", columns, input.len()); - let mut deserializer = RowBinaryWithNamesAndTypesDeserializer::_new(input, columns)?; - let value = T::deserialize(&mut deserializer); - // println!("Remaining input size: {}", input.len()); - value -} - -#[derive(Clone, Debug)] -enum DeserializerState<'cursor> { - /// At this point, we are either processing a "simple" column (e.g., `UInt32`, `String`, etc.), - /// or starting to process a more complex one (e.g., `Array(T)`, `Map(K, V)`, etc.). - TopLevelColumn(&'cursor Column), - /// Processing a column with a complex type (e.g., `Array(T)`), and we've got what `T` is. - /// We can use this to verify the inner type definition in the struct. - InnerDataType { - column: &'cursor Column, - prev_state: Rc>, - inner_data_type: &'cursor DataTypeNode, - }, - /// We are done with all columns and should not try to deserialize anything else. - EndOfStruct, -} - -pub(crate) struct RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { - row_binary: RowBinaryDeserializer<'cursor, 'data>, - state: DeserializerState<'cursor>, - columns: &'cursor [Column], - current_column_idx: usize, - // main usage is to check if the struct field definition is compatible with the expected one - last_delegated_from: SerdeType, - // every deserialization begins from a struct with some name - struct_name: Option<&'static str>, - struct_fields: Option<&'static [&'static str]>, -} - -impl<'cursor, 'data> RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data> { - #[inline] - fn _new( - input: &'cursor mut &'data [u8], - columns: &'cursor [Column], - ) -> Result> { - if columns.is_empty() { - // unlikely - should be validated by the columns parser already - panic!("Zero columns definitions in the response"); - } - Ok(RowBinaryWithNamesAndTypesDeserializer { - row_binary: RowBinaryDeserializer { - input, - columns_validator: (), - }, - state: DeserializerState::TopLevelColumn(&columns[0]), - last_delegated_from: SerdeType::default(), - current_column_idx: 0, - struct_name: None, - struct_fields: None, - columns, - }) - } - - #[inline] - fn set_last_delegated_from(&mut self, from: SerdeType) { - if self.last_delegated_from != from { - self.last_delegated_from = from; - } - } - - #[inline] - fn set_struct_name(&mut self, name: &'static str) { - // TODO: nested structs support? - if self.struct_name.is_none() { - self.struct_name = Some(name); - } - } - - #[inline] - fn set_struct_fields(&mut self, fields: &'static [&'static str]) { - // TODO: nested structs support? - if self.struct_fields.is_none() { - self.struct_fields = Some(fields); - } - } - - #[inline] - fn advance_state(&mut self) -> Result<()> { - match &self.state { - DeserializerState::TopLevelColumn { .. } => { - self.current_column_idx += 1; - if self.current_column_idx >= self.columns.len() { - self.state = DeserializerState::EndOfStruct; - } else { - let current_col = self.get_current_column()?; - self.state = DeserializerState::TopLevelColumn(current_col); - } - } - DeserializerState::EndOfStruct => { - panic!("trying to advance the current column index after full deserialization"); - } - // skipping this when processing inner data types with more than one nesting level - _ => {} - } - Ok(()) - } - - #[inline] - fn set_inner_data_type_state(&mut self, inner_data_type: &'cursor DataTypeNode) { - match self.state { - DeserializerState::TopLevelColumn(column, ..) - | DeserializerState::InnerDataType { column, .. } => { - self.state = DeserializerState::InnerDataType { - prev_state: Rc::new(self.state.clone()), - inner_data_type, - column, - } - } - _ => { - panic!("to_inner called on invalid state"); - } - } - } - - #[inline] - fn set_previous_state(&mut self) { - match &self.state { - DeserializerState::InnerDataType { prev_state, .. } => { - self.state = prev_state.deref().clone() - } - _ => panic!("to_prev_state called on invalid state"), - } - } - - #[inline] - fn get_current_data_type(&self) -> Result<&'cursor DataTypeNode> { - match self.state { - DeserializerState::TopLevelColumn(col, ..) => Ok(&col.data_type), - DeserializerState::InnerDataType { - inner_data_type, .. - } => Ok(inner_data_type), - DeserializerState::EndOfStruct => Err(Error::DeserializeCallAfterEndOfStruct), - } - } - - #[inline] - fn get_current_column(&mut self) -> Result<&'cursor Column> { - if self.current_column_idx >= self.columns.len() { - return Err(Error::TooManyStructFields( - self.current_column_idx, - Vec::from(self.columns), - )); - } - let col = &self.columns[self.current_column_idx]; - Ok(col) - } - - #[inline] - fn check_data_type_is_allowed(&mut self, allowed: &[SerdeType]) -> Result<()> { - if !allowed.contains(&self.last_delegated_from) { - let column = self.get_current_column()?; - let field_name = match self.struct_name { - Some(struct_name) => format!("{}.{}", struct_name, column.name), - None => column.name.to_string(), - }; - let allowed_types = allowed - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "); - let all_columns = self - .columns - .iter() - .map(|x| x.to_string()) - .collect::>() - .join(", "); - let unexpected_type = self.last_delegated_from.to_string(); - Err(Error::DataTypeMismatch { - field_name, - allowed_types, - unexpected_type, - all_columns, - }) - } else { - Ok(()) - } - } -} - -macro_rules! rbwnat_deserialize_any { - ($self:ident, $delegated_from:expr, $visitor:ident) => {{ - $self.set_last_delegated_from($delegated_from); - $self.deserialize_any($visitor) - }}; -} - -impl<'data> Deserializer<'data> for &mut RowBinaryWithNamesAndTypesDeserializer<'_, 'data> { - type Error = Error; - - #[inline] - fn deserialize_any(self, visitor: V) -> Result - where - V: Visitor<'data>, - { - macro_rules! rbwnat_de_simple_with_type_check { - ($delegate:ident, $compatible:expr) => {{ - self.check_data_type_is_allowed(&$compatible)?; - self.row_binary.$delegate(visitor) - }}; - } - - println!("{} state: {:?}", self.last_delegated_from, self.state); - let data_type = self.get_current_data_type()?; - let result = match data_type { - DataTypeNode::Bool => rbwnat_de_simple_with_type_check!( - deserialize_bool, - [SerdeType::Bool, SerdeType::U8, SerdeType::I8] - ), - DataTypeNode::UInt8 => { - rbwnat_de_simple_with_type_check!(deserialize_u8, [SerdeType::U8]) - } - DataTypeNode::Int8 => { - rbwnat_de_simple_with_type_check!(deserialize_i8, [SerdeType::I8]) - } - DataTypeNode::Int16 => { - rbwnat_de_simple_with_type_check!(deserialize_i16, [SerdeType::I16]) - } - DataTypeNode::Int32 => { - rbwnat_de_simple_with_type_check!(deserialize_i32, [SerdeType::I32]) - } - DataTypeNode::Int64 => { - rbwnat_de_simple_with_type_check!(deserialize_i64, [SerdeType::I64]) - } - DataTypeNode::Int128 => { - rbwnat_de_simple_with_type_check!(deserialize_i128, [SerdeType::I128]) - } - DataTypeNode::UInt16 => { - rbwnat_de_simple_with_type_check!(deserialize_u16, [SerdeType::U16]) - } - DataTypeNode::UInt32 => { - rbwnat_de_simple_with_type_check!(deserialize_u32, [SerdeType::U32]) - } - DataTypeNode::UInt64 => { - rbwnat_de_simple_with_type_check!(deserialize_u64, [SerdeType::U64]) - } - DataTypeNode::UInt128 => { - rbwnat_de_simple_with_type_check!(deserialize_u128, [SerdeType::U128]) - } - DataTypeNode::Float32 => { - rbwnat_de_simple_with_type_check!(deserialize_f32, [SerdeType::F32]) - } - DataTypeNode::Float64 => { - rbwnat_de_simple_with_type_check!(deserialize_f64, [SerdeType::F64]) - } - DataTypeNode::String => { - rbwnat_de_simple_with_type_check!( - deserialize_str, - [SerdeType::Str, SerdeType::String] - ) - } - DataTypeNode::FixedString(len) => match self.last_delegated_from { - SerdeType::Bytes => visitor.visit_bytes(self.row_binary.read_slice(*len)?), - SerdeType::ByteBuf => visitor.visit_byte_buf(self.row_binary.read_vec(*len)?), - _ => unreachable!(), - }, - DataTypeNode::UUID => { - rbwnat_de_simple_with_type_check!( - deserialize_str, - [SerdeType::Str, SerdeType::String] - ) - } - DataTypeNode::Date => { - rbwnat_de_simple_with_type_check!(deserialize_u16, [SerdeType::U16]) - } - DataTypeNode::Date32 => { - rbwnat_de_simple_with_type_check!(deserialize_i32, [SerdeType::I32]) - } - DataTypeNode::DateTime { .. } => { - rbwnat_de_simple_with_type_check!(deserialize_u32, [SerdeType::U32]) - } - DataTypeNode::DateTime64 { .. } => { - rbwnat_de_simple_with_type_check!(deserialize_i64, [SerdeType::I64]) - } - DataTypeNode::IPv4 => { - rbwnat_de_simple_with_type_check!(deserialize_u32, [SerdeType::U32]) - } - DataTypeNode::IPv6 => self.row_binary.deserialize_tuple(16, visitor), - - DataTypeNode::Array(inner_type) => { - self.check_data_type_is_allowed(&[SerdeType::Seq])?; - let len = self.row_binary.read_size()?; - self.set_inner_data_type_state(inner_type); - - struct AnyArrayAccess<'de, 'cursor, 'data> { - deserializer: &'de mut RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data>, - remaining: usize, - } - - impl<'data> SeqAccess<'data> for AnyArrayAccess<'_, '_, 'data> { - type Error = Error; - - fn next_element_seed( - &mut self, - seed: T, - ) -> Result, Self::Error> - where - T: DeserializeSeed<'data>, - { - if self.remaining == 0 { - return Ok(None); - } - - self.remaining -= 1; - seed.deserialize(&mut *self.deserializer).map(Some) - } - - fn size_hint(&self) -> Option { - Some(self.remaining) - } - } - - let result = visitor.visit_seq(AnyArrayAccess { - deserializer: self, - remaining: len, - }); - // if we are processing `Array(String)`, the state has `String` as expected type - // revert it back to `Array(String)` - self.set_previous_state(); - result - } - // DataTypeNode::Nullable(inner_type) => { - // self.check_data_type_is_allowed(&[DelegatedFrom::Option])?; - // self.set_inner_data_type_state(inner_type); - // self.row_binary.deserialize_option(visitor) - // }, - _ => panic!("unsupported type for deserialize_any: {:?}", self.columns), - }; - result.and_then(|value| { - self.advance_state()?; - Ok(value) - }) - } - - #[inline] - fn deserialize_bool(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Bool, visitor) - } - - #[inline] - fn deserialize_i8(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::I8, visitor) - } - - #[inline] - fn deserialize_i16(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::I16, visitor) - } - - #[inline] - fn deserialize_i32(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::I32, visitor) - } - - #[inline] - fn deserialize_i64(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::I64, visitor) - } - - #[inline] - fn deserialize_i128(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::I128, visitor) - } - - #[inline] - fn deserialize_u8(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::U8, visitor) - } - - #[inline] - fn deserialize_u16(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::U16, visitor) - } - - #[inline] - fn deserialize_u32(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::U32, visitor) - } - - #[inline] - fn deserialize_u64(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::U64, visitor) - } - - #[inline] - fn deserialize_u128(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::U128, visitor) - } - - #[inline] - fn deserialize_f32(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::F32, visitor) - } - - #[inline] - fn deserialize_f64(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::F64, visitor) - } - - #[inline] - fn deserialize_char(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Char, visitor) - } - - #[inline] - fn deserialize_str(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Str, visitor) - } - - #[inline] - fn deserialize_string(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::String, visitor) - } - - #[inline] - fn deserialize_bytes(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Bytes, visitor) - } - - #[inline] - fn deserialize_byte_buf(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::ByteBuf, visitor) - } - - #[inline] - fn deserialize_option(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Option, visitor) - } - - #[inline] - fn deserialize_unit(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Unit, visitor) - } - - #[inline] - fn deserialize_unit_struct( - self, - _name: &'static str, - visitor: V, - ) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::UnitStruct, visitor) - } - - #[inline] - fn deserialize_newtype_struct( - self, - _name: &'static str, - visitor: V, - ) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::NewtypeStruct, visitor) - } - - #[inline] - fn deserialize_seq(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Seq, visitor) - } - - #[inline] - fn deserialize_tuple( - self, - _len: usize, - visitor: V, - ) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Tuple, visitor) - } - - #[inline] - fn deserialize_tuple_struct( - self, - _name: &'static str, - _len: usize, - visitor: V, - ) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::TupleStruct, visitor) - } - - #[inline] - fn deserialize_map(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Map, visitor) - } - - #[inline] - fn deserialize_struct( - self, - name: &'static str, - fields: &'static [&'static str], - visitor: V, - ) -> std::result::Result - where - V: Visitor<'data>, - { - struct StructAccess<'de, 'cursor, 'data> { - deserializer: &'de mut RowBinaryWithNamesAndTypesDeserializer<'cursor, 'data>, - len: usize, - } - - impl<'data> SeqAccess<'data> for StructAccess<'_, '_, 'data> { - type Error = Error; - - fn next_element_seed(&mut self, seed: T) -> Result> - where - T: DeserializeSeed<'data>, - { - if self.len > 0 { - self.len -= 1; - let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; - Ok(Some(value)) - } else { - Ok(None) - } - } - - fn size_hint(&self) -> Option { - Some(self.len) - } - } - - self.set_struct_name(name); - self.set_struct_fields(fields); - visitor.visit_seq(StructAccess { - deserializer: self, - len: fields.len(), - }) - } - - #[inline] - fn deserialize_enum( - self, - _name: &'static str, - _variants: &'static [&'static str], - visitor: V, - ) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Enum, visitor) - } - - #[inline] - fn deserialize_identifier(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::Identifier, visitor) - } - - #[inline] - fn deserialize_ignored_any(self, visitor: V) -> std::result::Result - where - V: Visitor<'data>, - { - rbwnat_deserialize_any!(self, SerdeType::IgnoredAny, visitor) - } - - #[inline] - fn is_human_readable(&self) -> bool { - false - } -} diff --git a/src/rowbinary/mod.rs b/src/rowbinary/mod.rs index 67540808..9107c391 100644 --- a/src/rowbinary/mod.rs +++ b/src/rowbinary/mod.rs @@ -1,91 +1,11 @@ pub(crate) use de::deserialize_from; pub(crate) use de::deserialize_from_and_validate; pub(crate) use ser::serialize_into; -use std::fmt::Display; +pub(crate) use validation::SerdeType; mod de; -mod de_rbwnat; mod ser; #[cfg(test)] mod tests; mod utils; - -/// Which Serde data type (De)serializer used for the given type. -/// Displays into Rust types for convenience in errors reporting. -#[derive(Clone, Debug, PartialEq)] -pub(crate) enum SerdeType { - Bool, - I8, - I16, - I32, - I64, - I128, - U8, - U16, - U32, - U64, - U128, - F32, - F64, - Char, - Str, - String, - Bytes, - ByteBuf, - Option, - Unit, - UnitStruct, - NewtypeStruct, - Seq, - Tuple, - TupleStruct, - Map, - Struct, - Enum, - Identifier, - IgnoredAny, -} - -impl Default for SerdeType { - fn default() -> Self { - SerdeType::Struct - } -} - -impl Display for SerdeType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let type_name = match self { - SerdeType::Bool => "bool", - SerdeType::I8 => "i8", - SerdeType::I16 => "i16", - SerdeType::I32 => "i32", - SerdeType::I64 => "i64", - SerdeType::I128 => "i128", - SerdeType::U8 => "u8", - SerdeType::U16 => "u16", - SerdeType::U32 => "u32", - SerdeType::U64 => "u64", - SerdeType::U128 => "u128", - SerdeType::F32 => "f32", - SerdeType::F64 => "f64", - SerdeType::Char => "char", - SerdeType::Str => "&str", - SerdeType::String => "String", - SerdeType::Bytes => "&[u8]", - SerdeType::ByteBuf => "Vec", - SerdeType::Option => "Option", - SerdeType::Unit => "()", - SerdeType::UnitStruct => "unit struct", - SerdeType::NewtypeStruct => "newtype struct", - SerdeType::Seq => "Vec", - SerdeType::Tuple => "tuple", - SerdeType::TupleStruct => "tuple struct", - SerdeType::Map => "map", - SerdeType::Struct => "struct", - SerdeType::Enum => "enum", - SerdeType::Identifier => "identifier", - SerdeType::IgnoredAny => "ignored any", - }; - write!(f, "{}", type_name) - } -} +mod validation; diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs new file mode 100644 index 00000000..c718ab2b --- /dev/null +++ b/src/rowbinary/validation.rs @@ -0,0 +1,380 @@ +use crate::error::{Error, Result}; +use clickhouse_rowbinary::data_types::{Column, DataTypeHint, DataTypeNode, DecimalSize, EnumType}; +use std::collections::HashMap; +use std::fmt::Display; + +pub(crate) trait ValidateDataType: Sized { + fn validate( + &mut self, + serde_type: &'static SerdeType, + compatible_db_types: &'static [DataTypeHint], + ) -> Result>>; + fn set_struct_name(&mut self, name: String) -> (); +} + +pub(crate) struct DataTypeValidator<'cursor> { + columns: &'cursor [Column], + struct_name: Option, +} + +impl<'cursor> DataTypeValidator<'cursor> { + pub(crate) fn new(columns: &'cursor [Column]) -> Self { + Self { + columns, + struct_name: None, + } + } +} + +impl<'cursor> Default for DataTypeValidator<'cursor> { + fn default() -> Self { + Self { + columns: &[], + struct_name: None, + } + } +} + +pub(crate) enum MapValidatorState { + Key, + Value, + Validated, +} + +pub(crate) enum ArrayValidatorState { + Pending, + Validated, +} + +pub(crate) enum InnerDataTypeValidator<'cursor> { + Array(&'cursor DataTypeNode, ArrayValidatorState), + Map( + &'cursor DataTypeNode, + &'cursor DataTypeNode, + MapValidatorState, + ), + Tuple(&'cursor [DataTypeNode]), + Enum(&'cursor HashMap), + Variant(&'cursor [DataTypeNode]), + Nullable(&'cursor DataTypeNode), +} + +impl ValidateDataType for () { + #[inline] + fn validate( + &mut self, + _serde_type: &'static SerdeType, + _compatible_db_types: &'static [DataTypeHint], + ) -> Result>> { + Ok(None) + } + #[inline] + fn set_struct_name(&mut self, _name: String) { + () + } +} + +impl<'cursor> ValidateDataType for Option> { + fn validate( + &mut self, + serde_type: &'static SerdeType, + compatible_db_types: &'static [DataTypeHint], + ) -> Result>> { + match self { + None => Ok(None), + Some(InnerDataTypeValidator::Map(key_type, value_type, state)) => match state { + MapValidatorState::Key => { + let result = validate_impl(key_type, serde_type, compatible_db_types); + *state = MapValidatorState::Value; + result + } + MapValidatorState::Value => { + let result = validate_impl(value_type, serde_type, compatible_db_types); + *state = MapValidatorState::Validated; + result + } + MapValidatorState::Validated => Ok(None), + }, + Some(InnerDataTypeValidator::Array(inner_type, state)) => match state { + ArrayValidatorState::Pending => { + println!( + "ArrayValidatorState::Pending; serde_type: {}; compatible_db_types: {:?}", + serde_type, compatible_db_types, + ); + let result = validate_impl(inner_type, serde_type, compatible_db_types); + *state = ArrayValidatorState::Validated; + result + } + // TODO: perhaps we can allow to validate the inner type more than once + // avoiding e.g. issues with Array(Nullable(T)) when the first element in NULL + ArrayValidatorState::Validated => Ok(None), + }, + Some(InnerDataTypeValidator::Nullable(inner_type)) => { + validate_impl(inner_type, serde_type, compatible_db_types) + } + Some(InnerDataTypeValidator::Tuple(elements_types)) => { + match elements_types.split_first() { + None => Ok(None), + Some((first, rest)) => { + let result = validate_impl(first, serde_type, compatible_db_types); + *elements_types = rest; + result + } + } + } + Some(InnerDataTypeValidator::Variant(_possible_types)) => { + Ok(None) // TODO - check type index in the parsed types vec + } + Some(InnerDataTypeValidator::Enum(_values_map)) => { + Ok(None) // TODO - check value correctness in the hashmap + } + } + } + + #[inline] + fn set_struct_name(&mut self, _name: String) { + unreachable!("it should never be called for inner validators") + } +} + +#[inline] +fn validate_impl<'cursor>( + data_type: &'cursor DataTypeNode, + serde_type: &'static SerdeType, + compatible_db_types: &'static [DataTypeHint], +) -> Result>> { + println!( + "validate_impl call from Serde {}; compatible types: {:?}, db type: {:?}", + serde_type, compatible_db_types, data_type, + ); + // FIXME: multiple branches with similar patterns + match data_type { + DataTypeNode::Bool if compatible_db_types.contains(&DataTypeHint::Bool) => Ok(None), + + DataTypeNode::Int8 if compatible_db_types.contains(&DataTypeHint::Int8) => Ok(None), + DataTypeNode::Int16 if compatible_db_types.contains(&DataTypeHint::Int16) => Ok(None), + DataTypeNode::Int32 + | DataTypeNode::Date32 + | DataTypeNode::Decimal(_, _, DecimalSize::Int32) + if compatible_db_types.contains(&DataTypeHint::Int32) => + { + Ok(None) + } + DataTypeNode::Int64 + | DataTypeNode::DateTime64(_, _) + | DataTypeNode::Decimal(_, _, DecimalSize::Int64) + if compatible_db_types.contains(&DataTypeHint::Int64) => + { + Ok(None) + } + DataTypeNode::Int128 | DataTypeNode::Decimal(_, _, DecimalSize::Int128) + if compatible_db_types.contains(&DataTypeHint::Int128) => + { + Ok(None) + } + + DataTypeNode::UInt8 if compatible_db_types.contains(&DataTypeHint::UInt8) => Ok(None), + DataTypeNode::UInt16 | DataTypeNode::Date + if compatible_db_types.contains(&DataTypeHint::UInt16) => + { + Ok(None) + } + DataTypeNode::UInt32 | DataTypeNode::DateTime(_) | DataTypeNode::IPv4 + if compatible_db_types.contains(&DataTypeHint::UInt32) => + { + Ok(None) + } + DataTypeNode::UInt64 if compatible_db_types.contains(&DataTypeHint::UInt64) => Ok(None), + DataTypeNode::UInt128 if compatible_db_types.contains(&DataTypeHint::UInt128) => Ok(None), + + DataTypeNode::Float32 if compatible_db_types.contains(&DataTypeHint::Float32) => Ok(None), + DataTypeNode::Float64 if compatible_db_types.contains(&DataTypeHint::Float64) => Ok(None), + + // Currently, we allow new JSON type only with `output_format_binary_write_json_as_string` + DataTypeNode::String | DataTypeNode::JSON + if compatible_db_types.contains(&DataTypeHint::String) => + { + Ok(None) + } + + DataTypeNode::FixedString(n) + if compatible_db_types.contains(&DataTypeHint::FixedString(*n)) => + { + Ok(None) + } + + // Deserialized as a sequence of 16 bytes + DataTypeNode::IPv6 if compatible_db_types.contains(&DataTypeHint::Array) => Ok(Some( + InnerDataTypeValidator::Array(&DataTypeNode::UInt8, ArrayValidatorState::Pending), + )), + + DataTypeNode::UUID => todo!(), + + DataTypeNode::Array(inner_type) if compatible_db_types.contains(&DataTypeHint::Array) => { + Ok(Some(InnerDataTypeValidator::Array( + inner_type, + ArrayValidatorState::Pending, + ))) + } + + DataTypeNode::Map(key_type, value_type) + if compatible_db_types.contains(&DataTypeHint::Map) => + { + Ok(Some(InnerDataTypeValidator::Map( + key_type, + value_type, + MapValidatorState::Key, + ))) + } + + DataTypeNode::Tuple(elements) if compatible_db_types.contains(&DataTypeHint::Tuple) => { + Ok(Some(InnerDataTypeValidator::Tuple(elements))) + } + + DataTypeNode::Nullable(inner_type) + if compatible_db_types.contains(&DataTypeHint::Nullable) => + { + Ok(Some(InnerDataTypeValidator::Nullable(inner_type))) + } + + // LowCardinality is completely transparent on the client side + DataTypeNode::LowCardinality(inner_type) => { + validate_impl(inner_type, serde_type, compatible_db_types) + } + + DataTypeNode::Enum(EnumType::Enum8, values_map) + if compatible_db_types.contains(&DataTypeHint::Int8) => + { + Ok(Some(InnerDataTypeValidator::Enum(values_map))) + } + DataTypeNode::Enum(EnumType::Enum16, values_map) + if compatible_db_types.contains(&DataTypeHint::Int16) => + { + Ok(Some(InnerDataTypeValidator::Enum(values_map))) + } + + DataTypeNode::Variant(possible_types) => { + Ok(Some(InnerDataTypeValidator::Variant(possible_types))) + } + + DataTypeNode::AggregateFunction(_, _) => panic!("AggregateFunction is not supported yet"), + DataTypeNode::Int256 => panic!("Int256 is not supported yet"), + DataTypeNode::UInt256 => panic!("UInt256 is not supported yet"), + DataTypeNode::BFloat16 => panic!("BFloat16 is not supported yet"), + DataTypeNode::Dynamic => panic!("Dynamic is not supported yet"), + + _ => Err(Error::InvalidColumnDataType( + data_type.clone(), + serde_type, + compatible_db_types, + )), + } +} + +impl<'cursor> ValidateDataType for DataTypeValidator<'cursor> { + #[inline] + fn validate( + &mut self, + serde_type: &'static SerdeType, + compatible_db_types: &'static [DataTypeHint], + ) -> Result>> { + println!( + "validate call {}; compatible: {:?}, db types: {:?}", + serde_type, compatible_db_types, self.columns, + ); + match self.columns.split_first() { + None => Err(Error::TooManyStructFields), + Some((first, rest)) => { + self.columns = rest; + validate_impl(&first.data_type, serde_type, compatible_db_types) + } + } + } + + // FIXME: remove copy of a String and use &str instead; but lifetimes are tricky here + #[inline] + fn set_struct_name(&mut self, name: String) { + self.struct_name = Some(name); + } +} + +/// Which Serde data type (De)serializer used for the given type. +/// Displays into Rust types for convenience in errors reporting. +#[derive(Clone, Debug, PartialEq)] +#[non_exhaustive] +pub enum SerdeType { + Bool, + I8, + I16, + I32, + I64, + I128, + U8, + U16, + U32, + U64, + U128, + F32, + F64, + Char, + Str, + String, + Bytes, + ByteBuf, + Option, + Unit, + UnitStruct, + NewtypeStruct, + Seq, + Tuple, + TupleStruct, + Map, + Struct, + Enum, + Identifier, + IgnoredAny, +} + +impl Default for SerdeType { + fn default() -> Self { + SerdeType::Struct + } +} + +impl Display for SerdeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let type_name = match self { + SerdeType::Bool => "bool", + SerdeType::I8 => "i8", + SerdeType::I16 => "i16", + SerdeType::I32 => "i32", + SerdeType::I64 => "i64", + SerdeType::I128 => "i128", + SerdeType::U8 => "u8", + SerdeType::U16 => "u16", + SerdeType::U32 => "u32", + SerdeType::U64 => "u64", + SerdeType::U128 => "u128", + SerdeType::F32 => "f32", + SerdeType::F64 => "f64", + SerdeType::Char => "char", + SerdeType::Str => "&str", + SerdeType::String => "String", + SerdeType::Bytes => "&[u8]", + SerdeType::ByteBuf => "Vec", + SerdeType::Option => "Option", + SerdeType::Unit => "()", + SerdeType::UnitStruct => "unit struct", + SerdeType::NewtypeStruct => "newtype struct", + SerdeType::Seq => "Vec", + SerdeType::Tuple => "tuple", + SerdeType::TupleStruct => "tuple struct", + SerdeType::Map => "map", + SerdeType::Struct => "struct", + SerdeType::Enum => "enum", + SerdeType::Identifier => "identifier", + SerdeType::IgnoredAny => "ignored any", + }; + write!(f, "{}", type_name) + } +} diff --git a/tests/it/main.rs b/tests/it/main.rs index 93ebbff2..b868e988 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -65,7 +65,7 @@ mod ip; mod mock; mod nested; mod query; -mod rbwnat_smoke; +mod rbwnat; mod time; mod user_agent; mod uuid; diff --git a/tests/it/rbwnat_smoke.rs b/tests/it/rbwnat.rs similarity index 78% rename from tests/it/rbwnat_smoke.rs rename to tests/it/rbwnat.rs index ace27da8..a41059fb 100644 --- a/tests/it/rbwnat_smoke.rs +++ b/tests/it/rbwnat.rs @@ -1,9 +1,12 @@ use clickhouse::error::Error; +use clickhouse::sql::Identifier; use clickhouse::validation_mode::StructValidationMode; use clickhouse_derive::Row; +use clickhouse_rowbinary::data_types::{Column, DataTypeNode}; use clickhouse_rowbinary::parse_rbwnat_columns_header; -use clickhouse_rowbinary::types::{Column, DataTypeHint, DataTypeNode}; use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use std::collections::HashMap; use std::str::FromStr; use time::format_description::well_known::Iso8601; use time::Month::{February, January}; @@ -52,69 +55,57 @@ async fn test_header_parsing() { Column { name: "CounterID".to_string(), data_type: DataTypeNode::UInt32, - type_hints: vec![DataTypeHint::UInt32] }, Column { name: "StartDate".to_string(), data_type: DataTypeNode::Date, - type_hints: vec![DataTypeHint::Date] }, Column { name: "Sign".to_string(), data_type: DataTypeNode::Int8, - type_hints: vec![DataTypeHint::Int8] }, Column { name: "IsNew".to_string(), data_type: DataTypeNode::UInt8, - type_hints: vec![DataTypeHint::UInt8] }, Column { name: "VisitID".to_string(), data_type: DataTypeNode::UInt64, - type_hints: vec![DataTypeHint::UInt64] }, Column { name: "UserID".to_string(), data_type: DataTypeNode::UInt64, - type_hints: vec![DataTypeHint::UInt64] }, Column { name: "Goals.ID".to_string(), data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)), - type_hints: vec![DataTypeHint::Array, DataTypeHint::UInt32] }, Column { name: "Goals.Serial".to_string(), data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)), - type_hints: vec![DataTypeHint::Array, DataTypeHint::UInt32] }, Column { name: "Goals.EventTime".to_string(), data_type: DataTypeNode::Array(Box::new(DataTypeNode::DateTime(None))), - type_hints: vec![DataTypeHint::Array, DataTypeHint::DateTime] }, Column { name: "Goals.Price".to_string(), data_type: DataTypeNode::Array(Box::new(DataTypeNode::Int64)), - type_hints: vec![DataTypeHint::Array, DataTypeHint::Int64] }, Column { name: "Goals.OrderID".to_string(), data_type: DataTypeNode::Array(Box::new(DataTypeNode::String)), - type_hints: vec![DataTypeHint::Array, DataTypeHint::String] }, Column { name: "Goals.CurrencyID".to_string(), data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)), - type_hints: vec![DataTypeHint::Array, DataTypeHint::UInt32] } ] ); } #[tokio::test] -async fn test_basic_types_deserialization() { +async fn test_basic_types() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { uint8_val: u8, @@ -229,7 +220,7 @@ async fn test_many_numbers() { } #[tokio::test] -async fn test_array_deserialization() { +async fn test_arrays() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { id: u16, @@ -271,20 +262,21 @@ async fn test_array_deserialization() { } #[tokio::test] -async fn test_multi_dimensional_array_deserialization() { +async fn test_maps() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { - three_dim_array: Vec>>, - id: u16, + map1: HashMap, + map2: HashMap>, } let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); let result = client .query( " - SELECT - [[[1.1, 2.2], [3.3, 4.4]], [], [[5.5, 6.6], [7.7, 8.8]]] :: Array(Array(Array(Float64))) AS three_dim_array, - 42 :: UInt16 AS id + SELECT + map('key1', 'value1', 'key2', 'value2') :: Map(String, String) AS m1, + map(42, map('foo', 100, 'bar', 200), + 144, map('qaz', 300, 'qux', 400)) :: Map(UInt16, Map(String, Int32)) AS m2 ", ) .fetch_one::() @@ -293,18 +285,119 @@ async fn test_multi_dimensional_array_deserialization() { assert_eq!( result.unwrap(), Data { - id: 42, - three_dim_array: vec![ - vec![vec![1.1, 2.2], vec![3.3, 4.4]], - vec![], - vec![vec![5.5, 6.6], vec![7.7, 8.8]] - ], + map1: vec![ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ] + .into_iter() + .collect(), + map2: vec![ + ( + 42, + vec![("foo".to_string(), 100), ("bar".to_string(), 200)] + .into_iter() + .collect() + ), + ( + 144, + vec![("qaz".to_string(), 300), ("qux".to_string(), 400)] + .into_iter() + .collect() + ) + ] + .into_iter() + .collect::>>(), } ); } +#[tokio::test] +async fn test_enum() { + #[derive(Debug, PartialEq, Serialize_repr, Deserialize_repr)] + #[repr(i8)] + enum MyEnum8 { + Winter = -128, + Spring = 0, + Summer = 100, + Autumn = 127, + } + + #[derive(Debug, PartialEq, Serialize_repr, Deserialize_repr)] + #[repr(i16)] + enum MyEnum16 { + North = -32768, + East = 0, + South = 144, + West = 32767, + } + + #[derive(Debug, PartialEq, Row, Serialize, Deserialize)] + struct Data { + id: u16, + enum8: MyEnum8, + enum16: MyEnum16, + } + + let table_name = "test_rbwnat_enum"; + + let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + client + .query( + " + CREATE OR REPLACE TABLE ? + ( + id UInt16, + enum8 Enum8 ('Winter' = -128, 'Spring' = 0, 'Summer' = 100, 'Autumn' = 127), + enum16 Enum16('North' = -32768, 'East' = 0, 'South' = 144, 'West' = 32767) + ) ENGINE MergeTree ORDER BY id + ", + ) + .bind(Identifier(table_name)) + .execute() + .await + .unwrap(); + + let expected = vec![ + Data { + id: 1, + + enum8: MyEnum8::Spring, + enum16: MyEnum16::East, + }, + Data { + id: 2, + enum8: MyEnum8::Autumn, + enum16: MyEnum16::North, + }, + Data { + id: 3, + enum8: MyEnum8::Winter, + enum16: MyEnum16::South, + }, + Data { + id: 4, + enum8: MyEnum8::Summer, + enum16: MyEnum16::West, + }, + ]; + + let mut insert = client.insert(table_name).unwrap(); + for row in &expected { + insert.write(row).await.unwrap() + } + insert.end().await.unwrap(); + + let result = client + .query("SELECT * FROM ? ORDER BY id ASC") + .bind(Identifier(table_name)) + .fetch_all::() + .await + .unwrap(); + + assert_eq!(result, expected); +} #[tokio::test] -async fn test_default_types_validation_nullable() { +async fn test_nullable() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { n: Option, @@ -333,7 +426,7 @@ async fn test_default_types_validation_nullable() { #[tokio::test] #[cfg(feature = "time")] -async fn test_default_types_validation_custom_serde() { +async fn test_serde_with() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { #[serde(with = "clickhouse::serde::time::datetime64::millis")] @@ -377,10 +470,12 @@ async fn test_too_many_struct_fields() { .await; assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - Error::DeserializeCallAfterEndOfStruct { .. } - )); + let err = result.unwrap_err(); + assert!( + matches!(err, Error::TooManyStructFields { .. }), + "{:?} should be an instance of TooManyStructFields", + err + ); } #[tokio::test] @@ -411,7 +506,7 @@ async fn test_serde_skip_deserializing() { #[tokio::test] #[cfg(feature = "time")] -async fn test_date_time_types() { +async fn test_date_and_time() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { #[serde(with = "clickhouse::serde::time::date")] @@ -472,9 +567,9 @@ async fn test_date_time_types() { async fn test_ipv4_ipv6() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { - id: u16, - #[serde(with = "clickhouse::serde::ipv4")] - ipv4: std::net::Ipv4Addr, + // id: u16, + // #[serde(with = "clickhouse::serde::ipv4")] + // ipv4: std::net::Ipv4Addr, ipv6: std::net::Ipv6Addr, } @@ -483,8 +578,8 @@ async fn test_ipv4_ipv6() { .query( " SELECT - 42 :: UInt16 AS id, - '192.168.0.1' :: IPv4 AS ipv4, + -- 42 :: UInt16 AS id, + -- '192.168.0.1' :: IPv4 AS ipv4, '2001:db8:3333:4444:5555:6666:7777:8888' :: IPv6 AS ipv6 ", ) @@ -494,14 +589,15 @@ async fn test_ipv4_ipv6() { assert_eq!( result.unwrap(), vec![Data { - id: 42, - ipv4: std::net::Ipv4Addr::new(192, 168, 0, 1), + // id: 42, + // ipv4: std::net::Ipv4Addr::new(192, 168, 0, 1), ipv6: std::net::Ipv6Addr::from_str("2001:db8:3333:4444:5555:6666:7777:8888").unwrap(), }] ) } // FIXME: RBWNAT should allow for tracking the order of fields in the struct and in the database! +// it is possible to use HashMap to deserialize the struct instead of Tuple visitor #[tokio::test] #[ignore] async fn test_different_struct_field_order() { From c20af7741a93a69d85cc6068b59de2e0be5a8bc4 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Wed, 21 May 2025 20:15:07 +0200 Subject: [PATCH 07/54] RBWNAT deserializer - validation, benches WIP --- benches/select_numbers.rs | 24 +++- docker-compose.yml | 3 +- rowbinary/src/data_types.rs | 6 +- src/cursors/row.rs | 23 ++-- src/error.rs | 14 -- src/lib.rs | 28 +++- src/query.rs | 14 +- src/rowbinary/de.rs | 113 ++++++++-------- src/rowbinary/mod.rs | 1 - src/rowbinary/validation.rs | 121 +++++++++-------- src/validation_mode.rs | 46 +++++-- tests/it/rbwnat.rs | 257 ++++++++++++++++++++++++++++-------- 12 files changed, 422 insertions(+), 228 deletions(-) diff --git a/benches/select_numbers.rs b/benches/select_numbers.rs index 869d6ba5..45d0b769 100644 --- a/benches/select_numbers.rs +++ b/benches/select_numbers.rs @@ -1,5 +1,6 @@ use serde::Deserialize; +use clickhouse::validation_mode::ValidationMode; use clickhouse::{Client, Compression, Row}; #[derive(Row, Deserialize)] @@ -7,18 +8,21 @@ struct Data { no: u64, } -async fn bench(name: &str, compression: Compression) { +async fn bench(name: &str, compression: Compression, validation_mode: ValidationMode) { let start = std::time::Instant::now(); - let (sum, dec_mbytes, rec_mbytes) = tokio::spawn(do_bench(compression)).await.unwrap(); + let (sum, dec_mbytes, rec_mbytes) = tokio::spawn(do_bench(compression, validation_mode)) + .await + .unwrap(); assert_eq!(sum, 124999999750000000); let elapsed = start.elapsed(); let throughput = dec_mbytes / elapsed.as_secs_f64(); - println!("{name:>8} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB"); + println!("{name:>8} {validation_mode:>10} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB"); } -async fn do_bench(compression: Compression) -> (u64, f64, f64) { +async fn do_bench(compression: Compression, validation_mode: ValidationMode) -> (u64, f64, f64) { let client = Client::default() .with_compression(compression) + .with_validation_mode(validation_mode) .with_url("http://localhost:8123"); let mut cursor = client @@ -40,8 +44,14 @@ async fn do_bench(compression: Compression) -> (u64, f64, f64) { #[tokio::main] async fn main() { - println!("compress elapsed throughput received"); - bench("none", Compression::None).await; + println!("compress validation elapsed throughput received"); + bench("none", Compression::None, ValidationMode::Disabled).await; + bench("none", Compression::None, ValidationMode::First(1)).await; + bench("none", Compression::None, ValidationMode::Each).await; #[cfg(feature = "lz4")] - bench("lz4", Compression::Lz4).await; + { + bench("lz4", Compression::Lz4, ValidationMode::Disabled).await; + bench("lz4", Compression::Lz4, ValidationMode::First(1)).await; + bench("lz4", Compression::Lz4, ValidationMode::Each).await; + } } diff --git a/docker-compose.yml b/docker-compose.yml index bfa26365..d3b99f0f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,8 @@ +name: clickhouse-rs services: clickhouse: image: 'clickhouse/clickhouse-server:${CLICKHOUSE_VERSION-24.10-alpine}' - container_name: 'clickhouse-rs-clickhouse-server' + container_name: clickhouse-rs-clickhouse-server ports: - '8123:8123' - '9000:9000' diff --git a/rowbinary/src/data_types.rs b/rowbinary/src/data_types.rs index 8fa204f0..6032d35a 100644 --- a/rowbinary/src/data_types.rs +++ b/rowbinary/src/data_types.rs @@ -99,7 +99,7 @@ pub enum DataTypeHint { Decimal(DecimalSize), String, - FixedString(usize), + FixedString, UUID, Date, @@ -147,7 +147,7 @@ impl Display for DataTypeHint { DataTypeHint::BFloat16 => write!(f, "BFloat16"), DataTypeHint::Decimal(size) => write!(f, "Decimal{}", size), DataTypeHint::String => write!(f, "String"), - DataTypeHint::FixedString(size) => write!(f, "FixedString({})", size), + DataTypeHint::FixedString => write!(f, "FixedString"), DataTypeHint::UUID => write!(f, "UUID"), DataTypeHint::Date => write!(f, "Date"), DataTypeHint::Date32 => write!(f, "Date32"), @@ -279,7 +279,7 @@ impl DataTypeNode { hints.push(DataTypeHint::Decimal(size.clone())); } DataTypeNode::String => hints.push(DataTypeHint::String), - DataTypeNode::FixedString(size) => hints.push(DataTypeHint::FixedString(*size)), + DataTypeNode::FixedString(_) => hints.push(DataTypeHint::FixedString), DataTypeNode::UUID => hints.push(DataTypeHint::UUID), DataTypeNode::Date => hints.push(DataTypeHint::Date), DataTypeNode::Date32 => hints.push(DataTypeHint::Date32), diff --git a/src/cursors/row.rs b/src/cursors/row.rs index e0da7cb9..4187cc50 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -1,4 +1,4 @@ -use crate::validation_mode::StructValidationMode; +use crate::validation_mode::ValidationMode; use crate::{ bytes_ext::BytesExt, cursors::RawCursor, @@ -16,21 +16,21 @@ use std::marker::PhantomData; pub struct RowCursor { raw: RawCursor, bytes: BytesExt, - validation_mode: StructValidationMode, + validation_mode: ValidationMode, columns: Option>, rows_emitted: u64, _marker: PhantomData, } impl RowCursor { - pub(crate) fn new(response: Response, format: StructValidationMode) -> Self { + pub(crate) fn new(response: Response, validation_mode: ValidationMode) -> Self { Self { _marker: PhantomData, raw: RawCursor::new(response), bytes: BytesExt::default(), columns: None, rows_emitted: 0, - validation_mode: format, + validation_mode, } } @@ -46,9 +46,9 @@ impl RowCursor { T: Deserialize<'b>, { let should_validate = match self.validation_mode { - StructValidationMode::Disabled => false, - StructValidationMode::EachRow => true, - StructValidationMode::FirstRow => self.rows_emitted == 0, + ValidationMode::Disabled => false, + ValidationMode::Each => true, + ValidationMode::First(n) => self.rows_emitted < (n as u64), }; loop { @@ -62,10 +62,15 @@ impl RowCursor { self.bytes.set_remaining(slice.len()); self.columns = Some(columns); let columns = self.columns.as_ref().unwrap(); - rowbinary::deserialize_from_and_validate(&mut slice, columns) + // usually, the header arrives as a separate first chunk + if self.bytes.remaining() > 0 { + rowbinary::deserialize_from_and_validate(&mut slice, columns) + } else { + Err(Error::NotEnoughData) + } } Some(columns) => { - rowbinary::deserialize_from_and_validate(&mut slice, &columns) + rowbinary::deserialize_from_and_validate(&mut slice, columns) } } } else { diff --git a/src/error.rs b/src/error.rs index cc1ab550..0377ea67 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,9 +1,6 @@ //! Contains [`Error`] and corresponding [`Result`]. -use crate::rowbinary::SerdeType; -use clickhouse_rowbinary::data_types::{DataTypeHint, DataTypeNode}; use serde::{de, ser}; -use std::fmt::Display; use std::{error::Error as StdError, fmt, io, result, str::Utf8Error}; /// A result with a specified [`Error`] type. @@ -55,21 +52,10 @@ pub enum Error { unexpected_type: String, all_columns: String, }, - #[error("deserializing field: {0}; serde type: {1} expected to be deserialized as: {}", join_seq(.2))] - InvalidColumnDataType(DataTypeNode, &'static SerdeType, &'static [DataTypeHint]), - #[error("too many struct fields: trying to read more columns than expected")] - TooManyStructFields, #[error("{0}")] Other(BoxedError), } -fn join_seq(vec: &[T]) -> String { - vec.iter() - .map(|x| x.to_string()) - .collect::>() - .join(", ") -} - assert_impl_all!(Error: StdError, Send, Sync); impl From for Error { diff --git a/src/lib.rs b/src/lib.rs index dad24fc0..f842d12a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,11 +5,10 @@ #[macro_use] extern crate static_assertions; -use self::{error::Result, http_client::HttpClient}; +use self::{error::Result, http_client::HttpClient, validation_mode::ValidationMode}; use std::{collections::HashMap, fmt::Display, sync::Arc}; pub use self::{compression::Compression, row::Row}; -use crate::validation_mode::StructValidationMode; pub use clickhouse_derive::Row; pub mod error; @@ -49,7 +48,7 @@ pub struct Client { options: HashMap, headers: HashMap, products_info: Vec, - struct_validation_mode: StructValidationMode, + validation_mode: ValidationMode, } #[derive(Clone)] @@ -104,7 +103,7 @@ impl Client { options: HashMap::new(), headers: HashMap::new(), products_info: Vec::default(), - struct_validation_mode: StructValidationMode::default(), + validation_mode: ValidationMode::default(), } } @@ -298,8 +297,12 @@ impl Client { self } - pub fn with_struct_validation_mode(mut self, mode: StructValidationMode) -> Self { - self.struct_validation_mode = mode; + /// Specifies the struct validation mode that will be used when calling + /// [`Query::fetch`], [`Query::fetch_one`], [`Query::fetch_all`], + /// and [`Query::fetch_optional`] methods. + /// See [`ValidationMode`] for more details. + pub fn with_validation_mode(mut self, mode: ValidationMode) -> Self { + self.validation_mode = mode; self } @@ -350,6 +353,7 @@ pub mod _priv { #[cfg(test)] mod client_tests { + use crate::validation_mode::ValidationMode; use crate::{Authentication, Client}; #[test] @@ -467,4 +471,16 @@ mod client_tests { .with_access_token("my_jwt") .with_password("secret"); } + + #[test] + fn it_sets_validation_mode() { + let client = Client::default(); + assert_eq!(client.validation_mode, ValidationMode::First(1)); + let client = client.with_validation_mode(ValidationMode::Each); + assert_eq!(client.validation_mode, ValidationMode::Each); + let client = client.with_validation_mode(ValidationMode::Disabled); + assert_eq!(client.validation_mode, ValidationMode::Disabled); + let client = client.with_validation_mode(ValidationMode::First(10)); + assert_eq!(client.validation_mode, ValidationMode::First(10)); + } } diff --git a/src/query.rs b/src/query.rs index 8066c74f..b9ccebde 100644 --- a/src/query.rs +++ b/src/query.rs @@ -16,8 +16,8 @@ use crate::{ const MAX_QUERY_LEN_TO_USE_GET: usize = 8192; pub use crate::cursors::{BytesCursor, RowCursor}; -use crate::validation_mode::StructValidationMode; use crate::headers::with_authentication; +use crate::validation_mode::ValidationMode; #[must_use] #[derive(Clone)] @@ -85,18 +85,16 @@ impl Query { /// # Ok(()) } /// ``` pub fn fetch(mut self) -> Result> { - let fetch_format = self.client.struct_validation_mode.clone(); + let validation_mode = self.client.validation_mode; self.sql.bind_fields::(); - self.sql.set_output_format(match fetch_format { - StructValidationMode::FirstRow | StructValidationMode::EachRow => { - "RowBinaryWithNamesAndTypes" - } - StructValidationMode::Disabled => "RowBinary", + self.sql.set_output_format(match validation_mode { + ValidationMode::First(_) | ValidationMode::Each => "RowBinaryWithNamesAndTypes", + ValidationMode::Disabled => "RowBinary", }); let response = self.do_execute(true)?; - Ok(RowCursor::new(response, fetch_format)) + Ok(RowCursor::new(response, validation_mode)) } /// Executes the query and returns just a single row. diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 033676d9..52b246ed 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -17,7 +17,7 @@ use std::{convert::TryFrom, mem, str}; /// performant generated code than `(&[u8]) -> Result<(T, usize)>` and even /// `(&[u8], &mut Option) -> Result`. pub(crate) fn deserialize_from<'data, T: Deserialize<'data>>(input: &mut &'data [u8]) -> Result { - println!("deserialize_from call"); + // println!("deserialize_from call"); let mut deserializer = RowBinaryDeserializer { input, @@ -37,9 +37,7 @@ pub(crate) fn deserialize_from_and_validate<'data, 'cursor, T: Deserialize<'data input, validator: DataTypeValidator::new(columns), }; - T::deserialize(&mut deserializer).inspect_err(|e| { - println!("deserialize_from_and_validate call failed: {:?}", e); - }) + T::deserialize(&mut deserializer) } /// A deserializer for the RowBinary(WithNamesAndTypes) format. @@ -79,7 +77,7 @@ macro_rules! impl_num { ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr, $type_hints:expr) => { #[inline] fn $deser_method>(self, visitor: V) -> Result { - self.validator.validate($serde_type, $type_hints)?; + self.validator.validate($serde_type, $type_hints, 0)?; ensure_size(&mut self.input, mem::size_of::<$ty>())?; let value = self.input.$reader_method(); visitor.$visitor_method(value) @@ -194,34 +192,30 @@ where #[inline] fn deserialize_any>(self, _: V) -> Result { - println!("deserialize_any call"); - Err(Error::DeserializeAnyNotSupported) } #[inline] fn deserialize_unit>(self, visitor: V) -> Result { - println!("deserialize_unit call"); - // TODO: revise this. visitor.visit_unit() } #[inline] fn deserialize_char>(self, _: V) -> Result { - println!("deserialize_char call"); - panic!("character types are unsupported: `char`"); } #[inline] fn deserialize_bool>(self, visitor: V) -> Result { - println!("deserialize_bool call"); - self.validator.validate( &SerdeType::Bool, - // TODO: shall we allow deserialization from integers? - &[DataTypeHint::Bool, DataTypeHint::Int8, DataTypeHint::UInt8], + &[ + DataTypeHint::Bool, + // it is possible to deserialize from UInt8 0 or 1 as Boolean + DataTypeHint::UInt8, + ], + 0, )?; ensure_size(&mut self.input, 1)?; match self.input.get_u8() { @@ -233,11 +227,11 @@ where #[inline] fn deserialize_str>(self, visitor: V) -> Result { - println!("deserialize_str call"); + // println!("deserialize_str call"); // TODO - which types to allow? self.validator - .validate(&SerdeType::String, &[DataTypeHint::String])?; + .validate(&SerdeType::String, &[DataTypeHint::String], 0)?; let size = self.read_size()?; let slice = self.read_slice(size)?; let str = str::from_utf8(slice).map_err(Error::from)?; @@ -246,11 +240,11 @@ where #[inline] fn deserialize_string>(self, visitor: V) -> Result { - println!("deserialize_string call"); + // println!("deserialize_string call"); // TODO - which types to allow? self.validator - .validate(&SerdeType::String, &[DataTypeHint::String])?; + .validate(&SerdeType::String, &[DataTypeHint::String], 0)?; let size = self.read_size()?; let vec = self.read_vec(size)?; let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; @@ -259,7 +253,7 @@ where #[inline] fn deserialize_bytes>(self, visitor: V) -> Result { - println!("deserialize_bytes call"); + // println!("deserialize_bytes call"); // TODO - which types to allow? let size = self.read_size()?; @@ -269,7 +263,7 @@ where #[inline] fn deserialize_byte_buf>(self, visitor: V) -> Result { - println!("deserialize_byte_buf call"); + // println!("deserialize_byte_buf call"); // TODO - which types to allow? let size = self.read_size()?; @@ -278,7 +272,7 @@ where #[inline] fn deserialize_identifier>(self, visitor: V) -> Result { - println!("deserialize_identifier call"); + // println!("deserialize_identifier call"); // TODO - which types to allow? self.deserialize_u8(visitor) @@ -291,7 +285,7 @@ where _variants: &'static [&'static str], visitor: V, ) -> Result { - println!("deserialize_enum call"); + // println!("deserialize_enum call"); struct RowBinaryEnumAccess<'de, 'cursor, 'data, Validator> where @@ -365,13 +359,13 @@ where // FIXME self.validator - .validate(&SerdeType::Enum, &[DataTypeHint::Enum])?; + .validate(&SerdeType::Enum, &[DataTypeHint::Enum], 0)?; visitor.visit_enum(RowBinaryEnumAccess { deserializer: self }) } #[inline] fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { - println!("deserialize_tuple call, len {}", len); + // println!("deserialize_tuple call, len {}", len); struct RowBinaryTupleAccess<'de, 'cursor, 'data, Validator> where @@ -391,6 +385,7 @@ where where T: DeserializeSeed<'data>, { + // println!("Processing value, len: {}", self.len); if self.len > 0 { self.len -= 1; let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; @@ -405,27 +400,37 @@ where } } - let len = self.read_size()?; - let inner_data_type_validator = self - .validator - .validate(&SerdeType::Seq, &[DataTypeHint::Array, DataTypeHint::IPv6])?; - visitor.visit_seq(RowBinaryTupleAccess { - deserializer: &mut RowBinaryDeserializer { - input: self.input, - validator: inner_data_type_validator, - }, + let inner_data_type_validator = self.validator.validate( + &SerdeType::Tuple, + &[ + DataTypeHint::Tuple, + DataTypeHint::Array, + DataTypeHint::FixedString, + // FIXME: uncomment when there is a way to implement ReverseSeqAccess + // DataTypeHint::IPv4, + DataTypeHint::IPv6, + ], len, - }) + )?; + let mut new_self = RowBinaryDeserializer { + input: self.input, + validator: inner_data_type_validator, + }; + let access = RowBinaryTupleAccess { + deserializer: &mut new_self, + len, + }; + visitor.visit_seq(access) } #[inline] fn deserialize_option>(self, visitor: V) -> Result { - println!("deserialize_option call"); + // println!("deserialize_option call"); ensure_size(&mut self.input, 1)?; - let inner_data_type_validator = self - .validator - .validate(&SerdeType::Option, &[DataTypeHint::Nullable])?; + let inner_data_type_validator = + self.validator + .validate(&SerdeType::Option, &[DataTypeHint::Nullable], 0)?; match self.input.get_u8() { 0 => visitor.visit_some(&mut RowBinaryDeserializer { input: self.input, @@ -438,7 +443,7 @@ where #[inline] fn deserialize_seq>(self, visitor: V) -> Result { - println!("deserialize_seq call"); + // println!("deserialize_seq call"); struct RowBinarySeqAccess<'de, 'cursor, 'data, Validator> where @@ -473,9 +478,9 @@ where } let len = self.read_size()?; - let inner_data_type_validator = self - .validator - .validate(&SerdeType::Seq, &[DataTypeHint::Array])?; + let inner_data_type_validator = + self.validator + .validate(&SerdeType::Seq, &[DataTypeHint::Array], len)?; visitor.visit_seq(RowBinarySeqAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, @@ -487,7 +492,9 @@ where #[inline] fn deserialize_map>(self, visitor: V) -> Result { - println!("deserialize_map call"); + // println!( + // "deserialize_map call", + // ); struct RowBinaryMapAccess<'de, 'cursor, 'data, Validator> where @@ -528,9 +535,9 @@ where } let len = self.read_size()?; - let inner_data_type_validator = self - .validator - .validate(&SerdeType::Map, &[DataTypeHint::Map])?; + let inner_data_type_validator = + self.validator + .validate(&SerdeType::Map, &[DataTypeHint::Map], len)?; visitor.visit_map(RowBinaryMapAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, @@ -548,7 +555,7 @@ where fields: &'static [&'static str], visitor: V, ) -> Result { - println!("deserialize_struct call - {}", name); + // println!("deserialize_struct: {} (fields: {:?})", name, fields,); // FIXME use &'_ str, fix lifetimes self.validator.set_struct_name(name.to_string()); @@ -596,11 +603,9 @@ where #[inline] fn deserialize_newtype_struct>( self, - name: &str, + _name: &str, visitor: V, ) -> Result { - println!("deserialize_newtype_struct call - {}", name); - // TODO - skip validation? visitor.visit_newtype_struct(self) } @@ -611,8 +616,6 @@ where name: &'static str, _visitor: V, ) -> Result { - println!("deserialize_unit_struct call"); - panic!("unit types are unsupported: `{name}`"); } @@ -623,15 +626,11 @@ where _len: usize, _visitor: V, ) -> Result { - println!("deserialize_tuple_struct call"); - panic!("tuple struct types are unsupported: `{name}`"); } #[inline] fn deserialize_ignored_any>(self, _visitor: V) -> Result { - println!("deserialize_ignored_any call"); - panic!("ignored types are unsupported"); } diff --git a/src/rowbinary/mod.rs b/src/rowbinary/mod.rs index 9107c391..6b864023 100644 --- a/src/rowbinary/mod.rs +++ b/src/rowbinary/mod.rs @@ -1,7 +1,6 @@ pub(crate) use de::deserialize_from; pub(crate) use de::deserialize_from_and_validate; pub(crate) use ser::serialize_into; -pub(crate) use validation::SerdeType; mod de; mod ser; diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index c718ab2b..c7c96b81 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -1,4 +1,4 @@ -use crate::error::{Error, Result}; +use crate::error::Result; use clickhouse_rowbinary::data_types::{Column, DataTypeHint, DataTypeNode, DecimalSize, EnumType}; use std::collections::HashMap; use std::fmt::Display; @@ -8,10 +8,14 @@ pub(crate) trait ValidateDataType: Sized { &mut self, serde_type: &'static SerdeType, compatible_db_types: &'static [DataTypeHint], + // TODO: currently used only for FixedString validation. + // Is there a better way, avoiding passing it? + seq_len: usize, ) -> Result>>; - fn set_struct_name(&mut self, name: String) -> (); + fn set_struct_name(&mut self, name: String); } +#[derive(Default)] pub(crate) struct DataTypeValidator<'cursor> { columns: &'cursor [Column], struct_name: Option, @@ -26,15 +30,6 @@ impl<'cursor> DataTypeValidator<'cursor> { } } -impl<'cursor> Default for DataTypeValidator<'cursor> { - fn default() -> Self { - Self { - columns: &[], - struct_name: None, - } - } -} - pub(crate) enum MapValidatorState { Key, Value, @@ -65,31 +60,33 @@ impl ValidateDataType for () { &mut self, _serde_type: &'static SerdeType, _compatible_db_types: &'static [DataTypeHint], + _len: usize, ) -> Result>> { Ok(None) } #[inline] - fn set_struct_name(&mut self, _name: String) { - () - } + fn set_struct_name(&mut self, _name: String) {} } impl<'cursor> ValidateDataType for Option> { + #[inline] fn validate( &mut self, serde_type: &'static SerdeType, compatible_db_types: &'static [DataTypeHint], + seq_len: usize, ) -> Result>> { match self { None => Ok(None), Some(InnerDataTypeValidator::Map(key_type, value_type, state)) => match state { MapValidatorState::Key => { - let result = validate_impl(key_type, serde_type, compatible_db_types); + let result = validate_impl(key_type, serde_type, compatible_db_types, seq_len); *state = MapValidatorState::Value; result } MapValidatorState::Value => { - let result = validate_impl(value_type, serde_type, compatible_db_types); + let result = + validate_impl(value_type, serde_type, compatible_db_types, seq_len); *state = MapValidatorState::Validated; result } @@ -97,11 +94,8 @@ impl<'cursor> ValidateDataType for Option> { }, Some(InnerDataTypeValidator::Array(inner_type, state)) => match state { ArrayValidatorState::Pending => { - println!( - "ArrayValidatorState::Pending; serde_type: {}; compatible_db_types: {:?}", - serde_type, compatible_db_types, - ); - let result = validate_impl(inner_type, serde_type, compatible_db_types); + let result = + validate_impl(inner_type, serde_type, compatible_db_types, seq_len); *state = ArrayValidatorState::Validated; result } @@ -110,30 +104,45 @@ impl<'cursor> ValidateDataType for Option> { ArrayValidatorState::Validated => Ok(None), }, Some(InnerDataTypeValidator::Nullable(inner_type)) => { - validate_impl(inner_type, serde_type, compatible_db_types) + validate_impl(inner_type, serde_type, compatible_db_types, 0) } Some(InnerDataTypeValidator::Tuple(elements_types)) => { match elements_types.split_first() { - None => Ok(None), Some((first, rest)) => { - let result = validate_impl(first, serde_type, compatible_db_types); *elements_types = rest; - result + validate_impl(first, serde_type, compatible_db_types, 0) } + None => panic!( + "Struct tries to deserialize {} as a tuple element, but there are no more allowed elements in the database schema", + serde_type, + ) } } Some(InnerDataTypeValidator::Variant(_possible_types)) => { - Ok(None) // TODO - check type index in the parsed types vec + todo!() // TODO - check type index in the parsed types vec } Some(InnerDataTypeValidator::Enum(_values_map)) => { - Ok(None) // TODO - check value correctness in the hashmap + todo!() // TODO - check value correctness in the hashmap } } } #[inline] fn set_struct_name(&mut self, _name: String) { - unreachable!("it should never be called for inner validators") + unreachable!("`set_struct_name` should never be called for inner validators") + } +} + +impl Drop for InnerDataTypeValidator<'_> { + fn drop(&mut self) { + if let InnerDataTypeValidator::Tuple(elements_types) = self { + if !elements_types.is_empty() { + panic!( + "Tuple was not fully deserialized, remaining elements: {:?}", + elements_types + ); + } + } } } @@ -142,11 +151,12 @@ fn validate_impl<'cursor>( data_type: &'cursor DataTypeNode, serde_type: &'static SerdeType, compatible_db_types: &'static [DataTypeHint], + seq_len: usize, ) -> Result>> { - println!( - "validate_impl call from Serde {}; compatible types: {:?}, db type: {:?}", - serde_type, compatible_db_types, data_type, - ); + // println!( + // "Validating data type: {:?} against serde type: {} with compatible db types: {:?}", + // data_type, serde_type, compatible_db_types + // ); // FIXME: multiple branches with similar patterns match data_type { DataTypeNode::Bool if compatible_db_types.contains(&DataTypeHint::Bool) => Ok(None), @@ -198,17 +208,28 @@ fn validate_impl<'cursor>( } DataTypeNode::FixedString(n) - if compatible_db_types.contains(&DataTypeHint::FixedString(*n)) => + if compatible_db_types.contains(&DataTypeHint::FixedString) && *n == seq_len => { Ok(None) } - // Deserialized as a sequence of 16 bytes + // FIXME: IPv4 from ClickHouse ends up reversed. + // Ideally, requires a ReversedSeqAccess implementation. Perhaps memoize IPv4 col index? + // IPv4 = [u8; 4] + // DataTypeNode::IPv4 if compatible_db_types.contains(&DataTypeHint::IPv4) => Ok(Some( + // InnerDataTypeValidator::Array(&DataTypeNode::UInt8, ArrayValidatorState::Pending(4)), + // )), + + // IPv6 = [u8; 16] DataTypeNode::IPv6 if compatible_db_types.contains(&DataTypeHint::Array) => Ok(Some( InnerDataTypeValidator::Array(&DataTypeNode::UInt8, ArrayValidatorState::Pending), )), - DataTypeNode::UUID => todo!(), + // UUID = [u64; 2] + DataTypeNode::UUID => Ok(Some(InnerDataTypeValidator::Tuple(&[ + DataTypeNode::UInt64, + DataTypeNode::UInt64, + ]))), DataTypeNode::Array(inner_type) if compatible_db_types.contains(&DataTypeHint::Array) => { Ok(Some(InnerDataTypeValidator::Array( @@ -239,7 +260,7 @@ fn validate_impl<'cursor>( // LowCardinality is completely transparent on the client side DataTypeNode::LowCardinality(inner_type) => { - validate_impl(inner_type, serde_type, compatible_db_types) + validate_impl(inner_type, serde_type, compatible_db_types, seq_len) } DataTypeNode::Enum(EnumType::Enum8, values_map) @@ -263,11 +284,10 @@ fn validate_impl<'cursor>( DataTypeNode::BFloat16 => panic!("BFloat16 is not supported yet"), DataTypeNode::Dynamic => panic!("Dynamic is not supported yet"), - _ => Err(Error::InvalidColumnDataType( - data_type.clone(), - serde_type, - compatible_db_types, - )), + _ => panic!( + "Database type is {}, but struct field is deserialized as {}, which is compatible only with {:?}", + data_type, serde_type, compatible_db_types + ), } } @@ -277,17 +297,14 @@ impl<'cursor> ValidateDataType for DataTypeValidator<'cursor> { &mut self, serde_type: &'static SerdeType, compatible_db_types: &'static [DataTypeHint], + len: usize, ) -> Result>> { - println!( - "validate call {}; compatible: {:?}, db types: {:?}", - serde_type, compatible_db_types, self.columns, - ); match self.columns.split_first() { - None => Err(Error::TooManyStructFields), Some((first, rest)) => { self.columns = rest; - validate_impl(&first.data_type, serde_type, compatible_db_types) + validate_impl(&first.data_type, serde_type, compatible_db_types, len) } + None => panic!("Struct has more fields than columns in the database schema"), } } @@ -301,8 +318,8 @@ impl<'cursor> ValidateDataType for DataTypeValidator<'cursor> { /// Which Serde data type (De)serializer used for the given type. /// Displays into Rust types for convenience in errors reporting. #[derive(Clone, Debug, PartialEq)] -#[non_exhaustive] -pub enum SerdeType { +#[allow(dead_code)] +pub(crate) enum SerdeType { Bool, I8, I16, @@ -335,12 +352,6 @@ pub enum SerdeType { IgnoredAny, } -impl Default for SerdeType { - fn default() -> Self { - SerdeType::Struct - } -} - impl Display for SerdeType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let type_name = match self { diff --git a/src/validation_mode.rs b/src/validation_mode.rs index a161b057..0774b129 100644 --- a/src/validation_mode.rs +++ b/src/validation_mode.rs @@ -1,23 +1,49 @@ #[non_exhaustive] -#[derive(Clone)] -pub enum StructValidationMode { - FirstRow, - EachRow, +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +/// The preferred mode of validation for struct (de)serialization. +/// It also affects which format is used by the client when sending queries. +/// +/// - [`ValidationMode::First`] enables validation _only for the first `N` rows_ +/// emitted by a cursor. For the following rows, validation is skipped. +/// Format: `RowBinaryWithNamesAndTypes`. +/// - [`ValidationMode::Each`] enables validation _for all rows_ emitted by a cursor. +/// This is the slowest mode. Format: `RowBinaryWithNamesAndTypes`. +/// - [`ValidationMode::Disabled`] means that no validation will be performed. +/// At the same time, this is the fastest mode. Format: `RowBinary`. +/// +/// # Default +/// +/// By default, [`ValidationMode::First`] with value `1` is used, +/// meaning that only the first row will be validated against the database schema, +/// which is extracted from the `RowBinaryWithNamesAndTypes` format header. +/// It is done to minimize the performance impact of the validation, +/// while still providing reasonable safety guarantees by default. +/// +/// # Safety +/// +/// While it is expected that the default validation mode is sufficient for most use cases, +/// in certain corner case scenarios there still can be schema mismatches after the first rows, +/// e.g., when a field is `Nullable(T)`, and the first value is `NULL`. In that case, +/// consider increasing the number of rows in [`ValidationMode::First`], +/// or even using [`ValidationMode::Each`] instead. +pub enum ValidationMode { + First(usize), + Each, Disabled, } -impl Default for StructValidationMode { +impl Default for ValidationMode { fn default() -> Self { - Self::FirstRow + Self::First(1) } } -impl std::fmt::Display for StructValidationMode { +impl std::fmt::Display for ValidationMode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::FirstRow => write!(f, "FirstRow"), - Self::EachRow => write!(f, "EachRow"), - Self::Disabled => write!(f, "Disabled"), + Self::First(n) => f.pad(&format!("FirstN({})", n)), + Self::Each => f.pad("Each"), + Self::Disabled => f.pad("Disabled"), } } } diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index a41059fb..531f2d28 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -1,6 +1,6 @@ -use clickhouse::error::Error; +use crate::get_client; use clickhouse::sql::Identifier; -use clickhouse::validation_mode::StructValidationMode; +use clickhouse::validation_mode::ValidationMode; use clickhouse_derive::Row; use clickhouse_rowbinary::data_types::{Column, DataTypeNode}; use clickhouse_rowbinary::parse_rbwnat_columns_header; @@ -8,9 +8,6 @@ use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::collections::HashMap; use std::str::FromStr; -use time::format_description::well_known::Iso8601; -use time::Month::{February, January}; -use time::OffsetDateTime; #[tokio::test] async fn test_header_parsing() { @@ -123,7 +120,7 @@ async fn test_basic_types() { string_val: String, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let client = get_client().with_validation_mode(ValidationMode::Each); let result = client .query( " @@ -174,7 +171,7 @@ async fn test_several_simple_rows() { str: String, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let client = get_client().with_validation_mode(ValidationMode::Each); let result = client .query("SELECT number AS num, toString(number) AS str FROM system.numbers LIMIT 3") .fetch_all::() @@ -206,7 +203,7 @@ async fn test_many_numbers() { no: u64, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let client = get_client().with_validation_mode(ValidationMode::Each); let mut cursor = client .query("SELECT number FROM system.numbers_mt LIMIT 2000") .fetch::() @@ -230,7 +227,7 @@ async fn test_arrays() { description: String, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let client = get_client().with_validation_mode(ValidationMode::Each); let result = client .query( " @@ -269,7 +266,7 @@ async fn test_maps() { map2: HashMap>, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let client = get_client().with_validation_mode(ValidationMode::Each); let result = client .query( " @@ -339,7 +336,7 @@ async fn test_enum() { let table_name = "test_rbwnat_enum"; - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let client = prepare_database!().with_validation_mode(ValidationMode::Each); client .query( " @@ -397,54 +394,36 @@ async fn test_enum() { } #[tokio::test] +#[should_panic] async fn test_nullable() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { n: Option, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); - let result = client + let client = get_client().with_validation_mode(ValidationMode::Each); + let _ = client .query("SELECT true AS b, 144 :: Int32 AS n2") .fetch_one::() .await; - - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - Error::InvalidColumnDataType { .. } - )); - - // FIXME: lack of derive PartialEq for Error prevents proper assertion - // assert_eq!(result, Error::DataTypeMismatch { - // column_name: "n".to_string(), - // expected_type: "Nullable".to_string(), - // actual_type: "Bool".to_string(), - // columns: vec![...], - // }); } #[tokio::test] +#[should_panic] #[cfg(feature = "time")] async fn test_serde_with() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { #[serde(with = "clickhouse::serde::time::datetime64::millis")] - n1: OffsetDateTime, // underlying is still Int64; should not compose it from two (U)Int32 + n1: time::OffsetDateTime, // underlying is still Int64; should not compose it from two (U)Int32 } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); - let result = client + let client = get_client().with_validation_mode(ValidationMode::Each); + let _ = client .query("SELECT 42 :: UInt32 AS n1, 144 :: Int32 AS n2") .fetch_one::() .await; - assert!(result.is_err()); - assert!(matches!( - result.unwrap_err(), - Error::InvalidColumnDataType { .. } - )); - // FIXME: lack of derive PartialEq for Error prevents proper assertion // assert_eq!(result, Error::DataTypeMismatch { // column_name: "n1".to_string(), @@ -455,6 +434,7 @@ async fn test_serde_with() { } #[tokio::test] +#[should_panic] async fn test_too_many_struct_fields() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { @@ -463,19 +443,11 @@ async fn test_too_many_struct_fields() { c: u32, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); - let result = client + let client = get_client().with_validation_mode(ValidationMode::Each); + let _ = client .query("SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS b") .fetch_one::() .await; - - assert!(result.is_err()); - let err = result.unwrap_err(); - assert!( - matches!(err, Error::TooManyStructFields { .. }), - "{:?} should be an instance of TooManyStructFields", - err - ); } #[tokio::test] @@ -488,7 +460,7 @@ async fn test_serde_skip_deserializing() { c: u32, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let client = get_client().with_validation_mode(ValidationMode::Each); let result = client .query("SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS c") .fetch_one::() @@ -507,6 +479,10 @@ async fn test_serde_skip_deserializing() { #[tokio::test] #[cfg(feature = "time")] async fn test_date_and_time() { + use time::format_description::well_known::Iso8601; + use time::Month::{February, January}; + use time::OffsetDateTime; + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { #[serde(with = "clickhouse::serde::time::date")] @@ -525,7 +501,7 @@ async fn test_date_and_time() { date_time64_9: OffsetDateTime, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let client = get_client().with_validation_mode(ValidationMode::Each); let result = client .query( " @@ -563,23 +539,54 @@ async fn test_date_and_time() { ); } +#[tokio::test] +#[cfg(feature = "uuid")] +async fn test_uuid() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + id: u16, + #[serde(with = "clickhouse::serde::uuid")] + uuid: uuid::Uuid, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query( + " + SELECT + 42 :: UInt16 AS id, + '550e8400-e29b-41d4-a716-446655440000' :: UUID AS uuid + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + id: 42, + uuid: uuid::Uuid::from_str("550e8400-e29b-41d4-a716-446655440000").unwrap(), + } + ); +} + #[tokio::test] async fn test_ipv4_ipv6() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { - // id: u16, - // #[serde(with = "clickhouse::serde::ipv4")] - // ipv4: std::net::Ipv4Addr, + id: u16, + #[serde(with = "clickhouse::serde::ipv4")] + ipv4: std::net::Ipv4Addr, ipv6: std::net::Ipv6Addr, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let client = get_client().with_validation_mode(ValidationMode::Each); let result = client .query( " SELECT - -- 42 :: UInt16 AS id, - -- '192.168.0.1' :: IPv4 AS ipv4, + 42 :: UInt16 AS id, + '192.168.0.1' :: IPv4 AS ipv4, '2001:db8:3333:4444:5555:6666:7777:8888' :: IPv6 AS ipv6 ", ) @@ -589,13 +596,149 @@ async fn test_ipv4_ipv6() { assert_eq!( result.unwrap(), vec![Data { - // id: 42, - // ipv4: std::net::Ipv4Addr::new(192, 168, 0, 1), + id: 42, + ipv4: std::net::Ipv4Addr::new(192, 168, 0, 1), ipv6: std::net::Ipv6Addr::from_str("2001:db8:3333:4444:5555:6666:7777:8888").unwrap(), }] ) } +#[tokio::test] +async fn test_fixed_str() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: [u8; 4], + b: [u8; 3], + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query("SELECT '1234' :: FixedString(4) AS a, '777' :: FixedString(3) AS b") + .fetch_one::() + .await; + + let data = result.unwrap(); + assert_eq!(String::from_utf8_lossy(&data.a), "1234",); + assert_eq!(String::from_utf8_lossy(&data.b), "777",); +} + +#[tokio::test] +#[should_panic] +async fn test_fixed_str_too_long() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: [u8; 4], + b: [u8; 3], + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let _ = client + .query("SELECT '12345' :: FixedString(5) AS a, '777' :: FixedString(3) AS b") + .fetch_one::() + .await; +} + +#[tokio::test] +async fn test_tuple() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: (u32, String), + b: (i128, HashMap), + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query( + " + SELECT + (42, 'foo') :: Tuple(UInt32, String) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt16, String)) AS b + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + a: (42, "foo".to_string()), + b: (144, vec![(255, "bar".to_string())].into_iter().collect()), + } + ); +} + +#[tokio::test] +#[should_panic] +async fn test_tuple_invalid_definition() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: (u32, String), + b: (i128, HashMap), + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + + // Map key is UInt64 instead of UInt16 requested in the struct + let _ = client + .query( + " + SELECT + (42, 'foo') :: Tuple(UInt32, String) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt64, String)) AS b + ", + ) + .fetch_one::() + .await; +} + +#[tokio::test] +#[should_panic] +async fn test_tuple_too_many_elements_in_the_schema() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: (u32, String), + b: (i128, HashMap), + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + + // too many elements in the db type definition + let _ = client + .query( + " + SELECT + (42, 'foo', true) :: Tuple(UInt32, String, Bool) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt16, String)) AS b + ", + ) + .fetch_one::() + .await; +} + +#[tokio::test] +#[should_panic] +async fn test_tuple_too_many_elements_in_the_struct() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: (u32, String, bool), + b: (i128, HashMap), + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + + // too many elements in the struct enum + let _ = client + .query( + " + SELECT + (42, 'foo') :: Tuple(UInt32, String) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt16, String)) AS b + ", + ) + .fetch_one::() + .await; +} + // FIXME: RBWNAT should allow for tracking the order of fields in the struct and in the database! // it is possible to use HashMap to deserialize the struct instead of Tuple visitor #[tokio::test] @@ -607,7 +750,7 @@ async fn test_different_struct_field_order() { a: String, } - let client = prepare_database!().with_struct_validation_mode(StructValidationMode::EachRow); + let client = get_client().with_validation_mode(ValidationMode::Each); let result = client .query("SELECT 'foo' AS a, 'bar' :: String AS c") .fetch_one::() From c4a608ec282ff11548f76d6c1411a9a27aa9bdfe Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 22 May 2025 17:58:40 +0200 Subject: [PATCH 08/54] RBWNAT deserializer - improve performance --- benches/select_numbers.rs | 12 +- src/cursors/row.rs | 70 ++++----- src/error.rs | 11 +- src/lib.rs | 2 - src/query.rs | 6 +- src/rowbinary/de.rs | 290 +++++++++++++++--------------------- src/rowbinary/utils.rs | 1 + src/rowbinary/validation.rs | 207 +++++++++++++++---------- src/validation_mode.rs | 4 - tests/it/rbwnat.rs | 92 +++++++++++- 10 files changed, 378 insertions(+), 317 deletions(-) diff --git a/benches/select_numbers.rs b/benches/select_numbers.rs index 45d0b769..b05bd8d3 100644 --- a/benches/select_numbers.rs +++ b/benches/select_numbers.rs @@ -45,13 +45,11 @@ async fn do_bench(compression: Compression, validation_mode: ValidationMode) -> #[tokio::main] async fn main() { println!("compress validation elapsed throughput received"); - bench("none", Compression::None, ValidationMode::Disabled).await; bench("none", Compression::None, ValidationMode::First(1)).await; bench("none", Compression::None, ValidationMode::Each).await; - #[cfg(feature = "lz4")] - { - bench("lz4", Compression::Lz4, ValidationMode::Disabled).await; - bench("lz4", Compression::Lz4, ValidationMode::First(1)).await; - bench("lz4", Compression::Lz4, ValidationMode::Each).await; - } + // #[cfg(feature = "lz4")] + // { + // bench("lz4", Compression::Lz4, ValidationMode::First(1)).await; + // bench("lz4", Compression::Lz4, ValidationMode::Each).await; + // } } diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 4187cc50..8398a898 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -16,9 +16,8 @@ use std::marker::PhantomData; pub struct RowCursor { raw: RawCursor, bytes: BytesExt, - validation_mode: ValidationMode, + rows_to_check: u64, columns: Option>, - rows_emitted: u64, _marker: PhantomData, } @@ -28,9 +27,11 @@ impl RowCursor { _marker: PhantomData, raw: RawCursor::new(response), bytes: BytesExt::default(), + rows_to_check: match validation_mode { + ValidationMode::First(n) => n as u64, + ValidationMode::Each => u64::MAX, + }, columns: None, - rows_emitted: 0, - validation_mode, } } @@ -45,42 +46,26 @@ impl RowCursor { where T: Deserialize<'b>, { - let should_validate = match self.validation_mode { - ValidationMode::Disabled => false, - ValidationMode::Each => true, - ValidationMode::First(n) => self.rows_emitted < (n as u64), - }; - loop { if self.bytes.remaining() > 0 { let mut slice = super::workaround_51132(self.bytes.slice()); - let deserialize_result = if should_validate { - match &self.columns { - // TODO: can it be moved to `new` instead? - None => { - let columns = parse_rbwnat_columns_header(&mut slice)?; - self.bytes.set_remaining(slice.len()); - self.columns = Some(columns); - let columns = self.columns.as_ref().unwrap(); - // usually, the header arrives as a separate first chunk - if self.bytes.remaining() > 0 { - rowbinary::deserialize_from_and_validate(&mut slice, columns) - } else { - Err(Error::NotEnoughData) - } - } - Some(columns) => { - rowbinary::deserialize_from_and_validate(&mut slice, columns) - } + let deserialize_result = match &self.columns { + None => self.extract_columns_and_deserialize_from(slice), + Some(columns) if self.rows_to_check > 0 => { + rowbinary::deserialize_from_and_validate(&mut slice, columns) + } + Some(_) => { + // Schema is validated already, skipping for better performance + rowbinary::deserialize_from(&mut slice) } - } else { - rowbinary::deserialize_from(&mut slice) }; match deserialize_result { Ok(value) => { self.bytes.set_remaining(slice.len()); - self.rows_emitted += 1; + if self.rows_to_check > 0 { + self.rows_to_check -= 1; + } return Ok(Some(value)); } Err(Error::NotEnoughData) => {} @@ -116,9 +101,24 @@ impl RowCursor { self.raw.decoded_bytes() } - /// Returns the number of rows emitted via [`Self::next`] since the cursor was created. - #[inline] - pub fn rows_emitted(&self) -> u64 { - self.rows_emitted + #[cold] + #[inline(never)] + fn extract_columns_and_deserialize_from<'a, 'b: 'a>( + &'a mut self, + mut slice: &'b [u8], + ) -> Result + where + T: Deserialize<'b>, + { + let columns = parse_rbwnat_columns_header(&mut slice)?; + self.bytes.set_remaining(slice.len()); + self.columns = Some(columns); + let columns = self.columns.as_ref().unwrap(); + // usually, the header arrives as a separate first chunk + if self.bytes.remaining() > 0 { + rowbinary::deserialize_from_and_validate(&mut slice, columns) + } else { + Err(Error::NotEnoughData) + } } } diff --git a/src/error.rs b/src/error.rs index 0377ea67..7eacb759 100644 --- a/src/error.rs +++ b/src/error.rs @@ -43,15 +43,10 @@ pub enum Error { TimedOut, #[error("unsupported: {0}")] Unsupported(String), + #[error("error while deserializing data: {0}")] + DeserializationError(String), #[error("error while parsing data from the response: {0}")] - ParserError(BoxedError), - #[error("struct mismatches the database definition; field {field_name} has unexpected type {unexpected_type}; allowed types for {field_name}: {allowed_types}; database columns: {all_columns:?}")] - DataTypeMismatch { - field_name: String, - allowed_types: String, - unexpected_type: String, - all_columns: String, - }, + ParserError(#[source] BoxedError), #[error("{0}")] Other(BoxedError), } diff --git a/src/lib.rs b/src/lib.rs index f842d12a..4641ce8f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -478,8 +478,6 @@ mod client_tests { assert_eq!(client.validation_mode, ValidationMode::First(1)); let client = client.with_validation_mode(ValidationMode::Each); assert_eq!(client.validation_mode, ValidationMode::Each); - let client = client.with_validation_mode(ValidationMode::Disabled); - assert_eq!(client.validation_mode, ValidationMode::Disabled); let client = client.with_validation_mode(ValidationMode::First(10)); assert_eq!(client.validation_mode, ValidationMode::First(10)); } diff --git a/src/query.rs b/src/query.rs index b9ccebde..2a1036fa 100644 --- a/src/query.rs +++ b/src/query.rs @@ -17,7 +17,6 @@ const MAX_QUERY_LEN_TO_USE_GET: usize = 8192; pub use crate::cursors::{BytesCursor, RowCursor}; use crate::headers::with_authentication; -use crate::validation_mode::ValidationMode; #[must_use] #[derive(Clone)] @@ -88,10 +87,7 @@ impl Query { let validation_mode = self.client.validation_mode; self.sql.bind_fields::(); - self.sql.set_output_format(match validation_mode { - ValidationMode::First(_) | ValidationMode::Each => "RowBinaryWithNamesAndTypes", - ValidationMode::Disabled => "RowBinary", - }); + self.sql.set_output_format("RowBinaryWithNamesAndTypes"); let response = self.do_execute(true)?; Ok(RowCursor::new(response, validation_mode)) diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 52b246ed..54eb4e7c 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -19,10 +19,7 @@ use std::{convert::TryFrom, mem, str}; pub(crate) fn deserialize_from<'data, T: Deserialize<'data>>(input: &mut &'data [u8]) -> Result { // println!("deserialize_from call"); - let mut deserializer = RowBinaryDeserializer { - input, - validator: (), - }; + let mut deserializer = RowBinaryDeserializer::new(input, ()); T::deserialize(&mut deserializer) } @@ -33,10 +30,7 @@ pub(crate) fn deserialize_from_and_validate<'data, 'cursor, T: Deserialize<'data input: &mut &'data [u8], columns: &'cursor [Column], ) -> Result { - let mut deserializer = RowBinaryDeserializer { - input, - validator: DataTypeValidator::new(columns), - }; + let mut deserializer = RowBinaryDeserializer::new(input, DataTypeValidator::new(columns)); T::deserialize(&mut deserializer) } @@ -51,22 +45,30 @@ where pub(crate) input: &'cursor mut &'data [u8], } -impl<'data, Validator> RowBinaryDeserializer<'_, 'data, Validator> +impl<'cursor, 'data, Validator> RowBinaryDeserializer<'cursor, 'data, Validator> where Validator: ValidateDataType, { - pub(crate) fn read_vec(&mut self, size: usize) -> Result> { + #[inline] + fn new(input: &'cursor mut &'data [u8], validator: Validator) -> Self { + Self { input, validator } + } + + #[inline] + fn read_vec(&mut self, size: usize) -> Result> { Ok(self.read_slice(size)?.to_vec()) } - pub(crate) fn read_slice(&mut self, size: usize) -> Result<&'data [u8]> { + #[inline] + fn read_slice(&mut self, size: usize) -> Result<&'data [u8]> { ensure_size(&mut self.input, size)?; let slice = &self.input[..size]; self.input.advance(size); Ok(slice) } - pub(crate) fn read_size(&mut self) -> Result { + #[inline] + fn read_size(&mut self) -> Result { let size = get_unsigned_leb128(&mut self.input)?; // TODO: what about another error? usize::try_from(size).map_err(|_| Error::NotEnoughData) @@ -75,9 +77,9 @@ where macro_rules! impl_num { ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr, $type_hints:expr) => { - #[inline] + #[inline(always)] fn $deser_method>(self, visitor: V) -> Result { - self.validator.validate($serde_type, $type_hints, 0)?; + self.validator.validate($serde_type, $type_hints)?; ensure_size(&mut self.input, mem::size_of::<$ty>())?; let value = self.input.$reader_method(); visitor.$visitor_method(value) @@ -91,23 +93,28 @@ where { type Error = Error; - impl_num!( - i8, - deserialize_i8, - visit_i8, - get_i8, - &SerdeType::I8, - // TODO: shall we allow deserialization from boolean? - &[DataTypeHint::Int8, DataTypeHint::Bool] - ); - impl_num!( - i16, - deserialize_i16, - visit_i16, - get_i16_le, - &SerdeType::I16, - &[DataTypeHint::Int16] - ); + #[inline(always)] + fn deserialize_i8>(self, visitor: V) -> Result { + let mut maybe_enum_validator = self + .validator + .validate(&SerdeType::I8, &[DataTypeHint::Int8, DataTypeHint::Bool])?; + ensure_size(&mut self.input, size_of::())?; + let value = self.input.get_i8(); + maybe_enum_validator.validate_enum8(value); + visitor.visit_i8(value) + } + + #[inline(always)] + fn deserialize_i16>(self, visitor: V) -> Result { + let mut maybe_enum_validator = self + .validator + .validate(&SerdeType::I16, &[DataTypeHint::Int16])?; + ensure_size(&mut self.input, size_of::())?; + let value = self.input.get_i16_le(); + maybe_enum_validator.validate_enum16(value); + visitor.visit_i16(value) + } + impl_num!( i32, deserialize_i32, @@ -190,23 +197,23 @@ where &[DataTypeHint::Float64] ); - #[inline] + #[inline(always)] fn deserialize_any>(self, _: V) -> Result { Err(Error::DeserializeAnyNotSupported) } - #[inline] + #[inline(always)] fn deserialize_unit>(self, visitor: V) -> Result { // TODO: revise this. visitor.visit_unit() } - #[inline] + #[inline(always)] fn deserialize_char>(self, _: V) -> Result { panic!("character types are unsupported: `char`"); } - #[inline] + #[inline(always)] fn deserialize_bool>(self, visitor: V) -> Result { self.validator.validate( &SerdeType::Bool, @@ -215,7 +222,6 @@ where // it is possible to deserialize from UInt8 0 or 1 as Boolean DataTypeHint::UInt8, ], - 0, )?; ensure_size(&mut self.input, 1)?; match self.input.get_u8() { @@ -225,33 +231,33 @@ where } } - #[inline] + #[inline(always)] fn deserialize_str>(self, visitor: V) -> Result { // println!("deserialize_str call"); // TODO - which types to allow? self.validator - .validate(&SerdeType::String, &[DataTypeHint::String], 0)?; + .validate(&SerdeType::String, &[DataTypeHint::String])?; let size = self.read_size()?; let slice = self.read_slice(size)?; let str = str::from_utf8(slice).map_err(Error::from)?; visitor.visit_borrowed_str(str) } - #[inline] + #[inline(always)] fn deserialize_string>(self, visitor: V) -> Result { // println!("deserialize_string call"); // TODO - which types to allow? self.validator - .validate(&SerdeType::String, &[DataTypeHint::String], 0)?; + .validate(&SerdeType::String, &[DataTypeHint::String])?; let size = self.read_size()?; let vec = self.read_vec(size)?; let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; visitor.visit_string(string) } - #[inline] + #[inline(always)] fn deserialize_bytes>(self, visitor: V) -> Result { // println!("deserialize_bytes call"); @@ -261,7 +267,7 @@ where visitor.visit_borrowed_bytes(slice) } - #[inline] + #[inline(always)] fn deserialize_byte_buf>(self, visitor: V) -> Result { // println!("deserialize_byte_buf call"); @@ -270,7 +276,7 @@ where visitor.visit_byte_buf(self.read_vec(size)?) } - #[inline] + #[inline(always)] fn deserialize_identifier>(self, visitor: V) -> Result { // println!("deserialize_identifier call"); @@ -278,7 +284,7 @@ where self.deserialize_u8(visitor) } - #[inline] + #[inline(always)] fn deserialize_enum>( self, _name: &'static str, @@ -357,50 +363,23 @@ where } } - // FIXME - self.validator - .validate(&SerdeType::Enum, &[DataTypeHint::Enum], 0)?; - visitor.visit_enum(RowBinaryEnumAccess { deserializer: self }) + let inner_data_type_validator = self + .validator + .validate(&SerdeType::Enum, &[DataTypeHint::Variant])?; + let mut new_self = RowBinaryDeserializer { + input: self.input, + validator: inner_data_type_validator, + }; + visitor.visit_enum(RowBinaryEnumAccess { + deserializer: &mut new_self, + }) } - #[inline] + #[inline(always)] fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { // println!("deserialize_tuple call, len {}", len); - struct RowBinaryTupleAccess<'de, 'cursor, 'data, Validator> - where - Validator: ValidateDataType, - { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, - len: usize, - } - - impl<'data, Validator> SeqAccess<'data> for RowBinaryTupleAccess<'_, '_, 'data, Validator> - where - Validator: ValidateDataType, - { - type Error = Error; - - fn next_element_seed(&mut self, seed: T) -> Result> - where - T: DeserializeSeed<'data>, - { - // println!("Processing value, len: {}", self.len); - if self.len > 0 { - self.len -= 1; - let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; - Ok(Some(value)) - } else { - Ok(None) - } - } - - fn size_hint(&self) -> Option { - Some(self.len) - } - } - - let inner_data_type_validator = self.validator.validate( + let mut inner_data_type_validator = self.validator.validate( &SerdeType::Tuple, &[ DataTypeHint::Tuple, @@ -410,27 +389,27 @@ where // DataTypeHint::IPv4, DataTypeHint::IPv6, ], - len, )?; + inner_data_type_validator.validate_fixed_string(len); let mut new_self = RowBinaryDeserializer { input: self.input, validator: inner_data_type_validator, }; - let access = RowBinaryTupleAccess { + let access = RowBinarySeqAccess { deserializer: &mut new_self, len, }; visitor.visit_seq(access) } - #[inline] + #[inline(always)] fn deserialize_option>(self, visitor: V) -> Result { // println!("deserialize_option call"); ensure_size(&mut self.input, 1)?; - let inner_data_type_validator = - self.validator - .validate(&SerdeType::Option, &[DataTypeHint::Nullable], 0)?; + let inner_data_type_validator = self + .validator + .validate(&SerdeType::Option, &[DataTypeHint::Nullable])?; match self.input.get_u8() { 0 => visitor.visit_some(&mut RowBinaryDeserializer { input: self.input, @@ -441,46 +420,14 @@ where } } - #[inline] + #[inline(always)] fn deserialize_seq>(self, visitor: V) -> Result { // println!("deserialize_seq call"); - struct RowBinarySeqAccess<'de, 'cursor, 'data, Validator> - where - Validator: ValidateDataType, - { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, - len: usize, - } - - impl<'data, Validator> SeqAccess<'data> for RowBinarySeqAccess<'_, '_, 'data, Validator> - where - Validator: ValidateDataType, - { - type Error = Error; - - fn next_element_seed(&mut self, seed: T) -> Result> - where - T: DeserializeSeed<'data>, - { - if self.len > 0 { - self.len -= 1; - let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; - Ok(Some(value)) - } else { - Ok(None) - } - } - - fn size_hint(&self) -> Option { - Some(self.len) - } - } - let len = self.read_size()?; - let inner_data_type_validator = - self.validator - .validate(&SerdeType::Seq, &[DataTypeHint::Array], len)?; + let inner_data_type_validator = self + .validator + .validate(&SerdeType::Seq, &[DataTypeHint::Array])?; visitor.visit_seq(RowBinarySeqAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, @@ -490,7 +437,7 @@ where }) } - #[inline] + #[inline(always)] fn deserialize_map>(self, visitor: V) -> Result { // println!( // "deserialize_map call", @@ -535,9 +482,9 @@ where } let len = self.read_size()?; - let inner_data_type_validator = - self.validator - .validate(&SerdeType::Map, &[DataTypeHint::Map], len)?; + let inner_data_type_validator = self + .validator + .validate(&SerdeType::Map, &[DataTypeHint::Map])?; visitor.visit_map(RowBinaryMapAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, @@ -548,59 +495,22 @@ where }) } - #[inline] + #[inline(always)] fn deserialize_struct>( self, - name: &str, + _name: &str, fields: &'static [&'static str], visitor: V, ) -> Result { // println!("deserialize_struct: {} (fields: {:?})", name, fields,); - // FIXME use &'_ str, fix lifetimes - self.validator.set_struct_name(name.to_string()); - - // TODO: it should also support using HashMap to deserialize - // currently just copy-pasted to prevent former `deserialize_tuple` delegation - struct RowBinaryStructAccess<'de, 'cursor, 'data, Validator> - where - Validator: ValidateDataType, - { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, - len: usize, - } - - impl<'data, Validator> SeqAccess<'data> for RowBinaryStructAccess<'_, '_, 'data, Validator> - where - Validator: ValidateDataType, - { - type Error = Error; - - fn next_element_seed(&mut self, seed: T) -> Result> - where - T: DeserializeSeed<'data>, - { - if self.len > 0 { - self.len -= 1; - let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; - Ok(Some(value)) - } else { - Ok(None) - } - } - - fn size_hint(&self) -> Option { - Some(self.len) - } - } - - visitor.visit_seq(RowBinaryStructAccess { + visitor.visit_seq(RowBinarySeqAccess { deserializer: self, len: fields.len(), }) } - #[inline] + #[inline(always)] fn deserialize_newtype_struct>( self, _name: &str, @@ -610,7 +520,7 @@ where visitor.visit_newtype_struct(self) } - #[inline] + #[inline(always)] fn deserialize_unit_struct>( self, name: &'static str, @@ -619,7 +529,7 @@ where panic!("unit types are unsupported: `{name}`"); } - #[inline] + #[inline(always)] fn deserialize_tuple_struct>( self, name: &'static str, @@ -629,13 +539,45 @@ where panic!("tuple struct types are unsupported: `{name}`"); } - #[inline] + #[inline(always)] fn deserialize_ignored_any>(self, _visitor: V) -> Result { panic!("ignored types are unsupported"); } - #[inline] + #[inline(always)] fn is_human_readable(&self) -> bool { false } } + +struct RowBinarySeqAccess<'de, 'cursor, 'data, Validator> +where + Validator: ValidateDataType, +{ + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + len: usize, +} + +impl<'data, Validator> SeqAccess<'data> for RowBinarySeqAccess<'_, '_, 'data, Validator> +where + Validator: ValidateDataType, +{ + type Error = Error; + + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'data>, + { + if self.len > 0 { + self.len -= 1; + let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; + Ok(Some(value)) + } else { + Ok(None) + } + } + + fn size_hint(&self) -> Option { + Some(self.len) + } +} diff --git a/src/rowbinary/utils.rs b/src/rowbinary/utils.rs index fc2db7e9..3e9a3dc7 100644 --- a/src/rowbinary/utils.rs +++ b/src/rowbinary/utils.rs @@ -10,6 +10,7 @@ pub(crate) fn ensure_size(buffer: impl Buf, size: usize) -> crate::error::Result } } +#[inline] pub(crate) fn get_unsigned_leb128(mut buffer: impl Buf) -> crate::error::Result { let mut value = 0u64; let mut shift = 0; diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index c7c96b81..cff6a285 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -8,41 +8,41 @@ pub(crate) trait ValidateDataType: Sized { &mut self, serde_type: &'static SerdeType, compatible_db_types: &'static [DataTypeHint], - // TODO: currently used only for FixedString validation. - // Is there a better way, avoiding passing it? - seq_len: usize, ) -> Result>>; - fn set_struct_name(&mut self, name: String); + fn validate_enum8(&mut self, value: i8); + fn validate_enum16(&mut self, value: i16); + fn validate_fixed_string(&mut self, len: usize); } #[derive(Default)] pub(crate) struct DataTypeValidator<'cursor> { columns: &'cursor [Column], - struct_name: Option, } impl<'cursor> DataTypeValidator<'cursor> { + #[inline(always)] pub(crate) fn new(columns: &'cursor [Column]) -> Self { - Self { - columns, - struct_name: None, - } + Self { columns } } } +#[derive(Debug)] pub(crate) enum MapValidatorState { Key, Value, Validated, } +#[derive(Debug)] pub(crate) enum ArrayValidatorState { Pending, Validated, } +#[derive(Debug)] pub(crate) enum InnerDataTypeValidator<'cursor> { Array(&'cursor DataTypeNode, ArrayValidatorState), + FixedString(usize), Map( &'cursor DataTypeNode, &'cursor DataTypeNode, @@ -55,17 +55,24 @@ pub(crate) enum InnerDataTypeValidator<'cursor> { } impl ValidateDataType for () { - #[inline] + #[inline(always)] fn validate( &mut self, _serde_type: &'static SerdeType, _compatible_db_types: &'static [DataTypeHint], - _len: usize, + // _len: usize, ) -> Result>> { Ok(None) } - #[inline] - fn set_struct_name(&mut self, _name: String) {} + + #[inline(always)] + fn validate_enum8(&mut self, _enum_value: i8) {} + + #[inline(always)] + fn validate_enum16(&mut self, _enum_value: i16) {} + + #[inline(always)] + fn validate_fixed_string(&mut self, _len: usize) {} } impl<'cursor> ValidateDataType for Option> { @@ -74,19 +81,20 @@ impl<'cursor> ValidateDataType for Option> { &mut self, serde_type: &'static SerdeType, compatible_db_types: &'static [DataTypeHint], - seq_len: usize, + // seq_len: usize, ) -> Result>> { + // println!("Validating inner data type: {:?} against serde type: {} with compatible db types: {:?}", + // self, serde_type, compatible_db_types); match self { None => Ok(None), Some(InnerDataTypeValidator::Map(key_type, value_type, state)) => match state { MapValidatorState::Key => { - let result = validate_impl(key_type, serde_type, compatible_db_types, seq_len); + let result = validate_impl(key_type, serde_type, compatible_db_types); *state = MapValidatorState::Value; result } MapValidatorState::Value => { - let result = - validate_impl(value_type, serde_type, compatible_db_types, seq_len); + let result = validate_impl(value_type, serde_type, compatible_db_types); *state = MapValidatorState::Validated; result } @@ -94,8 +102,7 @@ impl<'cursor> ValidateDataType for Option> { }, Some(InnerDataTypeValidator::Array(inner_type, state)) => match state { ArrayValidatorState::Pending => { - let result = - validate_impl(inner_type, serde_type, compatible_db_types, seq_len); + let result = validate_impl(inner_type, serde_type, compatible_db_types); *state = ArrayValidatorState::Validated; result } @@ -104,13 +111,13 @@ impl<'cursor> ValidateDataType for Option> { ArrayValidatorState::Validated => Ok(None), }, Some(InnerDataTypeValidator::Nullable(inner_type)) => { - validate_impl(inner_type, serde_type, compatible_db_types, 0) + validate_impl(inner_type, serde_type, compatible_db_types) } Some(InnerDataTypeValidator::Tuple(elements_types)) => { match elements_types.split_first() { Some((first, rest)) => { *elements_types = rest; - validate_impl(first, serde_type, compatible_db_types, 0) + validate_impl(first, serde_type, compatible_db_types) } None => panic!( "Struct tries to deserialize {} as a tuple element, but there are no more allowed elements in the database schema", @@ -118,8 +125,11 @@ impl<'cursor> ValidateDataType for Option> { ) } } + Some(InnerDataTypeValidator::FixedString(_len)) => { + Ok(None) // actually unreachable + } Some(InnerDataTypeValidator::Variant(_possible_types)) => { - todo!() // TODO - check type index in the parsed types vec + Ok(None) // FIXME: requires comparing DataTypeNode vs TypeHint... } Some(InnerDataTypeValidator::Enum(_values_map)) => { todo!() // TODO - check value correctness in the hashmap @@ -127,9 +137,34 @@ impl<'cursor> ValidateDataType for Option> { } } - #[inline] - fn set_struct_name(&mut self, _name: String) { - unreachable!("`set_struct_name` should never be called for inner validators") + #[inline(always)] + fn validate_fixed_string(&mut self, len: usize) { + if let Some(InnerDataTypeValidator::FixedString(expected_len)) = self { + if *expected_len != len { + panic!( + "FixedString byte length mismatch: expected {}, got {}", + expected_len, len + ); + } + } + } + + #[inline(always)] + fn validate_enum8(&mut self, value: i8) { + if let Some(InnerDataTypeValidator::Enum(values_map)) = self { + if !values_map.contains_key(&(value as i16)) { + panic!("Enum8 value `{value}` is not present in the database schema"); + } + } + } + + #[inline(always)] + fn validate_enum16(&mut self, enum_value: i16) { + if let Some(InnerDataTypeValidator::Enum(value)) = self { + if !value.contains_key(&enum_value) { + panic!("Enum16 value `{enum_value}` is not present in the database schema"); + } + } } } @@ -151,7 +186,6 @@ fn validate_impl<'cursor>( data_type: &'cursor DataTypeNode, serde_type: &'static SerdeType, compatible_db_types: &'static [DataTypeHint], - seq_len: usize, ) -> Result>> { // println!( // "Validating data type: {:?} against serde type: {} with compatible db types: {:?}", @@ -166,34 +200,34 @@ fn validate_impl<'cursor>( DataTypeNode::Int32 | DataTypeNode::Date32 | DataTypeNode::Decimal(_, _, DecimalSize::Int32) - if compatible_db_types.contains(&DataTypeHint::Int32) => - { - Ok(None) - } + if compatible_db_types.contains(&DataTypeHint::Int32) => + { + Ok(None) + } DataTypeNode::Int64 | DataTypeNode::DateTime64(_, _) | DataTypeNode::Decimal(_, _, DecimalSize::Int64) - if compatible_db_types.contains(&DataTypeHint::Int64) => - { - Ok(None) - } + if compatible_db_types.contains(&DataTypeHint::Int64) => + { + Ok(None) + } DataTypeNode::Int128 | DataTypeNode::Decimal(_, _, DecimalSize::Int128) - if compatible_db_types.contains(&DataTypeHint::Int128) => - { - Ok(None) - } + if compatible_db_types.contains(&DataTypeHint::Int128) => + { + Ok(None) + } DataTypeNode::UInt8 if compatible_db_types.contains(&DataTypeHint::UInt8) => Ok(None), DataTypeNode::UInt16 | DataTypeNode::Date - if compatible_db_types.contains(&DataTypeHint::UInt16) => - { - Ok(None) - } + if compatible_db_types.contains(&DataTypeHint::UInt16) => + { + Ok(None) + } DataTypeNode::UInt32 | DataTypeNode::DateTime(_) | DataTypeNode::IPv4 - if compatible_db_types.contains(&DataTypeHint::UInt32) => - { - Ok(None) - } + if compatible_db_types.contains(&DataTypeHint::UInt32) => + { + Ok(None) + } DataTypeNode::UInt64 if compatible_db_types.contains(&DataTypeHint::UInt64) => Ok(None), DataTypeNode::UInt128 if compatible_db_types.contains(&DataTypeHint::UInt128) => Ok(None), @@ -202,16 +236,16 @@ fn validate_impl<'cursor>( // Currently, we allow new JSON type only with `output_format_binary_write_json_as_string` DataTypeNode::String | DataTypeNode::JSON - if compatible_db_types.contains(&DataTypeHint::String) => - { - Ok(None) - } + if compatible_db_types.contains(&DataTypeHint::String) => + { + Ok(None) + } DataTypeNode::FixedString(n) - if compatible_db_types.contains(&DataTypeHint::FixedString) && *n == seq_len => - { - Ok(None) - } + if compatible_db_types.contains(&DataTypeHint::FixedString) => + { + Ok(Some(InnerDataTypeValidator::FixedString(*n))) + } // FIXME: IPv4 from ClickHouse ends up reversed. // Ideally, requires a ReversedSeqAccess implementation. Perhaps memoize IPv4 col index? @@ -239,40 +273,40 @@ fn validate_impl<'cursor>( } DataTypeNode::Map(key_type, value_type) - if compatible_db_types.contains(&DataTypeHint::Map) => - { - Ok(Some(InnerDataTypeValidator::Map( - key_type, - value_type, - MapValidatorState::Key, - ))) - } + if compatible_db_types.contains(&DataTypeHint::Map) => + { + Ok(Some(InnerDataTypeValidator::Map( + key_type, + value_type, + MapValidatorState::Key, + ))) + } DataTypeNode::Tuple(elements) if compatible_db_types.contains(&DataTypeHint::Tuple) => { Ok(Some(InnerDataTypeValidator::Tuple(elements))) } DataTypeNode::Nullable(inner_type) - if compatible_db_types.contains(&DataTypeHint::Nullable) => - { - Ok(Some(InnerDataTypeValidator::Nullable(inner_type))) - } + if compatible_db_types.contains(&DataTypeHint::Nullable) => + { + Ok(Some(InnerDataTypeValidator::Nullable(inner_type))) + } // LowCardinality is completely transparent on the client side DataTypeNode::LowCardinality(inner_type) => { - validate_impl(inner_type, serde_type, compatible_db_types, seq_len) + validate_impl(inner_type, serde_type, compatible_db_types) } DataTypeNode::Enum(EnumType::Enum8, values_map) - if compatible_db_types.contains(&DataTypeHint::Int8) => - { - Ok(Some(InnerDataTypeValidator::Enum(values_map))) - } + if compatible_db_types.contains(&DataTypeHint::Int8) => + { + Ok(Some(InnerDataTypeValidator::Enum(values_map))) + } DataTypeNode::Enum(EnumType::Enum16, values_map) - if compatible_db_types.contains(&DataTypeHint::Int16) => - { - Ok(Some(InnerDataTypeValidator::Enum(values_map))) - } + if compatible_db_types.contains(&DataTypeHint::Int16) => + { + Ok(Some(InnerDataTypeValidator::Enum(values_map))) + } DataTypeNode::Variant(possible_types) => { Ok(Some(InnerDataTypeValidator::Variant(possible_types))) @@ -297,21 +331,32 @@ impl<'cursor> ValidateDataType for DataTypeValidator<'cursor> { &mut self, serde_type: &'static SerdeType, compatible_db_types: &'static [DataTypeHint], - len: usize, ) -> Result>> { match self.columns.split_first() { Some((first, rest)) => { self.columns = rest; - validate_impl(&first.data_type, serde_type, compatible_db_types, len) + validate_impl(&first.data_type, serde_type, compatible_db_types) } None => panic!("Struct has more fields than columns in the database schema"), } } - // FIXME: remove copy of a String and use &str instead; but lifetimes are tricky here - #[inline] - fn set_struct_name(&mut self, name: String) { - self.struct_name = Some(name); + #[cold] + #[inline(never)] + fn validate_enum8(&mut self, _value: i8) { + unreachable!() + } + + #[cold] + #[inline(never)] + fn validate_enum16(&mut self, _value: i16) { + unreachable!() + } + + #[cold] + #[inline(never)] + fn validate_fixed_string(&mut self, _len: usize) { + unreachable!() } } diff --git a/src/validation_mode.rs b/src/validation_mode.rs index 0774b129..1755bd3f 100644 --- a/src/validation_mode.rs +++ b/src/validation_mode.rs @@ -8,8 +8,6 @@ /// Format: `RowBinaryWithNamesAndTypes`. /// - [`ValidationMode::Each`] enables validation _for all rows_ emitted by a cursor. /// This is the slowest mode. Format: `RowBinaryWithNamesAndTypes`. -/// - [`ValidationMode::Disabled`] means that no validation will be performed. -/// At the same time, this is the fastest mode. Format: `RowBinary`. /// /// # Default /// @@ -29,7 +27,6 @@ pub enum ValidationMode { First(usize), Each, - Disabled, } impl Default for ValidationMode { @@ -43,7 +40,6 @@ impl std::fmt::Display for ValidationMode { match self { Self::First(n) => f.pad(&format!("FirstN({})", n)), Self::Each => f.pad("Each"), - Self::Disabled => f.pad("Disabled"), } } } diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 531f2d28..bffb4801 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -356,7 +356,6 @@ async fn test_enum() { let expected = vec![ Data { id: 1, - enum8: MyEnum8::Spring, enum16: MyEnum16::East, }, @@ -739,6 +738,97 @@ async fn test_tuple_too_many_elements_in_the_struct() { .await; } +#[tokio::test] +async fn test_variant() { + #[derive(Debug, Deserialize, PartialEq)] + enum MyVariant { + Str(String), + U16(u16), + } + + #[derive(Debug, Row, Deserialize, PartialEq)] + struct Data { + id: u8, + var: MyVariant, + } + + let client = get_client() + .with_validation_mode(ValidationMode::Each) + .with_option("allow_experimental_variant_type", "1"); + let result = client + .query( + " + SELECT * FROM ( + SELECT 0 :: UInt8 AS id, 'foo' :: Variant(String, UInt16) AS var + UNION ALL + SELECT 1 :: UInt8 AS id, 144 :: Variant(String, UInt16) AS var + ) ORDER BY id ASC + ", + ) + .fetch_all::() + .await; + + assert_eq!( + result.unwrap(), + vec![ + Data { + id: 0, + var: MyVariant::Str("foo".to_string()) + }, + Data { + id: 1, + var: MyVariant::U16(144) + }, + ] + ); +} + +#[tokio::test] +#[ignore] // this is currently disabled, see validation todo +async fn test_variant_wrong_definition() { + #[derive(Debug, Deserialize, PartialEq)] + enum MyVariant { + Str(String), + U32(u32), + } + + #[derive(Debug, Row, Deserialize, PartialEq)] + struct Data { + id: u8, + var: MyVariant, + } + + let client = get_client() + .with_validation_mode(ValidationMode::Each) + .with_option("allow_experimental_variant_type", "1"); + let result = client + .query( + " + SELECT * FROM ( + SELECT 0 :: UInt8 AS id, 'foo' :: Variant(String, UInt16) AS var + UNION ALL + SELECT 1 :: UInt8 AS id, 144 :: Variant(String, UInt16) AS var + ) ORDER BY id ASC + ", + ) + .fetch_all::() + .await; + + assert_eq!( + result.unwrap(), + vec![ + Data { + id: 0, + var: MyVariant::Str("foo".to_string()) + }, + Data { + id: 1, + var: MyVariant::U32(144) + }, + ] + ); +} + // FIXME: RBWNAT should allow for tracking the order of fields in the struct and in the database! // it is possible to use HashMap to deserialize the struct instead of Tuple visitor #[tokio::test] From 0d416cfa5ca336a2b2270ebcfa1902b73ef6e4fd Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Fri, 23 May 2025 21:49:37 +0200 Subject: [PATCH 09/54] RBWNAT deserializer - clearer error messages on panics --- src/error.rs | 10 +- src/rowbinary/de.rs | 189 +++------- src/rowbinary/ser.rs | 4 +- src/rowbinary/validation.rs | 663 +++++++++++++++++++++--------------- tests/it/main.rs | 18 + tests/it/rbwnat.rs | 149 ++++---- 6 files changed, 522 insertions(+), 511 deletions(-) diff --git a/src/error.rs b/src/error.rs index 7eacb759..d25545f9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,12 +41,8 @@ pub enum Error { BadResponse(String), #[error("timeout expired")] TimedOut, - #[error("unsupported: {0}")] - Unsupported(String), - #[error("error while deserializing data: {0}")] - DeserializationError(String), - #[error("error while parsing data from the response: {0}")] - ParserError(#[source] BoxedError), + #[error("error while parsing columns header from the response: {0}")] + ColumnsHeaderParserError(#[source] BoxedError), #[error("{0}")] Other(BoxedError), } @@ -55,7 +51,7 @@ assert_impl_all!(Error: StdError, Send, Sync); impl From for Error { fn from(err: clickhouse_rowbinary::error::ParserError) -> Self { - Self::ParserError(Box::new(err)) + Self::ColumnsHeaderParserError(Box::new(err)) } } diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 54eb4e7c..90d5eae0 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -3,7 +3,7 @@ use crate::rowbinary::utils::{ensure_size, get_unsigned_leb128}; use crate::rowbinary::validation::SerdeType; use crate::rowbinary::validation::{DataTypeValidator, ValidateDataType}; use bytes::Buf; -use clickhouse_rowbinary::data_types::{Column, DataTypeHint}; +use clickhouse_rowbinary::data_types::Column; use serde::de::MapAccess; use serde::{ de::{DeserializeSeed, Deserializer, EnumAccess, SeqAccess, VariantAccess, Visitor}, @@ -37,12 +37,12 @@ pub(crate) fn deserialize_from_and_validate<'data, 'cursor, T: Deserialize<'data /// A deserializer for the RowBinary(WithNamesAndTypes) format. /// /// See https://clickhouse.com/docs/en/interfaces/formats#rowbinary for details. -pub(crate) struct RowBinaryDeserializer<'cursor, 'data, Validator = ()> +struct RowBinaryDeserializer<'cursor, 'data, Validator = ()> where Validator: ValidateDataType, { - pub(crate) validator: Validator, - pub(crate) input: &'cursor mut &'data [u8], + validator: Validator, + input: &'cursor mut &'data [u8], } impl<'cursor, 'data, Validator> RowBinaryDeserializer<'cursor, 'data, Validator> @@ -76,10 +76,10 @@ where } macro_rules! impl_num { - ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr, $type_hints:expr) => { + ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr) => { #[inline(always)] fn $deser_method>(self, visitor: V) -> Result { - self.validator.validate($serde_type, $type_hints)?; + self.validator.validate($serde_type)?; ensure_size(&mut self.input, mem::size_of::<$ty>())?; let value = self.input.$reader_method(); visitor.$visitor_method(value) @@ -95,9 +95,7 @@ where #[inline(always)] fn deserialize_i8>(self, visitor: V) -> Result { - let mut maybe_enum_validator = self - .validator - .validate(&SerdeType::I8, &[DataTypeHint::Int8, DataTypeHint::Bool])?; + let mut maybe_enum_validator = self.validator.validate(SerdeType::I8)?; ensure_size(&mut self.input, size_of::())?; let value = self.input.get_i8(); maybe_enum_validator.validate_enum8(value); @@ -106,96 +104,36 @@ where #[inline(always)] fn deserialize_i16>(self, visitor: V) -> Result { - let mut maybe_enum_validator = self - .validator - .validate(&SerdeType::I16, &[DataTypeHint::Int16])?; + let mut maybe_enum_validator = self.validator.validate(SerdeType::I16)?; ensure_size(&mut self.input, size_of::())?; let value = self.input.get_i16_le(); + // TODO: is there a better way to validate that the deserialized value matches the schema? maybe_enum_validator.validate_enum16(value); visitor.visit_i16(value) } - impl_num!( - i32, - deserialize_i32, - visit_i32, - get_i32_le, - &SerdeType::I32, - &[DataTypeHint::Int32] - ); - impl_num!( - i64, - deserialize_i64, - visit_i64, - get_i64_le, - &SerdeType::I64, - &[DataTypeHint::Int64] - ); + impl_num!(i32, deserialize_i32, visit_i32, get_i32_le, SerdeType::I32); + impl_num!(i64, deserialize_i64, visit_i64, get_i64_le, SerdeType::I64); impl_num!( i128, deserialize_i128, visit_i128, get_i128_le, - &SerdeType::I128, - &[DataTypeHint::Int128] - ); - impl_num!( - u8, - deserialize_u8, - visit_u8, - get_u8, - &SerdeType::U8, - // TODO: shall we allow deserialization from boolean? - &[DataTypeHint::Bool, DataTypeHint::UInt8] - ); - impl_num!( - u16, - deserialize_u16, - visit_u16, - get_u16_le, - &SerdeType::U16, - &[DataTypeHint::UInt16] - ); - impl_num!( - u32, - deserialize_u32, - visit_u32, - get_u32_le, - &SerdeType::U32, - &[DataTypeHint::UInt32] - ); - impl_num!( - u64, - deserialize_u64, - visit_u64, - get_u64_le, - &SerdeType::U64, - &[DataTypeHint::UInt64] + SerdeType::I128 ); + impl_num!(u8, deserialize_u8, visit_u8, get_u8, SerdeType::U8); + impl_num!(u16, deserialize_u16, visit_u16, get_u16_le, SerdeType::U16); + impl_num!(u32, deserialize_u32, visit_u32, get_u32_le, SerdeType::U32); + impl_num!(u64, deserialize_u64, visit_u64, get_u64_le, SerdeType::U64); impl_num!( u128, deserialize_u128, visit_u128, get_u128_le, - &SerdeType::U128, - &[DataTypeHint::UInt128] - ); - impl_num!( - f32, - deserialize_f32, - visit_f32, - get_f32_le, - &SerdeType::F32, - &[DataTypeHint::Float32] - ); - impl_num!( - f64, - deserialize_f64, - visit_f64, - get_f64_le, - &SerdeType::F64, - &[DataTypeHint::Float64] + SerdeType::U128 ); + impl_num!(f32, deserialize_f32, visit_f32, get_f32_le, SerdeType::F32); + impl_num!(f64, deserialize_f64, visit_f64, get_f64_le, SerdeType::F64); #[inline(always)] fn deserialize_any>(self, _: V) -> Result { @@ -205,24 +143,13 @@ where #[inline(always)] fn deserialize_unit>(self, visitor: V) -> Result { // TODO: revise this. + // TODO - skip validation? visitor.visit_unit() } - #[inline(always)] - fn deserialize_char>(self, _: V) -> Result { - panic!("character types are unsupported: `char`"); - } - #[inline(always)] fn deserialize_bool>(self, visitor: V) -> Result { - self.validator.validate( - &SerdeType::Bool, - &[ - DataTypeHint::Bool, - // it is possible to deserialize from UInt8 0 or 1 as Boolean - DataTypeHint::UInt8, - ], - )?; + self.validator.validate(SerdeType::Bool)?; ensure_size(&mut self.input, 1)?; match self.input.get_u8() { 0 => visitor.visit_bool(false), @@ -235,9 +162,7 @@ where fn deserialize_str>(self, visitor: V) -> Result { // println!("deserialize_str call"); - // TODO - which types to allow? - self.validator - .validate(&SerdeType::String, &[DataTypeHint::String])?; + self.validator.validate(SerdeType::Str)?; let size = self.read_size()?; let slice = self.read_slice(size)?; let str = str::from_utf8(slice).map_err(Error::from)?; @@ -248,9 +173,7 @@ where fn deserialize_string>(self, visitor: V) -> Result { // println!("deserialize_string call"); - // TODO - which types to allow? - self.validator - .validate(&SerdeType::String, &[DataTypeHint::String])?; + self.validator.validate(SerdeType::String)?; let size = self.read_size()?; let vec = self.read_vec(size)?; let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; @@ -261,8 +184,8 @@ where fn deserialize_bytes>(self, visitor: V) -> Result { // println!("deserialize_bytes call"); - // TODO - which types to allow? let size = self.read_size()?; + self.validator.validate(SerdeType::Bytes(size))?; let slice = self.read_slice(size)?; visitor.visit_borrowed_bytes(slice) } @@ -271,8 +194,8 @@ where fn deserialize_byte_buf>(self, visitor: V) -> Result { // println!("deserialize_byte_buf call"); - // TODO - which types to allow? let size = self.read_size()?; + self.validator.validate(SerdeType::ByteBuf(size))?; visitor.visit_byte_buf(self.read_vec(size)?) } @@ -280,7 +203,7 @@ where fn deserialize_identifier>(self, visitor: V) -> Result { // println!("deserialize_identifier call"); - // TODO - which types to allow? + self.validator.validate(SerdeType::Identifier)?; self.deserialize_u8(visitor) } @@ -314,7 +237,7 @@ where type Error = Error; fn unit_variant(self) -> Result<()> { - Err(Error::Unsupported("unit variants".to_string())) + panic!("unit variants are unsupported"); } fn newtype_variant_seed(self, seed: T) -> Result @@ -363,15 +286,11 @@ where } } - let inner_data_type_validator = self - .validator - .validate(&SerdeType::Enum, &[DataTypeHint::Variant])?; - let mut new_self = RowBinaryDeserializer { - input: self.input, - validator: inner_data_type_validator, - }; visitor.visit_enum(RowBinaryEnumAccess { - deserializer: &mut new_self, + deserializer: &mut RowBinaryDeserializer { + input: self.input, + validator: self.validator.validate(SerdeType::Enum)?, + }, }) } @@ -379,24 +298,11 @@ where fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { // println!("deserialize_tuple call, len {}", len); - let mut inner_data_type_validator = self.validator.validate( - &SerdeType::Tuple, - &[ - DataTypeHint::Tuple, - DataTypeHint::Array, - DataTypeHint::FixedString, - // FIXME: uncomment when there is a way to implement ReverseSeqAccess - // DataTypeHint::IPv4, - DataTypeHint::IPv6, - ], - )?; - inner_data_type_validator.validate_fixed_string(len); - let mut new_self = RowBinaryDeserializer { - input: self.input, - validator: inner_data_type_validator, - }; let access = RowBinarySeqAccess { - deserializer: &mut new_self, + deserializer: &mut RowBinaryDeserializer { + input: self.input, + validator: self.validator.validate(SerdeType::Tuple(len))?, + }, len, }; visitor.visit_seq(access) @@ -407,13 +313,11 @@ where // println!("deserialize_option call"); ensure_size(&mut self.input, 1)?; - let inner_data_type_validator = self - .validator - .validate(&SerdeType::Option, &[DataTypeHint::Nullable])?; + let inner_validator = self.validator.validate(SerdeType::Option)?; match self.input.get_u8() { 0 => visitor.visit_some(&mut RowBinaryDeserializer { input: self.input, - validator: inner_data_type_validator, + validator: inner_validator, }), 1 => visitor.visit_none(), v => Err(Error::InvalidTagEncoding(v as usize)), @@ -425,13 +329,10 @@ where // println!("deserialize_seq call"); let len = self.read_size()?; - let inner_data_type_validator = self - .validator - .validate(&SerdeType::Seq, &[DataTypeHint::Array])?; visitor.visit_seq(RowBinarySeqAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, - validator: inner_data_type_validator, + validator: self.validator.validate(SerdeType::Seq(len))?, }, len, }) @@ -482,13 +383,10 @@ where } let len = self.read_size()?; - let inner_data_type_validator = self - .validator - .validate(&SerdeType::Map, &[DataTypeHint::Map])?; visitor.visit_map(RowBinaryMapAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, - validator: inner_data_type_validator, + validator: self.validator.validate(SerdeType::Map(len))?, }, entries_visited: 0, len, @@ -498,12 +396,14 @@ where #[inline(always)] fn deserialize_struct>( self, - _name: &str, + name: &'static str, fields: &'static [&'static str], visitor: V, ) -> Result { // println!("deserialize_struct: {} (fields: {:?})", name, fields,); + // TODO - skip validation? + self.validator.set_struct_name(name); visitor.visit_seq(RowBinarySeqAccess { deserializer: self, len: fields.len(), @@ -520,6 +420,11 @@ where visitor.visit_newtype_struct(self) } + #[inline(always)] + fn deserialize_char>(self, _: V) -> Result { + panic!("character types are unsupported: `char`"); + } + #[inline(always)] fn deserialize_unit_struct>( self, diff --git a/src/rowbinary/ser.rs b/src/rowbinary/ser.rs index 68fec881..47682f0a 100644 --- a/src/rowbinary/ser.rs +++ b/src/rowbinary/ser.rs @@ -148,9 +148,7 @@ impl Serializer for &'_ mut RowBinarySerializer { // Max number of types in the Variant data type is 255 // See also: https://github.com/ClickHouse/ClickHouse/issues/54864 if variant_index > 255 { - return Err(Error::VariantDiscriminatorIsOutOfBound( - variant_index as usize, - )); + panic!("max number of types in the Variant data type is 255, got {variant_index}") } self.buffer.put_u8(variant_index as u8); value.serialize(self) diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index cff6a285..619ab0bc 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -1,28 +1,122 @@ use crate::error::Result; -use clickhouse_rowbinary::data_types::{Column, DataTypeHint, DataTypeNode, DecimalSize, EnumType}; +use clickhouse_rowbinary::data_types::{Column, DataTypeNode, DecimalSize, EnumType}; use std::collections::HashMap; use std::fmt::Display; pub(crate) trait ValidateDataType: Sized { fn validate( - &mut self, - serde_type: &'static SerdeType, - compatible_db_types: &'static [DataTypeHint], - ) -> Result>>; + &'_ mut self, + serde_type: SerdeType, + ) -> Result>>; fn validate_enum8(&mut self, value: i8); fn validate_enum16(&mut self, value: i16); - fn validate_fixed_string(&mut self, len: usize); + fn set_struct_name(&mut self, name: &'static str); } #[derive(Default)] pub(crate) struct DataTypeValidator<'cursor> { + struct_name: Option<&'static str>, + current_column_idx: usize, columns: &'cursor [Column], } impl<'cursor> DataTypeValidator<'cursor> { #[inline(always)] pub(crate) fn new(columns: &'cursor [Column]) -> Self { - Self { columns } + Self { + struct_name: None, + current_column_idx: 0, + columns, + } + } + + fn get_current_column(&self) -> Option<&Column> { + if self.current_column_idx > 0 && self.current_column_idx <= self.columns.len() { + // index is immediately moved to the next column after the root validator is called + Some(&self.columns[self.current_column_idx - 1]) + } else { + None + } + } + + fn get_current_column_name_and_type(&self) -> (String, &DataTypeNode) { + self.get_current_column() + .map(|c| { + ( + format!("{}.{}", self.get_struct_name(), c.name), + &c.data_type, + ) + }) + // both should be defined at this point + .unwrap_or(("Struct".to_string(), &DataTypeNode::Bool)) + } + + fn get_struct_name(&self) -> String { + // should be available at the time of the panic call + self.struct_name.unwrap_or("Struct").to_string() + } + + #[inline(always)] + fn panic_on_schema_mismatch<'de>( + &'de self, + data_type: &DataTypeNode, + serde_type: &SerdeType, + is_inner: bool, + ) -> Result>> { + if is_inner { + let (full_name, full_data_type) = self.get_current_column_name_and_type(); + panic!( + "While processing column {} defined as {}: attempting to deserialize \ + nested ClickHouse type {} as {} which is not compatible", + full_name, full_data_type, data_type, serde_type + ) + } else { + panic!( + "While processing column {}: attempting to deserialize \ + ClickHouse type {} as {} which is not compatible", + self.get_current_column_name_and_type().0, + data_type, + serde_type + ) + } + } +} + +impl ValidateDataType for DataTypeValidator<'_> { + #[inline] + fn validate( + &'_ mut self, + serde_type: SerdeType, + ) -> Result>> { + if self.current_column_idx < self.columns.len() { + let current_column = &self.columns[self.current_column_idx]; + self.current_column_idx += 1; + validate_impl(self, ¤t_column.data_type, &serde_type, false) + } else { + panic!( + "Struct {} has more fields than columns in the database schema", + self.get_struct_name() + ) + } + } + + #[inline(always)] + fn set_struct_name(&mut self, name: &'static str) { + if self.struct_name.is_none() { + self.struct_name = Some(name); + } + } + + #[cold] + #[inline(never)] + fn validate_enum8(&mut self, _value: i8) { + unreachable!() + } + + #[cold] + #[inline(never)] + fn validate_enum16(&mut self, _value: i16) { + unreachable!() } } @@ -39,8 +133,13 @@ pub(crate) enum ArrayValidatorState { Validated, } +pub(crate) struct InnerDataTypeValidator<'de, 'cursor> { + root: &'de DataTypeValidator<'cursor>, + kind: InnerDataTypeValidatorKind<'cursor>, +} + #[derive(Debug)] -pub(crate) enum InnerDataTypeValidator<'cursor> { +pub(crate) enum InnerDataTypeValidatorKind<'cursor> { Array(&'cursor DataTypeNode, ArrayValidatorState), FixedString(usize), Map( @@ -50,320 +149,326 @@ pub(crate) enum InnerDataTypeValidator<'cursor> { ), Tuple(&'cursor [DataTypeNode]), Enum(&'cursor HashMap), - Variant(&'cursor [DataTypeNode]), + // Variant(&'cursor [DataTypeNode]), Nullable(&'cursor DataTypeNode), } -impl ValidateDataType for () { - #[inline(always)] - fn validate( - &mut self, - _serde_type: &'static SerdeType, - _compatible_db_types: &'static [DataTypeHint], - // _len: usize, - ) -> Result>> { - Ok(None) - } - - #[inline(always)] - fn validate_enum8(&mut self, _enum_value: i8) {} - - #[inline(always)] - fn validate_enum16(&mut self, _enum_value: i16) {} - - #[inline(always)] - fn validate_fixed_string(&mut self, _len: usize) {} -} - -impl<'cursor> ValidateDataType for Option> { +impl<'de, 'cursor> ValidateDataType for Option> { #[inline] fn validate( &mut self, - serde_type: &'static SerdeType, - compatible_db_types: &'static [DataTypeHint], - // seq_len: usize, - ) -> Result>> { + serde_type: SerdeType, + ) -> Result>> { // println!("Validating inner data type: {:?} against serde type: {} with compatible db types: {:?}", // self, serde_type, compatible_db_types); match self { None => Ok(None), - Some(InnerDataTypeValidator::Map(key_type, value_type, state)) => match state { - MapValidatorState::Key => { - let result = validate_impl(key_type, serde_type, compatible_db_types); - *state = MapValidatorState::Value; - result + Some(inner) => match &mut inner.kind { + InnerDataTypeValidatorKind::Map(key_type, value_type, state) => match state { + MapValidatorState::Key => { + let result = validate_impl(inner.root, key_type, &serde_type, true); + *state = MapValidatorState::Value; + result + } + MapValidatorState::Value => { + let result = validate_impl(inner.root, value_type, &serde_type, true); + *state = MapValidatorState::Validated; + result + } + MapValidatorState::Validated => Ok(None), + }, + InnerDataTypeValidatorKind::Array(inner_type, state) => match state { + ArrayValidatorState::Pending => { + let result = validate_impl(inner.root, inner_type, &serde_type, true); + *state = ArrayValidatorState::Validated; + result + } + // TODO: perhaps we can allow to validate the inner type more than once + // avoiding e.g. issues with Array(Nullable(T)) when the first element in NULL + ArrayValidatorState::Validated => Ok(None), + }, + InnerDataTypeValidatorKind::Nullable(inner_type) => { + validate_impl(inner.root, inner_type, &serde_type, true) } - MapValidatorState::Value => { - let result = validate_impl(value_type, serde_type, compatible_db_types); - *state = MapValidatorState::Validated; - result + InnerDataTypeValidatorKind::Tuple(elements_types) => { + match elements_types.split_first() { + Some((first, rest)) => { + *elements_types = rest; + validate_impl(inner.root, first, &serde_type, true) + } + None => { + let (full_name, full_data_type) = + inner.root.get_current_column_name_and_type(); + panic!( + "While processing column {} defined as {}: \ + attempting to deserialize {} while no more elements are allowed", + full_name, full_data_type, serde_type + ) + } + } } - MapValidatorState::Validated => Ok(None), - }, - Some(InnerDataTypeValidator::Array(inner_type, state)) => match state { - ArrayValidatorState::Pending => { - let result = validate_impl(inner_type, serde_type, compatible_db_types); - *state = ArrayValidatorState::Validated; - result + InnerDataTypeValidatorKind::FixedString(_len) => { + Ok(None) // actually unreachable } - // TODO: perhaps we can allow to validate the inner type more than once - // avoiding e.g. issues with Array(Nullable(T)) when the first element in NULL - ArrayValidatorState::Validated => Ok(None), - }, - Some(InnerDataTypeValidator::Nullable(inner_type)) => { - validate_impl(inner_type, serde_type, compatible_db_types) - } - Some(InnerDataTypeValidator::Tuple(elements_types)) => { - match elements_types.split_first() { - Some((first, rest)) => { - *elements_types = rest; - validate_impl(first, serde_type, compatible_db_types) - } - None => panic!( - "Struct tries to deserialize {} as a tuple element, but there are no more allowed elements in the database schema", - serde_type, - ) + // InnerDataTypeValidatorKind::Variant(_possible_types) => { + // Ok(None) // FIXME: requires comparing DataTypeNode vs TypeHint or SerdeType + // } + InnerDataTypeValidatorKind::Enum(_values_map) => { + todo!() // TODO - check value correctness in the hashmap } - } - Some(InnerDataTypeValidator::FixedString(_len)) => { - Ok(None) // actually unreachable - } - Some(InnerDataTypeValidator::Variant(_possible_types)) => { - Ok(None) // FIXME: requires comparing DataTypeNode vs TypeHint... - } - Some(InnerDataTypeValidator::Enum(_values_map)) => { - todo!() // TODO - check value correctness in the hashmap - } + }, } } #[inline(always)] - fn validate_fixed_string(&mut self, len: usize) { - if let Some(InnerDataTypeValidator::FixedString(expected_len)) = self { - if *expected_len != len { - panic!( - "FixedString byte length mismatch: expected {}, got {}", - expected_len, len - ); + fn validate_enum8(&mut self, value: i8) { + if let Some(inner) = self { + if let InnerDataTypeValidatorKind::Enum(values_map) = &inner.kind { + if !values_map.contains_key(&(value as i16)) { + panic!("Enum8 value `{value}` is not present in the database schema"); + } } } } #[inline(always)] - fn validate_enum8(&mut self, value: i8) { - if let Some(InnerDataTypeValidator::Enum(values_map)) = self { - if !values_map.contains_key(&(value as i16)) { - panic!("Enum8 value `{value}` is not present in the database schema"); + fn validate_enum16(&mut self, value: i16) { + if let Some(inner) = self { + if let InnerDataTypeValidatorKind::Enum(values_map) = &inner.kind { + if !values_map.contains_key(&value) { + panic!("Enum16 value `{value}` is not present in the database schema"); + } } } } - #[inline(always)] - fn validate_enum16(&mut self, enum_value: i16) { - if let Some(InnerDataTypeValidator::Enum(value)) = self { - if !value.contains_key(&enum_value) { - panic!("Enum16 value `{enum_value}` is not present in the database schema"); - } - } + #[cold] + #[inline(never)] + fn set_struct_name(&mut self, _name: &'static str) { + panic!("Struct name should not be set in the inner deserializer"); } } -impl Drop for InnerDataTypeValidator<'_> { +impl Drop for InnerDataTypeValidator<'_, '_> { fn drop(&mut self) { - if let InnerDataTypeValidator::Tuple(elements_types) = self { + if let InnerDataTypeValidatorKind::Tuple(elements_types) = self.kind { if !elements_types.is_empty() { + let (column_name, column_type) = self.root.get_current_column_name_and_type(); panic!( - "Tuple was not fully deserialized, remaining elements: {:?}", + "While processing column {} defined as {}: tuple was not fully deserialized; \ + remaining elements: {}; likely, the field definition is incomplete", + column_name, + column_type, elements_types - ); + .iter() + .map(|c| c.to_string()) + .collect::>() + .join(", ") + ) } } } } #[inline] -fn validate_impl<'cursor>( +fn validate_impl<'de, 'cursor>( + root: &'de DataTypeValidator<'cursor>, data_type: &'cursor DataTypeNode, - serde_type: &'static SerdeType, - compatible_db_types: &'static [DataTypeHint], -) -> Result>> { + serde_type: &SerdeType, + is_inner: bool, +) -> Result>> { // println!( // "Validating data type: {:?} against serde type: {} with compatible db types: {:?}", // data_type, serde_type, compatible_db_types // ); - // FIXME: multiple branches with similar patterns - match data_type { - DataTypeNode::Bool if compatible_db_types.contains(&DataTypeHint::Bool) => Ok(None), - - DataTypeNode::Int8 if compatible_db_types.contains(&DataTypeHint::Int8) => Ok(None), - DataTypeNode::Int16 if compatible_db_types.contains(&DataTypeHint::Int16) => Ok(None), - DataTypeNode::Int32 - | DataTypeNode::Date32 - | DataTypeNode::Decimal(_, _, DecimalSize::Int32) - if compatible_db_types.contains(&DataTypeHint::Int32) => - { - Ok(None) - } - DataTypeNode::Int64 - | DataTypeNode::DateTime64(_, _) - | DataTypeNode::Decimal(_, _, DecimalSize::Int64) - if compatible_db_types.contains(&DataTypeHint::Int64) => - { - Ok(None) - } - DataTypeNode::Int128 | DataTypeNode::Decimal(_, _, DecimalSize::Int128) - if compatible_db_types.contains(&DataTypeHint::Int128) => - { - Ok(None) - } - - DataTypeNode::UInt8 if compatible_db_types.contains(&DataTypeHint::UInt8) => Ok(None), - DataTypeNode::UInt16 | DataTypeNode::Date - if compatible_db_types.contains(&DataTypeHint::UInt16) => - { - Ok(None) - } - DataTypeNode::UInt32 | DataTypeNode::DateTime(_) | DataTypeNode::IPv4 - if compatible_db_types.contains(&DataTypeHint::UInt32) => - { - Ok(None) - } - DataTypeNode::UInt64 if compatible_db_types.contains(&DataTypeHint::UInt64) => Ok(None), - DataTypeNode::UInt128 if compatible_db_types.contains(&DataTypeHint::UInt128) => Ok(None), - - DataTypeNode::Float32 if compatible_db_types.contains(&DataTypeHint::Float32) => Ok(None), - DataTypeNode::Float64 if compatible_db_types.contains(&DataTypeHint::Float64) => Ok(None), - - // Currently, we allow new JSON type only with `output_format_binary_write_json_as_string` - DataTypeNode::String | DataTypeNode::JSON - if compatible_db_types.contains(&DataTypeHint::String) => - { - Ok(None) - } - - DataTypeNode::FixedString(n) - if compatible_db_types.contains(&DataTypeHint::FixedString) => - { - Ok(Some(InnerDataTypeValidator::FixedString(*n))) - } - - // FIXME: IPv4 from ClickHouse ends up reversed. - // Ideally, requires a ReversedSeqAccess implementation. Perhaps memoize IPv4 col index? - // IPv4 = [u8; 4] - // DataTypeNode::IPv4 if compatible_db_types.contains(&DataTypeHint::IPv4) => Ok(Some( - // InnerDataTypeValidator::Array(&DataTypeNode::UInt8, ArrayValidatorState::Pending(4)), - // )), - - // IPv6 = [u8; 16] - DataTypeNode::IPv6 if compatible_db_types.contains(&DataTypeHint::Array) => Ok(Some( - InnerDataTypeValidator::Array(&DataTypeNode::UInt8, ArrayValidatorState::Pending), - )), - - // UUID = [u64; 2] - DataTypeNode::UUID => Ok(Some(InnerDataTypeValidator::Tuple(&[ - DataTypeNode::UInt64, - DataTypeNode::UInt64, - ]))), - - DataTypeNode::Array(inner_type) if compatible_db_types.contains(&DataTypeHint::Array) => { - Ok(Some(InnerDataTypeValidator::Array( - inner_type, - ArrayValidatorState::Pending, - ))) + // TODO: eliminate multiple branches with similar patterns? + match serde_type { + SerdeType::Bool + if data_type == &DataTypeNode::Bool || data_type == &DataTypeNode::UInt8 => + { + Ok(None) } - - DataTypeNode::Map(key_type, value_type) - if compatible_db_types.contains(&DataTypeHint::Map) => - { - Ok(Some(InnerDataTypeValidator::Map( - key_type, - value_type, - MapValidatorState::Key, - ))) + SerdeType::I8 => match data_type { + DataTypeNode::Int8 => Ok(None), + DataTypeNode::Enum(EnumType::Enum8, values_map) => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Enum(values_map), + })), + _ => root.panic_on_schema_mismatch(data_type, &serde_type, is_inner), + }, + SerdeType::I16 => match data_type { + DataTypeNode::Int16 => Ok(None), + DataTypeNode::Enum(EnumType::Enum16, values_map) => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Enum(values_map), + })), + _ => root.panic_on_schema_mismatch(data_type, &serde_type, is_inner), + }, + SerdeType::I32 + if data_type == &DataTypeNode::Int32 + || data_type == &DataTypeNode::Date32 + || matches!(data_type, DataTypeNode::Decimal(_, _, DecimalSize::Int32)) => + { + Ok(None) + } + SerdeType::I64 + if data_type == &DataTypeNode::Int64 + || matches!(data_type, DataTypeNode::DateTime64(_, _)) + || matches!(data_type, DataTypeNode::Decimal(_, _, DecimalSize::Int64)) => + { + Ok(None) + } + SerdeType::I128 + if data_type == &DataTypeNode::Int128 + || matches!(data_type, DataTypeNode::Decimal(_, _, DecimalSize::Int128)) => + { + Ok(None) + } + // TODO: what should be allowed type for SerdeType::Identifier? + SerdeType::Identifier | SerdeType::U8 if data_type == &DataTypeNode::UInt8 => Ok(None), + SerdeType::U16 + if data_type == &DataTypeNode::UInt16 || data_type == &DataTypeNode::Date => + { + Ok(None) + } + SerdeType::U32 + if data_type == &DataTypeNode::UInt32 + || matches!(data_type, DataTypeNode::DateTime(_)) + || data_type == &DataTypeNode::IPv4 => + { + Ok(None) + } + SerdeType::U64 if data_type == &DataTypeNode::UInt64 => Ok(None), + SerdeType::U128 if data_type == &DataTypeNode::UInt128 => Ok(None), + SerdeType::F32 if data_type == &DataTypeNode::Float32 => Ok(None), + SerdeType::F64 if data_type == &DataTypeNode::Float64 => Ok(None), + SerdeType::Str | SerdeType::String + if data_type == &DataTypeNode::String || data_type == &DataTypeNode::JSON => + { + Ok(None) + } + // TODO: find use cases where this is called instead of `deserialize_tuple` + // SerdeType::Bytes | SerdeType::ByteBuf => { + // if let DataTypeNode::FixedString(n) = data_type { + // Ok(Some(InnerDataTypeValidator::FixedString(*n))) + // } else { + // panic!( + // "Expected FixedString(N) for {} call, but got {}", + // serde_type, data_type + // ) + // } + // } + SerdeType::Option => { + if let DataTypeNode::Nullable(inner_type) = data_type { + Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Nullable(inner_type), + })) + } else { + root.panic_on_schema_mismatch(data_type, &serde_type, is_inner) } - - DataTypeNode::Tuple(elements) if compatible_db_types.contains(&DataTypeHint::Tuple) => { - Ok(Some(InnerDataTypeValidator::Tuple(elements))) } - - DataTypeNode::Nullable(inner_type) - if compatible_db_types.contains(&DataTypeHint::Nullable) => - { - Ok(Some(InnerDataTypeValidator::Nullable(inner_type))) + SerdeType::Seq(_) => { + if let DataTypeNode::Array(inner_type) = data_type { + Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array( + inner_type, + ArrayValidatorState::Pending, + ), + })) + } else { + root.panic_on_schema_mismatch(data_type, &serde_type, is_inner) } - - // LowCardinality is completely transparent on the client side - DataTypeNode::LowCardinality(inner_type) => { - validate_impl(inner_type, serde_type, compatible_db_types) } - - DataTypeNode::Enum(EnumType::Enum8, values_map) - if compatible_db_types.contains(&DataTypeHint::Int8) => - { - Ok(Some(InnerDataTypeValidator::Enum(values_map))) + SerdeType::Tuple(len) => match data_type { + DataTypeNode::FixedString(n) => { + if n == len { + Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::FixedString(*n), + })) + } else { + let (full_name, full_data_type) = root.get_current_column_name_and_type(); + panic!( + "While processing column {} defined as {}: attempting to deserialize \ + nested ClickHouse type {} as {}", + full_name, full_data_type, data_type, serde_type, + ) + } } - DataTypeNode::Enum(EnumType::Enum16, values_map) - if compatible_db_types.contains(&DataTypeHint::Int16) => - { - Ok(Some(InnerDataTypeValidator::Enum(values_map))) + DataTypeNode::Tuple(elements) => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Tuple(elements), + })), + DataTypeNode::Array(inner_type) => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(inner_type, ArrayValidatorState::Pending), + })), + DataTypeNode::IPv6 => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array( + &DataTypeNode::UInt8, + ArrayValidatorState::Pending, + ), + })), + DataTypeNode::UUID => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Tuple(&[ + DataTypeNode::UInt64, + DataTypeNode::UInt64, + ]), + })), + _ => root.panic_on_schema_mismatch(data_type, &serde_type, is_inner), + }, + SerdeType::Map(_) => { + if let DataTypeNode::Map(key_type, value_type) = data_type { + Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Map( + key_type, + value_type, + MapValidatorState::Key, + ), + })) + } else { + panic!( + "Expected Map for {} call, but got {}", + serde_type, data_type + ) } - - DataTypeNode::Variant(possible_types) => { - Ok(Some(InnerDataTypeValidator::Variant(possible_types))) + } + SerdeType::Enum => { + todo!("variant data type validation") } - DataTypeNode::AggregateFunction(_, _) => panic!("AggregateFunction is not supported yet"), - DataTypeNode::Int256 => panic!("Int256 is not supported yet"), - DataTypeNode::UInt256 => panic!("UInt256 is not supported yet"), - DataTypeNode::BFloat16 => panic!("BFloat16 is not supported yet"), - DataTypeNode::Dynamic => panic!("Dynamic is not supported yet"), - - _ => panic!( - "Database type is {}, but struct field is deserialized as {}, which is compatible only with {:?}", - data_type, serde_type, compatible_db_types - ), + _ => root.panic_on_schema_mismatch(data_type, &serde_type, is_inner), } } -impl<'cursor> ValidateDataType for DataTypeValidator<'cursor> { - #[inline] +impl ValidateDataType for () { + #[inline(always)] fn validate( &mut self, - serde_type: &'static SerdeType, - compatible_db_types: &'static [DataTypeHint], - ) -> Result>> { - match self.columns.split_first() { - Some((first, rest)) => { - self.columns = rest; - validate_impl(&first.data_type, serde_type, compatible_db_types) - } - None => panic!("Struct has more fields than columns in the database schema"), - } + _serde_type: SerdeType, + ) -> Result>> { + Ok(None) } - #[cold] - #[inline(never)] - fn validate_enum8(&mut self, _value: i8) { - unreachable!() - } + #[inline(always)] + fn validate_enum8(&mut self, _enum_value: i8) {} - #[cold] - #[inline(never)] - fn validate_enum16(&mut self, _value: i16) { - unreachable!() - } + #[inline(always)] + fn validate_enum16(&mut self, _enum_value: i16) {} - #[cold] - #[inline(never)] - fn validate_fixed_string(&mut self, _len: usize) { - unreachable!() - } + #[inline(always)] + fn set_struct_name(&mut self, _name: &'static str) {} } /// Which Serde data type (De)serializer used for the given type. -/// Displays into Rust types for convenience in errors reporting. +/// Displays into certain Rust types for convenience in errors reporting. +/// See also: available methods in [`serde::Serializer`] and [`serde::Deserializer`]. #[derive(Clone, Debug, PartialEq)] -#[allow(dead_code)] pub(crate) enum SerdeType { Bool, I8, @@ -378,23 +483,23 @@ pub(crate) enum SerdeType { U128, F32, F64, - Char, Str, String, - Bytes, - ByteBuf, Option, - Unit, - UnitStruct, - NewtypeStruct, - Seq, - Tuple, - TupleStruct, - Map, - Struct, Enum, Identifier, - IgnoredAny, + Bytes(usize), + ByteBuf(usize), + Tuple(usize), + Seq(usize), + Map(usize), + // Char, + // Unit, + // Struct, + // NewtypeStruct, + // TupleStruct, + // UnitStruct, + // IgnoredAny, } impl Display for SerdeType { @@ -413,23 +518,23 @@ impl Display for SerdeType { SerdeType::U128 => "u128", SerdeType::F32 => "f32", SerdeType::F64 => "f64", - SerdeType::Char => "char", SerdeType::Str => "&str", SerdeType::String => "String", - SerdeType::Bytes => "&[u8]", - SerdeType::ByteBuf => "Vec", + SerdeType::Bytes(len) => &format!("&[u8; {len}]"), + SerdeType::ByteBuf(_len) => "Vec", SerdeType::Option => "Option", - SerdeType::Unit => "()", - SerdeType::UnitStruct => "unit struct", - SerdeType::NewtypeStruct => "newtype struct", - SerdeType::Seq => "Vec", - SerdeType::Tuple => "tuple", - SerdeType::TupleStruct => "tuple struct", - SerdeType::Map => "map", - SerdeType::Struct => "struct", SerdeType::Enum => "enum", + SerdeType::Seq(_len) => "Vec", + SerdeType::Tuple(len) => &format!("a tuple or sequence with length {len}"), + SerdeType::Map(_len) => "map", SerdeType::Identifier => "identifier", - SerdeType::IgnoredAny => "ignored any", + // SerdeType::Char => "char", + // SerdeType::Unit => "()", + // SerdeType::Struct => "struct", + // SerdeType::NewtypeStruct => "newtype struct", + // SerdeType::TupleStruct => "tuple struct", + // SerdeType::UnitStruct => "unit struct", + // SerdeType::IgnoredAny => "ignored any", }; write!(f, "{}", type_name) } diff --git a/tests/it/main.rs b/tests/it/main.rs index 801219d1..90571b9c 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -23,6 +23,24 @@ use clickhouse::{sql::Identifier, Client, Row}; use serde::{Deserialize, Serialize}; +macro_rules! assert_panic_on_fetch { + ($msg_parts:expr, $query:expr) => { + use futures::FutureExt; + let client = get_client().with_validation_mode(ValidationMode::Each); + let async_panic = + std::panic::AssertUnwindSafe(async { client.query($query).fetch_one::().await }); + let result = async_panic.catch_unwind().await; + assert!(result.is_err()); + let panic_msg = *result.unwrap_err().downcast::().unwrap(); + for &msg in $msg_parts { + assert!( + panic_msg.contains(msg), + "panic message:\n{panic_msg}\ndid not contain the expected part:\n{msg}" + ); + } + }; +} + macro_rules! prepare_database { () => { crate::_priv::prepare_database({ diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index bffb4801..04119c49 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -393,22 +393,18 @@ async fn test_enum() { } #[tokio::test] -#[should_panic] async fn test_nullable() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { n: Option, } - - let client = get_client().with_validation_mode(ValidationMode::Each); - let _ = client - .query("SELECT true AS b, 144 :: Int32 AS n2") - .fetch_one::() - .await; + assert_panic_on_fetch!( + &["Data.b", "Bool", "Option"], + "SELECT true AS b, 144 :: Int32 AS n2" + ); } #[tokio::test] -#[should_panic] #[cfg(feature = "time")] async fn test_serde_with() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] @@ -416,24 +412,13 @@ async fn test_serde_with() { #[serde(with = "clickhouse::serde::time::datetime64::millis")] n1: time::OffsetDateTime, // underlying is still Int64; should not compose it from two (U)Int32 } - - let client = get_client().with_validation_mode(ValidationMode::Each); - let _ = client - .query("SELECT 42 :: UInt32 AS n1, 144 :: Int32 AS n2") - .fetch_one::() - .await; - - // FIXME: lack of derive PartialEq for Error prevents proper assertion - // assert_eq!(result, Error::DataTypeMismatch { - // column_name: "n1".to_string(), - // expected_type: "Int64".to_string(), - // actual_type: "Int32".to_string(), - // columns: vec![...], - // }); + assert_panic_on_fetch!( + &["Data.n1", "UInt32", "i64"], + "SELECT 42 :: UInt32 AS n1, 144 :: Int32 AS n2" + ); } #[tokio::test] -#[should_panic] async fn test_too_many_struct_fields() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { @@ -441,12 +426,10 @@ async fn test_too_many_struct_fields() { b: u32, c: u32, } - - let client = get_client().with_validation_mode(ValidationMode::Each); - let _ = client - .query("SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS b") - .fetch_one::() - .await; + assert_panic_on_fetch!( + &["Struct Data has more fields than columns in the database schema"], + "SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS b" + ); } #[tokio::test] @@ -617,24 +600,21 @@ async fn test_fixed_str() { .await; let data = result.unwrap(); - assert_eq!(String::from_utf8_lossy(&data.a), "1234",); - assert_eq!(String::from_utf8_lossy(&data.b), "777",); + assert_eq!(String::from_utf8_lossy(&data.a), "1234"); + assert_eq!(String::from_utf8_lossy(&data.b), "777"); } #[tokio::test] -#[should_panic] async fn test_fixed_str_too_long() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: [u8; 4], b: [u8; 3], } - - let client = get_client().with_validation_mode(ValidationMode::Each); - let _ = client - .query("SELECT '12345' :: FixedString(5) AS a, '777' :: FixedString(3) AS b") - .fetch_one::() - .await; + assert_panic_on_fetch!( + &["Data.a", "FixedString(5)", "with length 4"], + "SELECT '12345' :: FixedString(5) AS a, '777' :: FixedString(3) AS b" + ); } #[tokio::test] @@ -667,78 +647,87 @@ async fn test_tuple() { } #[tokio::test] -#[should_panic] async fn test_tuple_invalid_definition() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: (u32, String), b: (i128, HashMap), } - - let client = get_client().with_validation_mode(ValidationMode::Each); - // Map key is UInt64 instead of UInt16 requested in the struct - let _ = client - .query( - " - SELECT - (42, 'foo') :: Tuple(UInt32, String) AS a, - (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt64, String)) AS b - ", - ) - .fetch_one::() - .await; + assert_panic_on_fetch!( + &[ + "Data.b", + "Tuple(Int128, Map(UInt64, String))", + "UInt64 as u16" + ], + " + SELECT + (42, 'foo') :: Tuple(UInt32, String) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt64, String)) AS b + " + ); } #[tokio::test] -#[should_panic] async fn test_tuple_too_many_elements_in_the_schema() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: (u32, String), b: (i128, HashMap), } - - let client = get_client().with_validation_mode(ValidationMode::Each); - // too many elements in the db type definition - let _ = client - .query( - " - SELECT - (42, 'foo', true) :: Tuple(UInt32, String, Bool) AS a, - (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt16, String)) AS b - ", - ) - .fetch_one::() - .await; + assert_panic_on_fetch!( + &[ + "Data.a", + "Tuple(UInt32, String, Bool)", + "remaining elements: Bool" + ], + " + SELECT + (42, 'foo', true) :: Tuple(UInt32, String, Bool) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt16, String)) AS b + " + ); } #[tokio::test] -#[should_panic] async fn test_tuple_too_many_elements_in_the_struct() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: (u32, String, bool), b: (i128, HashMap), } - - let client = get_client().with_validation_mode(ValidationMode::Each); - // too many elements in the struct enum - let _ = client - .query( - " - SELECT - (42, 'foo') :: Tuple(UInt32, String) AS a, - (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt16, String)) AS b - ", - ) - .fetch_one::() - .await; + assert_panic_on_fetch!( + &["Data.a", "Tuple(UInt32, String)", "deserialize bool"], + " + SELECT + (42, 'foo') :: Tuple(UInt32, String) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt16, String)) AS b + " + ); } #[tokio::test] +async fn test_deeply_nested_validation_incorrect_fixed_string() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + id: u32, + col: Vec>>>, + } + // Struct has FixedString(2) instead of FixedString(1) + assert_panic_on_fetch!( + &["Data.col", "FixedString(1)", "with length 2"], + " + SELECT + 42 :: UInt32 AS id, + array(array(map(42, array('1', '2')))) :: Array(Array(Map(UInt32, Array(FixedString(1))))) AS col + " + ); +} + +#[tokio::test] +#[ignore] async fn test_variant() { #[derive(Debug, Deserialize, PartialEq)] enum MyVariant { From 65cb92fec3381d18b4361c93dd12c6c4e7edcf40 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Fri, 23 May 2025 21:58:27 +0200 Subject: [PATCH 10/54] Fix clippy and build --- src/rowbinary/de.rs | 4 ++-- src/rowbinary/validation.rs | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 90d5eae0..28afb2bc 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -9,7 +9,7 @@ use serde::{ de::{DeserializeSeed, Deserializer, EnumAccess, SeqAccess, VariantAccess, Visitor}, Deserialize, }; -use std::{convert::TryFrom, mem, str}; +use std::{convert::TryFrom, str}; /// Deserializes a value from `input` with a row encoded in `RowBinary`. /// @@ -80,7 +80,7 @@ macro_rules! impl_num { #[inline(always)] fn $deser_method>(self, visitor: V) -> Result { self.validator.validate($serde_type)?; - ensure_size(&mut self.input, mem::size_of::<$ty>())?; + ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; let value = self.input.$reader_method(); visitor.$visitor_method(value) } diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 619ab0bc..1874cbfa 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -294,7 +294,7 @@ fn validate_impl<'de, 'cursor>( root, kind: InnerDataTypeValidatorKind::Enum(values_map), })), - _ => root.panic_on_schema_mismatch(data_type, &serde_type, is_inner), + _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), }, SerdeType::I16 => match data_type { DataTypeNode::Int16 => Ok(None), @@ -302,7 +302,7 @@ fn validate_impl<'de, 'cursor>( root, kind: InnerDataTypeValidatorKind::Enum(values_map), })), - _ => root.panic_on_schema_mismatch(data_type, &serde_type, is_inner), + _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), }, SerdeType::I32 if data_type == &DataTypeNode::Int32 @@ -365,7 +365,7 @@ fn validate_impl<'de, 'cursor>( kind: InnerDataTypeValidatorKind::Nullable(inner_type), })) } else { - root.panic_on_schema_mismatch(data_type, &serde_type, is_inner) + root.panic_on_schema_mismatch(data_type, serde_type, is_inner) } } SerdeType::Seq(_) => { @@ -378,7 +378,7 @@ fn validate_impl<'de, 'cursor>( ), })) } else { - root.panic_on_schema_mismatch(data_type, &serde_type, is_inner) + root.panic_on_schema_mismatch(data_type, serde_type, is_inner) } } SerdeType::Tuple(len) => match data_type { @@ -419,7 +419,7 @@ fn validate_impl<'de, 'cursor>( DataTypeNode::UInt64, ]), })), - _ => root.panic_on_schema_mismatch(data_type, &serde_type, is_inner), + _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), }, SerdeType::Map(_) => { if let DataTypeNode::Map(key_type, value_type) = data_type { @@ -442,7 +442,7 @@ fn validate_impl<'de, 'cursor>( todo!("variant data type validation") } - _ => root.panic_on_schema_mismatch(data_type, &serde_type, is_inner), + _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), } } From fbfbd992e2a4d71cc0bbb81161d36e77ec107e7c Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Fri, 23 May 2025 22:02:11 +0200 Subject: [PATCH 11/54] Fix core::mem::size_of import --- src/rowbinary/de.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 28afb2bc..78c71bcc 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -4,6 +4,7 @@ use crate::rowbinary::validation::SerdeType; use crate::rowbinary::validation::{DataTypeValidator, ValidateDataType}; use bytes::Buf; use clickhouse_rowbinary::data_types::Column; +use core::mem::size_of; use serde::de::MapAccess; use serde::{ de::{DeserializeSeed, Deserializer, EnumAccess, SeqAccess, VariantAccess, Visitor}, From 1d5c01a5acbc057b98778e16a58ccebb733150a5 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 26 May 2025 22:46:06 +0200 Subject: [PATCH 12/54] Slightly faster implementation --- Cargo.toml | 3 +- src/cursors/row.rs | 100 +++++++++++++------------ src/error.rs | 4 +- src/rowbinary/de.rs | 24 ++++-- src/rowbinary/tests.rs | 30 ++++---- src/rowbinary/validation.rs | 2 +- tests/it/rbwnat.rs | 4 +- {rowbinary => types}/Cargo.toml | 4 +- {rowbinary => types}/src/data_types.rs | 0 {rowbinary => types}/src/decoders.rs | 0 {rowbinary => types}/src/error.rs | 0 {rowbinary => types}/src/leb128.rs | 0 {rowbinary => types}/src/lib.rs | 0 13 files changed, 91 insertions(+), 80 deletions(-) rename {rowbinary => types}/Cargo.toml (71%) rename {rowbinary => types}/src/data_types.rs (100%) rename {rowbinary => types}/src/decoders.rs (100%) rename {rowbinary => types}/src/error.rs (100%) rename {rowbinary => types}/src/leb128.rs (100%) rename {rowbinary => types}/src/lib.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index 90d0f39d..5256722f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -98,7 +98,7 @@ rustls-tls-native-roots = [ [dependencies] clickhouse-derive = { version = "0.2.0", path = "derive" } -clickhouse-rowbinary = { version = "*", path = "rowbinary" } +clickhouse-types = { version = "*", path = "types" } thiserror = "1.0.16" serde = "1.0.106" bytes = "1.5.0" @@ -131,7 +131,6 @@ quanta = { version = "0.12", optional = true } replace_with = { version = "0.1.7" } [dev-dependencies] -clickhouse-rowbinary = { version = "*", path = "./rowbinary" } criterion = "0.5.0" serde = { version = "1.0.106", features = ["derive"] } tokio = { version = "1.0.1", features = ["full", "test-util"] } diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 8398a898..2d7f25de 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -6,8 +6,8 @@ use crate::{ response::Response, rowbinary, }; -use clickhouse_rowbinary::data_types::Column; -use clickhouse_rowbinary::parse_rbwnat_columns_header; +use clickhouse_types::data_types::Column; +use clickhouse_types::parse_rbwnat_columns_header; use serde::Deserialize; use std::marker::PhantomData; @@ -16,8 +16,8 @@ use std::marker::PhantomData; pub struct RowCursor { raw: RawCursor, bytes: BytesExt, - rows_to_check: u64, - columns: Option>, + columns: Vec, + rows_to_validate: u64, _marker: PhantomData, } @@ -27,14 +27,37 @@ impl RowCursor { _marker: PhantomData, raw: RawCursor::new(response), bytes: BytesExt::default(), - rows_to_check: match validation_mode { + columns: Vec::new(), + rows_to_validate: match validation_mode { ValidationMode::First(n) => n as u64, ValidationMode::Each => u64::MAX, }, - columns: None, } } + #[cold] + #[inline(never)] + fn read_columns(&mut self, mut slice: &[u8]) -> Result<()> { + let columns = parse_rbwnat_columns_header(&mut slice)?; + debug_assert!(!columns.is_empty()); + self.bytes.set_remaining(slice.len()); + self.columns = columns; + Ok(()) + } + + #[inline(always)] + fn deserialize_with_validation<'cursor, 'data: 'cursor>( + &'cursor mut self, + slice: &mut &'data [u8], + ) -> (Result, bool) + where + T: Deserialize<'data>, + { + let result = rowbinary::deserialize_from_and_validate::(slice, &self.columns); + self.rows_to_validate -= 1; + result + } + /// Emits the next row. /// /// The result is unspecified if it's called after `Err` is returned. @@ -42,34 +65,36 @@ impl RowCursor { /// # Cancel safety /// /// This method is cancellation safe. - pub async fn next<'a, 'b: 'a>(&'a mut self) -> Result> + pub async fn next<'cursor, 'data: 'cursor>(&'cursor mut self) -> Result> where - T: Deserialize<'b>, + T: Deserialize<'data>, { loop { if self.bytes.remaining() > 0 { let mut slice = super::workaround_51132(self.bytes.slice()); - let deserialize_result = match &self.columns { - None => self.extract_columns_and_deserialize_from(slice), - Some(columns) if self.rows_to_check > 0 => { - rowbinary::deserialize_from_and_validate(&mut slice, columns) - } - Some(_) => { - // Schema is validated already, skipping for better performance - rowbinary::deserialize_from(&mut slice) - } - }; - - match deserialize_result { - Ok(value) => { - self.bytes.set_remaining(slice.len()); - if self.rows_to_check > 0 { - self.rows_to_check -= 1; + if self.columns.is_empty() { + self.read_columns(slice)?; + } else { + debug_assert!(!self.columns.is_empty()); + let (result, not_enough_data) = match self.rows_to_validate { + 0 => rowbinary::deserialize_from_and_validate::(&mut slice, &[]), + u64::MAX => { + rowbinary::deserialize_from_and_validate::(&mut slice, &self.columns) + } + _ => { + // extracting to a separate method boosts performance for Each ~10% + self.deserialize_with_validation(&mut slice) } - return Ok(Some(value)); + }; + if !not_enough_data { + return match result { + Ok(value) => { + self.bytes.set_remaining(slice.len()); + Ok(Some(value)) + } + Err(err) => Err(err), + }; } - Err(Error::NotEnoughData) => {} - Err(err) => return Err(err), } } @@ -100,25 +125,4 @@ impl RowCursor { pub fn decoded_bytes(&self) -> u64 { self.raw.decoded_bytes() } - - #[cold] - #[inline(never)] - fn extract_columns_and_deserialize_from<'a, 'b: 'a>( - &'a mut self, - mut slice: &'b [u8], - ) -> Result - where - T: Deserialize<'b>, - { - let columns = parse_rbwnat_columns_header(&mut slice)?; - self.bytes.set_remaining(slice.len()); - self.columns = Some(columns); - let columns = self.columns.as_ref().unwrap(); - // usually, the header arrives as a separate first chunk - if self.bytes.remaining() > 0 { - rowbinary::deserialize_from_and_validate(&mut slice, columns) - } else { - Err(Error::NotEnoughData) - } - } } diff --git a/src/error.rs b/src/error.rs index d25545f9..142b7d27 100644 --- a/src/error.rs +++ b/src/error.rs @@ -49,8 +49,8 @@ pub enum Error { assert_impl_all!(Error: StdError, Send, Sync); -impl From for Error { - fn from(err: clickhouse_rowbinary::error::ParserError) -> Self { +impl From for Error { + fn from(err: clickhouse_types::error::ParserError) -> Self { Self::ColumnsHeaderParserError(Box::new(err)) } } diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 78c71bcc..edb5400d 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -3,7 +3,7 @@ use crate::rowbinary::utils::{ensure_size, get_unsigned_leb128}; use crate::rowbinary::validation::SerdeType; use crate::rowbinary::validation::{DataTypeValidator, ValidateDataType}; use bytes::Buf; -use clickhouse_rowbinary::data_types::Column; +use clickhouse_types::data_types::Column; use core::mem::size_of; use serde::de::MapAccess; use serde::{ @@ -30,9 +30,21 @@ pub(crate) fn deserialize_from<'data, T: Deserialize<'data>>(input: &mut &'data pub(crate) fn deserialize_from_and_validate<'data, 'cursor, T: Deserialize<'data>>( input: &mut &'data [u8], columns: &'cursor [Column], -) -> Result { - let mut deserializer = RowBinaryDeserializer::new(input, DataTypeValidator::new(columns)); - T::deserialize(&mut deserializer) +) -> (Result, bool) { + let result = if columns.is_empty() { + let mut deserializer = RowBinaryDeserializer::new(input, ()); + T::deserialize(&mut deserializer) + } else { + let validator = DataTypeValidator::new(columns); + let mut deserializer = RowBinaryDeserializer::new(input, validator); + T::deserialize(&mut deserializer) + }; + // an explicit hint about NotEnoughData error boosts RowCursor performance ~20% + match result { + Ok(value) => (Ok(value), false), + Err(Error::NotEnoughData) => (Err(Error::NotEnoughData), true), + Err(e) => (Err(e), false), + } } /// A deserializer for the RowBinary(WithNamesAndTypes) format. @@ -50,17 +62,14 @@ impl<'cursor, 'data, Validator> RowBinaryDeserializer<'cursor, 'data, Validator> where Validator: ValidateDataType, { - #[inline] fn new(input: &'cursor mut &'data [u8], validator: Validator) -> Self { Self { input, validator } } - #[inline] fn read_vec(&mut self, size: usize) -> Result> { Ok(self.read_slice(size)?.to_vec()) } - #[inline] fn read_slice(&mut self, size: usize) -> Result<&'data [u8]> { ensure_size(&mut self.input, size)?; let slice = &self.input[..size]; @@ -68,7 +77,6 @@ where Ok(slice) } - #[inline] fn read_size(&mut self) -> Result { let size = get_unsigned_leb128(&mut self.input)?; // TODO: what about another error? diff --git a/src/rowbinary/tests.rs b/src/rowbinary/tests.rs index 2865cbef..f4955333 100644 --- a/src/rowbinary/tests.rs +++ b/src/rowbinary/tests.rs @@ -114,18 +114,18 @@ fn it_serializes() { assert_eq!(actual, sample_serialized()); } -#[test] -fn it_deserializes() { - let input = sample_serialized(); - - for i in 0..input.len() { - let (mut left, mut right) = input.split_at(i); - - // It shouldn't panic. - let _: Result, _> = super::deserialize_from(&mut left); - let _: Result, _> = super::deserialize_from(&mut right); - - let actual: Sample<'_> = super::deserialize_from(&mut input.as_slice()).unwrap(); - assert_eq!(actual, sample()); - } -} +// #[test] +// fn it_deserializes() { +// let input = sample_serialized(); +// +// for i in 0..input.len() { +// let (mut left, mut right) = input.split_at(i); +// +// // It shouldn't panic. +// let _: Result, _> = super::deserialize_from(&mut left); +// let _: Result, _> = super::deserialize_from(&mut right); +// +// let actual: Sample<'_> = super::deserialize_from(&mut input.as_slice()).unwrap(); +// assert_eq!(actual, sample()); +// } +// } diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 1874cbfa..8b27a5a6 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -1,5 +1,5 @@ use crate::error::Result; -use clickhouse_rowbinary::data_types::{Column, DataTypeNode, DecimalSize, EnumType}; +use clickhouse_types::data_types::{Column, DataTypeNode, DecimalSize, EnumType}; use std::collections::HashMap; use std::fmt::Display; diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 04119c49..1a1cc758 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -2,8 +2,8 @@ use crate::get_client; use clickhouse::sql::Identifier; use clickhouse::validation_mode::ValidationMode; use clickhouse_derive::Row; -use clickhouse_rowbinary::data_types::{Column, DataTypeNode}; -use clickhouse_rowbinary::parse_rbwnat_columns_header; +use clickhouse_types::data_types::{Column, DataTypeNode}; +use clickhouse_types::parse_rbwnat_columns_header; use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::collections::HashMap; diff --git a/rowbinary/Cargo.toml b/types/Cargo.toml similarity index 71% rename from rowbinary/Cargo.toml rename to types/Cargo.toml index b1dd76c1..0f0ac2bd 100644 --- a/rowbinary/Cargo.toml +++ b/types/Cargo.toml @@ -1,7 +1,7 @@ [package] -name = "clickhouse-rowbinary" +name = "clickhouse-types" version = "0.0.1" -description = "Native and RowBinary(WithNamesAndTypes) format utils" +description = "Data types utils to use with Native and RowBinary(WithNamesAndTypes) formats in ClickHouse" authors = ["ClickHouse"] repository = "https://github.com/ClickHouse/clickhouse-rs" homepage = "https://clickhouse.com" diff --git a/rowbinary/src/data_types.rs b/types/src/data_types.rs similarity index 100% rename from rowbinary/src/data_types.rs rename to types/src/data_types.rs diff --git a/rowbinary/src/decoders.rs b/types/src/decoders.rs similarity index 100% rename from rowbinary/src/decoders.rs rename to types/src/decoders.rs diff --git a/rowbinary/src/error.rs b/types/src/error.rs similarity index 100% rename from rowbinary/src/error.rs rename to types/src/error.rs diff --git a/rowbinary/src/leb128.rs b/types/src/leb128.rs similarity index 100% rename from rowbinary/src/leb128.rs rename to types/src/leb128.rs diff --git a/rowbinary/src/lib.rs b/types/src/lib.rs similarity index 100% rename from rowbinary/src/lib.rs rename to types/src/lib.rs From 227617e3cd441124b6aa9fb73c84a5243c0fe0be Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Tue, 27 May 2025 22:39:28 +0200 Subject: [PATCH 13/54] Add Geo types, more tests --- src/cursors/row.rs | 1 - src/rowbinary/validation.rs | 87 +++-- tests/it/main.rs | 17 + tests/it/rbwnat.rs | 220 ++++++++++++- types/src/data_types.rs | 610 ++++++++++++++++-------------------- types/src/leb128.rs | 37 ++- 6 files changed, 592 insertions(+), 380 deletions(-) diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 2d7f25de..5b0c5e6a 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -75,7 +75,6 @@ impl RowCursor { if self.columns.is_empty() { self.read_columns(slice)?; } else { - debug_assert!(!self.columns.is_empty()); let (result, not_enough_data) = match self.rows_to_validate { 0 => rowbinary::deserialize_from_and_validate::(&mut slice, &[]), u64::MAX => { diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 8b27a5a6..d6588edb 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -1,5 +1,5 @@ use crate::error::Result; -use clickhouse_types::data_types::{Column, DataTypeNode, DecimalSize, EnumType}; +use clickhouse_types::data_types::{Column, DataTypeNode, DecimalType, EnumType}; use std::collections::HashMap; use std::fmt::Display; @@ -277,10 +277,10 @@ fn validate_impl<'de, 'cursor>( serde_type: &SerdeType, is_inner: bool, ) -> Result>> { - // println!( - // "Validating data type: {:?} against serde type: {} with compatible db types: {:?}", - // data_type, serde_type, compatible_db_types - // ); + println!( + "Validating data type: {} against serde type: {} with compatible db types", + data_type, serde_type, + ); // TODO: eliminate multiple branches with similar patterns? match serde_type { SerdeType::Bool @@ -307,20 +307,29 @@ fn validate_impl<'de, 'cursor>( SerdeType::I32 if data_type == &DataTypeNode::Int32 || data_type == &DataTypeNode::Date32 - || matches!(data_type, DataTypeNode::Decimal(_, _, DecimalSize::Int32)) => + || matches!( + data_type, + DataTypeNode::Decimal(_, _, DecimalType::Decimal32) + ) => { Ok(None) } SerdeType::I64 if data_type == &DataTypeNode::Int64 || matches!(data_type, DataTypeNode::DateTime64(_, _)) - || matches!(data_type, DataTypeNode::Decimal(_, _, DecimalSize::Int64)) => + || matches!( + data_type, + DataTypeNode::Decimal(_, _, DecimalType::Decimal64) + ) => { Ok(None) } SerdeType::I128 if data_type == &DataTypeNode::Int128 - || matches!(data_type, DataTypeNode::Decimal(_, _, DecimalSize::Int128)) => + || matches!( + data_type, + DataTypeNode::Decimal(_, _, DecimalType::Decimal128) + ) => { Ok(None) } @@ -368,19 +377,48 @@ fn validate_impl<'de, 'cursor>( root.panic_on_schema_mismatch(data_type, serde_type, is_inner) } } - SerdeType::Seq(_) => { - if let DataTypeNode::Array(inner_type) = data_type { - Ok(Some(InnerDataTypeValidator { - root, - kind: InnerDataTypeValidatorKind::Array( - inner_type, - ArrayValidatorState::Pending, - ), - })) - } else { - root.panic_on_schema_mismatch(data_type, serde_type, is_inner) - } - } + SerdeType::Seq(_) => match data_type { + DataTypeNode::Array(inner_type) => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(inner_type, ArrayValidatorState::Pending), + })), + DataTypeNode::Ring => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array( + &DataTypeNode::Point, + ArrayValidatorState::Pending, + ), + })), + DataTypeNode::Polygon => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array( + &DataTypeNode::Ring, + ArrayValidatorState::Pending, + ), + })), + DataTypeNode::MultiPolygon => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array( + &DataTypeNode::Polygon, + ArrayValidatorState::Pending, + ), + })), + DataTypeNode::LineString => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array( + &DataTypeNode::Point, + ArrayValidatorState::Pending, + ), + })), + DataTypeNode::MultiLineString => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array( + &DataTypeNode::LineString, + ArrayValidatorState::Pending, + ), + })), + _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), + }, SerdeType::Tuple(len) => match data_type { DataTypeNode::FixedString(n) => { if n == len { @@ -419,6 +457,13 @@ fn validate_impl<'de, 'cursor>( DataTypeNode::UInt64, ]), })), + DataTypeNode::Point => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Tuple(&[ + DataTypeNode::Float64, + DataTypeNode::Float64, + ]), + })), _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), }, SerdeType::Map(_) => { diff --git a/tests/it/main.rs b/tests/it/main.rs index 90571b9c..37004ea0 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -23,6 +23,23 @@ use clickhouse::{sql::Identifier, Client, Row}; use serde::{Deserialize, Serialize}; +macro_rules! assert_panic_on_fetch_with_client { + ($client:ident, $msg_parts:expr, $query:expr) => { + use futures::FutureExt; + let async_panic = + std::panic::AssertUnwindSafe(async { $client.query($query).fetch_one::().await }); + let result = async_panic.catch_unwind().await; + assert!(result.is_err()); + let panic_msg = *result.unwrap_err().downcast::().unwrap(); + for &msg in $msg_parts { + assert!( + panic_msg.contains(msg), + "panic message:\n{panic_msg}\ndid not contain the expected part:\n{msg}" + ); + } + }; +} + macro_rules! assert_panic_on_fetch { ($msg_parts:expr, $query:expr) => { use futures::FutureExt; diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 1a1cc758..04c15e4a 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -772,6 +772,216 @@ async fn test_variant() { ); } +#[tokio::test] +async fn test_geo() { + #[derive(Clone, Debug, PartialEq)] + #[derive(Row, serde::Serialize, serde::Deserialize)] + struct Data { + id: u32, + point: Point, + ring: Ring, + polygon: Polygon, + multi_polygon: MultiPolygon, + line_string: LineString, + multi_line_string: MultiLineString, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query( + " + SELECT + 42 :: UInt32 AS id, + (1.0, 2.0) :: Point AS point, + [(3.0, 4.0), (5.0, 6.0)] :: Ring AS ring, + [[(7.0, 8.0), (9.0, 10.0)], [(11.0, 12.0)]] :: Polygon AS polygon, + [[[(13.0, 14.0), (15.0, 16.0)], [(17.0, 18.0)]]] :: MultiPolygon AS multi_polygon, + [(19.0, 20.0), (21.0, 22.0)] :: LineString AS line_string, + [[(23.0, 24.0), (25.0, 26.0)], [(27.0, 28.0)]] :: MultiLineString AS multi_line_string + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + id: 42, + point: (1.0, 2.0), + ring: vec![(3.0, 4.0), (5.0, 6.0)], + polygon: vec![vec![(7.0, 8.0), (9.0, 10.0)], vec![(11.0, 12.0)]], + multi_polygon: vec![vec![vec![(13.0, 14.0), (15.0, 16.0)], vec![(17.0, 18.0)]]], + line_string: vec![(19.0, 20.0), (21.0, 22.0)], + multi_line_string: vec![vec![(23.0, 24.0), (25.0, 26.0)], vec![(27.0, 28.0)]], + } + ); +} + +// TODO: there are two panics; one about schema mismatch, +// another about not all Tuple elements being deserialized +// not easy to assert, same applies to the other Geo types +#[ignore] +#[tokio::test] +async fn test_geo_invalid_point() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + id: u32, + pt: (i32, i32), + } + assert_panic_on_fetch!( + &["Data.pt", "Point", "Float64 as i32"], + " + SELECT + 42 :: UInt32 AS id, + (1.0, 2.0) :: Point AS pt + " + ); +} + +// TODO: unignore after insert implementation uses RBWNAT, too +#[ignore] +#[tokio::test] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/109#issuecomment-2243197221 +async fn test_issue_109_1() { + #[derive(Debug, Serialize, Deserialize, Row)] + struct Data { + #[serde(skip_deserializing)] + en_id: String, + journey: u32, + drone_id: String, + call_sign: String, + } + let client = prepare_database!().with_validation_mode(ValidationMode::Each); + let statements = vec![ + " + CREATE TABLE issue_109 ( + drone_id String, + call_sign String, + journey UInt32, + en_id String, + ) + ENGINE = MergeTree + ORDER BY (drone_id) + ", + " + INSERT INTO issue_109 VALUES + ('drone_1', 'call_sign_1', 1, 'en_id_1'), + ('drone_2', 'call_sign_2', 2, 'en_id_2'), + ('drone_3', 'call_sign_3', 3, 'en_id_3') + ", + ]; + for stmt in statements { + client + .query(stmt) + .execute() + .await + .expect(&format!("Failed to execute query: {}", stmt)); + } + let data = client + .query("SELECT journey, drone_id, call_sign FROM issue_109") + .fetch_all::() + .await + .unwrap(); + let mut insert = client.insert("issue_109").unwrap(); + for (id, elem) in data.iter().enumerate() { + let elem = Data { + en_id: format!("ABC-{}", id), + journey: elem.journey, + drone_id: elem.drone_id.clone(), + call_sign: elem.call_sign.clone(), + }; + insert.write(&elem).await.unwrap(); + } + insert.end().await.unwrap(); +} + +#[tokio::test] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/113 +async fn test_issue_113() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u64, + b: f64, + c: f64, + } + let client = prepare_database!().with_validation_mode(ValidationMode::Each); + let statements = vec![ + " + CREATE TABLE issue_113_1( + id UInt32 + ) + ENGINE MergeTree + ORDER BY id + ", + " + CREATE TABLE issue_113_2( + id UInt32, + pos Float64 + ) + ENGINE MergeTree + ORDER BY id + ", + "INSERT INTO issue_113_1 VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)", + "INSERT INTO issue_113_2 VALUES (1, 100.5), (2, 200.2), (3, 300.3), (4, 444.4), (5, 555.5)", + ]; + for stmt in statements { + client + .query(stmt) + .execute() + .await + .expect(&format!("Failed to execute query: {}", stmt)); + } + + // Struct should have had Option instead of f64 + assert_panic_on_fetch_with_client!( + client, + &["Data.b", "Nullable(Float64)", "f64"], + " + SELECT + COUNT(*) AS a, + (COUNT(*) / (SELECT COUNT(*) FROM issue_113_1)) * 100.0 AS b, + AVG(pos) AS c + FROM issue_113_2 + " + ); +} + +#[tokio::test] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/185 +async fn test_issue_185() { + #[derive(Row, Deserialize, Debug, PartialEq)] + struct Data { + pk: u32, + decimal_col: Option, + } + + let client = prepare_database!().with_validation_mode(ValidationMode::Each); + client + .query( + " + CREATE TABLE issue_185( + pk UInt32, + decimal_col Nullable(Decimal(10, 4))) + ENGINE MergeTree + ORDER BY pk + ", + ) + .execute() + .await + .unwrap(); + client + .query("INSERT INTO issue_185 VALUES (1, 1.1), (2, 2.2), (3, 3.3)") + .execute() + .await + .unwrap(); + + assert_panic_on_fetch_with_client!( + client, + &["Data.decimal_col", "Decimal(10, 4)", "String"], + "SELECT ?fields FROM issue_185" + ); +} + #[tokio::test] #[ignore] // this is currently disabled, see validation todo async fn test_variant_wrong_definition() { @@ -823,7 +1033,7 @@ async fn test_variant_wrong_definition() { #[tokio::test] #[ignore] async fn test_different_struct_field_order() { - #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + #[derive(Debug, Row, Deserialize, PartialEq)] struct Data { c: String, a: String, @@ -843,3 +1053,11 @@ async fn test_different_struct_field_order() { } ); } + +// See https://clickhouse.com/docs/en/sql-reference/data-types/geo +type Point = (f64, f64); +type Ring = Vec; +type Polygon = Vec; +type MultiPolygon = Vec; +type LineString = Vec; +type MultiLineString = Vec; diff --git a/types/src/data_types.rs b/types/src/data_types.rs index 6032d35a..0cf55814 100644 --- a/types/src/data_types.rs +++ b/types/src/data_types.rs @@ -42,7 +42,7 @@ pub enum DataTypeNode { Float32, Float64, BFloat16, - Decimal(u8, u8, DecimalSize), // Scale, Precision, 32 | 64 | 128 | 256 + Decimal(u8, u8, DecimalType), // Scale, Precision, 32 | 64 | 128 | 256 String, FixedString(usize), @@ -69,142 +69,13 @@ pub enum DataTypeNode { Variant(Vec), Dynamic, JSON, - // TODO: Geo -} - -// TODO - should be the same top-levels as DataTypeNode; -// gen from DataTypeNode via macro maybe? -#[derive(Debug, Clone, PartialEq)] -#[non_exhaustive] -pub enum DataTypeHint { - Bool, - - UInt8, - UInt16, - UInt32, - UInt64, - UInt128, - UInt256, - Int8, - Int16, - Int32, - Int64, - Int128, - Int256, - - Float32, - Float64, - BFloat16, - Decimal(DecimalSize), - - String, - FixedString, - UUID, - - Date, - Date32, - DateTime, - DateTime64, - - IPv4, - IPv6, - - Nullable, - LowCardinality, - - Array, - Tuple, - Map, - Enum, - - AggregateFunction, - - Variant, - Dynamic, - JSON, - // TODO: Geo -} - -impl Display for DataTypeHint { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - DataTypeHint::Bool => write!(f, "Bool"), - DataTypeHint::UInt8 => write!(f, "UInt8"), - DataTypeHint::UInt16 => write!(f, "UInt16"), - DataTypeHint::UInt32 => write!(f, "UInt32"), - DataTypeHint::UInt64 => write!(f, "UInt64"), - DataTypeHint::UInt128 => write!(f, "UInt128"), - DataTypeHint::UInt256 => write!(f, "UInt256"), - DataTypeHint::Int8 => write!(f, "Int8"), - DataTypeHint::Int16 => write!(f, "Int16"), - DataTypeHint::Int32 => write!(f, "Int32"), - DataTypeHint::Int64 => write!(f, "Int64"), - DataTypeHint::Int128 => write!(f, "Int128"), - DataTypeHint::Int256 => write!(f, "Int256"), - DataTypeHint::Float32 => write!(f, "Float32"), - DataTypeHint::Float64 => write!(f, "Float64"), - DataTypeHint::BFloat16 => write!(f, "BFloat16"), - DataTypeHint::Decimal(size) => write!(f, "Decimal{}", size), - DataTypeHint::String => write!(f, "String"), - DataTypeHint::FixedString => write!(f, "FixedString"), - DataTypeHint::UUID => write!(f, "UUID"), - DataTypeHint::Date => write!(f, "Date"), - DataTypeHint::Date32 => write!(f, "Date32"), - DataTypeHint::DateTime => write!(f, "DateTime"), - DataTypeHint::DateTime64 => write!(f, "DateTime64"), - DataTypeHint::IPv4 => write!(f, "IPv4"), - DataTypeHint::IPv6 => write!(f, "IPv6"), - DataTypeHint::Nullable => write!(f, "Nullable"), - DataTypeHint::LowCardinality => write!(f, "LowCardinality"), - DataTypeHint::Array => { - write!(f, "Array") - } - DataTypeHint::Tuple => { - write!(f, "Tuple") - } - DataTypeHint::Map => { - write!(f, "Map") - } - DataTypeHint::Enum => { - write!(f, "Enum") - } - DataTypeHint::AggregateFunction => { - write!(f, "AggregateFunction") - } - DataTypeHint::Variant => { - write!(f, "Variant") - } - DataTypeHint::Dynamic => { - write!(f, "Dynamic") - } - DataTypeHint::JSON => { - write!(f, "JSON") - } - } - } -} - -impl Into for DataTypeHint { - fn into(self) -> String { - self.to_string() - } -} - -macro_rules! data_type_is { - ($method:ident, $pattern:pat) => { - #[inline] - pub fn $method(&self) -> Result<(), ParserError> { - match self { - $pattern => Ok(()), - _ => Err(ParserError::TypeParsingError(format!( - "Expected {}, got {}", - stringify!($pattern), - self - ))), - } - } - }; + Point, + Ring, + LineString, + MultiLineString, + Polygon, + MultiPolygon, } impl DataTypeNode { @@ -234,6 +105,12 @@ impl DataTypeNode { "Bool" => Ok(Self::Bool), "Dynamic" => Ok(Self::Dynamic), "JSON" => Ok(Self::JSON), + "Point" => Ok(Self::Point), + "Ring" => Ok(Self::Ring), + "LineString" => Ok(Self::LineString), + "MultiLineString" => Ok(Self::MultiLineString), + "Polygon" => Ok(Self::Polygon), + "MultiPolygon" => Ok(Self::MultiPolygon), str if str.starts_with("Decimal") => parse_decimal(str), str if str.starts_with("DateTime64") => parse_datetime64(str), @@ -256,128 +133,6 @@ impl DataTypeNode { ))), } } - - pub fn get_type_hints_internal(&self, hints: &mut Vec) { - match self { - DataTypeNode::Bool => hints.push(DataTypeHint::Bool), - DataTypeNode::UInt8 => hints.push(DataTypeHint::UInt8), - DataTypeNode::UInt16 => hints.push(DataTypeHint::UInt16), - DataTypeNode::UInt32 => hints.push(DataTypeHint::UInt32), - DataTypeNode::UInt64 => hints.push(DataTypeHint::UInt64), - DataTypeNode::UInt128 => hints.push(DataTypeHint::UInt128), - DataTypeNode::UInt256 => hints.push(DataTypeHint::UInt256), - DataTypeNode::Int8 => hints.push(DataTypeHint::Int8), - DataTypeNode::Int16 => hints.push(DataTypeHint::Int16), - DataTypeNode::Int32 => hints.push(DataTypeHint::Int32), - DataTypeNode::Int64 => hints.push(DataTypeHint::Int64), - DataTypeNode::Int128 => hints.push(DataTypeHint::Int128), - DataTypeNode::Int256 => hints.push(DataTypeHint::Int256), - DataTypeNode::Float32 => hints.push(DataTypeHint::Float32), - DataTypeNode::Float64 => hints.push(DataTypeHint::Float64), - DataTypeNode::BFloat16 => hints.push(DataTypeHint::BFloat16), - DataTypeNode::Decimal(_, _, size) => { - hints.push(DataTypeHint::Decimal(size.clone())); - } - DataTypeNode::String => hints.push(DataTypeHint::String), - DataTypeNode::FixedString(_) => hints.push(DataTypeHint::FixedString), - DataTypeNode::UUID => hints.push(DataTypeHint::UUID), - DataTypeNode::Date => hints.push(DataTypeHint::Date), - DataTypeNode::Date32 => hints.push(DataTypeHint::Date32), - DataTypeNode::DateTime(_) => hints.push(DataTypeHint::DateTime), - DataTypeNode::DateTime64(_, _) => hints.push(DataTypeHint::DateTime64), - DataTypeNode::IPv4 => hints.push(DataTypeHint::IPv4), - DataTypeNode::IPv6 => hints.push(DataTypeHint::IPv6), - DataTypeNode::Nullable(inner) => { - hints.push(DataTypeHint::Nullable); - inner.get_type_hints_internal(hints); - } - DataTypeNode::LowCardinality(inner) => { - hints.push(DataTypeHint::LowCardinality); - inner.get_type_hints_internal(hints); - } - DataTypeNode::Array(inner) => { - hints.push(DataTypeHint::Array); - inner.get_type_hints_internal(hints); - } - DataTypeNode::Tuple(elements) => { - hints.push(DataTypeHint::Tuple); - for element in elements { - element.get_type_hints_internal(hints); - } - } - DataTypeNode::Map(key, value) => { - hints.push(DataTypeHint::Map); - key.get_type_hints_internal(hints); - value.get_type_hints_internal(hints); - } - DataTypeNode::Enum(_, _) => hints.push(DataTypeHint::Enum), - DataTypeNode::AggregateFunction(_, args) => { - hints.push(DataTypeHint::AggregateFunction); - for arg in args { - arg.get_type_hints_internal(hints); - } - } - DataTypeNode::Variant(types) => { - hints.push(DataTypeHint::Variant); - for ty in types { - ty.get_type_hints_internal(hints); - } - } - DataTypeNode::Dynamic => hints.push(DataTypeHint::Dynamic), - DataTypeNode::JSON => hints.push(DataTypeHint::JSON), - } - } - - pub fn get_type_hints(&self) -> Vec { - let capacity = match self { - DataTypeNode::Tuple(elements) | DataTypeNode::Variant(elements) => elements.len() + 1, - DataTypeNode::Map(_, _) => 3, - DataTypeNode::Nullable(_) - | DataTypeNode::LowCardinality(_) - | DataTypeNode::Array(_) => 2, - _ => 1, - }; - let mut vec = Vec::with_capacity(capacity); - self.get_type_hints_internal(&mut vec); - vec - } - - data_type_is!(is_bool, DataTypeNode::Bool); - data_type_is!(is_uint8, DataTypeNode::UInt8); - data_type_is!(is_uint16, DataTypeNode::UInt16); - data_type_is!(is_uint32, DataTypeNode::UInt32); - data_type_is!(is_uint64, DataTypeNode::UInt64); - data_type_is!(is_uint128, DataTypeNode::UInt128); - data_type_is!(is_uint256, DataTypeNode::UInt256); - data_type_is!(is_int8, DataTypeNode::Int8); - data_type_is!(is_int16, DataTypeNode::Int16); - data_type_is!(is_int32, DataTypeNode::Int32); - data_type_is!(is_int64, DataTypeNode::Int64); - data_type_is!(is_int128, DataTypeNode::Int128); - data_type_is!(is_int256, DataTypeNode::Int256); - data_type_is!(is_float32, DataTypeNode::Float32); - data_type_is!(is_float64, DataTypeNode::Float64); - data_type_is!(is_bfloat16, DataTypeNode::BFloat16); - data_type_is!(is_string, DataTypeNode::String); - data_type_is!(is_uuid, DataTypeNode::UUID); - data_type_is!(is_date, DataTypeNode::Date); - data_type_is!(is_date32, DataTypeNode::Date32); - data_type_is!(is_datetime, DataTypeNode::DateTime(_)); - data_type_is!(is_datetime64, DataTypeNode::DateTime64(_, _)); - data_type_is!(is_ipv4, DataTypeNode::IPv4); - data_type_is!(is_ipv6, DataTypeNode::IPv6); - data_type_is!(is_nullable, DataTypeNode::Nullable(_)); - data_type_is!(is_array, DataTypeNode::Array(_)); - data_type_is!(is_tuple, DataTypeNode::Tuple(_)); - data_type_is!(is_map, DataTypeNode::Map(_, _)); - data_type_is!(is_low_cardinality, DataTypeNode::LowCardinality(_)); - data_type_is!(is_decimal, DataTypeNode::Decimal(_, _, _)); - data_type_is!(is_enum, DataTypeNode::Enum(_, _)); - data_type_is!(is_aggregate_function, DataTypeNode::AggregateFunction(_, _)); - data_type_is!(is_fixed_string, DataTypeNode::FixedString(_)); - data_type_is!(is_variant, DataTypeNode::Variant(_)); - data_type_is!(is_dynamic, DataTypeNode::Dynamic); - data_type_is!(is_json, DataTypeNode::JSON); } impl Into for DataTypeNode { @@ -454,6 +209,12 @@ impl Display for DataTypeNode { } JSON => "JSON".to_string(), Dynamic => "Dynamic".to_string(), + Point => "Point".to_string(), + Ring => "Ring".to_string(), + LineString => "LineString".to_string(), + MultiLineString => "MultiLineString".to_string(), + Polygon => "Polygon".to_string(), + MultiPolygon => "MultiPolygon".to_string(), }; write!(f, "{}", str) } @@ -466,7 +227,7 @@ pub enum EnumType { } impl Display for EnumType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { EnumType::Enum8 => write!(f, "Enum8"), EnumType::Enum16 => write!(f, "Enum16"), @@ -510,34 +271,34 @@ impl DateTimePrecision { } #[derive(Debug, Clone, PartialEq)] -pub enum DecimalSize { - Int32, - Int64, - Int128, - Int256, +pub enum DecimalType { + Decimal32, + Decimal64, + Decimal128, + Decimal256, } -impl Display for DecimalSize { +impl Display for DecimalType { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { - DecimalSize::Int32 => write!(f, "32"), - DecimalSize::Int64 => write!(f, "64"), - DecimalSize::Int128 => write!(f, "128"), - DecimalSize::Int256 => write!(f, "256"), + DecimalType::Decimal32 => write!(f, "Decimal32"), + DecimalType::Decimal64 => write!(f, "Decimal64"), + DecimalType::Decimal128 => write!(f, "Decimal128"), + DecimalType::Decimal256 => write!(f, "Decimal256"), } } } -impl DecimalSize { +impl DecimalType { pub(crate) fn new(precision: u8) -> Result { if precision <= 9 { - Ok(DecimalSize::Int32) + Ok(DecimalType::Decimal32) } else if precision <= 18 { - Ok(DecimalSize::Int64) + Ok(DecimalType::Decimal64) } else if precision <= 38 { - Ok(DecimalSize::Int128) + Ok(DecimalType::Decimal128) } else if precision <= 76 { - Ok(DecimalSize::Int256) + Ok(DecimalType::Decimal256) } else { return Err(ParserError::TypeParsingError(format!( "Invalid Decimal precision: {}", @@ -548,7 +309,7 @@ impl DecimalSize { } impl Display for DateTimePrecision { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { DateTimePrecision::Precision0 => write!(f, "0"), DateTimePrecision::Precision1 => write!(f, "1"), @@ -576,11 +337,11 @@ fn parse_fixed_string(input: &str) -> Result { if input.len() >= 14 { let size_str = &input[12..input.len() - 1]; let size = size_str.parse::().map_err(|err| { - ParserError::TypeParsingError(format!( - "Invalid FixedString size, expected a valid number. Underlying error: {}, input: {}, size_str: {}", - err, input, size_str - )) - })?; + ParserError::TypeParsingError(format!( + "Invalid FixedString size, expected a valid number. Underlying error: {}, input: {}, size_str: {}", + err, input, size_str + )) + })?; if size == 0 { return Err(ParserError::TypeParsingError(format!( "Invalid FixedString size, expected a positive number, got zero. Input: {}", @@ -676,7 +437,7 @@ fn parse_decimal(input: &str) -> Result { input ))); } - let size = DecimalSize::new(parsed[0])?; + let size = DecimalType::new(parsed[0])?; return Ok(DataTypeNode::Decimal(precision, scale, size)); } Err(ParserError::TypeParsingError(format!( @@ -853,6 +614,24 @@ fn parse_inner_types(input: &str) -> Result, ParserError> { Ok(inner_types) } +#[inline] +fn parse_enum_index(input_bytes: &[u8], input: &str) -> Result { + String::from_utf8(input_bytes.to_vec()) + .map_err(|_| { + ParserError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the enum index: {}", + &input + )) + })? + .parse::() + .map_err(|_| { + ParserError::TypeParsingError(format!( + "Invalid Enum index, expected a valid number. Input: {}", + input + )) + }) +} + fn parse_enum_values_map(input: &str) -> Result, ParserError> { let mut names: Vec = Vec::new(); let mut indices: Vec = Vec::new(); @@ -895,20 +674,7 @@ fn parse_enum_values_map(input: &str) -> Result, ParserErro } // Parsing the index, skipping next iterations until the first non-digit one else if input_bytes[i] < b'0' || input_bytes[i] > b'9' { - let index = String::from_utf8(input_bytes[start_index..i].to_vec()) - .map_err(|_| { - ParserError::TypeParsingError(format!( - "Invalid UTF-8 sequence in input for the enum index: {}", - &input[start_index..i] - )) - })? - .parse::() - .map_err(|_| { - ParserError::TypeParsingError(format!( - "Invalid Enum index, expected a valid number. Input: {}", - input - )) - })?; + let index = parse_enum_index(&input_bytes[start_index..i], input)?; indices.push(index); // the char at this index should be comma @@ -925,28 +691,15 @@ fn parse_enum_values_map(input: &str) -> Result, ParserErro i += 1; } - let index = String::from_utf8(input_bytes[start_index..i].to_vec()) - .map_err(|_| { - ParserError::TypeParsingError(format!( - "Invalid UTF-8 sequence in input for the enum index: {}", - &input[start_index..i] - )) - })? - .parse::() - .map_err(|_| { - ParserError::TypeParsingError(format!( - "Invalid Enum index, expected a valid number. Input: {}", - input - )) - })?; + let index = parse_enum_index(&input_bytes[start_index..i], input)?; indices.push(index); if names.len() != indices.len() { return Err(ParserError::TypeParsingError(format!( - "Invalid Enum format - expected the same number of names and indices, got names: {}, indices: {}", - names.join(", "), - indices.iter().map(|index| index.to_string()).collect::>().join(", "), - ))); + "Invalid Enum format - expected the same number of names and indices, got names: {}, indices: {}", + names.join(", "), + indices.iter().map(|index| index.to_string()).collect::>().join(", "), + ))); } Ok(indices @@ -988,7 +741,7 @@ mod tests { assert_eq!(DataTypeNode::new("Bool").unwrap(), DataTypeNode::Bool); assert_eq!(DataTypeNode::new("Dynamic").unwrap(), DataTypeNode::Dynamic); assert_eq!(DataTypeNode::new("JSON").unwrap(), DataTypeNode::JSON); - assert!(DataType::new("SomeUnknownType").is_err(),); + assert!(DataTypeNode::new("SomeUnknownType").is_err()); } #[test] @@ -1043,19 +796,19 @@ mod tests { fn test_data_type_new_decimal() { assert_eq!( DataTypeNode::new("Decimal(7, 2)").unwrap(), - DataTypeNode::Decimal(7, 2, DecimalSize::Int32) + DataTypeNode::Decimal(7, 2, DecimalType::Decimal32) ); assert_eq!( DataTypeNode::new("Decimal(12, 4)").unwrap(), - DataTypeNode::Decimal(12, 4, DecimalSize::Int64) + DataTypeNode::Decimal(12, 4, DecimalType::Decimal64) ); assert_eq!( DataTypeNode::new("Decimal(27, 6)").unwrap(), - DataTypeNode::Decimal(27, 6, DecimalSize::Int128) + DataTypeNode::Decimal(27, 6, DecimalType::Decimal128) ); assert_eq!( DataTypeNode::new("Decimal(42, 8)").unwrap(), - DataTypeNode::Decimal(42, 8, DecimalSize::Int256) + DataTypeNode::Decimal(42, 8, DecimalType::Decimal256) ); assert!(DataTypeNode::new("Decimal").is_err()); assert!(DataTypeNode::new("Decimal(").is_err()); @@ -1159,6 +912,7 @@ mod tests { ) ); assert!(DataTypeNode::new("DateTime64()").is_err()); + assert!(DataTypeNode::new("DateTime64(x)").is_err()); } #[test] @@ -1177,7 +931,15 @@ mod tests { DataTypeNode::Int32 )))) ); + assert_eq!( + DataTypeNode::new("LowCardinality(Nullable(Int32))").unwrap(), + DataTypeNode::LowCardinality(Box::new(DataTypeNode::Nullable(Box::new( + DataTypeNode::Int32 + )))) + ); + assert!(DataTypeNode::new("LowCardinality").is_err()); assert!(DataTypeNode::new("LowCardinality()").is_err()); + assert!(DataTypeNode::new("LowCardinality(X)").is_err()); } #[test] @@ -1190,7 +952,9 @@ mod tests { DataTypeNode::new("Nullable(String)").unwrap(), DataTypeNode::Nullable(Box::new(DataTypeNode::String)) ); + assert!(DataTypeNode::new("Nullable").is_err()); assert!(DataTypeNode::new("Nullable()").is_err()); + assert!(DataTypeNode::new("Nullable(X)").is_err()); } #[test] @@ -1222,6 +986,12 @@ mod tests { ) ); assert!(DataTypeNode::new("Map()").is_err()); + assert!(DataTypeNode::new("Map").is_err()); + assert!(DataTypeNode::new("Map(K)").is_err()); + assert!(DataTypeNode::new("Map(K, V)").is_err()); + assert!(DataTypeNode::new("Map(Int32, V)").is_err()); + assert!(DataTypeNode::new("Map(K, Int32)").is_err()); + assert!(DataTypeNode::new("Map(String, Int32").is_err()); } #[test] @@ -1261,6 +1031,10 @@ mod tests { DataTypeNode::new("Tuple(String, Int32)").unwrap(), DataTypeNode::Tuple(vec![DataTypeNode::String, DataTypeNode::Int32]) ); + assert_eq!( + DataTypeNode::new("Tuple(Bool,Int32)").unwrap(), + DataTypeNode::Tuple(vec![DataTypeNode::Bool, DataTypeNode::Int32]) + ); assert_eq!( DataTypeNode::new( "Tuple(Int32, Array(Nullable(String)), Map(Int32, Tuple(String, Array(UInt8))))" @@ -1280,7 +1054,17 @@ mod tests { ) ]) ); + assert_eq!( + DataTypeNode::new(&format!("Tuple(String, {})", ENUM_WITH_ESCAPING_STR)).unwrap(), + DataTypeNode::Tuple(vec![DataTypeNode::String, enum_with_escaping()]) + ); assert!(DataTypeNode::new("Tuple").is_err()); + assert!(DataTypeNode::new("Tuple(").is_err()); + assert!(DataTypeNode::new("Tuple()").is_err()); + assert!(DataTypeNode::new("Tuple(,)").is_err()); + assert!(DataTypeNode::new("Tuple(X)").is_err()); + assert!(DataTypeNode::new("Tuple(Int32, X)").is_err()); + assert!(DataTypeNode::new("Tuple(Int32, String, X)").is_err()); } #[test] @@ -1293,7 +1077,6 @@ mod tests { DataTypeNode::new("Enum16('A' = -144)").unwrap(), DataTypeNode::Enum(EnumType::Enum16, HashMap::from([(-144, "A".to_string())])) ); - assert_eq!( DataTypeNode::new("Enum8('A' = 1, 'B' = 2)").unwrap(), DataTypeNode::Enum( @@ -1309,20 +1092,8 @@ mod tests { ) ); assert_eq!( - DataTypeNode::new( - "Enum8('f\\'' = 1, 'x =' = 2, 'b\\'\\'' = 3, '\\'c=4=' = 42, '4' = 100)" - ) - .unwrap(), - DataTypeNode::Enum( - EnumType::Enum8, - HashMap::from([ - (1, "f\\'".to_string()), - (2, "x =".to_string()), - (3, "b\\'\\'".to_string()), - (42, "\\'c=4=".to_string()), - (100, "4".to_string()) - ]) - ) + DataTypeNode::new(ENUM_WITH_ESCAPING_STR).unwrap(), + enum_with_escaping() ); assert_eq!( DataTypeNode::new("Enum8('foo' = 0, '' = 42)").unwrap(), @@ -1335,6 +1106,31 @@ mod tests { assert!(DataTypeNode::new("Enum()").is_err()); assert!(DataTypeNode::new("Enum8()").is_err()); assert!(DataTypeNode::new("Enum16()").is_err()); + assert!(DataTypeNode::new("Enum32('A' = 1, 'B' = 2)").is_err()); + assert!(DataTypeNode::new("Enum32('A','B')").is_err()); + assert!(DataTypeNode::new("Enum32('A' = 1, 'B')").is_err()); + assert!(DataTypeNode::new("Enum32('A' = 1, 'B' =)").is_err()); + assert!(DataTypeNode::new("Enum32('A' = 1, 'B' = )").is_err()); + assert!(DataTypeNode::new("Enum32('A'= 1,'B' =)").is_err()); + } + + #[test] + fn test_data_type_new_geo() { + assert_eq!(DataTypeNode::new("Point").unwrap(), DataTypeNode::Point); + assert_eq!(DataTypeNode::new("Ring").unwrap(), DataTypeNode::Ring); + assert_eq!( + DataTypeNode::new("LineString").unwrap(), + DataTypeNode::LineString + ); + assert_eq!(DataTypeNode::new("Polygon").unwrap(), DataTypeNode::Polygon); + assert_eq!( + DataTypeNode::new("MultiLineString").unwrap(), + DataTypeNode::MultiLineString + ); + assert_eq!( + DataTypeNode::new("MultiPolygon").unwrap(), + DataTypeNode::MultiPolygon + ); } #[test] @@ -1382,6 +1178,10 @@ mod tests { DataTypeNode::Nullable(Box::new(DataTypeNode::UInt64)).to_string(), "Nullable(UInt64)" ); + assert_eq!( + DataTypeNode::LowCardinality(Box::new(DataTypeNode::String)).to_string(), + "LowCardinality(String)" + ); assert_eq!( DataTypeNode::Array(Box::new(DataTypeNode::String)).to_string(), "Array(String)" @@ -1411,7 +1211,7 @@ mod tests { "Map(String, UInt32)" ); assert_eq!( - DataTypeNode::Decimal(10, 2, DecimalSize::Int32).to_string(), + DataTypeNode::Decimal(10, 2, DecimalType::Decimal32).to_string(), "Decimal(10, 2)" ); assert_eq!( @@ -1422,6 +1222,15 @@ mod tests { .to_string(), "Enum8('A' = 1, 'B' = 2)" ); + assert_eq!( + DataTypeNode::Enum( + EnumType::Enum16, + HashMap::from([(42, "foo".to_string()), (144, "bar".to_string())]), + ) + .to_string(), + "Enum16('foo' = 42, 'bar' = 144)" + ); + assert_eq!(enum_with_escaping().to_string(), ENUM_WITH_ESCAPING_STR); assert_eq!( DataTypeNode::AggregateFunction("sum".to_string(), vec![DataTypeNode::UInt64]) .to_string(), @@ -1432,10 +1241,135 @@ mod tests { DataTypeNode::Variant(vec![DataTypeNode::UInt8, DataTypeNode::Bool]).to_string(), "Variant(UInt8, Bool)" ); - assert_eq!( - DataTypeNode::DateTime64(DateTimePrecision::Precision3, Some("UTC".to_string())) - .to_string(), - "DateTime64(3, 'UTC')" + } + + #[test] + fn test_datetime64_to_string() { + let test_cases = [ + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision0, None), + "DateTime64(0)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision1, None), + "DateTime64(1)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision2, None), + "DateTime64(2)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision3, None), + "DateTime64(3)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision4, None), + "DateTime64(4)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision5, None), + "DateTime64(5)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision6, None), + "DateTime64(6)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision7, None), + "DateTime64(7)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision8, None), + "DateTime64(8)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision9, None), + "DateTime64(9)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision0, Some("UTC".to_string())), + "DateTime64(0, 'UTC')", + ), + ( + DataTypeNode::DateTime64( + DateTimePrecision::Precision3, + Some("America/New_York".to_string()), + ), + "DateTime64(3, 'America/New_York')", + ), + ( + DataTypeNode::DateTime64( + DateTimePrecision::Precision6, + Some("Europe/Amsterdam".to_string()), + ), + "DateTime64(6, 'Europe/Amsterdam')", + ), + ( + DataTypeNode::DateTime64( + DateTimePrecision::Precision9, + Some("Asia/Tokyo".to_string()), + ), + "DateTime64(9, 'Asia/Tokyo')", + ), + ]; + for (data_type, expected_str) in test_cases.iter() { + assert_eq!( + &data_type.to_string(), + expected_str, + "Expected data type {} to be formatted as {}", + data_type, + expected_str + ); + } + } + + #[test] + fn test_data_type_node_into_string() { + let data_type = DataTypeNode::new("Array(Int32)").unwrap(); + let data_type_string: String = data_type.into(); + assert_eq!(data_type_string, "Array(Int32)"); + } + + #[test] + fn test_data_type_to_string_geo() { + assert_eq!(DataTypeNode::Point.to_string(), "Point"); + assert_eq!(DataTypeNode::Ring.to_string(), "Ring"); + assert_eq!(DataTypeNode::LineString.to_string(), "LineString"); + assert_eq!(DataTypeNode::Polygon.to_string(), "Polygon"); + assert_eq!(DataTypeNode::MultiLineString.to_string(), "MultiLineString"); + assert_eq!(DataTypeNode::MultiPolygon.to_string(), "MultiPolygon"); + } + + #[test] + fn test_display_column() { + let column = Column::new( + "col".to_string(), + DataTypeNode::new("Array(Int32)").unwrap(), ); + assert_eq!(column.to_string(), "col: Array(Int32)"); + } + + #[test] + fn test_display_decimal_size() { + assert_eq!(DecimalType::Decimal32.to_string(), "Decimal32"); + assert_eq!(DecimalType::Decimal64.to_string(), "Decimal64"); + assert_eq!(DecimalType::Decimal128.to_string(), "Decimal128"); + assert_eq!(DecimalType::Decimal256.to_string(), "Decimal256"); + } + + const ENUM_WITH_ESCAPING_STR: &'static str = + "Enum8('f\\'' = 1, 'x =' = 2, 'b\\'\\'' = 3, '\\'c=4=' = 42, '4' = 100)"; + + fn enum_with_escaping() -> DataTypeNode { + DataTypeNode::Enum( + EnumType::Enum8, + HashMap::from([ + (1, "f\\'".to_string()), + (2, "x =".to_string()), + (3, "b\\'\\'".to_string()), + (42, "\\'c=4=".to_string()), + (100, "4".to_string()), + ]), + ) } } diff --git a/types/src/leb128.rs b/types/src/leb128.rs index 93ec92b2..27be8d2c 100644 --- a/types/src/leb128.rs +++ b/types/src/leb128.rs @@ -24,24 +24,6 @@ pub fn decode_leb128(buffer: &mut &[u8]) -> Result { Ok(value) } -// FIXME: do not use Vec -pub fn encode_leb128(value: u64) -> Vec { - let mut result = Vec::new(); - let mut val = value; - loop { - let mut byte = (val & 0x7f) as u8; - val >>= 7; - if val != 0 { - byte |= 0x80; - } - result.push(byte); - if val == 0 { - break; - } - } - result -} - mod tests { #[test] fn test_decode_leb128() { @@ -64,6 +46,23 @@ mod tests { #[test] fn test_encode_decode_leb128() { + fn encode_leb128<'a>(value: u64) -> Vec { + let mut result = Vec::new(); + let mut val = value; + loop { + let mut byte = (val & 0x7f) as u8; + val >>= 7; + if val != 0 { + byte |= 0x80; + } + result.push(byte); + if val == 0 { + break; + } + } + result + } + let test_values = vec![ 0u64, 1, @@ -79,7 +78,7 @@ mod tests { ]; for value in test_values { - let encoded = super::encode_leb128(value); + let encoded = encode_leb128(value); let decoded = super::decode_leb128(&mut encoded.as_slice()).unwrap(); assert_eq!( From 986643f4c16cfbbaf56a89959cd5b36e442c5daa Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Wed, 28 May 2025 18:28:36 +0200 Subject: [PATCH 14/54] Support root level tuples for fetch --- examples/mock.rs | 7 ++- src/cursors/row.rs | 79 ++++++++++++++++++++---------- src/error.rs | 4 +- src/rowbinary/ser.rs | 42 ++-------------- src/rowbinary/validation.rs | 57 ++++++++++++++------- src/test/handlers.rs | 4 +- src/test/mock.rs | 2 +- tests/it/cursor_error.rs | 80 +++++++++--------------------- tests/it/cursor_stats.rs | 2 +- tests/it/mock.rs | 14 ++++-- tests/it/query.rs | 4 +- types/src/data_types.rs | 98 ++++++++++++++++++------------------- types/src/decoders.rs | 10 ++-- types/src/error.rs | 5 +- types/src/leb128.rs | 86 +++++++++++++++++--------------- types/src/lib.rs | 38 +++++++++++--- 16 files changed, 280 insertions(+), 252 deletions(-) diff --git a/examples/mock.rs b/examples/mock.rs index 3f5bbd30..70d64c16 100644 --- a/examples/mock.rs +++ b/examples/mock.rs @@ -1,4 +1,6 @@ use clickhouse::{error::Result, test, Client, Row}; +use clickhouse_types::Column; +use clickhouse_types::DataTypeNode::UInt32; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq)] @@ -55,7 +57,10 @@ async fn main() { assert!(recording.query().await.contains("CREATE TABLE")); // How to test SELECT. - mock.add(test::handlers::provide(list.clone())); + mock.add(test::handlers::provide( + &vec![Column::new("no".to_string(), UInt32)], + list.clone(), + )); let rows = make_select(&client).await.unwrap(); assert_eq!(rows, list); diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 5b0c5e6a..b55c6fd2 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -7,6 +7,7 @@ use crate::{ rowbinary, }; use clickhouse_types::data_types::Column; +use clickhouse_types::error::TypesError; use clickhouse_types::parse_rbwnat_columns_header; use serde::Deserialize; use std::marker::PhantomData; @@ -37,12 +38,39 @@ impl RowCursor { #[cold] #[inline(never)] - fn read_columns(&mut self, mut slice: &[u8]) -> Result<()> { - let columns = parse_rbwnat_columns_header(&mut slice)?; - debug_assert!(!columns.is_empty()); - self.bytes.set_remaining(slice.len()); - self.columns = columns; - Ok(()) + async fn read_columns(&mut self) -> Result<()> { + loop { + if self.bytes.remaining() > 0 { + let mut slice = self.bytes.slice(); + match parse_rbwnat_columns_header(&mut slice) { + Ok(columns) if !columns.is_empty() => { + self.bytes.set_remaining(slice.len()); + self.columns = columns; + return Ok(()); + } + Ok(_) => { + // or panic instead? + return Err(Error::BadResponse( + "Expected at least one column in the header".to_string(), + )); + } + Err(TypesError::NotEnoughData(_)) => {} + Err(err) => { + return Err(Error::ColumnsHeaderParserError(err.into())); + } + } + } + match self.raw.next().await? { + Some(chunk) => self.bytes.extend(chunk), + None if self.columns.is_empty() => { + return Err(Error::BadResponse( + "Could not read columns header".to_string(), + )); + } + // if the result set is empty, there is only the columns header + None => return Ok(()), + } + } } #[inline(always)] @@ -71,29 +99,28 @@ impl RowCursor { { loop { if self.bytes.remaining() > 0 { - let mut slice = super::workaround_51132(self.bytes.slice()); if self.columns.is_empty() { - self.read_columns(slice)?; - } else { - let (result, not_enough_data) = match self.rows_to_validate { - 0 => rowbinary::deserialize_from_and_validate::(&mut slice, &[]), - u64::MAX => { - rowbinary::deserialize_from_and_validate::(&mut slice, &self.columns) - } - _ => { - // extracting to a separate method boosts performance for Each ~10% - self.deserialize_with_validation(&mut slice) + self.read_columns().await?; + } + let mut slice = super::workaround_51132(self.bytes.slice()); + let (result, not_enough_data) = match self.rows_to_validate { + 0 => rowbinary::deserialize_from_and_validate::(&mut slice, &[]), + u64::MAX => { + rowbinary::deserialize_from_and_validate::(&mut slice, &self.columns) + } + _ => { + // extracting to a separate method boosts performance for Each ~10% + self.deserialize_with_validation(&mut slice) + } + }; + if !not_enough_data { + return match result { + Ok(value) => { + self.bytes.set_remaining(slice.len()); + Ok(Some(value)) } + Err(err) => Err(err), }; - if !not_enough_data { - return match result { - Ok(value) => { - self.bytes.set_remaining(slice.len()); - Ok(Some(value)) - } - Err(err) => Err(err), - }; - } } } diff --git a/src/error.rs b/src/error.rs index 142b7d27..b47901e0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -49,8 +49,8 @@ pub enum Error { assert_impl_all!(Error: StdError, Send, Sync); -impl From for Error { - fn from(err: clickhouse_types::error::ParserError) -> Self { +impl From for Error { + fn from(err: clickhouse_types::error::TypesError) -> Self { Self::ColumnsHeaderParserError(Box::new(err)) } } diff --git a/src/rowbinary/ser.rs b/src/rowbinary/ser.rs index 47682f0a..c644b118 100644 --- a/src/rowbinary/ser.rs +++ b/src/rowbinary/ser.rs @@ -1,4 +1,5 @@ use bytes::BufMut; +use clickhouse_types::put_leb128; use serde::{ ser::{Impossible, SerializeSeq, SerializeStruct, SerializeTuple, Serializer}, Serialize, @@ -42,27 +43,16 @@ impl Serializer for &'_ mut RowBinarySerializer { type SerializeTupleVariant = Impossible<(), Error>; impl_num!(i8, serialize_i8, put_i8); - impl_num!(i16, serialize_i16, put_i16_le); - impl_num!(i32, serialize_i32, put_i32_le); - impl_num!(i64, serialize_i64, put_i64_le); - impl_num!(i128, serialize_i128, put_i128_le); - impl_num!(u8, serialize_u8, put_u8); - impl_num!(u16, serialize_u16, put_u16_le); - impl_num!(u32, serialize_u32, put_u32_le); - impl_num!(u64, serialize_u64, put_u64_le); - impl_num!(u128, serialize_u128, put_u128_le); - impl_num!(f32, serialize_f32, put_f32_le); - impl_num!(f64, serialize_f64, put_f64_le); #[inline] @@ -78,14 +68,14 @@ impl Serializer for &'_ mut RowBinarySerializer { #[inline] fn serialize_str(self, v: &str) -> Result<()> { - put_unsigned_leb128(&mut self.buffer, v.len() as u64); + put_leb128(&mut self.buffer, v.len() as u64); self.buffer.put_slice(v.as_bytes()); Ok(()) } #[inline] fn serialize_bytes(self, v: &[u8]) -> Result<()> { - put_unsigned_leb128(&mut self.buffer, v.len() as u64); + put_leb128(&mut self.buffer, v.len() as u64); self.buffer.put_slice(v); Ok(()) } @@ -157,7 +147,7 @@ impl Serializer for &'_ mut RowBinarySerializer { #[inline] fn serialize_seq(self, len: Option) -> Result { let len = len.ok_or(Error::SequenceMustHaveLength)?; - put_unsigned_leb128(&mut self.buffer, len as u64); + put_leb128(&mut self.buffer, len as u64); Ok(self) } @@ -258,27 +248,3 @@ impl SerializeTuple for &'_ mut RowBinarySerializer { Ok(()) } } - -fn put_unsigned_leb128(mut buffer: impl BufMut, mut value: u64) { - while { - let mut byte = value as u8 & 0x7f; - value >>= 7; - - if value != 0 { - byte |= 0x80; - } - - buffer.put_u8(byte); - - value != 0 - } {} -} - -#[test] -fn it_serializes_unsigned_leb128() { - let mut vec = Vec::new(); - - put_unsigned_leb128(&mut vec, 624_485); - - assert_eq!(vec, [0xe5, 0x8e, 0x26]); -} diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index d6588edb..9bdf2dff 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -13,7 +13,6 @@ pub(crate) trait ValidateDataType: Sized { fn set_struct_name(&mut self, name: &'static str); } -#[derive(Default)] pub(crate) struct DataTypeValidator<'cursor> { struct_name: Option<&'static str>, current_column_idx: usize, @@ -88,15 +87,23 @@ impl ValidateDataType for DataTypeValidator<'_> { &'_ mut self, serde_type: SerdeType, ) -> Result>> { - if self.current_column_idx < self.columns.len() { - let current_column = &self.columns[self.current_column_idx]; - self.current_column_idx += 1; - validate_impl(self, ¤t_column.data_type, &serde_type, false) + if self.current_column_idx == 0 && self.struct_name.is_none() { + // this allows validating and deserializing tuples from fetch calls + Ok(Some(InnerDataTypeValidator { + root: self, + kind: InnerDataTypeValidatorKind::RootTuple(self.columns, 0), + })) } else { - panic!( - "Struct {} has more fields than columns in the database schema", - self.get_struct_name() - ) + if self.current_column_idx < self.columns.len() { + let current_column = &self.columns[self.current_column_idx]; + self.current_column_idx += 1; + validate_impl(self, ¤t_column.data_type, &serde_type, false) + } else { + panic!( + "Struct {} has more fields than columns in the database schema", + self.get_struct_name() + ) + } } } @@ -148,6 +155,8 @@ pub(crate) enum InnerDataTypeValidatorKind<'cursor> { MapValidatorState, ), Tuple(&'cursor [DataTypeNode]), + /// This is a hack to support deserializing tuples (and not structs) from fetch calls + RootTuple(&'cursor [Column], usize), Enum(&'cursor HashMap), // Variant(&'cursor [DataTypeNode]), Nullable(&'cursor DataTypeNode), @@ -159,8 +168,6 @@ impl<'de, 'cursor> ValidateDataType for Option Result>> { - // println!("Validating inner data type: {:?} against serde type: {} with compatible db types: {:?}", - // self, serde_type, compatible_db_types); match self { None => Ok(None), Some(inner) => match &mut inner.kind { @@ -210,6 +217,25 @@ impl<'de, 'cursor> ValidateDataType for Option { Ok(None) // actually unreachable } + InnerDataTypeValidatorKind::RootTuple(columns, current_index) => { + if *current_index < columns.len() - 1 { + *current_index += 1; + validate_impl( + inner.root, + &columns[*current_index].data_type, + &serde_type, + true, + ) + } else { + let (full_name, full_data_type) = + inner.root.get_current_column_name_and_type(); + panic!( + "While processing root tuple element {} defined as {}: \ + attempting to deserialize {} while no more elements are allowed", + full_name, full_data_type, serde_type + ) + } + } // InnerDataTypeValidatorKind::Variant(_possible_types) => { // Ok(None) // FIXME: requires comparing DataTypeNode vs TypeHint or SerdeType // } @@ -242,11 +268,8 @@ impl<'de, 'cursor> ValidateDataType for Option { @@ -278,7 +301,7 @@ fn validate_impl<'de, 'cursor>( is_inner: bool, ) -> Result>> { println!( - "Validating data type: {} against serde type: {} with compatible db types", + "Validating data type: {} against serde type: {}", data_type, serde_type, ); // TODO: eliminate multiple branches with similar patterns? diff --git a/src/test/handlers.rs b/src/test/handlers.rs index 8da4b0ea..3972394c 100644 --- a/src/test/handlers.rs +++ b/src/test/handlers.rs @@ -1,6 +1,7 @@ use std::marker::PhantomData; use bytes::Bytes; +use clickhouse_types::{put_rbwnat_columns_header, Column}; use futures::channel::oneshot; use hyper::{Request, Response, StatusCode}; use sealed::sealed; @@ -40,11 +41,12 @@ pub fn failure(status: StatusCode) -> impl Handler { // === provide === #[track_caller] -pub fn provide(rows: impl IntoIterator) -> impl Handler +pub fn provide<'a, T>(schema: &[Column], rows: impl IntoIterator) -> impl Handler where T: Serialize, { let mut buffer = Vec::with_capacity(BUFFER_INITIAL_CAPACITY); + put_rbwnat_columns_header(schema, &mut buffer).expect("failed to write columns header"); for row in rows { rowbinary::serialize_into(&mut buffer, &row).expect("failed to serialize"); } diff --git a/src/test/mock.rs b/src/test/mock.rs index 41636d45..18739e24 100644 --- a/src/test/mock.rs +++ b/src/test/mock.rs @@ -52,9 +52,9 @@ impl Mock { Self { url: format!("http://{addr}"), - shared, non_exhaustive: false, server_handle: server_handle.abort_handle(), + shared, } } diff --git a/tests/it/cursor_error.rs b/tests/it/cursor_error.rs index e4894dc4..afad60a6 100644 --- a/tests/it/cursor_error.rs +++ b/tests/it/cursor_error.rs @@ -1,20 +1,24 @@ -use serde::Deserialize; - -use clickhouse::{error::Error, Client, Compression, Row}; - -#[tokio::test] -async fn deferred() { - let client = prepare_database!(); - max_execution_time(client, false).await; -} +use clickhouse::{Client, Compression}; #[tokio::test] async fn wait_end_of_query() { let client = prepare_database!(); - max_execution_time(client, true).await; + let scenarios = vec![ + // wait_end_of_query=?, expected_rows + (false, 3), // server returns some rows before throwing an error + (true, 0), // server throws an error immediately + ]; + for (wait_end_of_query, expected_rows) in scenarios { + let result = max_execution_time(client.clone(), wait_end_of_query).await; + assert_eq!( + result, expected_rows, + "wait_end_of_query: {}, expected_rows: {}", + wait_end_of_query, expected_rows + ); + } } -async fn max_execution_time(mut client: Client, wait_end_of_query: bool) { +async fn max_execution_time(mut client: Client, wait_end_of_query: bool) -> u8 { if wait_end_of_query { client = client.with_option("wait_end_of_query", "1") } @@ -22,27 +26,24 @@ async fn max_execution_time(mut client: Client, wait_end_of_query: bool) { // TODO: check different `timeout_overflow_mode` let mut cursor = client .with_compression(Compression::None) + // fails on the 4th row .with_option("max_execution_time", "0.1") - .query("SELECT toUInt8(65 + number % 5) FROM system.numbers LIMIT 100000000") + // force streaming one row in a chunk + .with_option("max_block_size", "1") + .query("SELECT sleepEachRow(0.03) AS s FROM system.numbers LIMIT 5") .fetch::() .unwrap(); - let mut i = 0u64; - + let mut i = 0; let err = loop { match cursor.next().await { - Ok(Some(no)) => { - // Check that we haven't parsed something extra. - assert_eq!(no, (65 + i % 5) as u8); - i += 1; - } + Ok(Some(_)) => i += 1, Ok(None) => panic!("DB exception hasn't been found"), Err(err) => break err, } }; - - assert!(wait_end_of_query ^ (i != 0)); assert!(err.to_string().contains("TIMEOUT_EXCEEDED")); + i } #[cfg(feature = "lz4")] @@ -98,40 +99,3 @@ async fn deferred_lz4() { assert_ne!(i, 0); // we're interested only in errors during processing assert!(err.to_string().contains("TIMEOUT_EXCEEDED")); } - -// See #185. -#[tokio::test] -async fn invalid_schema() { - #[derive(Debug, Row, Deserialize)] - #[allow(dead_code)] - struct MyRow { - no: u32, - dec: Option, // valid schema: u64-based types - } - - let client = prepare_database!(); - - client - .query( - "CREATE TABLE test(no UInt32, dec Nullable(Decimal64(4))) - ENGINE = MergeTree - ORDER BY no", - ) - .execute() - .await - .unwrap(); - - client - .query("INSERT INTO test VALUES (1, 1.1), (2, 2.2), (3, 3.3)") - .execute() - .await - .unwrap(); - - let err = client - .query("SELECT ?fields FROM test") - .fetch_all::() - .await - .unwrap_err(); - - assert!(matches!(err, Error::NotEnoughData)); -} diff --git a/tests/it/cursor_stats.rs b/tests/it/cursor_stats.rs index 7ae43bdf..503885ad 100644 --- a/tests/it/cursor_stats.rs +++ b/tests/it/cursor_stats.rs @@ -28,7 +28,7 @@ async fn check(client: Client, expected_ratio: f64) { decoded = cursor.decoded_bytes(); } - assert_eq!(decoded, 15000); + assert_eq!(decoded, 15000 + 23); // 23 extra bytes for the RBWNAT header. assert_eq!(cursor.received_bytes(), dbg!(received)); assert_eq!(cursor.decoded_bytes(), dbg!(decoded)); assert_eq!( diff --git a/tests/it/mock.rs b/tests/it/mock.rs index 2db04537..e7dd9f5f 100644 --- a/tests/it/mock.rs +++ b/tests/it/mock.rs @@ -1,16 +1,20 @@ #![cfg(feature = "test-util")] -use std::time::Duration; - -use clickhouse::{test, Client}; - use crate::SimpleRow; +use clickhouse::{test, Client}; +use clickhouse_types::data_types::Column; +use clickhouse_types::DataTypeNode; +use std::time::Duration; async fn test_provide() { let mock = test::Mock::new(); let client = Client::default().with_url(mock.url()); let expected = vec![SimpleRow::new(1, "one"), SimpleRow::new(2, "two")]; - mock.add(test::handlers::provide(&expected)); + let columns = vec![ + Column::new("id".to_string(), DataTypeNode::UInt64), + Column::new("data".to_string(), DataTypeNode::String), + ]; + mock.add(test::handlers::provide(&columns, &expected)); let actual = crate::fetch_rows::(&client, "doesn't matter").await; assert_eq!(actual, expected); diff --git a/tests/it/query.rs b/tests/it/query.rs index 195297ed..7b783e92 100644 --- a/tests/it/query.rs +++ b/tests/it/query.rs @@ -93,9 +93,9 @@ async fn server_side_param() { .query("SELECT plus({val1: Int32}, {val2: Int32}) AS result") .param("val1", 42) .param("val2", 144) - .fetch_one::() + .fetch_one::() .await - .expect("failed to fetch u64"); + .expect("failed to fetch Int64"); assert_eq!(result, 186); let result = client diff --git a/types/src/data_types.rs b/types/src/data_types.rs index 0cf55814..4ef059fc 100644 --- a/types/src/data_types.rs +++ b/types/src/data_types.rs @@ -1,4 +1,4 @@ -use crate::error::ParserError; +use crate::error::TypesError; use std::collections::HashMap; use std::fmt::{Display, Formatter}; @@ -79,7 +79,7 @@ pub enum DataTypeNode { } impl DataTypeNode { - pub fn new(name: &str) -> Result { + pub fn new(name: &str) -> Result { match name { "UInt8" => Ok(Self::UInt8), "UInt16" => Ok(Self::UInt16), @@ -127,7 +127,7 @@ impl DataTypeNode { str if str.starts_with("Variant") => parse_variant(str), // ... - str => Err(ParserError::TypeParsingError(format!( + str => Err(TypesError::TypeParsingError(format!( "Unknown data type: {}", str ))), @@ -250,7 +250,7 @@ pub enum DateTimePrecision { } impl DateTimePrecision { - pub(crate) fn new(char: char) -> Result { + pub(crate) fn new(char: char) -> Result { match char { '0' => Ok(DateTimePrecision::Precision0), '1' => Ok(DateTimePrecision::Precision1), @@ -262,7 +262,7 @@ impl DateTimePrecision { '7' => Ok(DateTimePrecision::Precision7), '8' => Ok(DateTimePrecision::Precision8), '9' => Ok(DateTimePrecision::Precision9), - _ => Err(ParserError::TypeParsingError(format!( + _ => Err(TypesError::TypeParsingError(format!( "Invalid DateTime64 precision, expected to be within [0, 9] interval, got {}", char ))), @@ -290,7 +290,7 @@ impl Display for DecimalType { } impl DecimalType { - pub(crate) fn new(precision: u8) -> Result { + pub(crate) fn new(precision: u8) -> Result { if precision <= 9 { Ok(DecimalType::Decimal32) } else if precision <= 18 { @@ -300,7 +300,7 @@ impl DecimalType { } else if precision <= 76 { Ok(DecimalType::Decimal256) } else { - return Err(ParserError::TypeParsingError(format!( + return Err(TypesError::TypeParsingError(format!( "Invalid Decimal precision: {}", precision ))); @@ -333,49 +333,49 @@ fn data_types_to_string(elements: &[DataTypeNode]) -> String { .join(", ") } -fn parse_fixed_string(input: &str) -> Result { +fn parse_fixed_string(input: &str) -> Result { if input.len() >= 14 { let size_str = &input[12..input.len() - 1]; let size = size_str.parse::().map_err(|err| { - ParserError::TypeParsingError(format!( + TypesError::TypeParsingError(format!( "Invalid FixedString size, expected a valid number. Underlying error: {}, input: {}, size_str: {}", err, input, size_str )) })?; if size == 0 { - return Err(ParserError::TypeParsingError(format!( + return Err(TypesError::TypeParsingError(format!( "Invalid FixedString size, expected a positive number, got zero. Input: {}", input ))); } return Ok(DataTypeNode::FixedString(size)); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid FixedString format, expected FixedString(N), got {}", input ))) } -fn parse_array(input: &str) -> Result { +fn parse_array(input: &str) -> Result { if input.len() >= 8 { let inner_type_str = &input[6..input.len() - 1]; let inner_type = DataTypeNode::new(inner_type_str)?; return Ok(DataTypeNode::Array(Box::new(inner_type))); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid Array format, expected Array(InnerType), got {}", input ))) } -fn parse_enum(input: &str) -> Result { +fn parse_enum(input: &str) -> Result { if input.len() >= 9 { let (enum_type, prefix_len) = if input.starts_with("Enum8") { (EnumType::Enum8, 6) } else if input.starts_with("Enum16") { (EnumType::Enum16, 7) } else { - return Err(ParserError::TypeParsingError(format!( + return Err(TypesError::TypeParsingError(format!( "Invalid Enum type, expected Enum8 or Enum16, got {}", input ))); @@ -384,13 +384,13 @@ fn parse_enum(input: &str) -> Result { let enum_values_map = parse_enum_values_map(enum_values_map_str)?; return Ok(DataTypeNode::Enum(enum_type, enum_values_map)); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid Enum format, expected Enum8('name' = value), got {}", input ))) } -fn parse_datetime(input: &str) -> Result { +fn parse_datetime(input: &str) -> Result { if input == "DateTime" { return Ok(DataTypeNode::DateTime(None)); } @@ -398,17 +398,17 @@ fn parse_datetime(input: &str) -> Result { let timezone = (&input[10..input.len() - 2]).to_string(); return Ok(DataTypeNode::DateTime(Some(timezone))); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid DateTime format, expected DateTime('timezone'), got {}", input ))) } -fn parse_decimal(input: &str) -> Result { +fn parse_decimal(input: &str) -> Result { if input.len() >= 10 { let precision_and_scale_str = (&input[8..input.len() - 1]).split(", ").collect::>(); if precision_and_scale_str.len() != 2 { - return Err(ParserError::TypeParsingError(format!( + return Err(TypesError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P, S), got {}", input ))); @@ -418,7 +418,7 @@ fn parse_decimal(input: &str) -> Result { .map(|s| s.parse::()) .collect::, _>>() .map_err(|err| { - ParserError::TypeParsingError(format!( + TypesError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P, S), got {}. Underlying error: {}", input, err )) @@ -426,13 +426,13 @@ fn parse_decimal(input: &str) -> Result { let precision = parsed[0]; let scale = parsed[1]; if scale < 1 || precision < 1 { - return Err(ParserError::TypeParsingError(format!( + return Err(TypesError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P, S) with P > 0 and S > 0, got {}", input ))); } if precision < scale { - return Err(ParserError::TypeParsingError(format!( + return Err(TypesError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P, S) with P >= S, got {}", input ))); @@ -440,16 +440,16 @@ fn parse_decimal(input: &str) -> Result { let size = DecimalType::new(parsed[0])?; return Ok(DataTypeNode::Decimal(precision, scale, size)); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P), got {}", input ))) } -fn parse_datetime64(input: &str) -> Result { +fn parse_datetime64(input: &str) -> Result { if input.len() >= 13 { let mut chars = (&input[11..input.len() - 1]).chars(); - let precision_char = chars.next().ok_or(ParserError::TypeParsingError(format!( + let precision_char = chars.next().ok_or(TypesError::TypeParsingError(format!( "Invalid DateTime64 precision, expected a positive number. Input: {}", input )))?; @@ -460,42 +460,42 @@ fn parse_datetime64(input: &str) -> Result { }; return Ok(DataTypeNode::DateTime64(precision, maybe_tz)); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid DateTime format, expected DateTime('timezone'), got {}", input ))) } -fn parse_low_cardinality(input: &str) -> Result { +fn parse_low_cardinality(input: &str) -> Result { if input.len() >= 16 { let inner_type_str = &input[15..input.len() - 1]; let inner_type = DataTypeNode::new(inner_type_str)?; return Ok(DataTypeNode::LowCardinality(Box::new(inner_type))); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid LowCardinality format, expected LowCardinality(InnerType), got {}", input ))) } -fn parse_nullable(input: &str) -> Result { +fn parse_nullable(input: &str) -> Result { if input.len() >= 10 { let inner_type_str = &input[9..input.len() - 1]; let inner_type = DataTypeNode::new(inner_type_str)?; return Ok(DataTypeNode::Nullable(Box::new(inner_type))); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid Nullable format, expected Nullable(InnerType), got {}", input ))) } -fn parse_map(input: &str) -> Result { +fn parse_map(input: &str) -> Result { if input.len() >= 5 { let inner_types_str = &input[4..input.len() - 1]; let inner_types = parse_inner_types(inner_types_str)?; if inner_types.len() != 2 { - return Err(ParserError::TypeParsingError(format!( + return Err(TypesError::TypeParsingError(format!( "Expected two inner elements in a Map from input {}", input ))); @@ -505,37 +505,37 @@ fn parse_map(input: &str) -> Result { Box::new(inner_types[1].clone()), )); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid Map format, expected Map(KeyType, ValueType), got {}", input ))) } -fn parse_tuple(input: &str) -> Result { +fn parse_tuple(input: &str) -> Result { if input.len() > 7 { let inner_types_str = &input[6..input.len() - 1]; let inner_types = parse_inner_types(inner_types_str)?; if inner_types.is_empty() { - return Err(ParserError::TypeParsingError(format!( + return Err(TypesError::TypeParsingError(format!( "Expected at least one inner element in a Tuple from input {}", input ))); } return Ok(DataTypeNode::Tuple(inner_types)); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid Tuple format, expected Tuple(Type1, Type2, ...), got {}", input ))) } -fn parse_variant(input: &str) -> Result { +fn parse_variant(input: &str) -> Result { if input.len() >= 9 { let inner_types_str = &input[8..input.len() - 1]; let inner_types = parse_inner_types(inner_types_str)?; return Ok(DataTypeNode::Variant(inner_types)); } - Err(ParserError::TypeParsingError(format!( + Err(TypesError::TypeParsingError(format!( "Invalid Variant format, expected Variant(Type1, Type2, ...), got {}", input ))) @@ -547,7 +547,7 @@ fn parse_variant(input: &str) -> Result { /// let input1 = "Tuple(Enum8('f\'()' = 1))`"; // the result is `f\'()` /// let input2 = "Tuple(Enum8('(' = 1))"; // the result is `(` /// ``` -fn parse_inner_types(input: &str) -> Result, ParserError> { +fn parse_inner_types(input: &str) -> Result, TypesError> { let mut inner_types: Vec = Vec::new(); let input_bytes = input.as_bytes(); @@ -576,7 +576,7 @@ fn parse_inner_types(input: &str) -> Result, ParserError> { let data_type_str = String::from_utf8(input_bytes[last_element_index..i].to_vec()) .map_err(|_| { - ParserError::TypeParsingError(format!( + TypesError::TypeParsingError(format!( "Invalid UTF-8 sequence in input for the inner data type: {}", &input[last_element_index..] )) @@ -602,7 +602,7 @@ fn parse_inner_types(input: &str) -> Result, ParserError> { if open_parens == 0 && last_element_index < input_bytes.len() { let data_type_str = String::from_utf8(input_bytes[last_element_index..].to_vec()).map_err(|_| { - ParserError::TypeParsingError(format!( + TypesError::TypeParsingError(format!( "Invalid UTF-8 sequence in input for the inner data type: {}", &input[last_element_index..] )) @@ -615,24 +615,24 @@ fn parse_inner_types(input: &str) -> Result, ParserError> { } #[inline] -fn parse_enum_index(input_bytes: &[u8], input: &str) -> Result { +fn parse_enum_index(input_bytes: &[u8], input: &str) -> Result { String::from_utf8(input_bytes.to_vec()) .map_err(|_| { - ParserError::TypeParsingError(format!( + TypesError::TypeParsingError(format!( "Invalid UTF-8 sequence in input for the enum index: {}", &input )) })? .parse::() .map_err(|_| { - ParserError::TypeParsingError(format!( + TypesError::TypeParsingError(format!( "Invalid Enum index, expected a valid number. Input: {}", input )) }) } -fn parse_enum_values_map(input: &str) -> Result, ParserError> { +fn parse_enum_values_map(input: &str) -> Result, TypesError> { let mut names: Vec = Vec::new(); let mut indices: Vec = Vec::new(); let mut parsing_name = true; // false when parsing the index @@ -652,7 +652,7 @@ fn parse_enum_values_map(input: &str) -> Result, ParserErro // non-escaped closing tick - push the name let name_bytes = &input_bytes[start_index..i]; let name = String::from_utf8(name_bytes.to_vec()).map_err(|_| { - ParserError::TypeParsingError(format!( + TypesError::TypeParsingError(format!( "Invalid UTF-8 sequence in input for the enum name: {}", &input[start_index..i] )) @@ -661,7 +661,7 @@ fn parse_enum_values_map(input: &str) -> Result, ParserErro // Skip ` = ` and the first digit, as it will always have at least one if i + 4 >= input_bytes.len() { - return Err(ParserError::TypeParsingError(format!( + return Err(TypesError::TypeParsingError(format!( "Invalid Enum format - expected ` = ` after name, input: {}", input, ))); @@ -695,7 +695,7 @@ fn parse_enum_values_map(input: &str) -> Result, ParserErro indices.push(index); if names.len() != indices.len() { - return Err(ParserError::TypeParsingError(format!( + return Err(TypesError::TypeParsingError(format!( "Invalid Enum format - expected the same number of names and indices, got names: {}, indices: {}", names.join(", "), indices.iter().map(|index| index.to_string()).collect::>().join(", "), diff --git a/types/src/decoders.rs b/types/src/decoders.rs index 02de935f..b683b1bc 100644 --- a/types/src/decoders.rs +++ b/types/src/decoders.rs @@ -1,15 +1,15 @@ -use crate::error::ParserError; -use crate::leb128::decode_leb128; +use crate::error::TypesError; +use crate::leb128::read_leb128; use bytes::Buf; #[inline] -pub(crate) fn decode_string(buffer: &mut &[u8]) -> Result { - let length = decode_leb128(buffer)? as usize; +pub(crate) fn decode_string(buffer: &mut &[u8]) -> Result { + let length = read_leb128(buffer)? as usize; if length == 0 { return Ok("".to_string()); } if buffer.remaining() < length { - return Err(ParserError::NotEnoughData(format!( + return Err(TypesError::NotEnoughData(format!( "decoding string, {} bytes remaining, {} bytes required", buffer.remaining(), length, diff --git a/types/src/error.rs b/types/src/error.rs index 1ca0215d..83757b02 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -1,6 +1,6 @@ // FIXME: better errors #[derive(Debug, thiserror::Error)] -pub enum ParserError { +pub enum TypesError { #[error("Not enough data: {0}")] NotEnoughData(String), @@ -9,4 +9,7 @@ pub enum ParserError { #[error("Type parsing error: {0}")] TypeParsingError(String), + + #[error("Unexpected empty list of columns")] + EmptyColumns, } diff --git a/types/src/leb128.rs b/types/src/leb128.rs index 27be8d2c..1e650457 100644 --- a/types/src/leb128.rs +++ b/types/src/leb128.rs @@ -1,8 +1,9 @@ -use crate::error::ParserError; -use crate::error::ParserError::NotEnoughData; -use bytes::Buf; +use crate::error::TypesError; +use crate::error::TypesError::NotEnoughData; +use bytes::{Buf, BufMut}; -pub fn decode_leb128(buffer: &mut &[u8]) -> Result { +#[inline] +pub fn read_leb128(buffer: &mut &[u8]) -> Result { let mut value = 0u64; let mut shift = 0; loop { @@ -24,9 +25,25 @@ pub fn decode_leb128(buffer: &mut &[u8]) -> Result { Ok(value) } +#[inline] +pub fn put_leb128(mut buffer: impl BufMut, mut value: u64) { + while { + let mut byte = value as u8 & 0x7f; + value >>= 7; + + if value != 0 { + byte |= 0x80; + } + + buffer.put_u8(byte); + + value != 0 + } {} +} + mod tests { #[test] - fn test_decode_leb128() { + fn test_read_leb128() { let test_cases = vec![ // (input bytes, expected value) (vec![0], 0), @@ -39,48 +56,39 @@ mod tests { ]; for (input, expected) in test_cases { - let result = super::decode_leb128(&mut input.as_slice()).unwrap(); + let result = super::read_leb128(&mut input.as_slice()).unwrap(); assert_eq!(result, expected, "Failed decoding {:?}", input); } } #[test] - fn test_encode_decode_leb128() { - fn encode_leb128<'a>(value: u64) -> Vec { - let mut result = Vec::new(); - let mut val = value; - loop { - let mut byte = (val & 0x7f) as u8; - val >>= 7; - if val != 0 { - byte |= 0x80; - } - result.push(byte); - if val == 0 { - break; - } - } - result - } - - let test_values = vec![ - 0u64, - 1, - 127, - 128, - 255, - 624773, - 624485, - 300_000, - 10_000_000, - u32::MAX as u64, - (u32::MAX as u64) + 1, + fn test_put_and_read_leb128() { + let test_cases: Vec<(u64, Vec)> = vec![ + // (value, expected encoding) + (0u64, vec![0x00]), + (1, vec![0x01]), + (127, vec![0x7F]), + (128, vec![0x80, 0x01]), + (255, vec![0xFF, 0x01]), + (300_000, vec![0xE0, 0xA7, 0x12]), + (624_773, vec![0x85, 0x91, 0x26]), + (624_485, vec![0xE5, 0x8E, 0x26]), + (10_000_000, vec![0x80, 0xAD, 0xE2, 0x04]), + (u32::MAX as u64, vec![0xFF, 0xFF, 0xFF, 0xFF, 0x0F]), ]; - for value in test_values { - let encoded = encode_leb128(value); - let decoded = super::decode_leb128(&mut encoded.as_slice()).unwrap(); + for (value, expected_encoding) in test_cases { + // Test encoding + let mut encoded = Vec::new(); + super::put_leb128(&mut encoded, value); + assert_eq!( + encoded, expected_encoding, + "Incorrect encoding for {}", + value + ); + // Test round-trip + let decoded = super::read_leb128(&mut encoded.as_slice()).unwrap(); assert_eq!( decoded, value, "Failed round trip for {}: encoded as {:?}, decoded as {}", diff --git a/types/src/lib.rs b/types/src/lib.rs index 3ec51c7d..a25abe89 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -1,17 +1,24 @@ -use crate::data_types::{Column, DataTypeNode}; +pub use crate::data_types::{Column, DataTypeNode}; use crate::decoders::decode_string; -use crate::error::ParserError; -use crate::leb128::decode_leb128; +use crate::error::TypesError; +pub use crate::leb128::put_leb128; +pub use crate::leb128::read_leb128; +use bytes::BufMut; pub mod data_types; pub mod decoders; pub mod error; pub mod leb128; -pub fn parse_rbwnat_columns_header(bytes: &mut &[u8]) -> Result, ParserError> { - let num_columns = decode_leb128(bytes)?; +pub fn parse_rbwnat_columns_header(bytes: &mut &[u8]) -> Result, TypesError> { + if bytes.len() < 1 { + return Err(TypesError::NotEnoughData( + "decoding columns header, expected at least one byte to start".to_string(), + )); + } + let num_columns = read_leb128(bytes)?; if num_columns == 0 { - return Err(ParserError::HeaderParsingError( + return Err(TypesError::HeaderParsingError( "Expected at least one column in the header".to_string(), )); } @@ -33,3 +40,22 @@ pub fn parse_rbwnat_columns_header(bytes: &mut &[u8]) -> Result, Par .collect(); Ok(columns) } + +pub fn put_rbwnat_columns_header( + columns: &[Column], + mut buffer: impl BufMut, +) -> Result<(), TypesError> { + if columns.is_empty() { + return Err(TypesError::EmptyColumns); + } + put_leb128(&mut buffer, columns.len() as u64); + for column in columns { + put_leb128(&mut buffer, column.name.len() as u64); + buffer.put_slice(column.name.as_bytes()); + } + for column in columns.into_iter() { + put_leb128(&mut buffer, column.data_type.to_string().len() as u64); + buffer.put_slice(column.data_type.to_string().as_bytes()); + } + Ok(()) +} From b26006e0c7b79886f000287a60d02d6130f624fc Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Wed, 28 May 2025 23:57:06 +0200 Subject: [PATCH 15/54] Add Variant support, improve validation, tests --- examples/mock.rs | 2 +- src/cursors/row.rs | 24 ++--- src/rowbinary/de.rs | 32 +++--- src/rowbinary/validation.rs | 175 ++++++++++++++++++++++++--------- src/test/handlers.rs | 2 +- tests/it/insert.rs | 48 +++++---- tests/it/main.rs | 4 +- tests/it/query.rs | 50 +++++----- tests/it/rbwnat.rs | 190 ++++++++++++++++++++++-------------- tests/it/variant.rs | 12 +-- types/src/data_types.rs | 7 ++ types/src/decoders.rs | 25 +++-- types/src/lib.rs | 16 ++- 13 files changed, 368 insertions(+), 219 deletions(-) diff --git a/examples/mock.rs b/examples/mock.rs index 70d64c16..f71bdc29 100644 --- a/examples/mock.rs +++ b/examples/mock.rs @@ -58,7 +58,7 @@ async fn main() { // How to test SELECT. mock.add(test::handlers::provide( - &vec![Column::new("no".to_string(), UInt32)], + &[Column::new("no".to_string(), UInt32)], list.clone(), )); let rows = make_select(&client).await.unwrap(); diff --git a/src/cursors/row.rs b/src/cursors/row.rs index b55c6fd2..a502e634 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -73,19 +73,6 @@ impl RowCursor { } } - #[inline(always)] - fn deserialize_with_validation<'cursor, 'data: 'cursor>( - &'cursor mut self, - slice: &mut &'data [u8], - ) -> (Result, bool) - where - T: Deserialize<'data>, - { - let result = rowbinary::deserialize_from_and_validate::(slice, &self.columns); - self.rows_to_validate -= 1; - result - } - /// Emits the next row. /// /// The result is unspecified if it's called after `Err` is returned. @@ -101,6 +88,9 @@ impl RowCursor { if self.bytes.remaining() > 0 { if self.columns.is_empty() { self.read_columns().await?; + if self.bytes.remaining() == 0 { + continue; + } } let mut slice = super::workaround_51132(self.bytes.slice()); let (result, not_enough_data) = match self.rows_to_validate { @@ -109,8 +99,12 @@ impl RowCursor { rowbinary::deserialize_from_and_validate::(&mut slice, &self.columns) } _ => { - // extracting to a separate method boosts performance for Each ~10% - self.deserialize_with_validation(&mut slice) + let result = rowbinary::deserialize_from_and_validate::( + &mut slice, + &self.columns, + ); + self.rows_to_validate -= 1; + result } }; if !not_enough_data { diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index edb5400d..af73efa5 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -31,6 +31,7 @@ pub(crate) fn deserialize_from_and_validate<'data, 'cursor, T: Deserialize<'data input: &mut &'data [u8], columns: &'cursor [Column], ) -> (Result, bool) { + // println!("deserialize_from_and_validate call"); let result = if columns.is_empty() { let mut deserializer = RowBinaryDeserializer::new(input, ()); T::deserialize(&mut deserializer) @@ -107,7 +108,7 @@ where let mut maybe_enum_validator = self.validator.validate(SerdeType::I8)?; ensure_size(&mut self.input, size_of::())?; let value = self.input.get_i8(); - maybe_enum_validator.validate_enum8(value); + maybe_enum_validator.validate_enum8_value(value); visitor.visit_i8(value) } @@ -117,7 +118,7 @@ where ensure_size(&mut self.input, size_of::())?; let value = self.input.get_i16_le(); // TODO: is there a better way to validate that the deserialized value matches the schema? - maybe_enum_validator.validate_enum16(value); + maybe_enum_validator.validate_enum16_value(value); visitor.visit_i16(value) } @@ -212,8 +213,11 @@ where fn deserialize_identifier>(self, visitor: V) -> Result { // println!("deserialize_identifier call"); - self.validator.validate(SerdeType::Identifier)?; - self.deserialize_u8(visitor) + ensure_size(&mut self.input, size_of::())?; + let value = self.input.get_u8(); + // TODO: is there a better way to validate that the deserialized value matches the schema? + self.validator.set_next_variant_value(value); + visitor.visit_u8(value) } #[inline(always)] @@ -295,10 +299,11 @@ where } } + let validator = self.validator.validate(SerdeType::Enum)?; visitor.visit_enum(RowBinaryEnumAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, - validator: self.validator.validate(SerdeType::Enum)?, + validator, }, }) } @@ -307,11 +312,13 @@ where fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { // println!("deserialize_tuple call, len {}", len); + let validator = self.validator.validate(SerdeType::Tuple(len))?; + let mut de = RowBinaryDeserializer { + input: self.input, + validator, + }; let access = RowBinarySeqAccess { - deserializer: &mut RowBinaryDeserializer { - input: self.input, - validator: self.validator.validate(SerdeType::Tuple(len))?, - }, + deserializer: &mut de, len, }; visitor.visit_seq(access) @@ -350,8 +357,8 @@ where #[inline(always)] fn deserialize_map>(self, visitor: V) -> Result { // println!( - // "deserialize_map call", - // ); + // "deserialize_map call", + // ); struct RowBinaryMapAccess<'de, 'cursor, 'data, Validator> where @@ -392,10 +399,11 @@ where } let len = self.read_size()?; + let validator = self.validator.validate(SerdeType::Map(len))?; visitor.visit_map(RowBinaryMapAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, - validator: self.validator.validate(SerdeType::Map(len))?, + validator, }, entries_visited: 0, len, diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 9bdf2dff..dee8f817 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -8,8 +8,9 @@ pub(crate) trait ValidateDataType: Sized { &'_ mut self, serde_type: SerdeType, ) -> Result>>; - fn validate_enum8(&mut self, value: i8); - fn validate_enum16(&mut self, value: i16); + fn set_next_variant_value(&mut self, value: u8); + fn validate_enum8_value(&mut self, value: i8); + fn validate_enum16_value(&mut self, value: i16); fn set_struct_name(&mut self, name: &'static str); } @@ -91,19 +92,30 @@ impl ValidateDataType for DataTypeValidator<'_> { // this allows validating and deserializing tuples from fetch calls Ok(Some(InnerDataTypeValidator { root: self, - kind: InnerDataTypeValidatorKind::RootTuple(self.columns, 0), + kind: if matches!(serde_type, SerdeType::Seq(_)) && self.columns.len() == 1 { + let data_type = &self.columns[0].data_type; + match data_type { + DataTypeNode::Array(inner_type) => { + InnerDataTypeValidatorKind::RootArray(inner_type) + } + _ => panic!( + "Expected Array type when validating root level sequence, but got {}", + self.columns[0].data_type + ), + } + } else { + InnerDataTypeValidatorKind::RootTuple(self.columns, 0) + }, })) + } else if self.current_column_idx < self.columns.len() { + let current_column = &self.columns[self.current_column_idx]; + self.current_column_idx += 1; + validate_impl(self, ¤t_column.data_type, &serde_type, false) } else { - if self.current_column_idx < self.columns.len() { - let current_column = &self.columns[self.current_column_idx]; - self.current_column_idx += 1; - validate_impl(self, ¤t_column.data_type, &serde_type, false) - } else { - panic!( - "Struct {} has more fields than columns in the database schema", - self.get_struct_name() - ) - } + panic!( + "Struct {} has more fields than columns in the database schema", + self.get_struct_name() + ) } } @@ -116,13 +128,19 @@ impl ValidateDataType for DataTypeValidator<'_> { #[cold] #[inline(never)] - fn validate_enum8(&mut self, _value: i8) { + fn validate_enum8_value(&mut self, _value: i8) { unreachable!() } #[cold] #[inline(never)] - fn validate_enum16(&mut self, _value: i16) { + fn validate_enum16_value(&mut self, _value: i16) { + unreachable!() + } + + #[cold] + #[inline(never)] + fn set_next_variant_value(&mut self, _value: u8) { unreachable!() } } @@ -155,19 +173,27 @@ pub(crate) enum InnerDataTypeValidatorKind<'cursor> { MapValidatorState, ), Tuple(&'cursor [DataTypeNode]), - /// This is a hack to support deserializing tuples (and not structs) from fetch calls + /// This is a hack to support deserializing tuples/vectors (and not structs) from fetch calls RootTuple(&'cursor [Column], usize), + RootArray(&'cursor DataTypeNode), Enum(&'cursor HashMap), - // Variant(&'cursor [DataTypeNode]), + Variant(&'cursor [DataTypeNode], VariantValidationState), Nullable(&'cursor DataTypeNode), } +#[derive(Debug)] +pub(crate) enum VariantValidationState { + Pending, + Identifier(u8), +} + impl<'de, 'cursor> ValidateDataType for Option> { #[inline] fn validate( &mut self, serde_type: SerdeType, ) -> Result>> { + // println!("[validate] Validating serde type: {}", serde_type); match self { None => Ok(None), Some(inner) => match &mut inner.kind { @@ -218,14 +244,10 @@ impl<'de, 'cursor> ValidateDataType for Option { - if *current_index < columns.len() - 1 { + if *current_index < columns.len() { + let data_type = &columns[*current_index].data_type; *current_index += 1; - validate_impl( - inner.root, - &columns[*current_index].data_type, - &serde_type, - true, - ) + validate_impl(inner.root, data_type, &serde_type, true) } else { let (full_name, full_data_type) = inner.root.get_current_column_name_and_type(); @@ -236,9 +258,28 @@ impl<'de, 'cursor> ValidateDataType for Option { - // Ok(None) // FIXME: requires comparing DataTypeNode vs TypeHint or SerdeType - // } + InnerDataTypeValidatorKind::RootArray(inner_data_type) => { + validate_impl(inner.root, inner_data_type, &serde_type, true) + } + InnerDataTypeValidatorKind::Variant(possible_types, state) => match state { + VariantValidationState::Pending => { + unreachable!() + } + VariantValidationState::Identifier(value) => { + // println!("Validating variant identifier: {}", value); + if *value as usize >= possible_types.len() { + let (full_name, full_data_type) = + inner.root.get_current_column_name_and_type(); + panic!( + "While processing column {full_name} defined as {full_data_type}: \ + Variant identifier {value} is out of bounds, max allowed index is {}", + possible_types.len() - 1 + ); + } + let data_type = &possible_types[*value as usize]; + validate_impl(inner.root, data_type, &serde_type, true) + } + }, InnerDataTypeValidatorKind::Enum(_values_map) => { todo!() // TODO - check value correctness in the hashmap } @@ -247,22 +288,48 @@ impl<'de, 'cursor> ValidateDataType for Option { #[inline] fn validate_impl<'de, 'cursor>( root: &'de DataTypeValidator<'cursor>, - data_type: &'cursor DataTypeNode, + column_data_type: &'cursor DataTypeNode, serde_type: &SerdeType, is_inner: bool, ) -> Result>> { - println!( - "Validating data type: {} against serde type: {}", - data_type, serde_type, - ); + // println!( + // "Validating data type: {} against serde type: {}", + // column_data_type, serde_type, + // ); + let data_type = column_data_type.remove_low_cardinality(); // TODO: eliminate multiple branches with similar patterns? match serde_type { SerdeType::Bool @@ -356,8 +424,7 @@ fn validate_impl<'de, 'cursor>( { Ok(None) } - // TODO: what should be allowed type for SerdeType::Identifier? - SerdeType::Identifier | SerdeType::U8 if data_type == &DataTypeNode::UInt8 => Ok(None), + SerdeType::U8 if data_type == &DataTypeNode::UInt8 => Ok(None), SerdeType::U16 if data_type == &DataTypeNode::UInt16 || data_type == &DataTypeNode::Date => { @@ -507,10 +574,27 @@ fn validate_impl<'de, 'cursor>( } } SerdeType::Enum => { - todo!("variant data type validation") + if let DataTypeNode::Variant(possible_types) = data_type { + Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Variant( + possible_types, + VariantValidationState::Pending, + ), + })) + } else { + panic!( + "Expected Variant for {} call, but got {}", + serde_type, data_type + ) + } } - _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), + _ => root.panic_on_schema_mismatch( + data_type, + serde_type, + is_inner || matches!(column_data_type, DataTypeNode::LowCardinality { .. }), + ), } } @@ -524,10 +608,13 @@ impl ValidateDataType for () { } #[inline(always)] - fn validate_enum8(&mut self, _enum_value: i8) {} + fn validate_enum8_value(&mut self, _value: i8) {} + + #[inline(always)] + fn validate_enum16_value(&mut self, _value: i16) {} #[inline(always)] - fn validate_enum16(&mut self, _enum_value: i16) {} + fn set_next_variant_value(&mut self, _value: u8) {} #[inline(always)] fn set_struct_name(&mut self, _name: &'static str) {} @@ -555,12 +642,12 @@ pub(crate) enum SerdeType { String, Option, Enum, - Identifier, Bytes(usize), ByteBuf(usize), Tuple(usize), Seq(usize), Map(usize), + // Identifier, // Char, // Unit, // Struct, @@ -595,7 +682,7 @@ impl Display for SerdeType { SerdeType::Seq(_len) => "Vec", SerdeType::Tuple(len) => &format!("a tuple or sequence with length {len}"), SerdeType::Map(_len) => "map", - SerdeType::Identifier => "identifier", + // SerdeType::Identifier => "identifier", // SerdeType::Char => "char", // SerdeType::Unit => "()", // SerdeType::Struct => "struct", diff --git a/src/test/handlers.rs b/src/test/handlers.rs index 3972394c..f6b61f9f 100644 --- a/src/test/handlers.rs +++ b/src/test/handlers.rs @@ -41,7 +41,7 @@ pub fn failure(status: StatusCode) -> impl Handler { // === provide === #[track_caller] -pub fn provide<'a, T>(schema: &[Column], rows: impl IntoIterator) -> impl Handler +pub fn provide(schema: &[Column], rows: impl IntoIterator) -> impl Handler where T: Serialize, { diff --git a/tests/it/insert.rs b/tests/it/insert.rs index 5e7a77e1..47058696 100644 --- a/tests/it/insert.rs +++ b/tests/it/insert.rs @@ -1,26 +1,7 @@ use crate::{create_simple_table, fetch_rows, flush_query_log, SimpleRow}; -use clickhouse::{sql::Identifier, Client, Row}; +use clickhouse::{sql::Identifier, Row}; use serde::{Deserialize, Serialize}; -#[derive(Debug, Row, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "camelCase")] -struct RenameRow { - #[serde(rename = "fix_id")] - pub(crate) fix_id: i64, - #[serde(rename = "extComplexId")] - pub(crate) complex_id: String, - pub(crate) ext_float: f64, -} - -async fn create_rename_table(client: &Client, table_name: &str) { - client - .query("CREATE TABLE ?(fixId UInt64, extComplexId String, extFloat Float64) ENGINE = MergeTree ORDER BY fixId") - .bind(Identifier(table_name)) - .execute() - .await - .unwrap(); -} - #[tokio::test] async fn keeps_client_options() { let table_name = "insert_keeps_client_options"; @@ -144,11 +125,36 @@ async fn empty_insert() { #[tokio::test] async fn rename_insert() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + #[serde(rename_all = "camelCase")] + struct RenameRow { + #[serde(rename = "fix_id")] + pub(crate) fix_id: u64, + #[serde(rename = "extComplexId")] + pub(crate) complex_id: String, + pub(crate) ext_float: f64, + } + let table_name = "insert_rename"; let query_id = uuid::Uuid::new_v4().to_string(); let client = prepare_database!(); - create_rename_table(&client, table_name).await; + client + .query( + " + CREATE TABLE ?( + fixId UInt64, + extComplexId String, + extFloat Float64 + ) + ENGINE = MergeTree + ORDER BY fixId + ", + ) + .bind(Identifier(table_name)) + .execute() + .await + .unwrap(); let row = RenameRow { fix_id: 42, diff --git a/tests/it/main.rs b/tests/it/main.rs index 37004ea0..4148f15c 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -27,7 +27,7 @@ macro_rules! assert_panic_on_fetch_with_client { ($client:ident, $msg_parts:expr, $query:expr) => { use futures::FutureExt; let async_panic = - std::panic::AssertUnwindSafe(async { $client.query($query).fetch_one::().await }); + std::panic::AssertUnwindSafe(async { $client.query($query).fetch_all::().await }); let result = async_panic.catch_unwind().await; assert!(result.is_err()); let panic_msg = *result.unwrap_err().downcast::().unwrap(); @@ -45,7 +45,7 @@ macro_rules! assert_panic_on_fetch { use futures::FutureExt; let client = get_client().with_validation_mode(ValidationMode::Each); let async_panic = - std::panic::AssertUnwindSafe(async { client.query($query).fetch_one::().await }); + std::panic::AssertUnwindSafe(async { client.query($query).fetch_all::().await }); let result = async_panic.catch_unwind().await; assert!(result.is_err()); let panic_msg = *result.unwrap_err().downcast::().unwrap(); diff --git a/tests/it/query.rs b/tests/it/query.rs index 7b783e92..398b8654 100644 --- a/tests/it/query.rs +++ b/tests/it/query.rs @@ -88,31 +88,31 @@ async fn fetch_one_and_optional() { #[tokio::test] async fn server_side_param() { let client = prepare_database!(); - - let result = client - .query("SELECT plus({val1: Int32}, {val2: Int32}) AS result") - .param("val1", 42) - .param("val2", 144) - .fetch_one::() - .await - .expect("failed to fetch Int64"); - assert_eq!(result, 186); - - let result = client - .query("SELECT {val1: String} AS result") - .param("val1", "string") - .fetch_one::() - .await - .expect("failed to fetch string"); - assert_eq!(result, "string"); - - let result = client - .query("SELECT {val1: String} AS result") - .param("val1", "\x01\x02\x03\\ \"\'") - .fetch_one::() - .await - .expect("failed to fetch string"); - assert_eq!(result, "\x01\x02\x03\\ \"\'"); + // + // let result = client + // .query("SELECT plus({val1: Int32}, {val2: Int32}) AS result") + // .param("val1", 42) + // .param("val2", 144) + // .fetch_one::() + // .await + // .expect("failed to fetch Int64"); + // assert_eq!(result, 186); + // + // let result = client + // .query("SELECT {val1: String} AS result") + // .param("val1", "string") + // .fetch_one::() + // .await + // .expect("failed to fetch string"); + // assert_eq!(result, "string"); + // + // let result = client + // .query("SELECT {val1: String} AS result") + // .param("val1", "\x01\x02\x03\\ \"\'") + // .fetch_one::() + // .await + // .expect("failed to fetch string"); + // assert_eq!(result, "\x01\x02\x03\\ \"\'"); let result = client .query("SELECT {val1: Array(String)} AS result") diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 04c15e4a..35f29cab 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -394,6 +394,41 @@ async fn test_enum() { #[tokio::test] async fn test_nullable() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + b: Option, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query( + " + SELECT * FROM ( + SELECT 1 :: UInt32 AS a, 2 :: Nullable(Int64) AS b + UNION ALL + SELECT 3 :: UInt32 AS a, NULL :: Nullable(Int64) AS b + UNION ALL + SELECT 4 :: UInt32 AS a, 5 :: Nullable(Int64) AS b + ) + ORDER BY a ASC + ", + ) + .fetch_all::() + .await; + + assert_eq!( + result.unwrap(), + vec![ + Data { a: 1, b: Some(2) }, + Data { a: 3, b: None }, + Data { a: 4, b: Some(5) }, + ] + ); +} + +#[tokio::test] +async fn test_invalid_nullable() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { n: Option, @@ -404,9 +439,77 @@ async fn test_nullable() { ); } +#[tokio::test] +async fn test_low_cardinality() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + b: Option, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query( + " + SELECT * FROM ( + SELECT 1 :: LowCardinality(UInt32) AS a, 2 :: LowCardinality(Nullable(Int64)) AS b + UNION ALL + SELECT 3 :: LowCardinality(UInt32) AS a, NULL :: LowCardinality(Nullable(Int64)) AS b + UNION ALL + SELECT 4 :: LowCardinality(UInt32) AS a, 5 :: LowCardinality(Nullable(Int64)) AS b + ) + ORDER BY a ASC + ", + ) + .with_option("allow_suspicious_low_cardinality_types", "1") + .fetch_all::() + .await; + + assert_eq!( + result.unwrap(), + vec![ + Data { a: 1, b: Some(2) }, + Data { a: 3, b: None }, + Data { a: 4, b: Some(5) }, + ] + ); +} + +#[tokio::test] +async fn test_invalid_low_cardinality() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + } + let client = get_client() + .with_validation_mode(ValidationMode::Each) + .with_option("allow_suspicious_low_cardinality_types", "1"); + assert_panic_on_fetch_with_client!( + client, + &["Data.a", "LowCardinality(Int32)", "u32"], + "SELECT 144 :: LowCardinality(Int32) AS a" + ); +} + +#[tokio::test] +async fn test_invalid_nullable_low_cardinality() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: Option, + } + let client = get_client() + .with_validation_mode(ValidationMode::Each) + .with_option("allow_suspicious_low_cardinality_types", "1"); + assert_panic_on_fetch_with_client!( + client, + &["Data.a", "LowCardinality(Nullable(Int32))", "u32"], + "SELECT 144 :: LowCardinality(Nullable(Int32)) AS a" + ); +} + #[tokio::test] #[cfg(feature = "time")] -async fn test_serde_with() { +async fn test_invalid_serde_with() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { #[serde(with = "clickhouse::serde::time::datetime64::millis")] @@ -726,52 +829,6 @@ async fn test_deeply_nested_validation_incorrect_fixed_string() { ); } -#[tokio::test] -#[ignore] -async fn test_variant() { - #[derive(Debug, Deserialize, PartialEq)] - enum MyVariant { - Str(String), - U16(u16), - } - - #[derive(Debug, Row, Deserialize, PartialEq)] - struct Data { - id: u8, - var: MyVariant, - } - - let client = get_client() - .with_validation_mode(ValidationMode::Each) - .with_option("allow_experimental_variant_type", "1"); - let result = client - .query( - " - SELECT * FROM ( - SELECT 0 :: UInt8 AS id, 'foo' :: Variant(String, UInt16) AS var - UNION ALL - SELECT 1 :: UInt8 AS id, 144 :: Variant(String, UInt16) AS var - ) ORDER BY id ASC - ", - ) - .fetch_all::() - .await; - - assert_eq!( - result.unwrap(), - vec![ - Data { - id: 0, - var: MyVariant::Str("foo".to_string()) - }, - Data { - id: 1, - var: MyVariant::U16(144) - }, - ] - ); -} - #[tokio::test] async fn test_geo() { #[derive(Clone, Debug, PartialEq)] @@ -875,7 +932,7 @@ async fn test_issue_109_1() { .query(stmt) .execute() .await - .expect(&format!("Failed to execute query: {}", stmt)); + .unwrap_or_else(|e| panic!("Failed to execute query {stmt}, cause: {}", e)); } let data = client .query("SELECT journey, drone_id, call_sign FROM issue_109") @@ -929,7 +986,7 @@ async fn test_issue_113() { .query(stmt) .execute() .await - .expect(&format!("Failed to execute query: {}", stmt)); + .unwrap_or_else(|e| panic!("Failed to execute query {stmt}, cause: {}", e)); } // Struct should have had Option instead of f64 @@ -983,7 +1040,6 @@ async fn test_issue_185() { } #[tokio::test] -#[ignore] // this is currently disabled, see validation todo async fn test_variant_wrong_definition() { #[derive(Debug, Deserialize, PartialEq)] enum MyVariant { @@ -1000,31 +1056,17 @@ async fn test_variant_wrong_definition() { let client = get_client() .with_validation_mode(ValidationMode::Each) .with_option("allow_experimental_variant_type", "1"); - let result = client - .query( - " - SELECT * FROM ( - SELECT 0 :: UInt8 AS id, 'foo' :: Variant(String, UInt16) AS var - UNION ALL - SELECT 1 :: UInt8 AS id, 144 :: Variant(String, UInt16) AS var - ) ORDER BY id ASC - ", - ) - .fetch_all::() - .await; - assert_eq!( - result.unwrap(), - vec![ - Data { - id: 0, - var: MyVariant::Str("foo".to_string()) - }, - Data { - id: 1, - var: MyVariant::U32(144) - }, - ] + assert_panic_on_fetch_with_client!( + client, + &["Data.var", "Variant(String, UInt16)", "u32"], + " + SELECT * FROM ( + SELECT 0 :: UInt8 AS id, 'foo' :: Variant(String, UInt16) AS var + UNION ALL + SELECT 1 :: UInt8 AS id, 144 :: Variant(String, UInt16) AS var + ) ORDER BY id ASC + " ); } diff --git a/tests/it/variant.rs b/tests/it/variant.rs index 14e81901..d5f9dae2 100644 --- a/tests/it/variant.rs +++ b/tests/it/variant.rs @@ -3,13 +3,13 @@ use serde::{Deserialize, Serialize}; use time::Month::January; +use clickhouse::validation_mode::ValidationMode::Each; use clickhouse::Row; - // See also: https://clickhouse.com/docs/en/sql-reference/data-types/variant #[tokio::test] async fn variant_data_type() { - let client = prepare_database!(); + let client = prepare_database!().with_validation_mode(Each); // NB: Inner Variant types are _always_ sorted alphabetically, // and should be defined in _exactly_ the same order in the enum. @@ -30,10 +30,10 @@ async fn variant_data_type() { Int8(i8), String(String), UInt128(u128), - UInt16(i16), + UInt16(u16), UInt32(u32), UInt64(u64), - UInt8(i8), + UInt8(u8), } #[derive(Debug, PartialEq, Row, Serialize, Deserialize)] @@ -42,14 +42,14 @@ async fn variant_data_type() { } // No matter the order of the definition on the Variant types, it will always be sorted as follows: - // Variant(Array(UInt16), Bool, FixedString(6), Float32, Float64, Int128, Int16, Int32, Int64, Int8, String, UInt128, UInt16, UInt32, UInt64, UInt8) + // Variant(Array(Int16), Bool, FixedString(6), Float32, Float64, Int128, Int16, Int32, Int64, Int8, String, UInt128, UInt16, UInt32, UInt64, UInt8) client .query( " CREATE OR REPLACE TABLE test_var ( `var` Variant( - Array(UInt16), + Array(Int16), Bool, Date, FixedString(6), diff --git a/types/src/data_types.rs b/types/src/data_types.rs index 4ef059fc..6f5efb75 100644 --- a/types/src/data_types.rs +++ b/types/src/data_types.rs @@ -133,6 +133,13 @@ impl DataTypeNode { ))), } } + + pub fn remove_low_cardinality(&self) -> &DataTypeNode { + match self { + DataTypeNode::LowCardinality(inner) => inner, + _ => self, + } + } } impl Into for DataTypeNode { diff --git a/types/src/decoders.rs b/types/src/decoders.rs index b683b1bc..4e9c0865 100644 --- a/types/src/decoders.rs +++ b/types/src/decoders.rs @@ -3,18 +3,27 @@ use crate::leb128::read_leb128; use bytes::Buf; #[inline] -pub(crate) fn decode_string(buffer: &mut &[u8]) -> Result { +pub(crate) fn read_string(buffer: &mut &[u8]) -> Result { + ensure_size(buffer, 1)?; let length = read_leb128(buffer)? as usize; if length == 0 { return Ok("".to_string()); } - if buffer.remaining() < length { - return Err(TypesError::NotEnoughData(format!( - "decoding string, {} bytes remaining, {} bytes required", - buffer.remaining(), - length, - ))); - } + ensure_size(buffer, length)?; let result = String::from_utf8_lossy(&buffer.copy_to_bytes(length)).to_string(); Ok(result) } + +#[inline] +pub(crate) fn ensure_size(buffer: &[u8], size: usize) -> Result<(), TypesError> { + // println!("[ensure_size] buffer remaining: {}, required size: {}", buffer.len(), size); + if buffer.remaining() < size { + Err(TypesError::NotEnoughData(format!( + "expected at least {} bytes, but only {} bytes remaining", + size, + buffer.remaining() + ))) + } else { + Ok(()) + } +} diff --git a/types/src/lib.rs b/types/src/lib.rs index a25abe89..bed7ccea 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -1,5 +1,5 @@ pub use crate::data_types::{Column, DataTypeNode}; -use crate::decoders::decode_string; +use crate::decoders::{ensure_size, read_string}; use crate::error::TypesError; pub use crate::leb128::put_leb128; pub use crate::leb128::read_leb128; @@ -10,13 +10,9 @@ pub mod decoders; pub mod error; pub mod leb128; -pub fn parse_rbwnat_columns_header(bytes: &mut &[u8]) -> Result, TypesError> { - if bytes.len() < 1 { - return Err(TypesError::NotEnoughData( - "decoding columns header, expected at least one byte to start".to_string(), - )); - } - let num_columns = read_leb128(bytes)?; +pub fn parse_rbwnat_columns_header(buffer: &mut &[u8]) -> Result, TypesError> { + ensure_size(buffer, 1)?; + let num_columns = read_leb128(buffer)?; if num_columns == 0 { return Err(TypesError::HeaderParsingError( "Expected at least one column in the header".to_string(), @@ -24,12 +20,12 @@ pub fn parse_rbwnat_columns_header(bytes: &mut &[u8]) -> Result, Typ } let mut columns_names: Vec = Vec::with_capacity(num_columns as usize); for _ in 0..num_columns { - let column_name = decode_string(bytes)?; + let column_name = read_string(buffer)?; columns_names.push(column_name); } let mut column_data_types: Vec = Vec::with_capacity(num_columns as usize); for _ in 0..num_columns { - let column_type = decode_string(bytes)?; + let column_type = read_string(buffer)?; let data_type = DataTypeNode::new(&column_type)?; column_data_types.push(data_type); } From 856720016a7acec9319bf0e67ca1ccf1dc891cc0 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 29 May 2025 00:06:06 +0200 Subject: [PATCH 16/54] Fix compile issues, clippy, etc --- src/cursors/row.rs | 13 ++++--------- src/rowbinary/de.rs | 22 +++++++++------------- src/rowbinary/mod.rs | 1 - src/test/handlers.rs | 3 ++- 4 files changed, 15 insertions(+), 24 deletions(-) diff --git a/src/cursors/row.rs b/src/cursors/row.rs index a502e634..24bf1153 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -49,7 +49,7 @@ impl RowCursor { return Ok(()); } Ok(_) => { - // or panic instead? + // TODO: or panic instead? return Err(Error::BadResponse( "Expected at least one column in the header".to_string(), )); @@ -94,15 +94,10 @@ impl RowCursor { } let mut slice = super::workaround_51132(self.bytes.slice()); let (result, not_enough_data) = match self.rows_to_validate { - 0 => rowbinary::deserialize_from_and_validate::(&mut slice, &[]), - u64::MAX => { - rowbinary::deserialize_from_and_validate::(&mut slice, &self.columns) - } + 0 => rowbinary::deserialize_from::(&mut slice, &[]), + u64::MAX => rowbinary::deserialize_from::(&mut slice, &self.columns), _ => { - let result = rowbinary::deserialize_from_and_validate::( - &mut slice, - &self.columns, - ); + let result = rowbinary::deserialize_from::(&mut slice, &self.columns); self.rows_to_validate -= 1; result } diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index af73efa5..f4063a7a 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -12,26 +12,22 @@ use serde::{ }; use std::{convert::TryFrom, str}; -/// Deserializes a value from `input` with a row encoded in `RowBinary`. +/// Deserializes a value from `input` with a row encoded in `RowBinary(WithNamesAndTypes)`. /// /// It accepts _a reference to_ a byte slice because it somehow leads to a more /// performant generated code than `(&[u8]) -> Result<(T, usize)>` and even /// `(&[u8], &mut Option) -> Result`. -pub(crate) fn deserialize_from<'data, T: Deserialize<'data>>(input: &mut &'data [u8]) -> Result { - // println!("deserialize_from call"); - - let mut deserializer = RowBinaryDeserializer::new(input, ()); - T::deserialize(&mut deserializer) -} - -/// Similar to [`deserialize_from`], but expects a slice of [`Column`] objects -/// parsed from the beginning of `RowBinaryWithNamesAndTypes` data stream. +/// +/// Additionally, having a single function speeds up [`crate::cursors::RowCursor::next`] x2. +/// A hint about the [`Error::NotEnoughData`] gives another 20% performance boost. +/// +/// It expects a slice of [`Column`] objects parsed +/// from the beginning of `RowBinaryWithNamesAndTypes` data stream. /// After the header, the rows format is the same as `RowBinary`. -pub(crate) fn deserialize_from_and_validate<'data, 'cursor, T: Deserialize<'data>>( +pub(crate) fn deserialize_from<'data, 'cursor, T: Deserialize<'data>>( input: &mut &'data [u8], columns: &'cursor [Column], ) -> (Result, bool) { - // println!("deserialize_from_and_validate call"); let result = if columns.is_empty() { let mut deserializer = RowBinaryDeserializer::new(input, ()); T::deserialize(&mut deserializer) @@ -48,7 +44,7 @@ pub(crate) fn deserialize_from_and_validate<'data, 'cursor, T: Deserialize<'data } } -/// A deserializer for the RowBinary(WithNamesAndTypes) format. +/// A deserializer for the `RowBinary(WithNamesAndTypes)` format. /// /// See https://clickhouse.com/docs/en/interfaces/formats#rowbinary for details. struct RowBinaryDeserializer<'cursor, 'data, Validator = ()> diff --git a/src/rowbinary/mod.rs b/src/rowbinary/mod.rs index 6b864023..7a1dfbb1 100644 --- a/src/rowbinary/mod.rs +++ b/src/rowbinary/mod.rs @@ -1,5 +1,4 @@ pub(crate) use de::deserialize_from; -pub(crate) use de::deserialize_from_and_validate; pub(crate) use ser::serialize_into; mod de; diff --git a/src/test/handlers.rs b/src/test/handlers.rs index f6b61f9f..10854c49 100644 --- a/src/test/handlers.rs +++ b/src/test/handlers.rs @@ -95,7 +95,8 @@ where let mut result = C::default(); while !slice.is_empty() { - let row: T = rowbinary::deserialize_from(slice).expect("failed to deserialize"); + let (de_result, _) = rowbinary::deserialize_from(slice, &[]); + let row: T = de_result.expect("failed to deserialize"); result.extend(std::iter::once(row)); } From a1181a032837d2586f1de0867de91bfb961b09cf Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 29 May 2025 00:17:14 +0200 Subject: [PATCH 17/54] Fix older Rust versions compile issues, docs --- src/lib.rs | 4 +-- src/rowbinary/validation.rs | 49 ++++++++++++++++++------------------- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4641ce8f..55c2221f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -298,8 +298,8 @@ impl Client { } /// Specifies the struct validation mode that will be used when calling - /// [`Query::fetch`], [`Query::fetch_one`], [`Query::fetch_all`], - /// and [`Query::fetch_optional`] methods. + /// [`query::Query::fetch`], [`query::Query::fetch_one`], [`query::Query::fetch_all`], + /// and [`query::Query::fetch_optional`] methods. /// See [`ValidationMode`] for more details. pub fn with_validation_mode(mut self, mode: ValidationMode) -> Self { self.validation_mode = mode; diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index dee8f817..af076984 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -659,29 +659,29 @@ pub(crate) enum SerdeType { impl Display for SerdeType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let type_name = match self { - SerdeType::Bool => "bool", - SerdeType::I8 => "i8", - SerdeType::I16 => "i16", - SerdeType::I32 => "i32", - SerdeType::I64 => "i64", - SerdeType::I128 => "i128", - SerdeType::U8 => "u8", - SerdeType::U16 => "u16", - SerdeType::U32 => "u32", - SerdeType::U64 => "u64", - SerdeType::U128 => "u128", - SerdeType::F32 => "f32", - SerdeType::F64 => "f64", - SerdeType::Str => "&str", - SerdeType::String => "String", - SerdeType::Bytes(len) => &format!("&[u8; {len}]"), - SerdeType::ByteBuf(_len) => "Vec", - SerdeType::Option => "Option", - SerdeType::Enum => "enum", - SerdeType::Seq(_len) => "Vec", - SerdeType::Tuple(len) => &format!("a tuple or sequence with length {len}"), - SerdeType::Map(_len) => "map", + match self { + SerdeType::Bool => write!(f, "bool"), + SerdeType::I8 => write!(f, "i8"), + SerdeType::I16 => write!(f, "i16"), + SerdeType::I32 => write!(f, "i32"), + SerdeType::I64 => write!(f, "i64"), + SerdeType::I128 => write!(f, "i128"), + SerdeType::U8 => write!(f, "u8"), + SerdeType::U16 => write!(f, "u16"), + SerdeType::U32 => write!(f, "u32"), + SerdeType::U64 => write!(f, "u64"), + SerdeType::U128 => write!(f, "u128"), + SerdeType::F32 => write!(f, "f32"), + SerdeType::F64 => write!(f, "f64"), + SerdeType::Str => write!(f, "&str"), + SerdeType::String => write!(f, "String"), + SerdeType::Bytes(len) => write!(f, "&[u8; {len}]"), + SerdeType::ByteBuf(_len) => write!(f, "Vec"), + SerdeType::Option => write!(f, "Option"), + SerdeType::Enum => write!(f, "enum"), + SerdeType::Seq(_len) => write!(f, "Vec"), + SerdeType::Tuple(len) => write!(f, "a tuple or sequence with length {len}"), + SerdeType::Map(_len) => write!(f, "Map"), // SerdeType::Identifier => "identifier", // SerdeType::Char => "char", // SerdeType::Unit => "()", @@ -690,7 +690,6 @@ impl Display for SerdeType { // SerdeType::TupleStruct => "tuple struct", // SerdeType::UnitStruct => "unit struct", // SerdeType::IgnoredAny => "ignored any", - }; - write!(f, "{}", type_name) + } } } From 04c7a20bb2cf1b0dde4ef8445e3251aa235c63cf Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 29 May 2025 22:17:26 +0200 Subject: [PATCH 18/54] Add NYC benchmark --- Cargo.toml | 6 +++ benches/select_nyc_taxi_data.rs | 82 +++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 benches/select_nyc_taxi_data.rs diff --git a/Cargo.toml b/Cargo.toml index 0aeddaac..54f13caf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,11 @@ undocumented_unsafe_blocks = "warn" all-features = true rustdoc-args = ["--cfg", "docsrs"] +[[bench]] +name = "select_nyc_taxi_data" +harness = false +required-features = ["time"] + [[bench]] name = "select_numbers" harness = false @@ -132,6 +137,7 @@ replace_with = { version = "0.1.7" } [dev-dependencies] criterion = "0.5.0" +tracy-client = { version = "0.18.0", features = ["enable"]} serde = { version = "1.0.106", features = ["derive"] } tokio = { version = "1.0.1", features = ["full", "test-util"] } hyper = { version = "1.1", features = ["server"] } diff --git a/benches/select_nyc_taxi_data.rs b/benches/select_nyc_taxi_data.rs new file mode 100644 index 00000000..6c89cd72 --- /dev/null +++ b/benches/select_nyc_taxi_data.rs @@ -0,0 +1,82 @@ +#![cfg(feature = "time")] + +use clickhouse::validation_mode::ValidationMode; +use clickhouse::{Client, Compression, Row}; +use criterion::black_box; +use serde::Deserialize; +use serde_repr::Deserialize_repr; +use time::OffsetDateTime; + +#[derive(Debug, Clone, Deserialize_repr)] +#[repr(i8)] +pub enum PaymentType { + CSH = 1, + CRE = 2, + NOC = 3, + DIS = 4, + UNK = 5, +} + +#[derive(Debug, Clone, Row, Deserialize)] +#[allow(dead_code)] +pub struct TripSmall { + trip_id: u32, + #[serde(with = "clickhouse::serde::time::datetime")] + pickup_datetime: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime")] + dropoff_datetime: OffsetDateTime, + pickup_longitude: Option, + pickup_latitude: Option, + dropoff_longitude: Option, + dropoff_latitude: Option, + passenger_count: u8, + trip_distance: f32, + fare_amount: f32, + extra: f32, + tip_amount: f32, + tolls_amount: f32, + total_amount: f32, + payment_type: PaymentType, + pickup_ntaname: String, + dropoff_ntaname: String, +} + +async fn bench(name: &str, compression: Compression, validation_mode: ValidationMode) { + let start = std::time::Instant::now(); + let (sum_trip_ids, dec_mbytes, rec_mbytes) = do_bench(compression, validation_mode).await; + assert_eq!(sum_trip_ids, 3630387815532582); + let elapsed = start.elapsed(); + let throughput = dec_mbytes / elapsed.as_secs_f64(); + println!("{name:>8} {validation_mode:>10} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB"); +} + +async fn do_bench(compression: Compression, validation_mode: ValidationMode) -> (u64, f64, f64) { + let client = Client::default() + .with_compression(compression) + .with_validation_mode(validation_mode) + .with_url("http://localhost:8123"); + + let mut cursor = client + .query("SELECT * FROM nyc_taxi.trips_small ORDER BY trip_id DESC") + .fetch::() + .unwrap(); + + let mut sum = 0; + while let Some(row) = cursor.next().await.unwrap() { + sum += row.trip_id as u64; + black_box(&row); + } + + let dec_bytes = cursor.decoded_bytes(); + let dec_mbytes = dec_bytes as f64 / 1024.0 / 1024.0; + let recv_bytes = cursor.received_bytes(); + let recv_mbytes = recv_bytes as f64 / 1024.0 / 1024.0; + (sum, dec_mbytes, recv_mbytes) +} + +#[tokio::main] +async fn main() { + println!("compress validation elapsed throughput received"); + bench("none", Compression::None, ValidationMode::First(1)).await; + bench("none", Compression::None, ValidationMode::Each).await; +} From 1f6c9e6c15be3f285497268e4b143b20896f6c10 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 29 May 2025 22:24:03 +0200 Subject: [PATCH 19/54] Add compression to the NYC benchmark --- benches/select_nyc_taxi_data.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/benches/select_nyc_taxi_data.rs b/benches/select_nyc_taxi_data.rs index 6c89cd72..d3c449a9 100644 --- a/benches/select_nyc_taxi_data.rs +++ b/benches/select_nyc_taxi_data.rs @@ -78,5 +78,7 @@ async fn do_bench(compression: Compression, validation_mode: ValidationMode) -> async fn main() { println!("compress validation elapsed throughput received"); bench("none", Compression::None, ValidationMode::First(1)).await; + bench("lz4", Compression::Lz4, ValidationMode::First(1)).await; bench("none", Compression::None, ValidationMode::Each).await; + bench("lz4", Compression::Lz4, ValidationMode::Each).await; } From 9bafc9ad493f981d6d79c48fbe6ea74544f00ace Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Wed, 4 Jun 2025 17:04:51 +0200 Subject: [PATCH 20/54] Add more tests --- src/rowbinary/tests.rs | 32 +++++++++-------- tests/it/query.rs | 50 +++++++++++++------------- tests/it/rbwnat.rs | 79 +++++++++++++++++++++++++++++++++++++++++ types/src/data_types.rs | 2 +- 4 files changed, 122 insertions(+), 41 deletions(-) diff --git a/src/rowbinary/tests.rs b/src/rowbinary/tests.rs index f4955333..0b4c58fc 100644 --- a/src/rowbinary/tests.rs +++ b/src/rowbinary/tests.rs @@ -114,18 +114,20 @@ fn it_serializes() { assert_eq!(actual, sample_serialized()); } -// #[test] -// fn it_deserializes() { -// let input = sample_serialized(); -// -// for i in 0..input.len() { -// let (mut left, mut right) = input.split_at(i); -// -// // It shouldn't panic. -// let _: Result, _> = super::deserialize_from(&mut left); -// let _: Result, _> = super::deserialize_from(&mut right); -// -// let actual: Sample<'_> = super::deserialize_from(&mut input.as_slice()).unwrap(); -// assert_eq!(actual, sample()); -// } -// } +#[test] +fn it_deserializes() { + let input = sample_serialized(); + + for i in 0..input.len() { + let (mut left, mut right) = input.split_at(i); + + // It shouldn't panic. + let _: Result, _> = super::deserialize_from(&mut left, &[]).0; + let _: Result, _> = super::deserialize_from(&mut right, &[]).0; + + let actual: Sample<'_> = super::deserialize_from(&mut input.as_slice(), &[]) + .0 + .unwrap(); + assert_eq!(actual, sample()); + } +} diff --git a/tests/it/query.rs b/tests/it/query.rs index 398b8654..7b783e92 100644 --- a/tests/it/query.rs +++ b/tests/it/query.rs @@ -88,31 +88,31 @@ async fn fetch_one_and_optional() { #[tokio::test] async fn server_side_param() { let client = prepare_database!(); - // - // let result = client - // .query("SELECT plus({val1: Int32}, {val2: Int32}) AS result") - // .param("val1", 42) - // .param("val2", 144) - // .fetch_one::() - // .await - // .expect("failed to fetch Int64"); - // assert_eq!(result, 186); - // - // let result = client - // .query("SELECT {val1: String} AS result") - // .param("val1", "string") - // .fetch_one::() - // .await - // .expect("failed to fetch string"); - // assert_eq!(result, "string"); - // - // let result = client - // .query("SELECT {val1: String} AS result") - // .param("val1", "\x01\x02\x03\\ \"\'") - // .fetch_one::() - // .await - // .expect("failed to fetch string"); - // assert_eq!(result, "\x01\x02\x03\\ \"\'"); + + let result = client + .query("SELECT plus({val1: Int32}, {val2: Int32}) AS result") + .param("val1", 42) + .param("val2", 144) + .fetch_one::() + .await + .expect("failed to fetch Int64"); + assert_eq!(result, 186); + + let result = client + .query("SELECT {val1: String} AS result") + .param("val1", "string") + .fetch_one::() + .await + .expect("failed to fetch string"); + assert_eq!(result, "string"); + + let result = client + .query("SELECT {val1: String} AS result") + .param("val1", "\x01\x02\x03\\ \"\'") + .fetch_one::() + .await + .expect("failed to fetch string"); + assert_eq!(result, "\x01\x02\x03\\ \"\'"); let result = client .query("SELECT {val1: Array(String)} AS result") diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 35f29cab..826a7f76 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -4,6 +4,8 @@ use clickhouse::validation_mode::ValidationMode; use clickhouse_derive::Row; use clickhouse_types::data_types::{Column, DataTypeNode}; use clickhouse_types::parse_rbwnat_columns_header; +use fixnum::typenum::{U12, U4, U8}; +use fixnum::FixedPoint; use serde::{Deserialize, Serialize}; use serde_repr::{Deserialize_repr, Serialize_repr}; use std::collections::HashMap; @@ -1070,6 +1072,78 @@ async fn test_variant_wrong_definition() { ); } +#[tokio::test] +async fn test_decimals() { + #[derive(Row, Deserialize, Debug, PartialEq)] + struct Data { + decimal32_9_4: Decimal32, + decimal64_18_8: Decimal64, + decimal128_38_12: Decimal128, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query( + " + SELECT + 42.1234 :: Decimal32(4) AS decimal32_9_4, + 144.56789012 :: Decimal64(8) AS decimal64_18_8, + -17014118346046923173168730.37158841057 :: Decimal128(12) AS decimal128_38_12 + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + decimal32_9_4: Decimal32::from_str("42.1234").unwrap(), + decimal64_18_8: Decimal64::from_str("144.56789012").unwrap(), + decimal128_38_12: Decimal128::from_str("-17014118346046923173168730.37158841057") + .unwrap(), + } + ); +} + +#[tokio::test] +async fn test_decimal32_wrong_size() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + decimal32: i16, + } + + assert_panic_on_fetch!( + &["Data.decimal32", "Decimal(9, 4)", "i16"], + "SELECT 42 :: Decimal32(4) AS decimal32" + ); +} + +#[tokio::test] +async fn test_decimal64_wrong_size() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + decimal64: i32, + } + + assert_panic_on_fetch!( + &["Data.decimal64", "Decimal(18, 8)", "i32"], + "SELECT 144 :: Decimal64(8) AS decimal64" + ); +} + +#[tokio::test] +async fn test_decimal128_wrong_size() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + decimal128: i64, + } + + assert_panic_on_fetch!( + &["Data.decimal128", "Decimal(38, 12)", "i64"], + "SELECT -17014118346046923173168730.37158841057 :: Decimal128(12) AS decimal128" + ); +} + // FIXME: RBWNAT should allow for tracking the order of fields in the struct and in the database! // it is possible to use HashMap to deserialize the struct instead of Tuple visitor #[tokio::test] @@ -1103,3 +1177,8 @@ type Polygon = Vec; type MultiPolygon = Vec; type LineString = Vec; type MultiLineString = Vec; + +// See ClickHouse decimal sizes: https://clickhouse.com/docs/en/sql-reference/data-types/decimal +type Decimal32 = FixedPoint; // Decimal(9, 4) = Decimal32(4) +type Decimal64 = FixedPoint; // Decimal(18, 8) = Decimal64(8) +type Decimal128 = FixedPoint; // Decimal(38, 12) = Decimal128(12) diff --git a/types/src/data_types.rs b/types/src/data_types.rs index 6f5efb75..b0f939e5 100644 --- a/types/src/data_types.rs +++ b/types/src/data_types.rs @@ -551,7 +551,7 @@ fn parse_variant(input: &str) -> Result { /// Considers the element type parsed once we reach a comma outside of parens AND after an unescaped tick. /// The most complicated cases are values names in the self-defined Enum types: /// ``` -/// let input1 = "Tuple(Enum8('f\'()' = 1))`"; // the result is `f\'()` +/// let input1 = "Tuple(Enum8('f\'()' = 1))"; // the result is `f\'()` /// let input2 = "Tuple(Enum8('(' = 1))"; // the result is `(` /// ``` fn parse_inner_types(input: &str) -> Result, TypesError> { From c53ba74d456dba9998334783f32092af1f550d54 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 5 Jun 2025 00:01:30 +0200 Subject: [PATCH 21/54] Support structs with different field order via MapAccess --- benches/select_numbers.rs | 1 + src/cursors/row.rs | 25 ++- src/rowbinary/de.rs | 350 +++++++++++++++++++++++------------- src/rowbinary/mod.rs | 1 + src/rowbinary/tests.rs | 6 +- src/rowbinary/utils.rs | 2 + src/rowbinary/validation.rs | 214 +++++++++++++++++++--- src/test/handlers.rs | 2 +- tests/it/chrono.rs | 10 +- tests/it/insert.rs | 2 +- tests/it/rbwnat.rs | 59 ++++-- tests/it/time.rs | 10 +- 12 files changed, 491 insertions(+), 191 deletions(-) diff --git a/benches/select_numbers.rs b/benches/select_numbers.rs index b05bd8d3..52494526 100644 --- a/benches/select_numbers.rs +++ b/benches/select_numbers.rs @@ -5,6 +5,7 @@ use clickhouse::{Client, Compression, Row}; #[derive(Row, Deserialize)] struct Data { + #[serde(rename = "number")] no: u64, } diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 24bf1153..4be4044f 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -1,3 +1,4 @@ +use crate::rowbinary::StructMetadata; use crate::validation_mode::ValidationMode; use crate::{ bytes_ext::BytesExt, @@ -6,7 +7,6 @@ use crate::{ response::Response, rowbinary, }; -use clickhouse_types::data_types::Column; use clickhouse_types::error::TypesError; use clickhouse_types::parse_rbwnat_columns_header; use serde::Deserialize; @@ -17,7 +17,9 @@ use std::marker::PhantomData; pub struct RowCursor { raw: RawCursor, bytes: BytesExt, - columns: Vec, + /// [`None`] until the first call to [`RowCursor::next()`], + /// as [`RowCursor::new`] is not `async`, so it loads lazily. + struct_mapping: Option, rows_to_validate: u64, _marker: PhantomData, } @@ -28,7 +30,7 @@ impl RowCursor { _marker: PhantomData, raw: RawCursor::new(response), bytes: BytesExt::default(), - columns: Vec::new(), + struct_mapping: None, rows_to_validate: match validation_mode { ValidationMode::First(n) => n as u64, ValidationMode::Each => u64::MAX, @@ -45,7 +47,7 @@ impl RowCursor { match parse_rbwnat_columns_header(&mut slice) { Ok(columns) if !columns.is_empty() => { self.bytes.set_remaining(slice.len()); - self.columns = columns; + self.struct_mapping = Some(StructMetadata::new(columns)); return Ok(()); } Ok(_) => { @@ -62,7 +64,7 @@ impl RowCursor { } match self.raw.next().await? { Some(chunk) => self.bytes.extend(chunk), - None if self.columns.is_empty() => { + None if self.struct_mapping.is_none() => { return Err(Error::BadResponse( "Could not read columns header".to_string(), )); @@ -86,7 +88,7 @@ impl RowCursor { { loop { if self.bytes.remaining() > 0 { - if self.columns.is_empty() { + if self.struct_mapping.is_none() { self.read_columns().await?; if self.bytes.remaining() == 0 { continue; @@ -94,10 +96,15 @@ impl RowCursor { } let mut slice = super::workaround_51132(self.bytes.slice()); let (result, not_enough_data) = match self.rows_to_validate { - 0 => rowbinary::deserialize_from::(&mut slice, &[]), - u64::MAX => rowbinary::deserialize_from::(&mut slice, &self.columns), + 0 => rowbinary::deserialize_from::(&mut slice, None), + u64::MAX => { + rowbinary::deserialize_from::(&mut slice, self.struct_mapping.as_mut()) + } _ => { - let result = rowbinary::deserialize_from::(&mut slice, &self.columns); + let result = rowbinary::deserialize_from::( + &mut slice, + self.struct_mapping.as_mut(), + ); self.rows_to_validate -= 1; result } diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index f4063a7a..b78c2461 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -1,9 +1,9 @@ use crate::error::{Error, Result}; use crate::rowbinary::utils::{ensure_size, get_unsigned_leb128}; use crate::rowbinary::validation::SerdeType; -use crate::rowbinary::validation::{DataTypeValidator, ValidateDataType}; +use crate::rowbinary::validation::{DataTypeValidator, SchemaValidator}; +use crate::rowbinary::StructMetadata; use bytes::Buf; -use clickhouse_types::data_types::Column; use core::mem::size_of; use serde::de::MapAccess; use serde::{ @@ -26,13 +26,13 @@ use std::{convert::TryFrom, str}; /// After the header, the rows format is the same as `RowBinary`. pub(crate) fn deserialize_from<'data, 'cursor, T: Deserialize<'data>>( input: &mut &'data [u8], - columns: &'cursor [Column], + mapping: Option<&'cursor mut StructMetadata>, ) -> (Result, bool) { - let result = if columns.is_empty() { + let result = if mapping.is_none() { let mut deserializer = RowBinaryDeserializer::new(input, ()); T::deserialize(&mut deserializer) } else { - let validator = DataTypeValidator::new(columns); + let validator = DataTypeValidator::new(mapping.unwrap()); let mut deserializer = RowBinaryDeserializer::new(input, validator); T::deserialize(&mut deserializer) }; @@ -49,7 +49,7 @@ pub(crate) fn deserialize_from<'data, 'cursor, T: Deserialize<'data>>( /// See https://clickhouse.com/docs/en/interfaces/formats#rowbinary for details. struct RowBinaryDeserializer<'cursor, 'data, Validator = ()> where - Validator: ValidateDataType, + Validator: SchemaValidator, { validator: Validator, input: &'cursor mut &'data [u8], @@ -57,7 +57,7 @@ where impl<'cursor, 'data, Validator> RowBinaryDeserializer<'cursor, 'data, Validator> where - Validator: ValidateDataType, + Validator: SchemaValidator, { fn new(input: &'cursor mut &'data [u8], validator: Validator) -> Self { Self { input, validator } @@ -95,7 +95,7 @@ macro_rules! impl_num { impl<'data, Validator> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data, Validator> where - Validator: ValidateDataType, + Validator: SchemaValidator, { type Error = Error; @@ -225,76 +225,6 @@ where ) -> Result { // println!("deserialize_enum call"); - struct RowBinaryEnumAccess<'de, 'cursor, 'data, Validator> - where - Validator: ValidateDataType, - { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, - } - - struct VariantDeserializer<'de, 'cursor, 'data, Validator> - where - Validator: ValidateDataType, - { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, - } - - impl<'data, Validator> VariantAccess<'data> for VariantDeserializer<'_, '_, 'data, Validator> - where - Validator: ValidateDataType, - { - type Error = Error; - - fn unit_variant(self) -> Result<()> { - panic!("unit variants are unsupported"); - } - - fn newtype_variant_seed(self, seed: T) -> Result - where - T: DeserializeSeed<'data>, - { - DeserializeSeed::deserialize(seed, &mut *self.deserializer) - } - - fn tuple_variant(self, len: usize, visitor: V) -> Result - where - V: Visitor<'data>, - { - self.deserializer.deserialize_tuple(len, visitor) - } - - fn struct_variant( - self, - fields: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'data>, - { - self.deserializer.deserialize_tuple(fields.len(), visitor) - } - } - - impl<'de, 'cursor, 'data, Validator> EnumAccess<'data> - for RowBinaryEnumAccess<'de, 'cursor, 'data, Validator> - where - Validator: ValidateDataType, - { - type Error = Error; - type Variant = VariantDeserializer<'de, 'cursor, 'data, Validator>; - - fn variant_seed(self, seed: T) -> Result<(T::Value, Self::Variant), Self::Error> - where - T: DeserializeSeed<'data>, - { - let value = seed.deserialize(&mut *self.deserializer)?; - let deserializer = VariantDeserializer { - deserializer: self.deserializer, - }; - Ok((value, deserializer)) - } - } - let validator = self.validator.validate(SerdeType::Enum)?; visitor.visit_enum(RowBinaryEnumAccess { deserializer: &mut RowBinaryDeserializer { @@ -356,44 +286,6 @@ where // "deserialize_map call", // ); - struct RowBinaryMapAccess<'de, 'cursor, 'data, Validator> - where - Validator: ValidateDataType, - { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, - entries_visited: usize, - len: usize, - } - - impl<'data, Validator> MapAccess<'data> for RowBinaryMapAccess<'_, '_, 'data, Validator> - where - Validator: ValidateDataType, - { - type Error = Error; - - fn next_key_seed(&mut self, seed: K) -> Result> - where - K: DeserializeSeed<'data>, - { - if self.entries_visited >= self.len { - return Ok(None); - } - self.entries_visited += 1; - seed.deserialize(&mut *self.deserializer).map(Some) - } - - fn next_value_seed(&mut self, seed: V) -> Result - where - V: DeserializeSeed<'data>, - { - seed.deserialize(&mut *self.deserializer) - } - - fn size_hint(&self) -> Option { - Some(self.len) - } - } - let len = self.read_size()?; let validator = self.validator.validate(SerdeType::Map(len))?; visitor.visit_map(RowBinaryMapAccess { @@ -415,12 +307,19 @@ where ) -> Result { // println!("deserialize_struct: {} (fields: {:?})", name, fields,); - // TODO - skip validation? - self.validator.set_struct_name(name); - visitor.visit_seq(RowBinarySeqAccess { - deserializer: self, - len: fields.len(), - }) + let should_use_map_access = self.validator.ensure_struct_metadata(name, fields); + if !should_use_map_access { + visitor.visit_seq(RowBinarySeqAccess { + deserializer: self, + len: fields.len(), + }) + } else { + visitor.visit_map(RowBinaryStructAsMapAccess { + deserializer: self, + current_field_idx: 0, + fields, + }) + } } #[inline(always)] @@ -468,9 +367,12 @@ where } } +/// Used in [`Deserializer::deserialize_seq`], [`Deserializer::deserialize_tuple`], +/// and it could be used in [`Deserializer::deserialize_struct`], +/// if we detect that the field order matches the database schema. struct RowBinarySeqAccess<'de, 'cursor, 'data, Validator> where - Validator: ValidateDataType, + Validator: SchemaValidator, { deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, len: usize, @@ -478,7 +380,7 @@ where impl<'data, Validator> SeqAccess<'data> for RowBinarySeqAccess<'_, '_, 'data, Validator> where - Validator: ValidateDataType, + Validator: SchemaValidator, { type Error = Error; @@ -499,3 +401,203 @@ where Some(self.len) } } + +/// Used in [`Deserializer::deserialize_map`]. +struct RowBinaryMapAccess<'de, 'cursor, 'data, Validator> +where + Validator: SchemaValidator, +{ + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + entries_visited: usize, + len: usize, +} + +impl<'data, Validator> MapAccess<'data> for RowBinaryMapAccess<'_, '_, 'data, Validator> +where + Validator: SchemaValidator, +{ + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'data>, + { + if self.entries_visited >= self.len { + return Ok(None); + } + self.entries_visited += 1; + seed.deserialize(&mut *self.deserializer).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'data>, + { + seed.deserialize(&mut *self.deserializer) + } + + fn size_hint(&self) -> Option { + Some(self.len) + } +} + +/// Used in [`Deserializer::deserialize_struct`] to support wrong field order +/// as long as the data types are exactly matching the database schema. +struct RowBinaryStructAsMapAccess<'de, 'cursor, 'data, Validator> +where + Validator: SchemaValidator, +{ + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + current_field_idx: usize, + fields: &'static [&'static str], +} + +struct StructFieldIdentifier(&'static str); + +impl<'de> Deserializer<'de> for StructFieldIdentifier { + type Error = Error; + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_str(self.0) + } + + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + panic!("StructFieldIdentifier is supposed to use `deserialize_identifier` only"); + } + + serde::forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct enum ignored_any + } +} + +/// Without schema order "restoration", the following query: +/// +/// ```sql +/// SELECT 'foo' :: String AS a, +/// 'bar' :: String AS c +/// ``` +/// +/// Will produce a wrong result, if the struct is defined as: +/// +/// ```rs +/// struct Data { +/// c: String, +/// a: String, +/// } +/// ``` +impl<'data, Validator> MapAccess<'data> for RowBinaryStructAsMapAccess<'_, '_, 'data, Validator> +where + Validator: SchemaValidator, +{ + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'data>, + { + if self.current_field_idx >= self.fields.len() { + return Ok(None); + } + let schema_index = self + .deserializer + .validator + .get_schema_index(self.current_field_idx); + let field_id = StructFieldIdentifier(self.fields[schema_index]); + // println!( + // "RowBinaryStructAsMapAccess::next_key_seed: field_id: {}", + // field_id.0 + // ); + self.current_field_idx += 1; + seed.deserialize(field_id).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'data>, + { + // println!( + // "RowBinaryStructAsMapAccess::next_value_seed: current_field_idx: {}", + // self.current_field_idx + // ); + seed.deserialize(&mut *self.deserializer) + } + + fn size_hint(&self) -> Option { + Some(self.fields.len()) + } +} + +/// Used in [`Deserializer::deserialize_enum`]. +struct RowBinaryEnumAccess<'de, 'cursor, 'data, Validator> +where + Validator: SchemaValidator, +{ + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, +} + +struct VariantDeserializer<'de, 'cursor, 'data, Validator> +where + Validator: SchemaValidator, +{ + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, +} + +impl<'data, Validator> VariantAccess<'data> for VariantDeserializer<'_, '_, 'data, Validator> +where + Validator: SchemaValidator, +{ + type Error = Error; + + fn unit_variant(self) -> Result<()> { + panic!("unit variants are unsupported"); + } + + fn newtype_variant_seed(self, seed: T) -> Result + where + T: DeserializeSeed<'data>, + { + DeserializeSeed::deserialize(seed, &mut *self.deserializer) + } + + fn tuple_variant(self, len: usize, visitor: V) -> Result + where + V: Visitor<'data>, + { + self.deserializer.deserialize_tuple(len, visitor) + } + + fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result + where + V: Visitor<'data>, + { + self.deserializer.deserialize_tuple(fields.len(), visitor) + } +} + +impl<'de, 'cursor, 'data, Validator> EnumAccess<'data> + for RowBinaryEnumAccess<'de, 'cursor, 'data, Validator> +where + Validator: SchemaValidator, +{ + type Error = Error; + type Variant = VariantDeserializer<'de, 'cursor, 'data, Validator>; + + fn variant_seed(self, seed: T) -> Result<(T::Value, Self::Variant), Self::Error> + where + T: DeserializeSeed<'data>, + { + let value = seed.deserialize(&mut *self.deserializer)?; + let deserializer = VariantDeserializer { + deserializer: self.deserializer, + }; + Ok((value, deserializer)) + } +} diff --git a/src/rowbinary/mod.rs b/src/rowbinary/mod.rs index 7a1dfbb1..5a24975b 100644 --- a/src/rowbinary/mod.rs +++ b/src/rowbinary/mod.rs @@ -1,5 +1,6 @@ pub(crate) use de::deserialize_from; pub(crate) use ser::serialize_into; +pub(crate) use validation::StructMetadata; mod de; mod ser; diff --git a/src/rowbinary/tests.rs b/src/rowbinary/tests.rs index 0b4c58fc..44fd7d62 100644 --- a/src/rowbinary/tests.rs +++ b/src/rowbinary/tests.rs @@ -122,10 +122,10 @@ fn it_deserializes() { let (mut left, mut right) = input.split_at(i); // It shouldn't panic. - let _: Result, _> = super::deserialize_from(&mut left, &[]).0; - let _: Result, _> = super::deserialize_from(&mut right, &[]).0; + let _: Result, _> = super::deserialize_from(&mut left, None).0; + let _: Result, _> = super::deserialize_from(&mut right, None).0; - let actual: Sample<'_> = super::deserialize_from(&mut input.as_slice(), &[]) + let actual: Sample<'_> = super::deserialize_from(&mut input.as_slice(), None) .0 .unwrap(); assert_eq!(actual, sample()); diff --git a/src/rowbinary/utils.rs b/src/rowbinary/utils.rs index 3e9a3dc7..e1dc1d6e 100644 --- a/src/rowbinary/utils.rs +++ b/src/rowbinary/utils.rs @@ -1,6 +1,8 @@ use crate::error::Error; use bytes::Buf; +/// TODO: it is theoretically possible to ensure size in chunks, +/// at least for some types, given that we have the database schema. #[inline] pub(crate) fn ensure_size(buffer: impl Buf, size: usize) -> crate::error::Result<()> { if buffer.remaining() < size { diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index af076984..8da07c6a 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -3,37 +3,152 @@ use clickhouse_types::data_types::{Column, DataTypeNode, DecimalType, EnumType}; use std::collections::HashMap; use std::fmt::Display; -pub(crate) trait ValidateDataType: Sized { +pub(crate) trait SchemaValidator: Sized { fn validate( &'_ mut self, serde_type: SerdeType, ) -> Result>>; - fn set_next_variant_value(&mut self, value: u8); fn validate_enum8_value(&mut self, value: i8); fn validate_enum16_value(&mut self, value: i16); - fn set_struct_name(&mut self, name: &'static str); + fn set_next_variant_value(&mut self, value: u8); + fn ensure_struct_metadata( + &'_ mut self, + name: &'static str, + fields: &'static [&'static str], + ) -> bool; + fn get_schema_index(&self, struct_idx: usize) -> usize; +} + +#[derive(Debug, PartialEq)] +enum StructMetadataState { + Pending, + WithSeqAccess, + WithMapAccess(Vec), +} + +/// #### StructMetadata +/// +/// Should reside outside the (de)serializer, so it is calculated only once per struct. +/// No lifetimes, so it does not introduce a breaking change to [`crate::cursors::RowCursor`]. +/// +/// #### Lifecycle +/// +/// - the first call to [`crate::cursors::RowCursor::next`] creates an instance with `columns`. +/// - the first call to [`serde::Deserializer::deserialize_struct`] sets the `struct_name`, +/// and the field order is checked. If the order is different from the schema, the state is set to +/// [`StructMetadataState::WithMapAccess`], otherwise to [`StructMetadataState::WithSeqAccess`]. +/// - the following calls to [`crate::cursors::RowCursor::next`] and, consequently, +/// to [`serde::Deserializer::deserialize_struct`], will re-use the same prepared instance, +/// without re-checking the fields order for every struct. +pub(crate) struct StructMetadata { + /// Struct name is defined after the first call to [`serde::Deserializer::deserialize_struct`]. + /// If we are deserializing any other type, e.g., [`u64`], [`Vec`], etc., it is [`None`], + /// and it affects how the validation works, see [`DataTypeValidator::validate`]. + pub(crate) struct_name: Option<&'static str>, + /// Database schema, or columns, are parsed before the first call to (de)serializer. + pub(crate) columns: Vec, + /// This state determines whether we can just use [`crate::rowbinary::de::RowBinarySeqAccess`] + /// or a more sophisticated approach with [`crate::rowbinary::de::RowBinaryStructAsMapAccess`] + /// to support structs defined with different fields order than in the schema. + /// Deserializing a struct as a map will be approximately 40% slower than as a sequence. + state: StructMetadataState, +} + +impl StructMetadata { + pub(crate) fn new(columns: Vec) -> Self { + Self { + columns, + struct_name: None, + state: StructMetadataState::Pending, + } + } + + #[inline(always)] + pub(crate) fn check_should_use_map( + &mut self, + name: &'static str, + fields: &'static [&'static str], + ) -> bool { + match &self.state { + StructMetadataState::WithSeqAccess => false, + StructMetadataState::WithMapAccess(_) => true, + StructMetadataState::Pending => { + let mut mapping = Vec::with_capacity(fields.len()); + let mut expected_index = 0; + let mut should_use_map = false; + for col in &self.columns { + if let Some(index) = fields.iter().position(|field| col.name == *field) { + if index != expected_index { + should_use_map = true + } + expected_index += 1; + mapping.push(index); + } else { + panic!( + "While processing struct {}: database schema has a column {} \ + that was not found in the struct definition.\ + \n#### All struct fields:\n{}\n#### All schema columns:\n{}", + name, + col, + join_panic_schema_hint(fields), + join_panic_schema_hint(&self.columns), + ); + } + } + self.state = if should_use_map { + StructMetadataState::WithMapAccess(mapping) + } else { + StructMetadataState::WithSeqAccess + }; + true + } + } + } + + #[inline(always)] + pub(crate) fn get_schema_index(&self, struct_idx: usize) -> usize { + match &self.state { + StructMetadataState::WithMapAccess(mapping) => { + if struct_idx < mapping.len() { + mapping[struct_idx] + } else { + panic!( + "Struct {} has more fields than columns in the database schema", + self.struct_name.unwrap_or("Struct") + ) + } + } + // these two branches should be unreachable + StructMetadataState::WithSeqAccess => struct_idx, + StructMetadataState::Pending => { + panic!( + "Struct metadata is not initialized yet, \ + `ensure_struct_metadata` should be called first" + ) + } + } + } } pub(crate) struct DataTypeValidator<'cursor> { - struct_name: Option<&'static str>, + metadata: &'cursor mut StructMetadata, current_column_idx: usize, - columns: &'cursor [Column], } impl<'cursor> DataTypeValidator<'cursor> { #[inline(always)] - pub(crate) fn new(columns: &'cursor [Column]) -> Self { + pub(crate) fn new(metadata: &'cursor mut StructMetadata) -> Self { Self { - struct_name: None, current_column_idx: 0, - columns, + metadata, } } fn get_current_column(&self) -> Option<&Column> { - if self.current_column_idx > 0 && self.current_column_idx <= self.columns.len() { + if self.current_column_idx > 0 && self.current_column_idx <= self.metadata.columns.len() { // index is immediately moved to the next column after the root validator is called - Some(&self.columns[self.current_column_idx - 1]) + let schema_index = self.get_schema_index(self.current_column_idx - 1); + Some(&self.metadata.columns[schema_index]) } else { None } @@ -53,7 +168,7 @@ impl<'cursor> DataTypeValidator<'cursor> { fn get_struct_name(&self) -> String { // should be available at the time of the panic call - self.struct_name.unwrap_or("Struct").to_string() + self.metadata.struct_name.unwrap_or("Struct").to_string() } #[inline(always)] @@ -82,33 +197,40 @@ impl<'cursor> DataTypeValidator<'cursor> { } } -impl ValidateDataType for DataTypeValidator<'_> { +impl SchemaValidator for DataTypeValidator<'_> { #[inline] fn validate( &'_ mut self, serde_type: SerdeType, ) -> Result>> { - if self.current_column_idx == 0 && self.struct_name.is_none() { + // println!( + // "[validate] Validating serde type: {} for column {}", + // serde_type, + // self.get_current_column() + // .map_or("None".to_string(), |c| c.name.clone()) + // ); + if self.current_column_idx == 0 && self.metadata.struct_name.is_none() { // this allows validating and deserializing tuples from fetch calls Ok(Some(InnerDataTypeValidator { root: self, - kind: if matches!(serde_type, SerdeType::Seq(_)) && self.columns.len() == 1 { - let data_type = &self.columns[0].data_type; + kind: if matches!(serde_type, SerdeType::Seq(_)) && self.metadata.columns.len() == 1 + { + let data_type = &self.metadata.columns[0].data_type; match data_type { DataTypeNode::Array(inner_type) => { InnerDataTypeValidatorKind::RootArray(inner_type) } _ => panic!( "Expected Array type when validating root level sequence, but got {}", - self.columns[0].data_type + self.metadata.columns[0].data_type ), } } else { - InnerDataTypeValidatorKind::RootTuple(self.columns, 0) + InnerDataTypeValidatorKind::RootTuple(&self.metadata.columns, 0) }, })) - } else if self.current_column_idx < self.columns.len() { - let current_column = &self.columns[self.current_column_idx]; + } else if self.current_column_idx < self.metadata.columns.len() { + let current_column = &self.metadata.columns[self.current_column_idx]; self.current_column_idx += 1; validate_impl(self, ¤t_column.data_type, &serde_type, false) } else { @@ -120,10 +242,15 @@ impl ValidateDataType for DataTypeValidator<'_> { } #[inline(always)] - fn set_struct_name(&mut self, name: &'static str) { - if self.struct_name.is_none() { - self.struct_name = Some(name); + fn ensure_struct_metadata( + &'_ mut self, + name: &'static str, + fields: &'static [&'static str], + ) -> bool { + if self.metadata.struct_name.is_none() { + self.metadata.struct_name = Some(name); } + self.metadata.check_should_use_map(name, fields) } #[cold] @@ -143,6 +270,11 @@ impl ValidateDataType for DataTypeValidator<'_> { fn set_next_variant_value(&mut self, _value: u8) { unreachable!() } + + #[inline] + fn get_schema_index(&self, struct_idx: usize) -> usize { + self.metadata.get_schema_index(struct_idx) + } } #[derive(Debug)] @@ -187,7 +319,7 @@ pub(crate) enum VariantValidationState { Identifier(u8), } -impl<'de, 'cursor> ValidateDataType for Option> { +impl<'de, 'cursor> SchemaValidator for Option> { #[inline] fn validate( &mut self, @@ -336,7 +468,17 @@ impl<'de, 'cursor> ValidateDataType for Option bool { + false + } + + fn get_schema_index(&self, _struct_idx: usize) -> usize { + unreachable!() + } } impl Drop for InnerDataTypeValidator<'_, '_> { @@ -598,7 +740,7 @@ fn validate_impl<'de, 'cursor>( } } -impl ValidateDataType for () { +impl SchemaValidator for () { #[inline(always)] fn validate( &mut self, @@ -617,7 +759,17 @@ impl ValidateDataType for () { fn set_next_variant_value(&mut self, _value: u8) {} #[inline(always)] - fn set_struct_name(&mut self, _name: &'static str) {} + fn ensure_struct_metadata( + &mut self, + _name: &'static str, + _fields: &'static [&'static str], + ) -> bool { + false + } + + fn get_schema_index(&self, _struct_idx: usize) -> usize { + unreachable!() + } } /// Which Serde data type (De)serializer used for the given type. @@ -693,3 +845,13 @@ impl Display for SerdeType { } } } + +fn join_panic_schema_hint(col: &[T]) -> String { + if col.is_empty() { + return String::default(); + } + col.iter() + .map(|c| format!("- {}", c)) + .collect::>() + .join("\n") +} diff --git a/src/test/handlers.rs b/src/test/handlers.rs index 10854c49..42a1bfb7 100644 --- a/src/test/handlers.rs +++ b/src/test/handlers.rs @@ -95,7 +95,7 @@ where let mut result = C::default(); while !slice.is_empty() { - let (de_result, _) = rowbinary::deserialize_from(slice, &[]); + let (de_result, _) = rowbinary::deserialize_from(slice, None); let row: T = de_result.expect("failed to deserialize"); result.extend(std::iter::once(row)); } diff --git a/tests/it/chrono.rs b/tests/it/chrono.rs index 58a4a8b7..ee87eac2 100644 --- a/tests/it/chrono.rs +++ b/tests/it/chrono.rs @@ -101,11 +101,11 @@ async fn datetime() { let row_str = client .query( " - SELECT toString(dt), - toString(dt64s), - toString(dt64ms), - toString(dt64us), - toString(dt64ns) + SELECT toString(dt) AS dt, + toString(dt64s) AS dt64s, + toString(dt64ms) AS dt64ms, + toString(dt64us) AS dt64us, + toString(dt64ns) AS dt64ns FROM test ", ) diff --git a/tests/it/insert.rs b/tests/it/insert.rs index 47058696..952314a1 100644 --- a/tests/it/insert.rs +++ b/tests/it/insert.rs @@ -128,7 +128,7 @@ async fn rename_insert() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] #[serde(rename_all = "camelCase")] struct RenameRow { - #[serde(rename = "fix_id")] + #[serde(rename = "fixId")] pub(crate) fix_id: u64, #[serde(rename = "extComplexId")] pub(crate) complex_id: String, diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 826a7f76..9ef79edc 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -202,7 +202,7 @@ async fn test_several_simple_rows() { async fn test_many_numbers() { #[derive(Row, Deserialize)] struct Data { - no: u64, + number: u64, } let client = get_client().with_validation_mode(ValidationMode::Each); @@ -213,7 +213,7 @@ async fn test_many_numbers() { let mut sum = 0; while let Some(row) = cursor.next().await.unwrap() { - sum += row.no; + sum += row.number; } assert_eq!(sum, (0..2000).sum::()); } @@ -264,8 +264,8 @@ async fn test_arrays() { async fn test_maps() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { - map1: HashMap, - map2: HashMap>, + m1: HashMap, + m2: HashMap>, } let client = get_client().with_validation_mode(ValidationMode::Each); @@ -284,13 +284,13 @@ async fn test_maps() { assert_eq!( result.unwrap(), Data { - map1: vec![ + m1: vec![ ("key1".to_string(), "value1".to_string()), ("key2".to_string(), "value2".to_string()), ] .into_iter() .collect(), - map2: vec![ + m2: vec![ ( 42, vec![("foo".to_string(), 100), ("bar".to_string(), 200)] @@ -436,8 +436,8 @@ async fn test_invalid_nullable() { n: Option, } assert_panic_on_fetch!( - &["Data.b", "Bool", "Option"], - "SELECT true AS b, 144 :: Int32 AS n2" + &["Data.n", "Array(UInt32)", "Option"], + "SELECT array(42) :: Array(UInt32) AS n" ); } @@ -517,10 +517,7 @@ async fn test_invalid_serde_with() { #[serde(with = "clickhouse::serde::time::datetime64::millis")] n1: time::OffsetDateTime, // underlying is still Int64; should not compose it from two (U)Int32 } - assert_panic_on_fetch!( - &["Data.n1", "UInt32", "i64"], - "SELECT 42 :: UInt32 AS n1, 144 :: Int32 AS n2" - ); + assert_panic_on_fetch!(&["Data.n1", "UInt32", "i64"], "SELECT 42 :: UInt32 AS n1"); } #[tokio::test] @@ -1144,11 +1141,8 @@ async fn test_decimal128_wrong_size() { ); } -// FIXME: RBWNAT should allow for tracking the order of fields in the struct and in the database! -// it is possible to use HashMap to deserialize the struct instead of Tuple visitor #[tokio::test] -#[ignore] -async fn test_different_struct_field_order() { +async fn test_different_struct_field_order_same_types() { #[derive(Debug, Row, Deserialize, PartialEq)] struct Data { c: String, @@ -1164,8 +1158,39 @@ async fn test_different_struct_field_order() { assert_eq!( result.unwrap(), Data { - a: "foo".to_string(), c: "bar".to_string(), + a: "foo".to_string(), + } + ); +} + +#[tokio::test] +async fn test_different_struct_field_order_different_types() { + #[derive(Debug, Row, Deserialize, PartialEq)] + struct Data { + b: u32, + a: String, + c: Vec, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query( + " + SELECT array(true, false, true) AS c, + 42 :: UInt32 AS b, + 'foo' AS a + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + c: vec![true, false, true], + b: 42, + a: "foo".to_string(), } ); } diff --git a/tests/it/time.rs b/tests/it/time.rs index 9a736538..cbba97b7 100644 --- a/tests/it/time.rs +++ b/tests/it/time.rs @@ -93,11 +93,11 @@ async fn datetime() { let row_str = client .query( " - SELECT toString(dt), - toString(dt64s), - toString(dt64ms), - toString(dt64us), - toString(dt64ns) + SELECT toString(dt) AS dt, + toString(dt64s) AS dt64s, + toString(dt64ms) AS dt64ms, + toString(dt64us) AS dt64us, + toString(dt64ns) AS dt64ns FROM test ", ) From 00ff574d68819a419085e49dd339b499269dba9e Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 5 Jun 2025 01:37:20 +0200 Subject: [PATCH 22/54] Add more tests --- src/rowbinary/validation.rs | 18 +++-- tests/it/rbwnat.rs | 131 +++++++++++++++++++++++++++++++++++- 2 files changed, 142 insertions(+), 7 deletions(-) diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 8da07c6a..8bc677d8 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -73,6 +73,18 @@ impl StructMetadata { StructMetadataState::WithSeqAccess => false, StructMetadataState::WithMapAccess(_) => true, StructMetadataState::Pending => { + if self.columns.len() != fields.len() { + panic!( + "While processing struct {}: database schema has {} columns, \ + but the struct definition has {} fields.\ + \n#### All struct fields:\n{}\n#### All schema columns:\n{}", + name, + self.columns.len(), + fields.len(), + join_panic_schema_hint(fields), + join_panic_schema_hint(&self.columns), + ); + } let mut mapping = Vec::with_capacity(fields.len()); let mut expected_index = 0; let mut should_use_map = false; @@ -203,12 +215,6 @@ impl SchemaValidator for DataTypeValidator<'_> { &'_ mut self, serde_type: SerdeType, ) -> Result>> { - // println!( - // "[validate] Validating serde type: {} for column {}", - // serde_type, - // self.get_current_column() - // .map_or("None".to_string(), |c| c.name.clone()) - // ); if self.current_column_idx == 0 && self.metadata.struct_name.is_none() { // this allows validating and deserializing tuples from fetch calls Ok(Some(InnerDataTypeValidator { diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 9ef79edc..8a7e04dc 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -529,7 +529,7 @@ async fn test_too_many_struct_fields() { c: u32, } assert_panic_on_fetch!( - &["Struct Data has more fields than columns in the database schema"], + &["2 columns", "3 fields"], "SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS b" ); } @@ -894,6 +894,43 @@ async fn test_geo_invalid_point() { ); } +#[tokio::test] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/100 +async fn test_issue_100() { + { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + n: i8, + } + assert_panic_on_fetch!( + &["Data.n", "Nullable(Bool)", "i8"], + "SELECT NULL :: Nullable(Bool) AS n" + ); + } + + { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + n: u8, + } + assert_panic_on_fetch!( + &["Data.n", "Nullable(Bool)", "u8"], + "SELECT NULL :: Nullable(Bool) AS n" + ); + } + + { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + n: bool, + } + assert_panic_on_fetch!( + &["Data.n", "Nullable(Bool)", "bool"], + "SELECT NULL :: Nullable(Bool) AS n" + ); + } +} + // TODO: unignore after insert implementation uses RBWNAT, too #[ignore] #[tokio::test] @@ -951,6 +988,20 @@ async fn test_issue_109_1() { insert.end().await.unwrap(); } +#[tokio::test] +async fn test_issue_112() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: bool, + b: bool, + } + + assert_panic_on_fetch!( + &["Data.a", "Nullable(Bool)", "bool"], + "WITH (SELECT true) AS a, (SELECT true) AS b SELECT ?fields" + ); +} + #[tokio::test] /// See https://github.com/ClickHouse/clickhouse-rs/issues/113 async fn test_issue_113() { @@ -1002,6 +1053,84 @@ async fn test_issue_113() { ); } +#[tokio::test] +#[cfg(feature = "time")] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/114 +async fn test_issue_114() { + #[derive(Row, Deserialize, Debug, PartialEq)] + struct Data { + #[serde(with = "clickhouse::serde::time::date")] + date: time::Date, + arr: Vec>, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query( + " + SELECT + '2023-05-01' :: Date AS date, + array(map('k1', 'v1'), map('k2', 'v2')) :: Array(Map(String, String)) AS arr + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + date: time::Date::from_calendar_date(2023, time::Month::May, 1).unwrap(), + arr: vec![ + HashMap::from([("k1".to_owned(), "v1".to_owned())]), + HashMap::from([("k2".to_owned(), "v2".to_owned())]), + ], + } + ); +} + +#[tokio::test] +#[cfg(feature = "time")] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/173 +async fn test_issue_173() { + #[derive(Debug, Serialize, Deserialize, Row)] + struct Data { + log_id: String, + #[serde(with = "clickhouse::serde::time::datetime")] + ts: time::OffsetDateTime, + } + + let client = prepare_database!().with_validation_mode(ValidationMode::Each); + let statements = vec![ + " + CREATE OR REPLACE TABLE logs ( + log_id String, + timestamp DateTime('Europe/Berlin') + ) + ENGINE = MergeTree() + PRIMARY KEY (log_id, timestamp) + ", + "INSERT INTO logs VALUES ('56cde52f-5f34-45e0-9f08-79d6f582e913', '2024-11-05T11:52:52+01:00')", + "INSERT INTO logs VALUES ('0e967129-6271-44f2-967b-0c8d11a60fdc', '2024-11-05T11:59:21+01:00')", + ]; + + for stmt in statements { + client + .query(stmt) + .with_option("date_time_input_format", "best_effort") + .execute() + .await + .unwrap_or_else(|e| panic!("Failed to execute query {stmt}, cause: {}", e)); + } + + // panics as we fetch `ts` two times: one from `?fields` macro, and the second time explicitly + // the resulting dataset will, in fact, contain 3 columns instead of 2: + assert_panic_on_fetch_with_client!( + client, + &["3 columns", "2 fields"], + "SELECT ?fields, toUnixTimestamp(timestamp) AS ts FROM logs ORDER by ts DESC" + ); +} + #[tokio::test] /// See https://github.com/ClickHouse/clickhouse-rs/issues/185 async fn test_issue_185() { From bd71a7796214faf4c2df836cd362cfb10c34c3be Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Fri, 6 Jun 2025 19:23:45 +0200 Subject: [PATCH 23/54] Add more tests, `execute_statements` helper --- tests/it/main.rs | 17 ++++-- tests/it/rbwnat.rs | 126 ++++++++++++++++++++++++--------------------- 2 files changed, 82 insertions(+), 61 deletions(-) diff --git a/tests/it/main.rs b/tests/it/main.rs index 4148f15c..154a4432 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -118,7 +118,7 @@ impl SimpleRow { } } -async fn create_simple_table(client: &Client, table_name: &str) { +pub(crate) async fn create_simple_table(client: &Client, table_name: &str) { client .query("CREATE TABLE ?(id UInt64, data String) ENGINE = MergeTree ORDER BY id") .with_option("wait_end_of_query", "1") @@ -128,7 +128,7 @@ async fn create_simple_table(client: &Client, table_name: &str) { .unwrap(); } -async fn fetch_rows(client: &Client, table_name: &str) -> Vec +pub(crate) async fn fetch_rows(client: &Client, table_name: &str) -> Vec where T: Row + for<'b> Deserialize<'b>, { @@ -140,10 +140,21 @@ where .unwrap() } -async fn flush_query_log(client: &Client) { +pub(crate) async fn flush_query_log(client: &Client) { client.query("SYSTEM FLUSH LOGS").execute().await.unwrap(); } +pub(crate) async fn execute_statements(client: &Client, statements: &[&str]) { + for statement in statements { + client + .query(statement) + .with_option("wait_end_of_query", "1") + .execute() + .await + .unwrap_or_else(|err| panic!("cannot execute statement '{statement}', cause: {err}")); + } +} + mod chrono; mod cloud_jwt; mod compression; diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 8a7e04dc..d4a82e45 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -1,4 +1,4 @@ -use crate::get_client; +use crate::{execute_statements, get_client}; use clickhouse::sql::Identifier; use clickhouse::validation_mode::ValidationMode; use clickhouse_derive::Row; @@ -945,31 +945,29 @@ async fn test_issue_109_1() { call_sign: String, } let client = prepare_database!().with_validation_mode(ValidationMode::Each); - let statements = vec![ - " - CREATE TABLE issue_109 ( - drone_id String, - call_sign String, - journey UInt32, - en_id String, - ) - ENGINE = MergeTree - ORDER BY (drone_id) - ", - " - INSERT INTO issue_109 VALUES - ('drone_1', 'call_sign_1', 1, 'en_id_1'), - ('drone_2', 'call_sign_2', 2, 'en_id_2'), - ('drone_3', 'call_sign_3', 3, 'en_id_3') - ", - ]; - for stmt in statements { - client - .query(stmt) - .execute() - .await - .unwrap_or_else(|e| panic!("Failed to execute query {stmt}, cause: {}", e)); - } + execute_statements( + &client, + &[ + " + CREATE TABLE issue_109 ( + drone_id String, + call_sign String, + journey UInt32, + en_id String, + ) + ENGINE = MergeTree + ORDER BY (drone_id) + ", + " + INSERT INTO issue_109 VALUES + ('drone_1', 'call_sign_1', 1, 'en_id_1'), + ('drone_2', 'call_sign_2', 2, 'en_id_2'), + ('drone_3', 'call_sign_3', 3, 'en_id_3') + ", + ], + ) + .await; + let data = client .query("SELECT journey, drone_id, call_sign FROM issue_109") .fetch_all::() @@ -1012,7 +1010,7 @@ async fn test_issue_113() { c: f64, } let client = prepare_database!().with_validation_mode(ValidationMode::Each); - let statements = vec![ + execute_statements(&client, &[ " CREATE TABLE issue_113_1( id UInt32 @@ -1030,14 +1028,7 @@ async fn test_issue_113() { ", "INSERT INTO issue_113_1 VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)", "INSERT INTO issue_113_2 VALUES (1, 100.5), (2, 200.2), (3, 300.3), (4, 444.4), (5, 555.5)", - ]; - for stmt in statements { - client - .query(stmt) - .execute() - .await - .unwrap_or_else(|e| panic!("Failed to execute query {stmt}, cause: {}", e)); - } + ]).await; // Struct should have had Option instead of f64 assert_panic_on_fetch_with_client!( @@ -1099,8 +1090,11 @@ async fn test_issue_173() { ts: time::OffsetDateTime, } - let client = prepare_database!().with_validation_mode(ValidationMode::Each); - let statements = vec![ + let client = prepare_database!() + .with_validation_mode(ValidationMode::Each) + .with_option("date_time_input_format", "best_effort"); + + execute_statements(&client, &[ " CREATE OR REPLACE TABLE logs ( log_id String, @@ -1111,16 +1105,7 @@ async fn test_issue_173() { ", "INSERT INTO logs VALUES ('56cde52f-5f34-45e0-9f08-79d6f582e913', '2024-11-05T11:52:52+01:00')", "INSERT INTO logs VALUES ('0e967129-6271-44f2-967b-0c8d11a60fdc', '2024-11-05T11:59:21+01:00')", - ]; - - for stmt in statements { - client - .query(stmt) - .with_option("date_time_input_format", "best_effort") - .execute() - .await - .unwrap_or_else(|e| panic!("Failed to execute query {stmt}, cause: {}", e)); - } + ]).await; // panics as we fetch `ts` two times: one from `?fields` macro, and the second time explicitly // the resulting dataset will, in fact, contain 3 columns instead of 2: @@ -1141,8 +1126,9 @@ async fn test_issue_185() { } let client = prepare_database!().with_validation_mode(ValidationMode::Each); - client - .query( + execute_statements( + &client, + &[ " CREATE TABLE issue_185( pk UInt32, @@ -1150,15 +1136,10 @@ async fn test_issue_185() { ENGINE MergeTree ORDER BY pk ", - ) - .execute() - .await - .unwrap(); - client - .query("INSERT INTO issue_185 VALUES (1, 1.1), (2, 2.2), (3, 3.3)") - .execute() - .await - .unwrap(); + "INSERT INTO issue_185 VALUES (1, 1.1), (2, 2.2), (3, 3.3)", + ], + ) + .await; assert_panic_on_fetch_with_client!( client, @@ -1167,6 +1148,35 @@ async fn test_issue_185() { ); } +#[tokio::test] +async fn test_issue_218() { + #[derive(Row, Serialize, Deserialize, Debug)] + struct Data { + max_time: chrono::DateTime, + } + + let client = prepare_database!().with_validation_mode(ValidationMode::Each); + execute_statements( + &client, + &[" + CREATE TABLE IF NOT EXISTS issue_218 ( + my_time DateTime64(3, 'UTC') CODEC(Delta, ZSTD), + ) ENGINE = MergeTree + ORDER BY my_time + "], + ) + .await; + + // FIXME: It is not a super clear panic as it hints about `&str`, + // and not about the missing attribute for `chrono::DateTime`. + // Still better than a `premature end of input` error, though. + assert_panic_on_fetch_with_client!( + client, + &["Data.max_time", "DateTime64(3, 'UTC')", "&str"], + "SELECT max(my_time) AS max_time FROM issue_218" + ); +} + #[tokio::test] async fn test_variant_wrong_definition() { #[derive(Debug, Deserialize, PartialEq)] From 6ba6abfa4af2534467193b75c47debc9ab1ddc2d Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Sat, 7 Jun 2025 21:23:26 +0200 Subject: [PATCH 24/54] More optimal struct name/fields acquisition, cleanup --- benches/select_numbers.rs | 10 +- derive/src/lib.rs | 2 + examples/mock.rs | 8 +- src/cursors/row.rs | 25 ++--- src/lib.rs | 8 +- src/row.rs | 16 +++ src/rowbinary/de.rs | 52 +++------- src/rowbinary/mod.rs | 4 +- src/rowbinary/validation.rs | 191 +++--------------------------------- src/struct_metadata.rs | 165 +++++++++++++++++++++++++++++++ src/test/handlers.rs | 11 ++- src/watch.rs | 6 +- tests/it/mock.rs | 4 +- 13 files changed, 255 insertions(+), 247 deletions(-) create mode 100644 src/struct_metadata.rs diff --git a/benches/select_numbers.rs b/benches/select_numbers.rs index 52494526..517044fb 100644 --- a/benches/select_numbers.rs +++ b/benches/select_numbers.rs @@ -48,9 +48,9 @@ async fn main() { println!("compress validation elapsed throughput received"); bench("none", Compression::None, ValidationMode::First(1)).await; bench("none", Compression::None, ValidationMode::Each).await; - // #[cfg(feature = "lz4")] - // { - // bench("lz4", Compression::Lz4, ValidationMode::First(1)).await; - // bench("lz4", Compression::Lz4, ValidationMode::Each).await; - // } + #[cfg(feature = "lz4")] + { + bench("lz4", Compression::Lz4, ValidationMode::First(1)).await; + bench("lz4", Compression::Lz4, ValidationMode::Each).await; + } } diff --git a/derive/src/lib.rs b/derive/src/lib.rs index bd5675a8..3988941d 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -58,7 +58,9 @@ pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let expanded = quote! { #[automatically_derived] impl #impl_generics clickhouse::Row for #name #ty_generics #where_clause { + const NAME: &'static str = stringify!(#name); const COLUMN_NAMES: &'static [&'static str] = #column_names; + const TYPE: clickhouse::RowType = clickhouse::RowType::Struct; } }; diff --git a/examples/mock.rs b/examples/mock.rs index f71bdc29..57452179 100644 --- a/examples/mock.rs +++ b/examples/mock.rs @@ -56,11 +56,11 @@ async fn main() { make_create(&client).await.unwrap(); assert!(recording.query().await.contains("CREATE TABLE")); + let metadata = + clickhouse::StructMetadata::new::(vec![Column::new("no".to_string(), UInt32)]); + // How to test SELECT. - mock.add(test::handlers::provide( - &[Column::new("no".to_string(), UInt32)], - list.clone(), - )); + mock.add(test::handlers::provide(&metadata, list.clone())); let rows = make_select(&client).await.unwrap(); assert_eq!(rows, list); diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 4be4044f..2ef93988 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -1,11 +1,11 @@ -use crate::rowbinary::StructMetadata; +use crate::struct_metadata::StructMetadata; use crate::validation_mode::ValidationMode; use crate::{ bytes_ext::BytesExt, cursors::RawCursor, error::{Error, Result}, response::Response, - rowbinary, + rowbinary, Row, }; use clickhouse_types::error::TypesError; use clickhouse_types::parse_rbwnat_columns_header; @@ -19,7 +19,7 @@ pub struct RowCursor { bytes: BytesExt, /// [`None`] until the first call to [`RowCursor::next()`], /// as [`RowCursor::new`] is not `async`, so it loads lazily. - struct_mapping: Option, + struct_metadata: Option, rows_to_validate: u64, _marker: PhantomData, } @@ -30,7 +30,7 @@ impl RowCursor { _marker: PhantomData, raw: RawCursor::new(response), bytes: BytesExt::default(), - struct_mapping: None, + struct_metadata: None, rows_to_validate: match validation_mode { ValidationMode::First(n) => n as u64, ValidationMode::Each => u64::MAX, @@ -40,14 +40,17 @@ impl RowCursor { #[cold] #[inline(never)] - async fn read_columns(&mut self) -> Result<()> { + async fn read_columns(&mut self) -> Result<()> + where + T: Row, + { loop { if self.bytes.remaining() > 0 { let mut slice = self.bytes.slice(); match parse_rbwnat_columns_header(&mut slice) { Ok(columns) if !columns.is_empty() => { self.bytes.set_remaining(slice.len()); - self.struct_mapping = Some(StructMetadata::new(columns)); + self.struct_metadata = Some(StructMetadata::new::(columns)); return Ok(()); } Ok(_) => { @@ -64,7 +67,7 @@ impl RowCursor { } match self.raw.next().await? { Some(chunk) => self.bytes.extend(chunk), - None if self.struct_mapping.is_none() => { + None if self.struct_metadata.is_none() => { return Err(Error::BadResponse( "Could not read columns header".to_string(), )); @@ -84,11 +87,11 @@ impl RowCursor { /// This method is cancellation safe. pub async fn next<'cursor, 'data: 'cursor>(&'cursor mut self) -> Result> where - T: Deserialize<'data>, + T: Deserialize<'data> + Row, { loop { if self.bytes.remaining() > 0 { - if self.struct_mapping.is_none() { + if self.struct_metadata.is_none() { self.read_columns().await?; if self.bytes.remaining() == 0 { continue; @@ -98,12 +101,12 @@ impl RowCursor { let (result, not_enough_data) = match self.rows_to_validate { 0 => rowbinary::deserialize_from::(&mut slice, None), u64::MAX => { - rowbinary::deserialize_from::(&mut slice, self.struct_mapping.as_mut()) + rowbinary::deserialize_from::(&mut slice, self.struct_metadata.as_ref()) } _ => { let result = rowbinary::deserialize_from::( &mut slice, - self.struct_mapping.as_mut(), + self.struct_metadata.as_ref(), ); self.rows_to_validate -= 1; result diff --git a/src/lib.rs b/src/lib.rs index 55c2221f..3767480d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,11 +5,12 @@ #[macro_use] extern crate static_assertions; +#[cfg(feature = "test-util")] +pub use self::struct_metadata::StructMetadata; +pub use self::{compression::Compression, row::Row, row::RowType}; use self::{error::Result, http_client::HttpClient, validation_mode::ValidationMode}; -use std::{collections::HashMap, fmt::Display, sync::Arc}; - -pub use self::{compression::Compression, row::Row}; pub use clickhouse_derive::Row; +use std::{collections::HashMap, fmt::Display, sync::Arc}; pub mod error; pub mod insert; @@ -33,6 +34,7 @@ mod request_body; mod response; mod row; mod rowbinary; +mod struct_metadata; #[cfg(feature = "inserter")] mod ticks; diff --git a/src/row.rs b/src/row.rs index c5ca6808..7418dacb 100644 --- a/src/row.rs +++ b/src/row.rs @@ -1,7 +1,17 @@ use crate::sql; +#[derive(Debug, Clone)] +pub enum RowType { + Primitive, + Struct, + Tuple, + Vec, +} + pub trait Row { + const NAME: &'static str; const COLUMN_NAMES: &'static [&'static str]; + const TYPE: RowType; // TODO: count // TODO: different list for SELECT/INSERT (de/ser) @@ -32,7 +42,9 @@ macro_rules! impl_row_for_tuple { /// The second one is useful for queries like /// `SELECT ?fields, count() FROM .. GROUP BY ?fields`. impl<$i: Row, $($other: Primitive),+> Row for ($i, $($other),+) { + const NAME: &'static str = $i::NAME; const COLUMN_NAMES: &'static [&'static str] = $i::COLUMN_NAMES; + const TYPE: RowType = RowType::Tuple; } impl_row_for_tuple!($($other)+); @@ -44,13 +56,17 @@ macro_rules! impl_row_for_tuple { impl Primitive for () {} impl Row for P { + const NAME: &'static str = stringify!(P); const COLUMN_NAMES: &'static [&'static str] = &[]; + const TYPE: RowType = RowType::Primitive; } impl_row_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8); impl Row for Vec { + const NAME: &'static str = "Vec"; const COLUMN_NAMES: &'static [&'static str] = &[]; + const TYPE: RowType = RowType::Vec; } /// Collects all field names in depth and joins them with comma. diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index b78c2461..9762b8ec 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -2,7 +2,7 @@ use crate::error::{Error, Result}; use crate::rowbinary::utils::{ensure_size, get_unsigned_leb128}; use crate::rowbinary::validation::SerdeType; use crate::rowbinary::validation::{DataTypeValidator, SchemaValidator}; -use crate::rowbinary::StructMetadata; +use crate::struct_metadata::StructMetadata; use bytes::Buf; use core::mem::size_of; use serde::de::MapAccess; @@ -26,13 +26,13 @@ use std::{convert::TryFrom, str}; /// After the header, the rows format is the same as `RowBinary`. pub(crate) fn deserialize_from<'data, 'cursor, T: Deserialize<'data>>( input: &mut &'data [u8], - mapping: Option<&'cursor mut StructMetadata>, + metadata: Option<&'cursor StructMetadata>, ) -> (Result, bool) { - let result = if mapping.is_none() { + let result = if metadata.is_none() { let mut deserializer = RowBinaryDeserializer::new(input, ()); T::deserialize(&mut deserializer) } else { - let validator = DataTypeValidator::new(mapping.unwrap()); + let validator = DataTypeValidator::new(metadata.unwrap()); let mut deserializer = RowBinaryDeserializer::new(input, validator); T::deserialize(&mut deserializer) }; @@ -166,8 +166,6 @@ where #[inline(always)] fn deserialize_str>(self, visitor: V) -> Result { - // println!("deserialize_str call"); - self.validator.validate(SerdeType::Str)?; let size = self.read_size()?; let slice = self.read_slice(size)?; @@ -177,8 +175,6 @@ where #[inline(always)] fn deserialize_string>(self, visitor: V) -> Result { - // println!("deserialize_string call"); - self.validator.validate(SerdeType::String)?; let size = self.read_size()?; let vec = self.read_vec(size)?; @@ -188,8 +184,6 @@ where #[inline(always)] fn deserialize_bytes>(self, visitor: V) -> Result { - // println!("deserialize_bytes call"); - let size = self.read_size()?; self.validator.validate(SerdeType::Bytes(size))?; let slice = self.read_slice(size)?; @@ -198,8 +192,6 @@ where #[inline(always)] fn deserialize_byte_buf>(self, visitor: V) -> Result { - // println!("deserialize_byte_buf call"); - let size = self.read_size()?; self.validator.validate(SerdeType::ByteBuf(size))?; visitor.visit_byte_buf(self.read_vec(size)?) @@ -207,8 +199,6 @@ where #[inline(always)] fn deserialize_identifier>(self, visitor: V) -> Result { - // println!("deserialize_identifier call"); - ensure_size(&mut self.input, size_of::())?; let value = self.input.get_u8(); // TODO: is there a better way to validate that the deserialized value matches the schema? @@ -223,8 +213,6 @@ where _variants: &'static [&'static str], visitor: V, ) -> Result { - // println!("deserialize_enum call"); - let validator = self.validator.validate(SerdeType::Enum)?; visitor.visit_enum(RowBinaryEnumAccess { deserializer: &mut RowBinaryDeserializer { @@ -236,8 +224,6 @@ where #[inline(always)] fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { - // println!("deserialize_tuple call, len {}", len); - let validator = self.validator.validate(SerdeType::Tuple(len))?; let mut de = RowBinaryDeserializer { input: self.input, @@ -252,8 +238,6 @@ where #[inline(always)] fn deserialize_option>(self, visitor: V) -> Result { - // println!("deserialize_option call"); - ensure_size(&mut self.input, 1)?; let inner_validator = self.validator.validate(SerdeType::Option)?; match self.input.get_u8() { @@ -268,8 +252,6 @@ where #[inline(always)] fn deserialize_seq>(self, visitor: V) -> Result { - // println!("deserialize_seq call"); - let len = self.read_size()?; visitor.visit_seq(RowBinarySeqAccess { deserializer: &mut RowBinaryDeserializer { @@ -282,10 +264,6 @@ where #[inline(always)] fn deserialize_map>(self, visitor: V) -> Result { - // println!( - // "deserialize_map call", - // ); - let len = self.read_size()?; let validator = self.validator.validate(SerdeType::Map(len))?; visitor.visit_map(RowBinaryMapAccess { @@ -301,14 +279,11 @@ where #[inline(always)] fn deserialize_struct>( self, - name: &'static str, + _name: &'static str, fields: &'static [&'static str], visitor: V, ) -> Result { - // println!("deserialize_struct: {} (fields: {:?})", name, fields,); - - let should_use_map_access = self.validator.ensure_struct_metadata(name, fields); - if !should_use_map_access { + if !self.validator.is_field_order_wrong() { visitor.visit_seq(RowBinarySeqAccess { deserializer: self, len: fields.len(), @@ -441,8 +416,8 @@ where } } -/// Used in [`Deserializer::deserialize_struct`] to support wrong field order -/// as long as the data types are exactly matching the database schema. +/// Used in [`Deserializer::deserialize_struct`] to support wrong struct field order +/// as long as the data types and field names are exactly matching the database schema. struct RowBinaryStructAsMapAccess<'de, 'cursor, 'data, Validator> where Validator: SchemaValidator, @@ -493,6 +468,9 @@ impl<'de> Deserializer<'de> for StructFieldIdentifier { /// a: String, /// } /// ``` +/// +/// If we just use [`RowBinarySeqAccess`] here, `c` will be deserialized into the `a` field, +/// and `a` will be deserialized into the `c` field, which is a classic case of data corruption. impl<'data, Validator> MapAccess<'data> for RowBinaryStructAsMapAccess<'_, '_, 'data, Validator> where Validator: SchemaValidator, @@ -511,10 +489,6 @@ where .validator .get_schema_index(self.current_field_idx); let field_id = StructFieldIdentifier(self.fields[schema_index]); - // println!( - // "RowBinaryStructAsMapAccess::next_key_seed: field_id: {}", - // field_id.0 - // ); self.current_field_idx += 1; seed.deserialize(field_id).map(Some) } @@ -523,10 +497,6 @@ where where V: DeserializeSeed<'data>, { - // println!( - // "RowBinaryStructAsMapAccess::next_value_seed: current_field_idx: {}", - // self.current_field_idx - // ); seed.deserialize(&mut *self.deserializer) } diff --git a/src/rowbinary/mod.rs b/src/rowbinary/mod.rs index 5a24975b..a465a2cc 100644 --- a/src/rowbinary/mod.rs +++ b/src/rowbinary/mod.rs @@ -1,10 +1,10 @@ pub(crate) use de::deserialize_from; pub(crate) use ser::serialize_into; -pub(crate) use validation::StructMetadata; + +pub(crate) mod validation; mod de; mod ser; #[cfg(test)] mod tests; mod utils; -mod validation; diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 8bc677d8..7c3a50e9 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -1,4 +1,5 @@ use crate::error::Result; +use crate::struct_metadata::StructMetadata; use clickhouse_types::data_types::{Column, DataTypeNode, DecimalType, EnumType}; use std::collections::HashMap; use std::fmt::Display; @@ -11,145 +12,18 @@ pub(crate) trait SchemaValidator: Sized { fn validate_enum8_value(&mut self, value: i8); fn validate_enum16_value(&mut self, value: i16); fn set_next_variant_value(&mut self, value: u8); - fn ensure_struct_metadata( - &'_ mut self, - name: &'static str, - fields: &'static [&'static str], - ) -> bool; fn get_schema_index(&self, struct_idx: usize) -> usize; -} - -#[derive(Debug, PartialEq)] -enum StructMetadataState { - Pending, - WithSeqAccess, - WithMapAccess(Vec), -} - -/// #### StructMetadata -/// -/// Should reside outside the (de)serializer, so it is calculated only once per struct. -/// No lifetimes, so it does not introduce a breaking change to [`crate::cursors::RowCursor`]. -/// -/// #### Lifecycle -/// -/// - the first call to [`crate::cursors::RowCursor::next`] creates an instance with `columns`. -/// - the first call to [`serde::Deserializer::deserialize_struct`] sets the `struct_name`, -/// and the field order is checked. If the order is different from the schema, the state is set to -/// [`StructMetadataState::WithMapAccess`], otherwise to [`StructMetadataState::WithSeqAccess`]. -/// - the following calls to [`crate::cursors::RowCursor::next`] and, consequently, -/// to [`serde::Deserializer::deserialize_struct`], will re-use the same prepared instance, -/// without re-checking the fields order for every struct. -pub(crate) struct StructMetadata { - /// Struct name is defined after the first call to [`serde::Deserializer::deserialize_struct`]. - /// If we are deserializing any other type, e.g., [`u64`], [`Vec`], etc., it is [`None`], - /// and it affects how the validation works, see [`DataTypeValidator::validate`]. - pub(crate) struct_name: Option<&'static str>, - /// Database schema, or columns, are parsed before the first call to (de)serializer. - pub(crate) columns: Vec, - /// This state determines whether we can just use [`crate::rowbinary::de::RowBinarySeqAccess`] - /// or a more sophisticated approach with [`crate::rowbinary::de::RowBinaryStructAsMapAccess`] - /// to support structs defined with different fields order than in the schema. - /// Deserializing a struct as a map will be approximately 40% slower than as a sequence. - state: StructMetadataState, -} - -impl StructMetadata { - pub(crate) fn new(columns: Vec) -> Self { - Self { - columns, - struct_name: None, - state: StructMetadataState::Pending, - } - } - - #[inline(always)] - pub(crate) fn check_should_use_map( - &mut self, - name: &'static str, - fields: &'static [&'static str], - ) -> bool { - match &self.state { - StructMetadataState::WithSeqAccess => false, - StructMetadataState::WithMapAccess(_) => true, - StructMetadataState::Pending => { - if self.columns.len() != fields.len() { - panic!( - "While processing struct {}: database schema has {} columns, \ - but the struct definition has {} fields.\ - \n#### All struct fields:\n{}\n#### All schema columns:\n{}", - name, - self.columns.len(), - fields.len(), - join_panic_schema_hint(fields), - join_panic_schema_hint(&self.columns), - ); - } - let mut mapping = Vec::with_capacity(fields.len()); - let mut expected_index = 0; - let mut should_use_map = false; - for col in &self.columns { - if let Some(index) = fields.iter().position(|field| col.name == *field) { - if index != expected_index { - should_use_map = true - } - expected_index += 1; - mapping.push(index); - } else { - panic!( - "While processing struct {}: database schema has a column {} \ - that was not found in the struct definition.\ - \n#### All struct fields:\n{}\n#### All schema columns:\n{}", - name, - col, - join_panic_schema_hint(fields), - join_panic_schema_hint(&self.columns), - ); - } - } - self.state = if should_use_map { - StructMetadataState::WithMapAccess(mapping) - } else { - StructMetadataState::WithSeqAccess - }; - true - } - } - } - - #[inline(always)] - pub(crate) fn get_schema_index(&self, struct_idx: usize) -> usize { - match &self.state { - StructMetadataState::WithMapAccess(mapping) => { - if struct_idx < mapping.len() { - mapping[struct_idx] - } else { - panic!( - "Struct {} has more fields than columns in the database schema", - self.struct_name.unwrap_or("Struct") - ) - } - } - // these two branches should be unreachable - StructMetadataState::WithSeqAccess => struct_idx, - StructMetadataState::Pending => { - panic!( - "Struct metadata is not initialized yet, \ - `ensure_struct_metadata` should be called first" - ) - } - } - } + fn is_field_order_wrong(&self) -> bool; } pub(crate) struct DataTypeValidator<'cursor> { - metadata: &'cursor mut StructMetadata, + metadata: &'cursor StructMetadata, current_column_idx: usize, } impl<'cursor> DataTypeValidator<'cursor> { #[inline(always)] - pub(crate) fn new(metadata: &'cursor mut StructMetadata) -> Self { + pub(crate) fn new(metadata: &'cursor StructMetadata) -> Self { Self { current_column_idx: 0, metadata, @@ -170,7 +44,7 @@ impl<'cursor> DataTypeValidator<'cursor> { self.get_current_column() .map(|c| { ( - format!("{}.{}", self.get_struct_name(), c.name), + format!("{}.{}", self.metadata.struct_name, c.name), &c.data_type, ) }) @@ -178,11 +52,6 @@ impl<'cursor> DataTypeValidator<'cursor> { .unwrap_or(("Struct".to_string(), &DataTypeNode::Bool)) } - fn get_struct_name(&self) -> String { - // should be available at the time of the panic call - self.metadata.struct_name.unwrap_or("Struct").to_string() - } - #[inline(always)] fn panic_on_schema_mismatch<'de>( &'de self, @@ -215,8 +84,8 @@ impl SchemaValidator for DataTypeValidator<'_> { &'_ mut self, serde_type: SerdeType, ) -> Result>> { - if self.current_column_idx == 0 && self.metadata.struct_name.is_none() { - // this allows validating and deserializing tuples from fetch calls + if self.current_column_idx == 0 && !self.metadata.is_struct() { + // this allows validating and deserializing tuples/vectors/primitives from fetch calls Ok(Some(InnerDataTypeValidator { root: self, kind: if matches!(serde_type, SerdeType::Seq(_)) && self.metadata.columns.len() == 1 @@ -242,21 +111,13 @@ impl SchemaValidator for DataTypeValidator<'_> { } else { panic!( "Struct {} has more fields than columns in the database schema", - self.get_struct_name() + self.metadata.struct_name ) } } - #[inline(always)] - fn ensure_struct_metadata( - &'_ mut self, - name: &'static str, - fields: &'static [&'static str], - ) -> bool { - if self.metadata.struct_name.is_none() { - self.metadata.struct_name = Some(name); - } - self.metadata.check_should_use_map(name, fields) + fn is_field_order_wrong(&self) -> bool { + true } #[cold] @@ -331,7 +192,6 @@ impl<'de, 'cursor> SchemaValidator for Option Result>> { - // println!("[validate] Validating serde type: {}", serde_type); match self { None => Ok(None), Some(inner) => match &mut inner.kind { @@ -404,7 +264,6 @@ impl<'de, 'cursor> SchemaValidator for Option { - // println!("Validating variant identifier: {}", value); if *value as usize >= possible_types.len() { let (full_name, full_data_type) = inner.root.get_current_column_name_and_type(); @@ -474,11 +333,7 @@ impl<'de, 'cursor> SchemaValidator for Option bool { + fn is_field_order_wrong(&self) -> bool { false } @@ -515,12 +370,10 @@ fn validate_impl<'de, 'cursor>( serde_type: &SerdeType, is_inner: bool, ) -> Result>> { - // println!( - // "Validating data type: {} against serde type: {}", - // column_data_type, serde_type, - // ); let data_type = column_data_type.remove_low_cardinality(); - // TODO: eliminate multiple branches with similar patterns? + // TODO: is there a way to eliminate multiple branches with similar patterns? + // static/const dispatch? + // separate smaller inline functions? match serde_type { SerdeType::Bool if data_type == &DataTypeNode::Bool || data_type == &DataTypeNode::UInt8 => @@ -765,11 +618,7 @@ impl SchemaValidator for () { fn set_next_variant_value(&mut self, _value: u8) {} #[inline(always)] - fn ensure_struct_metadata( - &mut self, - _name: &'static str, - _fields: &'static [&'static str], - ) -> bool { + fn is_field_order_wrong(&self) -> bool { false } @@ -851,13 +700,3 @@ impl Display for SerdeType { } } } - -fn join_panic_schema_hint(col: &[T]) -> String { - if col.is_empty() { - return String::default(); - } - col.iter() - .map(|c| format!("- {}", c)) - .collect::>() - .join("\n") -} diff --git a/src/struct_metadata.rs b/src/struct_metadata.rs new file mode 100644 index 00000000..28c4bac2 --- /dev/null +++ b/src/struct_metadata.rs @@ -0,0 +1,165 @@ +// FIXME: this is allowed only temporarily, +// before the insert RBWNAT implementation is ready, +// cause otherwise the caches are never used. +#![allow(dead_code)] + +use crate::row::RowType; +use crate::sql::Identifier; +use crate::Result; +use crate::Row; +use clickhouse_types::{parse_rbwnat_columns_header, Column}; +use std::collections::HashMap; +use std::fmt::Display; +use std::sync::Arc; +use tokio::sync::{OnceCell, RwLock}; + +/// Cache for [`StructMetadata`] to avoid allocating it for the same struct more than once +/// during the application lifecycle. Key: fully qualified table name (e.g. `database.table`). +type LockedStructMetadataCache = RwLock>>; +static STRUCT_METADATA_CACHE: OnceCell = OnceCell::const_new(); + +#[derive(Debug, PartialEq)] +enum AccessType { + WithSeqAccess, + WithMapAccess(Vec), +} + +/// [`StructMetadata`] should be owned outside the (de)serializer, +/// as it is calculated only once per struct. It does not have lifetimes, +/// so it does not introduce a breaking change to [`crate::cursors::RowCursor`]. +pub struct StructMetadata { + /// See [`Row::NAME`] + pub(crate) struct_name: &'static str, + /// See [`Row::COLUMN_NAMES`] (currently unused) + // pub(crate) struct_fields: &'static [&'static str], + /// See [`Row::TYPE`] + pub(crate) row_type: RowType, + /// Database schema, or columns, are parsed before the first call to (de)serializer. + pub(crate) columns: Vec, + /// This determines whether we can just use [`crate::rowbinary::de::RowBinarySeqAccess`] + /// or a more sophisticated approach with [`crate::rowbinary::de::RowBinaryStructAsMapAccess`] + /// to support structs defined with different fields order than in the schema. + /// (De)serializing a struct as a map will be approximately 40% slower than as a sequence. + access_type: AccessType, +} + +impl StructMetadata { + // FIXME: perhaps it should not be public? But it is required for mocks/provide. + pub fn new(columns: Vec) -> Self { + let struct_name = T::NAME; + let struct_fields = T::COLUMN_NAMES; + if columns.len() != struct_fields.len() { + panic!( + "While processing struct {}: database schema has {} columns, \ + but the struct definition has {} fields.\ + \n#### All struct fields:\n{}\n#### All schema columns:\n{}", + struct_name, + columns.len(), + struct_fields.len(), + join_panic_schema_hint(struct_fields), + join_panic_schema_hint(&columns), + ); + } + let mut mapping = Vec::with_capacity(struct_fields.len()); + let mut expected_index = 0; + let mut should_use_map = false; + for col in &columns { + if let Some(index) = struct_fields.iter().position(|field| col.name == *field) { + if index != expected_index { + should_use_map = true + } + expected_index += 1; + mapping.push(index); + } else { + panic!( + "While processing struct {}: database schema has a column {} \ + that was not found in the struct definition.\ + \n#### All struct fields:\n{}\n#### All schema columns:\n{}", + struct_name, + col, + join_panic_schema_hint(struct_fields), + join_panic_schema_hint(&columns), + ); + } + } + Self { + columns, + struct_name, + // struct_fields, + row_type: T::TYPE, + access_type: if should_use_map { + AccessType::WithMapAccess(mapping) + } else { + AccessType::WithSeqAccess + }, + } + } + + #[inline(always)] + pub(crate) fn is_struct(&self) -> bool { + matches!(self.row_type, RowType::Struct) + } + + #[inline(always)] + pub(crate) fn get_schema_index(&self, struct_idx: usize) -> usize { + match &self.access_type { + AccessType::WithMapAccess(mapping) => { + if struct_idx < mapping.len() { + mapping[struct_idx] + } else { + panic!( + "Struct {} has more fields than columns in the database schema", + self.struct_name + ) + } + } + AccessType::WithSeqAccess => struct_idx, // should be unreachable + } + } +} + +pub(crate) async fn get_struct_metadata( + client: &crate::Client, + table_name: &str, +) -> Result> { + let locked_cache = STRUCT_METADATA_CACHE + .get_or_init(|| async { RwLock::new(HashMap::new()) }) + .await; + let cache_guard = locked_cache.read().await; + match cache_guard.get(table_name) { + Some(metadata) => Ok(metadata.clone()), + None => cache_struct_metadata::(client, table_name, locked_cache).await, + } +} + +/// Used internally to introspect and cache the table structure to allow validation +/// of serialized rows before submitting the first [`insert::Insert::write`]. +async fn cache_struct_metadata( + client: &crate::Client, + table_name: &str, + locked_cache: &LockedStructMetadataCache, +) -> Result> { + let mut bytes_cursor = client + .query("SELECT * FROM ? LIMIT 0") + .bind(Identifier(table_name)) + .fetch_bytes("RowBinaryWithNamesAndTypes")?; + let mut buffer = Vec::::new(); + while let Some(chunk) = bytes_cursor.next().await? { + buffer.extend_from_slice(&chunk); + } + let columns = parse_rbwnat_columns_header(&mut buffer.as_slice())?; + let mut cache = locked_cache.write().await; + let metadata = Arc::new(StructMetadata::new::(columns)); + cache.insert(table_name.to_string(), metadata.clone()); + Ok(metadata) +} + +fn join_panic_schema_hint(col: &[T]) -> String { + if col.is_empty() { + return String::default(); + } + col.iter() + .map(|c| format!("- {}", c)) + .collect::>() + .join("\n") +} diff --git a/src/test/handlers.rs b/src/test/handlers.rs index 42a1bfb7..4116b8ef 100644 --- a/src/test/handlers.rs +++ b/src/test/handlers.rs @@ -1,7 +1,7 @@ use std::marker::PhantomData; use bytes::Bytes; -use clickhouse_types::{put_rbwnat_columns_header, Column}; +use clickhouse_types::put_rbwnat_columns_header; use futures::channel::oneshot; use hyper::{Request, Response, StatusCode}; use sealed::sealed; @@ -9,6 +9,7 @@ use serde::{Deserialize, Serialize}; use super::{Handler, HandlerFn}; use crate::rowbinary; +use crate::struct_metadata::StructMetadata; const BUFFER_INITIAL_CAPACITY: usize = 1024; @@ -41,12 +42,16 @@ pub fn failure(status: StatusCode) -> impl Handler { // === provide === #[track_caller] -pub fn provide(schema: &[Column], rows: impl IntoIterator) -> impl Handler +pub fn provide( + struct_metadata: &StructMetadata, + rows: impl IntoIterator, +) -> impl Handler where T: Serialize, { let mut buffer = Vec::with_capacity(BUFFER_INITIAL_CAPACITY); - put_rbwnat_columns_header(schema, &mut buffer).expect("failed to write columns header"); + put_rbwnat_columns_header(&struct_metadata.columns, &mut buffer) + .expect("failed to write columns header"); for row in rows { rowbinary::serialize_into(&mut buffer, &row).expect("failed to serialize"); } diff --git a/src/watch.rs b/src/watch.rs index 109b642e..f47c8221 100644 --- a/src/watch.rs +++ b/src/watch.rs @@ -153,13 +153,15 @@ struct EventPayload { } impl Row for EventPayload { + const NAME: &'static str = "EventPayload"; const COLUMN_NAMES: &'static [&'static str] = &[]; + const TYPE: crate::row::RowType = crate::row::RowType::Struct; } impl EventCursor { /// Emits the next version. /// - /// An result is unspecified if it's called after `Err` is returned. + /// The result is unspecified if it's called after `Err` is returned. pub async fn next(&mut self) -> Result> { Ok(self.0.next().await?.map(|payload| payload.version)) } @@ -178,7 +180,9 @@ struct RowPayload { } impl Row for RowPayload { + const NAME: &'static str = T::NAME; const COLUMN_NAMES: &'static [&'static str] = T::COLUMN_NAMES; + const TYPE: crate::row::RowType = T::TYPE; } impl RowCursor { diff --git a/tests/it/mock.rs b/tests/it/mock.rs index e7dd9f5f..9ea31a73 100644 --- a/tests/it/mock.rs +++ b/tests/it/mock.rs @@ -14,7 +14,9 @@ async fn test_provide() { Column::new("id".to_string(), DataTypeNode::UInt64), Column::new("data".to_string(), DataTypeNode::String), ]; - mock.add(test::handlers::provide(&columns, &expected)); + + let metadata = clickhouse::StructMetadata::new::(columns); + mock.add(test::handlers::provide(&metadata, &expected)); let actual = crate::fetch_rows::(&client, "doesn't matter").await; assert_eq!(actual, expected); From fb49a24a27c22af2cff1187f0cc65f0b2d1ebcd9 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Sat, 7 Jun 2025 21:49:23 +0200 Subject: [PATCH 25/54] Temporarily allow unreachable items --- src/struct_metadata.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/struct_metadata.rs b/src/struct_metadata.rs index 28c4bac2..1075f91a 100644 --- a/src/struct_metadata.rs +++ b/src/struct_metadata.rs @@ -2,6 +2,7 @@ // before the insert RBWNAT implementation is ready, // cause otherwise the caches are never used. #![allow(dead_code)] +#![allow(unreachable_pub)] use crate::row::RowType; use crate::sql::Identifier; @@ -31,7 +32,7 @@ pub struct StructMetadata { /// See [`Row::NAME`] pub(crate) struct_name: &'static str, /// See [`Row::COLUMN_NAMES`] (currently unused) - // pub(crate) struct_fields: &'static [&'static str], + pub(crate) struct_fields: &'static [&'static str], /// See [`Row::TYPE`] pub(crate) row_type: RowType, /// Database schema, or columns, are parsed before the first call to (de)serializer. @@ -85,7 +86,7 @@ impl StructMetadata { Self { columns, struct_name, - // struct_fields, + struct_fields, row_type: T::TYPE, access_type: if should_use_map { AccessType::WithMapAccess(mapping) From 52d095314c4b8a32d2507f90ed0bfa3c7a42162b Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Sat, 7 Jun 2025 21:51:44 +0200 Subject: [PATCH 26/54] Add chrono feature to RBWNAT tests --- tests/it/rbwnat.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index d4a82e45..8c48a6ba 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -1149,6 +1149,7 @@ async fn test_issue_185() { } #[tokio::test] +#[cfg(feature = "chrono")] async fn test_issue_218() { #[derive(Row, Serialize, Deserialize, Debug)] struct Data { From 5ffae76d63d6ce0b62209e367e3db465ed96f19a Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 9 Jun 2025 00:40:43 +0200 Subject: [PATCH 27/54] Allow root primitives, rework benchmarks, address (most of) PR feedback --- Cargo.toml | 3 +- benches/common_select.rs | 141 +++++++++++++++++++++++++++ benches/select_numbers.rs | 63 +++++-------- benches/select_nyc_taxi_data.rs | 88 +++++++++-------- src/cursors/row.rs | 2 +- src/error.rs | 4 +- src/row.rs | 2 +- src/rowbinary/validation.rs | 162 ++++++++++++++++++++++---------- src/struct_metadata.rs | 123 +++++++++++++++--------- src/validation_mode.rs | 15 +-- tests/it/main.rs | 6 +- tests/it/rbwnat.rs | 90 ++++++++++++++++++ types/Cargo.toml | 2 +- 13 files changed, 507 insertions(+), 194 deletions(-) create mode 100644 benches/common_select.rs diff --git a/Cargo.toml b/Cargo.toml index 54f13caf..5b71ded9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -103,7 +103,7 @@ rustls-tls-native-roots = [ [dependencies] clickhouse-derive = { version = "0.2.0", path = "derive" } -clickhouse-types = { version = "*", path = "types" } +clickhouse-types = { version = "0.1.0", path = "types" } thiserror = "1.0.16" serde = "1.0.106" bytes = "1.5.0" @@ -137,7 +137,6 @@ replace_with = { version = "0.1.7" } [dev-dependencies] criterion = "0.5.0" -tracy-client = { version = "0.18.0", features = ["enable"]} serde = { version = "1.0.106", features = ["derive"] } tokio = { version = "1.0.1", features = ["full", "test-util"] } hyper = { version = "1.1", features = ["server"] } diff --git a/benches/common_select.rs b/benches/common_select.rs new file mode 100644 index 00000000..0d653e14 --- /dev/null +++ b/benches/common_select.rs @@ -0,0 +1,141 @@ +#![allow(dead_code)] + +use clickhouse::query::RowCursor; +use clickhouse::validation_mode::ValidationMode; +use clickhouse::{Client, Compression, Row}; +use criterion::black_box; +use serde::Deserialize; +use std::time::{Duration, Instant}; + +pub(crate) trait WithId { + fn id(&self) -> u64; +} +pub(crate) trait WithAccessType { + const ACCESS_TYPE: &'static str; +} +pub(crate) trait BenchmarkRow<'a>: Row + Deserialize<'a> + WithId + WithAccessType {} + +#[macro_export] +macro_rules! impl_benchmark_row { + ($type:ty, $id_field:ident, $access_type:literal) => { + impl WithId for $type { + fn id(&self) -> u64 { + self.$id_field as u64 + } + } + + impl WithAccessType for $type { + const ACCESS_TYPE: &'static str = $access_type; + } + + impl<'a> BenchmarkRow<'a> for $type {} + }; +} + +#[macro_export] +macro_rules! impl_benchmark_row_no_access_type { + ($type:ty, $id_field:ident) => { + impl WithId for $type { + fn id(&self) -> u64 { + self.$id_field + } + } + + impl WithAccessType for $type { + const ACCESS_TYPE: &'static str = ""; + } + + impl<'a> BenchmarkRow<'a> for $type {} + }; +} + +pub(crate) fn print_header(add: Option<&str>) { + let add = add.unwrap_or(""); + println!("compress validation elapsed throughput received{add}"); +} + +pub(crate) fn print_results<'a, T: BenchmarkRow<'a>>( + stats: &BenchmarkStats, + compression: Compression, + validation_mode: ValidationMode, +) { + let BenchmarkStats { + throughput_mbytes_sec, + received_mbytes, + elapsed, + .. + } = stats; + let validation_mode = match validation_mode { + ValidationMode::First(n) => format!("First({})", n), + ValidationMode::Each => "Each".to_string(), + _ => panic!("Unexpected validation mode"), + }; + let compression = match compression { + Compression::None => "none", + Compression::Lz4 => "lz4", + _ => panic!("Unexpected compression mode"), + }; + let access = if T::ACCESS_TYPE.is_empty() { + "" + } else { + let access_type = T::ACCESS_TYPE; + &format!(" {access_type:>6}") + }; + println!("{compression:>8} {validation_mode:>10} {elapsed:>9.3?} {throughput_mbytes_sec:>4.0} MiB/s {received_mbytes:>4.0} MiB{access}"); +} + +pub(crate) async fn fetch_cursor<'a, T: BenchmarkRow<'a>>( + compression: Compression, + validation_mode: ValidationMode, + query: &str, +) -> RowCursor { + let client = Client::default() + .with_compression(compression) + .with_validation_mode(validation_mode) + .with_url("http://localhost:8123"); + client.query(query).fetch::().unwrap() +} + +pub(crate) async fn do_select_bench<'a, T: BenchmarkRow<'a>>( + query: &str, + compression: Compression, + validation_mode: ValidationMode, +) -> BenchmarkStats { + let start = Instant::now(); + let mut cursor = fetch_cursor::(compression, validation_mode, query).await; + + let mut sum = 0; + while let Some(row) = cursor.next().await.unwrap() { + sum += row.id(); + black_box(&row); + } + + BenchmarkStats::new(&cursor, &start, sum) +} + +pub(crate) struct BenchmarkStats { + pub(crate) throughput_mbytes_sec: f64, + pub(crate) decoded_mbytes: f64, + pub(crate) received_mbytes: f64, + pub(crate) elapsed: Duration, + // RustRover is unhappy with pub(crate) + pub result: R, +} + +impl BenchmarkStats { + pub(crate) fn new(cursor: &RowCursor, start: &Instant, result: R) -> Self { + let elapsed = start.elapsed(); + let dec_bytes = cursor.decoded_bytes(); + let decoded_mbytes = dec_bytes as f64 / 1024.0 / 1024.0; + let recv_bytes = cursor.received_bytes(); + let received_mbytes = recv_bytes as f64 / 1024.0 / 1024.0; + let throughput_mbytes_sec = decoded_mbytes / elapsed.as_secs_f64(); + BenchmarkStats { + throughput_mbytes_sec, + decoded_mbytes, + received_mbytes, + elapsed, + result, + } + } +} diff --git a/benches/select_numbers.rs b/benches/select_numbers.rs index 517044fb..2cc98aab 100644 --- a/benches/select_numbers.rs +++ b/benches/select_numbers.rs @@ -1,56 +1,39 @@ use serde::Deserialize; +use crate::common_select::{ + do_select_bench, print_header, print_results, BenchmarkRow, WithAccessType, WithId, +}; use clickhouse::validation_mode::ValidationMode; -use clickhouse::{Client, Compression, Row}; +use clickhouse::{Compression, Row}; + +mod common_select; #[derive(Row, Deserialize)] struct Data { - #[serde(rename = "number")] - no: u64, -} - -async fn bench(name: &str, compression: Compression, validation_mode: ValidationMode) { - let start = std::time::Instant::now(); - let (sum, dec_mbytes, rec_mbytes) = tokio::spawn(do_bench(compression, validation_mode)) - .await - .unwrap(); - assert_eq!(sum, 124999999750000000); - let elapsed = start.elapsed(); - let throughput = dec_mbytes / elapsed.as_secs_f64(); - println!("{name:>8} {validation_mode:>10} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB"); + number: u64, } -async fn do_bench(compression: Compression, validation_mode: ValidationMode) -> (u64, f64, f64) { - let client = Client::default() - .with_compression(compression) - .with_validation_mode(validation_mode) - .with_url("http://localhost:8123"); - - let mut cursor = client - .query("SELECT number FROM system.numbers_mt LIMIT 500000000") - .fetch::() - .unwrap(); - - let mut sum = 0; - while let Some(row) = cursor.next().await.unwrap() { - sum += row.no; - } - - let dec_bytes = cursor.decoded_bytes(); - let dec_mbytes = dec_bytes as f64 / 1024.0 / 1024.0; - let recv_bytes = cursor.received_bytes(); - let recv_mbytes = recv_bytes as f64 / 1024.0 / 1024.0; - (sum, dec_mbytes, recv_mbytes) +impl_benchmark_row_no_access_type!(Data, number); + +async fn bench(compression: Compression, validation_mode: ValidationMode) { + let stats = do_select_bench::( + "SELECT number FROM system.numbers_mt LIMIT 500000000", + compression, + validation_mode, + ) + .await; + assert_eq!(stats.result, 124999999750000000); + print_results::(&stats, compression, validation_mode); } #[tokio::main] async fn main() { - println!("compress validation elapsed throughput received"); - bench("none", Compression::None, ValidationMode::First(1)).await; - bench("none", Compression::None, ValidationMode::Each).await; + print_header(None); + bench(Compression::None, ValidationMode::First(1)).await; + bench(Compression::None, ValidationMode::Each).await; #[cfg(feature = "lz4")] { - bench("lz4", Compression::Lz4, ValidationMode::First(1)).await; - bench("lz4", Compression::Lz4, ValidationMode::Each).await; + bench(Compression::Lz4, ValidationMode::First(1)).await; + bench(Compression::Lz4, ValidationMode::Each).await; } } diff --git a/benches/select_nyc_taxi_data.rs b/benches/select_nyc_taxi_data.rs index d3c449a9..1bfb6cf1 100644 --- a/benches/select_nyc_taxi_data.rs +++ b/benches/select_nyc_taxi_data.rs @@ -1,12 +1,16 @@ #![cfg(feature = "time")] +use crate::common_select::{ + do_select_bench, print_header, print_results, BenchmarkRow, WithAccessType, WithId, +}; use clickhouse::validation_mode::ValidationMode; -use clickhouse::{Client, Compression, Row}; -use criterion::black_box; +use clickhouse::{Compression, Row}; use serde::Deserialize; use serde_repr::Deserialize_repr; use time::OffsetDateTime; +mod common_select; + #[derive(Debug, Clone, Deserialize_repr)] #[repr(i8)] pub enum PaymentType { @@ -17,9 +21,10 @@ pub enum PaymentType { UNK = 5, } -#[derive(Debug, Clone, Row, Deserialize)] +/// Uses just `visit_seq` since the order of the fields matches the database schema. +#[derive(Row, Deserialize)] #[allow(dead_code)] -pub struct TripSmall { +struct TripSmallSeqAccess { trip_id: u32, #[serde(with = "clickhouse::serde::time::datetime")] pickup_datetime: OffsetDateTime, @@ -41,44 +46,53 @@ pub struct TripSmall { dropoff_ntaname: String, } -async fn bench(name: &str, compression: Compression, validation_mode: ValidationMode) { - let start = std::time::Instant::now(); - let (sum_trip_ids, dec_mbytes, rec_mbytes) = do_bench(compression, validation_mode).await; - assert_eq!(sum_trip_ids, 3630387815532582); - let elapsed = start.elapsed(); - let throughput = dec_mbytes / elapsed.as_secs_f64(); - println!("{name:>8} {validation_mode:>10} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB"); +/// Uses `visit_map` to deserialize instead of `visit_seq`, +/// since the fields definition is correct, but the order is wrong. +#[derive(Row, Deserialize)] +#[allow(dead_code)] +struct TripSmallMapAccess { + pickup_ntaname: String, + dropoff_ntaname: String, + trip_id: u32, + passenger_count: u8, + trip_distance: f32, + fare_amount: f32, + extra: f32, + tip_amount: f32, + tolls_amount: f32, + total_amount: f32, + payment_type: PaymentType, + #[serde(with = "clickhouse::serde::time::datetime")] + pickup_datetime: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime")] + dropoff_datetime: OffsetDateTime, + pickup_longitude: Option, + pickup_latitude: Option, + dropoff_longitude: Option, + dropoff_latitude: Option, } -async fn do_bench(compression: Compression, validation_mode: ValidationMode) -> (u64, f64, f64) { - let client = Client::default() - .with_compression(compression) - .with_validation_mode(validation_mode) - .with_url("http://localhost:8123"); - - let mut cursor = client - .query("SELECT * FROM nyc_taxi.trips_small ORDER BY trip_id DESC") - .fetch::() - .unwrap(); - - let mut sum = 0; - while let Some(row) = cursor.next().await.unwrap() { - sum += row.trip_id as u64; - black_box(&row); - } +impl_benchmark_row!(TripSmallSeqAccess, trip_id, "seq"); +impl_benchmark_row!(TripSmallMapAccess, trip_id, "map"); - let dec_bytes = cursor.decoded_bytes(); - let dec_mbytes = dec_bytes as f64 / 1024.0 / 1024.0; - let recv_bytes = cursor.received_bytes(); - let recv_mbytes = recv_bytes as f64 / 1024.0 / 1024.0; - (sum, dec_mbytes, recv_mbytes) +async fn bench<'a, T: BenchmarkRow<'a>>(compression: Compression, validation_mode: ValidationMode) { + let stats = do_select_bench::( + "SELECT * FROM nyc_taxi.trips_small ORDER BY trip_id DESC", + compression, + validation_mode, + ) + .await; + assert_eq!(stats.result, 3630387815532582); + print_results::(&stats, compression, validation_mode); } #[tokio::main] async fn main() { - println!("compress validation elapsed throughput received"); - bench("none", Compression::None, ValidationMode::First(1)).await; - bench("lz4", Compression::Lz4, ValidationMode::First(1)).await; - bench("none", Compression::None, ValidationMode::Each).await; - bench("lz4", Compression::Lz4, ValidationMode::Each).await; + print_header(Some(" access")); + bench::(Compression::None, ValidationMode::First(1)).await; + bench::(Compression::Lz4, ValidationMode::First(1)).await; + bench::(Compression::None, ValidationMode::Each).await; + bench::(Compression::Lz4, ValidationMode::Each).await; + bench::(Compression::None, ValidationMode::Each).await; + bench::(Compression::Lz4, ValidationMode::Each).await; } diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 2ef93988..982a08c2 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -61,7 +61,7 @@ impl RowCursor { } Err(TypesError::NotEnoughData(_)) => {} Err(err) => { - return Err(Error::ColumnsHeaderParserError(err.into())); + return Err(Error::InvalidColumnsHeader(err.into())); } } } diff --git a/src/error.rs b/src/error.rs index b47901e0..8b1c1dee 100644 --- a/src/error.rs +++ b/src/error.rs @@ -42,7 +42,7 @@ pub enum Error { #[error("timeout expired")] TimedOut, #[error("error while parsing columns header from the response: {0}")] - ColumnsHeaderParserError(#[source] BoxedError), + InvalidColumnsHeader(#[source] BoxedError), #[error("{0}")] Other(BoxedError), } @@ -51,7 +51,7 @@ assert_impl_all!(Error: StdError, Send, Sync); impl From for Error { fn from(err: clickhouse_types::error::TypesError) -> Self { - Self::ColumnsHeaderParserError(Box::new(err)) + Self::InvalidColumnsHeader(Box::new(err)) } } diff --git a/src/row.rs b/src/row.rs index 7418dacb..5dca119a 100644 --- a/src/row.rs +++ b/src/row.rs @@ -1,6 +1,6 @@ use crate::sql; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum RowType { Primitive, Struct, diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 7c3a50e9..65fd7b45 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -1,5 +1,6 @@ use crate::error::Result; use crate::struct_metadata::StructMetadata; +use crate::RowType; use clickhouse_types::data_types::{Column, DataTypeNode, DecimalType, EnumType}; use std::collections::HashMap; use std::fmt::Display; @@ -22,7 +23,6 @@ pub(crate) struct DataTypeValidator<'cursor> { } impl<'cursor> DataTypeValidator<'cursor> { - #[inline(always)] pub(crate) fn new(metadata: &'cursor StructMetadata) -> Self { Self { current_column_idx: 0, @@ -52,28 +52,52 @@ impl<'cursor> DataTypeValidator<'cursor> { .unwrap_or(("Struct".to_string(), &DataTypeNode::Bool)) } - #[inline(always)] fn panic_on_schema_mismatch<'de>( &'de self, data_type: &DataTypeNode, serde_type: &SerdeType, is_inner: bool, ) -> Result>> { - if is_inner { - let (full_name, full_data_type) = self.get_current_column_name_and_type(); - panic!( - "While processing column {} defined as {}: attempting to deserialize \ - nested ClickHouse type {} as {} which is not compatible", - full_name, full_data_type, data_type, serde_type - ) - } else { - panic!( - "While processing column {}: attempting to deserialize \ - ClickHouse type {} as {} which is not compatible", - self.get_current_column_name_and_type().0, - data_type, - serde_type - ) + match self.metadata.row_type { + RowType::Primitive => { + panic!( + "While processing row as a primitive: attempting to deserialize \ + ClickHouse type {} as {} which is not compatible", + data_type, serde_type + ) + } + RowType::Vec => { + panic!( + "While processing row as a vector: attempting to deserialize \ + ClickHouse type {} as {} which is not compatible", + data_type, serde_type + ) + } + RowType::Tuple => { + panic!( + "While processing row as a tuple: attempting to deserialize \ + ClickHouse type {} as {} which is not compatible", + data_type, serde_type + ) + } + RowType::Struct => { + if is_inner { + let (full_name, full_data_type) = self.get_current_column_name_and_type(); + panic!( + "While processing column {} defined as {}: attempting to deserialize \ + nested ClickHouse type {} as {} which is not compatible", + full_name, full_data_type, data_type, serde_type + ) + } else { + panic!( + "While processing column {}: attempting to deserialize \ + ClickHouse type {} as {} which is not compatible", + self.get_current_column_name_and_type().0, + data_type, + serde_type + ) + } + } } } } @@ -84,40 +108,85 @@ impl SchemaValidator for DataTypeValidator<'_> { &'_ mut self, serde_type: SerdeType, ) -> Result>> { - if self.current_column_idx == 0 && !self.metadata.is_struct() { - // this allows validating and deserializing tuples/vectors/primitives from fetch calls - Ok(Some(InnerDataTypeValidator { - root: self, - kind: if matches!(serde_type, SerdeType::Seq(_)) && self.metadata.columns.len() == 1 - { + match self.metadata.row_type { + // fetch::() for a "primitive row" type + RowType::Primitive => { + if self.current_column_idx == 0 && self.metadata.columns.len() == 1 { let data_type = &self.metadata.columns[0].data_type; - match data_type { - DataTypeNode::Array(inner_type) => { - InnerDataTypeValidatorKind::RootArray(inner_type) - } - _ => panic!( - "Expected Array type when validating root level sequence, but got {}", - self.metadata.columns[0].data_type - ), + validate_impl(self, data_type, &serde_type, false) + } else { + panic!( + "Primitive row is expected to be a single value, got columns: {:?}", + self.metadata.columns + ); + } + } + // fetch::<(i16, i32)>() for a "tuple row" type + RowType::Tuple => { + match serde_type { + SerdeType::Tuple(len) if len == self.metadata.columns.len() => { + Ok(Some(InnerDataTypeValidator { + root: self, + kind: InnerDataTypeValidatorKind::RootTuple(&self.metadata.columns, 0), + })) } + SerdeType::Tuple(len) => { + // TODO: theoretically, we can derive that from the Row macro, + // and check when creating StructMetadata + panic!( + "While processing tuple row: database schema has {} columns, \ + but the tuple definition has {} fields.", + self.metadata.columns.len(), + len + ) + } + _ => { + // should be unreachable + panic!( + "While processing tuple row: expected serde type Tuple(N), got {}", + serde_type + ); + } + } + } + // fetch::>() for a "vector row" type + RowType::Vec => { + let data_type = &self.metadata.columns[0].data_type; + let kind = match data_type { + DataTypeNode::Array(inner_type) => { + InnerDataTypeValidatorKind::RootArray(inner_type) + } + _ => panic!( + "Expected Array type when validating root level sequence, but got {}", + self.metadata.columns[0].data_type + ), + }; + Ok(Some(InnerDataTypeValidator { root: self, kind })) + } + // fetch::() for a "struct row" type, which is supposed to be the default flow + RowType::Struct => { + if self.current_column_idx < self.metadata.columns.len() { + let current_column = &self.metadata.columns[self.current_column_idx]; + self.current_column_idx += 1; + validate_impl(self, ¤t_column.data_type, &serde_type, false) } else { - InnerDataTypeValidatorKind::RootTuple(&self.metadata.columns, 0) - }, - })) - } else if self.current_column_idx < self.metadata.columns.len() { - let current_column = &self.metadata.columns[self.current_column_idx]; - self.current_column_idx += 1; - validate_impl(self, ¤t_column.data_type, &serde_type, false) - } else { - panic!( - "Struct {} has more fields than columns in the database schema", - self.metadata.struct_name - ) + panic!( + "Struct {} has more fields than columns in the database schema", + self.metadata.struct_name + ) + } + } } } + #[inline] fn is_field_order_wrong(&self) -> bool { - true + self.metadata.is_field_order_wrong() + } + + #[inline] + fn get_schema_index(&self, struct_idx: usize) -> usize { + self.metadata.get_schema_index(struct_idx) } #[cold] @@ -137,11 +206,6 @@ impl SchemaValidator for DataTypeValidator<'_> { fn set_next_variant_value(&mut self, _value: u8) { unreachable!() } - - #[inline] - fn get_schema_index(&self, struct_idx: usize) -> usize { - self.metadata.get_schema_index(struct_idx) - } } #[derive(Debug)] diff --git a/src/struct_metadata.rs b/src/struct_metadata.rs index 1075f91a..431c0eeb 100644 --- a/src/struct_metadata.rs +++ b/src/struct_metadata.rs @@ -47,61 +47,85 @@ pub struct StructMetadata { impl StructMetadata { // FIXME: perhaps it should not be public? But it is required for mocks/provide. pub fn new(columns: Vec) -> Self { - let struct_name = T::NAME; - let struct_fields = T::COLUMN_NAMES; - if columns.len() != struct_fields.len() { - panic!( - "While processing struct {}: database schema has {} columns, \ - but the struct definition has {} fields.\ - \n#### All struct fields:\n{}\n#### All schema columns:\n{}", - struct_name, - columns.len(), - struct_fields.len(), - join_panic_schema_hint(struct_fields), - join_panic_schema_hint(&columns), - ); - } - let mut mapping = Vec::with_capacity(struct_fields.len()); - let mut expected_index = 0; - let mut should_use_map = false; - for col in &columns { - if let Some(index) = struct_fields.iter().position(|field| col.name == *field) { - if index != expected_index { - should_use_map = true + let access_type = match T::TYPE { + RowType::Primitive => { + if columns.len() != 1 { + panic!( + "While processing a primitive row: \ + expected only 1 column in the database schema, \ + but got {} instead.\n#### All schema columns:\n{}", + columns.len(), + join_panic_schema_hint(&columns), + ); } - expected_index += 1; - mapping.push(index); - } else { - panic!( - "While processing struct {}: database schema has a column {} \ - that was not found in the struct definition.\ - \n#### All struct fields:\n{}\n#### All schema columns:\n{}", - struct_name, - col, - join_panic_schema_hint(struct_fields), - join_panic_schema_hint(&columns), - ); + AccessType::WithSeqAccess } - } + RowType::Tuple => AccessType::WithSeqAccess, + RowType::Vec => { + if columns.len() != 1 { + panic!( + "While processing a row defined as a vector: \ + expected only 1 column in the database schema, \ + but got {} instead.\n#### All schema columns:\n{}", + columns.len(), + join_panic_schema_hint(&columns), + ); + } + AccessType::WithSeqAccess + } + RowType::Struct => { + if columns.len() != T::COLUMN_NAMES.len() { + panic!( + "While processing struct {}: database schema has {} columns, \ + but the struct definition has {} fields.\ + \n#### All struct fields:\n{}\n#### All schema columns:\n{}", + T::NAME, + columns.len(), + T::COLUMN_NAMES.len(), + join_panic_schema_hint(T::COLUMN_NAMES), + join_panic_schema_hint(&columns), + ); + } + let mut mapping = Vec::with_capacity(T::COLUMN_NAMES.len()); + let mut expected_index = 0; + let mut should_use_map = false; + for col in &columns { + if let Some(index) = T::COLUMN_NAMES.iter().position(|field| col.name == *field) + { + if index != expected_index { + should_use_map = true + } + expected_index += 1; + mapping.push(index); + } else { + panic!( + "While processing struct {}: database schema has a column {} \ + that was not found in the struct definition.\ + \n#### All struct fields:\n{}\n#### All schema columns:\n{}", + T::NAME, + col, + join_panic_schema_hint(T::COLUMN_NAMES), + join_panic_schema_hint(&columns), + ); + } + } + if should_use_map { + AccessType::WithMapAccess(mapping) + } else { + AccessType::WithSeqAccess + } + } + }; Self { columns, - struct_name, - struct_fields, + access_type, row_type: T::TYPE, - access_type: if should_use_map { - AccessType::WithMapAccess(mapping) - } else { - AccessType::WithSeqAccess - }, + struct_name: T::NAME, + struct_fields: T::COLUMN_NAMES, } } - #[inline(always)] - pub(crate) fn is_struct(&self) -> bool { - matches!(self.row_type, RowType::Struct) - } - - #[inline(always)] + #[inline] pub(crate) fn get_schema_index(&self, struct_idx: usize) -> usize { match &self.access_type { AccessType::WithMapAccess(mapping) => { @@ -117,6 +141,11 @@ impl StructMetadata { AccessType::WithSeqAccess => struct_idx, // should be unreachable } } + + #[inline] + pub(crate) fn is_field_order_wrong(&self) -> bool { + matches!(self.access_type, AccessType::WithMapAccess(_)) + } } pub(crate) async fn get_struct_metadata( diff --git a/src/validation_mode.rs b/src/validation_mode.rs index 1755bd3f..a76d7d1e 100644 --- a/src/validation_mode.rs +++ b/src/validation_mode.rs @@ -1,5 +1,3 @@ -#[non_exhaustive] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] /// The preferred mode of validation for struct (de)serialization. /// It also affects which format is used by the client when sending queries. /// @@ -17,13 +15,13 @@ /// It is done to minimize the performance impact of the validation, /// while still providing reasonable safety guarantees by default. /// -/// # Safety -/// /// While it is expected that the default validation mode is sufficient for most use cases, /// in certain corner case scenarios there still can be schema mismatches after the first rows, /// e.g., when a field is `Nullable(T)`, and the first value is `NULL`. In that case, /// consider increasing the number of rows in [`ValidationMode::First`], /// or even using [`ValidationMode::Each`] instead. +#[non_exhaustive] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum ValidationMode { First(usize), Each, @@ -34,12 +32,3 @@ impl Default for ValidationMode { Self::First(1) } } - -impl std::fmt::Display for ValidationMode { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::First(n) => f.pad(&format!("FirstN({})", n)), - Self::Each => f.pad("Each"), - } - } -} diff --git a/tests/it/main.rs b/tests/it/main.rs index 154a4432..6bbe41b6 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -47,7 +47,11 @@ macro_rules! assert_panic_on_fetch { let async_panic = std::panic::AssertUnwindSafe(async { client.query($query).fetch_all::().await }); let result = async_panic.catch_unwind().await; - assert!(result.is_err()); + assert!( + result.is_err(), + "expected a panic, but got a result instead: {:?}", + result.unwrap() + ); let panic_msg = *result.unwrap_err().downcast::().unwrap(); for &msg in $msg_parts { assert!( diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 8c48a6ba..c1c7613e 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -103,6 +103,96 @@ async fn test_header_parsing() { ); } +#[tokio::test] +async fn test_fetch_primitive_row() { + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query("SELECT count() FROM (SELECT * FROM system.numbers LIMIT 3)") + .fetch_one::() + .await; + assert_eq!(result.unwrap(), 3); +} + +#[tokio::test] +async fn test_fetch_primitive_row_schema_mismatch() { + type Data = i32; // expected type is UInt64 + assert_panic_on_fetch!( + &["primitive", "UInt64", "i32"], + "SELECT count() FROM (SELECT * FROM system.numbers LIMIT 3)" + ); +} + +#[tokio::test] +async fn test_fetch_vector_row() { + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query("SELECT [1, 2, 3] :: Array(UInt32)") + .fetch_one::>() + .await; + assert_eq!(result.unwrap(), vec![1, 2, 3]); +} + +#[tokio::test] +async fn test_fetch_vector_row_schema_mismatch_nested_type() { + type Data = Vec; // expected type for Array(UInt32) is Vec + assert_panic_on_fetch!( + &["vector", "UInt32", "i128"], + "SELECT [1, 2, 3] :: Array(UInt32)" + ); +} + +#[tokio::test] +async fn test_fetch_tuple_row() { + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query("SELECT 42 :: UInt32 AS a, 'foo' :: String AS b") + .fetch_one::<(u32, String)>() + .await; + assert_eq!(result.unwrap(), (42, "foo".to_string())); +} + +#[tokio::test] +async fn test_fetch_tuple_row_schema_mismatch_first_element() { + type Data = (i128, String); // expected u32 instead of i128 + assert_panic_on_fetch!( + &["tuple", "UInt32", "i128"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b" + ); +} + +#[tokio::test] +async fn test_fetch_tuple_row_schema_mismatch_second_element() { + type Data = (u32, i64); // expected String instead of i64 + assert_panic_on_fetch!( + &["tuple", "String", "i64"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b" + ); +} + +#[tokio::test] +async fn test_fetch_tuple_row_schema_mismatch_missing_element() { + type Data = (u32, String); // expected to have the third element as i64 + assert_panic_on_fetch!( + &[ + "database schema has 3 columns", + "tuple definition has 2 fields" + ], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: Int64 AS c" + ); +} + +#[tokio::test] +async fn test_fetch_tuple_row_schema_mismatch_too_many_elements() { + type Data = (u32, String, i128); // i128 should not be there + assert_panic_on_fetch!( + &[ + "database schema has 2 columns", + "tuple definition has 3 fields" + ], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b" + ); +} + #[tokio::test] async fn test_basic_types() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] diff --git a/types/Cargo.toml b/types/Cargo.toml index 0f0ac2bd..b9576b54 100644 --- a/types/Cargo.toml +++ b/types/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "clickhouse-types" -version = "0.0.1" +version = "0.1.0" description = "Data types utils to use with Native and RowBinary(WithNamesAndTypes) formats in ClickHouse" authors = ["ClickHouse"] repository = "https://github.com/ClickHouse/clickhouse-rs" From a922d0d518b9eb6feb232c56f5441a0a498d6f54 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 9 Jun 2025 00:45:09 +0200 Subject: [PATCH 28/54] Add LZ4 feature flag --- benches/common_select.rs | 1 + benches/select_nyc_taxi_data.rs | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/benches/common_select.rs b/benches/common_select.rs index 0d653e14..54f183d1 100644 --- a/benches/common_select.rs +++ b/benches/common_select.rs @@ -72,6 +72,7 @@ pub(crate) fn print_results<'a, T: BenchmarkRow<'a>>( }; let compression = match compression { Compression::None => "none", + #[cfg(feature = "lz4")] Compression::Lz4 => "lz4", _ => panic!("Unexpected compression mode"), }; diff --git a/benches/select_nyc_taxi_data.rs b/benches/select_nyc_taxi_data.rs index 1bfb6cf1..b6e96baf 100644 --- a/benches/select_nyc_taxi_data.rs +++ b/benches/select_nyc_taxi_data.rs @@ -90,9 +90,12 @@ async fn bench<'a, T: BenchmarkRow<'a>>(compression: Compression, validation_mod async fn main() { print_header(Some(" access")); bench::(Compression::None, ValidationMode::First(1)).await; - bench::(Compression::Lz4, ValidationMode::First(1)).await; bench::(Compression::None, ValidationMode::Each).await; - bench::(Compression::Lz4, ValidationMode::Each).await; bench::(Compression::None, ValidationMode::Each).await; - bench::(Compression::Lz4, ValidationMode::Each).await; + #[cfg(feature = "lz4")] + { + bench::(Compression::Lz4, ValidationMode::First(1)).await; + bench::(Compression::Lz4, ValidationMode::Each).await; + bench::(Compression::Lz4, ValidationMode::Each).await; + } } From 90132cb785660968485999978c2f499a545a4170 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 9 Jun 2025 01:41:32 +0200 Subject: [PATCH 29/54] Support proper validation for `(Row, P1, P2, ...)` fetching --- derive/src/lib.rs | 3 +- examples/mock.rs | 2 +- src/cursors/row.rs | 16 ++-- src/lib.rs | 6 +- src/row.rs | 31 ++++--- src/{struct_metadata.rs => row_metadata.rs} | 69 ++++++++------- src/rowbinary/de.rs | 4 +- src/rowbinary/validation.rs | 65 ++++++-------- src/test/handlers.rs | 9 +- src/watch.rs | 6 +- tests/it/mock.rs | 2 +- tests/it/rbwnat.rs | 95 +++++++++++++++++++++ 12 files changed, 202 insertions(+), 106 deletions(-) rename src/{struct_metadata.rs => row_metadata.rs} (78%) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 3988941d..5b6ceb92 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -60,7 +60,8 @@ pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream { impl #impl_generics clickhouse::Row for #name #ty_generics #where_clause { const NAME: &'static str = stringify!(#name); const COLUMN_NAMES: &'static [&'static str] = #column_names; - const TYPE: clickhouse::RowType = clickhouse::RowType::Struct; + const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + const KIND: clickhouse::RowKind = clickhouse::RowKind::Struct; } }; diff --git a/examples/mock.rs b/examples/mock.rs index 57452179..ca961f32 100644 --- a/examples/mock.rs +++ b/examples/mock.rs @@ -57,7 +57,7 @@ async fn main() { assert!(recording.query().await.contains("CREATE TABLE")); let metadata = - clickhouse::StructMetadata::new::(vec![Column::new("no".to_string(), UInt32)]); + clickhouse::RowMetadata::new::(vec![Column::new("no".to_string(), UInt32)]); // How to test SELECT. mock.add(test::handlers::provide(&metadata, list.clone())); diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 982a08c2..f66269bf 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -1,4 +1,4 @@ -use crate::struct_metadata::StructMetadata; +use crate::row_metadata::RowMetadata; use crate::validation_mode::ValidationMode; use crate::{ bytes_ext::BytesExt, @@ -19,7 +19,7 @@ pub struct RowCursor { bytes: BytesExt, /// [`None`] until the first call to [`RowCursor::next()`], /// as [`RowCursor::new`] is not `async`, so it loads lazily. - struct_metadata: Option, + row_metadata: Option, rows_to_validate: u64, _marker: PhantomData, } @@ -30,7 +30,7 @@ impl RowCursor { _marker: PhantomData, raw: RawCursor::new(response), bytes: BytesExt::default(), - struct_metadata: None, + row_metadata: None, rows_to_validate: match validation_mode { ValidationMode::First(n) => n as u64, ValidationMode::Each => u64::MAX, @@ -50,7 +50,7 @@ impl RowCursor { match parse_rbwnat_columns_header(&mut slice) { Ok(columns) if !columns.is_empty() => { self.bytes.set_remaining(slice.len()); - self.struct_metadata = Some(StructMetadata::new::(columns)); + self.row_metadata = Some(RowMetadata::new::(columns)); return Ok(()); } Ok(_) => { @@ -67,7 +67,7 @@ impl RowCursor { } match self.raw.next().await? { Some(chunk) => self.bytes.extend(chunk), - None if self.struct_metadata.is_none() => { + None if self.row_metadata.is_none() => { return Err(Error::BadResponse( "Could not read columns header".to_string(), )); @@ -91,7 +91,7 @@ impl RowCursor { { loop { if self.bytes.remaining() > 0 { - if self.struct_metadata.is_none() { + if self.row_metadata.is_none() { self.read_columns().await?; if self.bytes.remaining() == 0 { continue; @@ -101,12 +101,12 @@ impl RowCursor { let (result, not_enough_data) = match self.rows_to_validate { 0 => rowbinary::deserialize_from::(&mut slice, None), u64::MAX => { - rowbinary::deserialize_from::(&mut slice, self.struct_metadata.as_ref()) + rowbinary::deserialize_from::(&mut slice, self.row_metadata.as_ref()) } _ => { let result = rowbinary::deserialize_from::( &mut slice, - self.struct_metadata.as_ref(), + self.row_metadata.as_ref(), ); self.rows_to_validate -= 1; result diff --git a/src/lib.rs b/src/lib.rs index 3767480d..d3760a89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,8 +6,8 @@ extern crate static_assertions; #[cfg(feature = "test-util")] -pub use self::struct_metadata::StructMetadata; -pub use self::{compression::Compression, row::Row, row::RowType}; +pub use self::row_metadata::RowMetadata; +pub use self::{compression::Compression, row::Row, row::RowKind}; use self::{error::Result, http_client::HttpClient, validation_mode::ValidationMode}; pub use clickhouse_derive::Row; use std::{collections::HashMap, fmt::Display, sync::Arc}; @@ -33,8 +33,8 @@ mod http_client; mod request_body; mod response; mod row; +mod row_metadata; mod rowbinary; -mod struct_metadata; #[cfg(feature = "inserter")] mod ticks; diff --git a/src/row.rs b/src/row.rs index 5dca119a..d591ebf9 100644 --- a/src/row.rs +++ b/src/row.rs @@ -1,7 +1,7 @@ use crate::sql; #[derive(Debug, Clone, PartialEq)] -pub enum RowType { +pub enum RowKind { Primitive, Struct, Tuple, @@ -11,7 +11,8 @@ pub enum RowType { pub trait Row { const NAME: &'static str; const COLUMN_NAMES: &'static [&'static str]; - const TYPE: RowType; + const COLUMN_COUNT: usize; + const KIND: RowKind; // TODO: count // TODO: different list for SELECT/INSERT (de/ser) @@ -33,18 +34,24 @@ impl_primitive_for![ bool, String, u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64, ]; +macro_rules! count_tokens { + () => { 0 }; + ($head:tt $($tail:tt)*) => { 1 + count_tokens!($($tail)*) }; +} + +/// Two forms are supported: +/// * (P1, P2, ...) +/// * (SomeRow, P1, P2, ...) +/// +/// The second one is useful for queries like +/// `SELECT ?fields, count() FROM ... GROUP BY ?fields`. macro_rules! impl_row_for_tuple { ($i:ident $($other:ident)+) => { - /// Two forms are supported: - /// * (P1, P2, ...) - /// * (SomeRow, P1, P2, ...) - /// - /// The second one is useful for queries like - /// `SELECT ?fields, count() FROM .. GROUP BY ?fields`. impl<$i: Row, $($other: Primitive),+> Row for ($i, $($other),+) { const NAME: &'static str = $i::NAME; const COLUMN_NAMES: &'static [&'static str] = $i::COLUMN_NAMES; - const TYPE: RowType = RowType::Tuple; + const COLUMN_COUNT: usize = $i::COLUMN_COUNT + count_tokens!($($other)*); + const KIND: RowKind = RowKind::Tuple; } impl_row_for_tuple!($($other)+); @@ -58,7 +65,8 @@ impl Primitive for () {} impl Row for P { const NAME: &'static str = stringify!(P); const COLUMN_NAMES: &'static [&'static str] = &[]; - const TYPE: RowType = RowType::Primitive; + const COLUMN_COUNT: usize = 1; + const KIND: RowKind = RowKind::Primitive; } impl_row_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8); @@ -66,7 +74,8 @@ impl_row_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8); impl Row for Vec { const NAME: &'static str = "Vec"; const COLUMN_NAMES: &'static [&'static str] = &[]; - const TYPE: RowType = RowType::Vec; + const COLUMN_COUNT: usize = 1; + const KIND: RowKind = RowKind::Vec; } /// Collects all field names in depth and joins them with comma. diff --git a/src/struct_metadata.rs b/src/row_metadata.rs similarity index 78% rename from src/struct_metadata.rs rename to src/row_metadata.rs index 431c0eeb..67047fc2 100644 --- a/src/struct_metadata.rs +++ b/src/row_metadata.rs @@ -4,7 +4,7 @@ #![allow(dead_code)] #![allow(unreachable_pub)] -use crate::row::RowType; +use crate::row::RowKind; use crate::sql::Identifier; use crate::Result; use crate::Row; @@ -14,10 +14,10 @@ use std::fmt::Display; use std::sync::Arc; use tokio::sync::{OnceCell, RwLock}; -/// Cache for [`StructMetadata`] to avoid allocating it for the same struct more than once +/// Cache for [`RowMetadata`] to avoid allocating it for the same struct more than once /// during the application lifecycle. Key: fully qualified table name (e.g. `database.table`). -type LockedStructMetadataCache = RwLock>>; -static STRUCT_METADATA_CACHE: OnceCell = OnceCell::const_new(); +type LockedRowMetadataCache = RwLock>>; +static ROW_METADATA_CACHE: OnceCell = OnceCell::const_new(); #[derive(Debug, PartialEq)] enum AccessType { @@ -25,16 +25,14 @@ enum AccessType { WithMapAccess(Vec), } -/// [`StructMetadata`] should be owned outside the (de)serializer, +/// [`RowMetadata`] should be owned outside the (de)serializer, /// as it is calculated only once per struct. It does not have lifetimes, /// so it does not introduce a breaking change to [`crate::cursors::RowCursor`]. -pub struct StructMetadata { +pub struct RowMetadata { /// See [`Row::NAME`] - pub(crate) struct_name: &'static str, - /// See [`Row::COLUMN_NAMES`] (currently unused) - pub(crate) struct_fields: &'static [&'static str], + pub(crate) name: &'static str, /// See [`Row::TYPE`] - pub(crate) row_type: RowType, + pub(crate) kind: RowKind, /// Database schema, or columns, are parsed before the first call to (de)serializer. pub(crate) columns: Vec, /// This determines whether we can just use [`crate::rowbinary::de::RowBinarySeqAccess`] @@ -44,11 +42,11 @@ pub struct StructMetadata { access_type: AccessType, } -impl StructMetadata { +impl RowMetadata { // FIXME: perhaps it should not be public? But it is required for mocks/provide. pub fn new(columns: Vec) -> Self { - let access_type = match T::TYPE { - RowType::Primitive => { + let access_type = match T::KIND { + RowKind::Primitive => { if columns.len() != 1 { panic!( "While processing a primitive row: \ @@ -58,10 +56,22 @@ impl StructMetadata { join_panic_schema_hint(&columns), ); } - AccessType::WithSeqAccess + AccessType::WithSeqAccess // ignored } - RowType::Tuple => AccessType::WithSeqAccess, - RowType::Vec => { + RowKind::Tuple => { + if T::COLUMN_COUNT != columns.len() { + panic!( + "While processing a tuple row: database schema has {} columns, \ + but the tuple definition has {} fields in total.\ + \n#### All schema columns:\n{}", + columns.len(), + T::COLUMN_COUNT, + join_panic_schema_hint(&columns), + ); + } + AccessType::WithSeqAccess // ignored + } + RowKind::Vec => { if columns.len() != 1 { panic!( "While processing a row defined as a vector: \ @@ -71,9 +81,9 @@ impl StructMetadata { join_panic_schema_hint(&columns), ); } - AccessType::WithSeqAccess + AccessType::WithSeqAccess // ignored } - RowType::Struct => { + RowKind::Struct => { if columns.len() != T::COLUMN_NAMES.len() { panic!( "While processing struct {}: database schema has {} columns, \ @@ -119,9 +129,8 @@ impl StructMetadata { Self { columns, access_type, - row_type: T::TYPE, - struct_name: T::NAME, - struct_fields: T::COLUMN_NAMES, + kind: T::KIND, + name: T::NAME, } } @@ -134,7 +143,7 @@ impl StructMetadata { } else { panic!( "Struct {} has more fields than columns in the database schema", - self.struct_name + self.name ) } } @@ -148,27 +157,27 @@ impl StructMetadata { } } -pub(crate) async fn get_struct_metadata( +pub(crate) async fn get_row_metadata( client: &crate::Client, table_name: &str, -) -> Result> { - let locked_cache = STRUCT_METADATA_CACHE +) -> Result> { + let locked_cache = ROW_METADATA_CACHE .get_or_init(|| async { RwLock::new(HashMap::new()) }) .await; let cache_guard = locked_cache.read().await; match cache_guard.get(table_name) { Some(metadata) => Ok(metadata.clone()), - None => cache_struct_metadata::(client, table_name, locked_cache).await, + None => cache_row_metadata::(client, table_name, locked_cache).await, } } /// Used internally to introspect and cache the table structure to allow validation /// of serialized rows before submitting the first [`insert::Insert::write`]. -async fn cache_struct_metadata( +async fn cache_row_metadata( client: &crate::Client, table_name: &str, - locked_cache: &LockedStructMetadataCache, -) -> Result> { + locked_cache: &LockedRowMetadataCache, +) -> Result> { let mut bytes_cursor = client .query("SELECT * FROM ? LIMIT 0") .bind(Identifier(table_name)) @@ -179,7 +188,7 @@ async fn cache_struct_metadata( } let columns = parse_rbwnat_columns_header(&mut buffer.as_slice())?; let mut cache = locked_cache.write().await; - let metadata = Arc::new(StructMetadata::new::(columns)); + let metadata = Arc::new(RowMetadata::new::(columns)); cache.insert(table_name.to_string(), metadata.clone()); Ok(metadata) } diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 9762b8ec..06b52a79 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -1,8 +1,8 @@ use crate::error::{Error, Result}; +use crate::row_metadata::RowMetadata; use crate::rowbinary::utils::{ensure_size, get_unsigned_leb128}; use crate::rowbinary::validation::SerdeType; use crate::rowbinary::validation::{DataTypeValidator, SchemaValidator}; -use crate::struct_metadata::StructMetadata; use bytes::Buf; use core::mem::size_of; use serde::de::MapAccess; @@ -26,7 +26,7 @@ use std::{convert::TryFrom, str}; /// After the header, the rows format is the same as `RowBinary`. pub(crate) fn deserialize_from<'data, 'cursor, T: Deserialize<'data>>( input: &mut &'data [u8], - metadata: Option<&'cursor StructMetadata>, + metadata: Option<&'cursor RowMetadata>, ) -> (Result, bool) { let result = if metadata.is_none() { let mut deserializer = RowBinaryDeserializer::new(input, ()); diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 65fd7b45..9ba47b51 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -1,6 +1,6 @@ use crate::error::Result; -use crate::struct_metadata::StructMetadata; -use crate::RowType; +use crate::row_metadata::RowMetadata; +use crate::RowKind; use clickhouse_types::data_types::{Column, DataTypeNode, DecimalType, EnumType}; use std::collections::HashMap; use std::fmt::Display; @@ -18,12 +18,12 @@ pub(crate) trait SchemaValidator: Sized { } pub(crate) struct DataTypeValidator<'cursor> { - metadata: &'cursor StructMetadata, + metadata: &'cursor RowMetadata, current_column_idx: usize, } impl<'cursor> DataTypeValidator<'cursor> { - pub(crate) fn new(metadata: &'cursor StructMetadata) -> Self { + pub(crate) fn new(metadata: &'cursor RowMetadata) -> Self { Self { current_column_idx: 0, metadata, @@ -42,12 +42,7 @@ impl<'cursor> DataTypeValidator<'cursor> { fn get_current_column_name_and_type(&self) -> (String, &DataTypeNode) { self.get_current_column() - .map(|c| { - ( - format!("{}.{}", self.metadata.struct_name, c.name), - &c.data_type, - ) - }) + .map(|c| (format!("{}.{}", self.metadata.name, c.name), &c.data_type)) // both should be defined at this point .unwrap_or(("Struct".to_string(), &DataTypeNode::Bool)) } @@ -58,29 +53,29 @@ impl<'cursor> DataTypeValidator<'cursor> { serde_type: &SerdeType, is_inner: bool, ) -> Result>> { - match self.metadata.row_type { - RowType::Primitive => { + match self.metadata.kind { + RowKind::Primitive => { panic!( "While processing row as a primitive: attempting to deserialize \ ClickHouse type {} as {} which is not compatible", data_type, serde_type ) } - RowType::Vec => { + RowKind::Vec => { panic!( "While processing row as a vector: attempting to deserialize \ ClickHouse type {} as {} which is not compatible", data_type, serde_type ) } - RowType::Tuple => { + RowKind::Tuple => { panic!( "While processing row as a tuple: attempting to deserialize \ ClickHouse type {} as {} which is not compatible", data_type, serde_type ) } - RowType::Struct => { + RowKind::Struct => { if is_inner { let (full_name, full_data_type) = self.get_current_column_name_and_type(); panic!( @@ -108,9 +103,9 @@ impl SchemaValidator for DataTypeValidator<'_> { &'_ mut self, serde_type: SerdeType, ) -> Result>> { - match self.metadata.row_type { - // fetch::() for a "primitive row" type - RowType::Primitive => { + match self.metadata.kind { + // `fetch::` for a "primitive row" type + RowKind::Primitive => { if self.current_column_idx == 0 && self.metadata.columns.len() == 1 { let data_type = &self.metadata.columns[0].data_type; validate_impl(self, data_type, &serde_type, false) @@ -121,25 +116,13 @@ impl SchemaValidator for DataTypeValidator<'_> { ); } } - // fetch::<(i16, i32)>() for a "tuple row" type - RowType::Tuple => { + // `fetch::<(i16, i32)>` or `fetch::<(T, u64)>` for a "tuple row" type + RowKind::Tuple => { match serde_type { - SerdeType::Tuple(len) if len == self.metadata.columns.len() => { - Ok(Some(InnerDataTypeValidator { - root: self, - kind: InnerDataTypeValidatorKind::RootTuple(&self.metadata.columns, 0), - })) - } - SerdeType::Tuple(len) => { - // TODO: theoretically, we can derive that from the Row macro, - // and check when creating StructMetadata - panic!( - "While processing tuple row: database schema has {} columns, \ - but the tuple definition has {} fields.", - self.metadata.columns.len(), - len - ) - } + SerdeType::Tuple(_) => Ok(Some(InnerDataTypeValidator { + root: self, + kind: InnerDataTypeValidatorKind::RootTuple(&self.metadata.columns, 0), + })), _ => { // should be unreachable panic!( @@ -149,8 +132,8 @@ impl SchemaValidator for DataTypeValidator<'_> { } } } - // fetch::>() for a "vector row" type - RowType::Vec => { + // `fetch::>` for a "vector row" type + RowKind::Vec => { let data_type = &self.metadata.columns[0].data_type; let kind = match data_type { DataTypeNode::Array(inner_type) => { @@ -163,8 +146,8 @@ impl SchemaValidator for DataTypeValidator<'_> { }; Ok(Some(InnerDataTypeValidator { root: self, kind })) } - // fetch::() for a "struct row" type, which is supposed to be the default flow - RowType::Struct => { + // `fetch::` for a "struct row" type, which is supposed to be the default flow + RowKind::Struct => { if self.current_column_idx < self.metadata.columns.len() { let current_column = &self.metadata.columns[self.current_column_idx]; self.current_column_idx += 1; @@ -172,7 +155,7 @@ impl SchemaValidator for DataTypeValidator<'_> { } else { panic!( "Struct {} has more fields than columns in the database schema", - self.metadata.struct_name + self.metadata.name ) } } diff --git a/src/test/handlers.rs b/src/test/handlers.rs index 4116b8ef..1e514110 100644 --- a/src/test/handlers.rs +++ b/src/test/handlers.rs @@ -8,8 +8,8 @@ use sealed::sealed; use serde::{Deserialize, Serialize}; use super::{Handler, HandlerFn}; +use crate::row_metadata::RowMetadata; use crate::rowbinary; -use crate::struct_metadata::StructMetadata; const BUFFER_INITIAL_CAPACITY: usize = 1024; @@ -42,15 +42,12 @@ pub fn failure(status: StatusCode) -> impl Handler { // === provide === #[track_caller] -pub fn provide( - struct_metadata: &StructMetadata, - rows: impl IntoIterator, -) -> impl Handler +pub fn provide(row_metadata: &RowMetadata, rows: impl IntoIterator) -> impl Handler where T: Serialize, { let mut buffer = Vec::with_capacity(BUFFER_INITIAL_CAPACITY); - put_rbwnat_columns_header(&struct_metadata.columns, &mut buffer) + put_rbwnat_columns_header(&row_metadata.columns, &mut buffer) .expect("failed to write columns header"); for row in rows { rowbinary::serialize_into(&mut buffer, &row).expect("failed to serialize"); diff --git a/src/watch.rs b/src/watch.rs index f47c8221..e3f887f1 100644 --- a/src/watch.rs +++ b/src/watch.rs @@ -155,7 +155,8 @@ struct EventPayload { impl Row for EventPayload { const NAME: &'static str = "EventPayload"; const COLUMN_NAMES: &'static [&'static str] = &[]; - const TYPE: crate::row::RowType = crate::row::RowType::Struct; + const COLUMN_COUNT: usize = 1; + const KIND: crate::row::RowKind = crate::row::RowKind::Struct; } impl EventCursor { @@ -182,7 +183,8 @@ struct RowPayload { impl Row for RowPayload { const NAME: &'static str = T::NAME; const COLUMN_NAMES: &'static [&'static str] = T::COLUMN_NAMES; - const TYPE: crate::row::RowType = T::TYPE; + const COLUMN_COUNT: usize = T::COLUMN_COUNT; + const KIND: crate::row::RowKind = T::KIND; } impl RowCursor { diff --git a/tests/it/mock.rs b/tests/it/mock.rs index 9ea31a73..2f3a5659 100644 --- a/tests/it/mock.rs +++ b/tests/it/mock.rs @@ -15,7 +15,7 @@ async fn test_provide() { Column::new("data".to_string(), DataTypeNode::String), ]; - let metadata = clickhouse::StructMetadata::new::(columns); + let metadata = clickhouse::RowMetadata::new::(columns); mock.add(test::handlers::provide(&metadata, &expected)); let actual = crate::fetch_rows::(&client, "doesn't matter").await; diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index c1c7613e..14cef6d7 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -193,6 +193,101 @@ async fn test_fetch_tuple_row_schema_mismatch_too_many_elements() { ); } +#[tokio::test] +async fn test_fetch_tuple_row_with_struct() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + b: String, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query("SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c") + .fetch_one::<(Data, u64)>() + .await; + assert_eq!( + result.unwrap(), + ( + Data { + a: 42, + b: "foo".to_string() + }, + 144 + ) + ); +} + +#[tokio::test] +async fn test_fetch_tuple_row_with_struct_schema_mismatch() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct _Data { + a: u64, // expected type is u32 + b: String, + } + type Data = (_Data, u64); + assert_panic_on_fetch!( + &["tuple", "UInt32", "u64"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c" + ); +} + +#[tokio::test] +async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_many_struct_fields() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct _Data { + a: u32, + b: String, + c: u64, // this field should not be here + } + type Data = (_Data, u64); + assert_panic_on_fetch!( + &["3 columns", "4 fields"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c" + ); +} + +#[tokio::test] +async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_many_fields() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct _Data { + a: u32, + b: String, + } + type Data = (_Data, u64, u64); // one too many u64 + assert_panic_on_fetch!( + &["3 columns", "4 fields"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c" + ); +} + +#[tokio::test] +async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_few_struct_fields() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct _Data { + a: u32, // the second field is missing now + } + type Data = (_Data, u64); + assert_panic_on_fetch!( + &["3 columns", "2 fields"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c" + ); +} + +#[tokio::test] +async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_few_fields() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct _Data { + a: u32, + b: String, + } + type Data = (_Data, u64); // another u64 is missing here + assert_panic_on_fetch!( + &["4 columns", "3 fields"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c, 255 :: UInt64 AS d" + ); +} + #[tokio::test] async fn test_basic_types() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] From 49af48c8e4e129e201224b4cdb81847c830714f0 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 9 Jun 2025 13:53:26 +0200 Subject: [PATCH 30/54] Use Cargo workspaces, update benchmarks and docs --- Cargo.toml | 17 +++++++++++--- README.md | 4 ++-- benches/README.md | 30 +++++++++++++++++++------ benches/{insert.rs => mocked_insert.rs} | 0 benches/{select.rs => mocked_select.rs} | 0 derive/Cargo.toml | 16 ++++++------- types/Cargo.toml | 17 ++++++-------- 7 files changed, 54 insertions(+), 30 deletions(-) rename benches/{insert.rs => mocked_insert.rs} (100%) rename benches/{select.rs => mocked_select.rs} (100%) diff --git a/Cargo.toml b/Cargo.toml index 5b71ded9..4ab86d98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,10 +9,21 @@ homepage = "https://clickhouse.com" license = "MIT OR Apache-2.0" readme = "README.md" edition = "2021" -# update `derive/Cargo.toml` and CI if changed +# update `workspace.package.rust-version` below and CI if changed # TODO: after bumping to v1.80, remove `--precise` in the "msrv" CI job rust-version = "1.73.0" +[workspace] +members = ["derive", "types"] + +[workspace.package] +authors = ["ClickHouse Contributors", "Paul Loyd "] +repository = "https://github.com/ClickHouse/clickhouse-rs" +homepage = "https://clickhouse.com" +edition = "2021" +license = "MIT OR Apache-2.0" +rust-version = "1.73.0" + [lints.rust] rust_2018_idioms = { level = "warn", priority = -1 } unreachable_pub = "warn" @@ -36,11 +47,11 @@ name = "select_numbers" harness = false [[bench]] -name = "insert" +name = "mocked_insert" harness = false [[bench]] -name = "select" +name = "mocked_select" harness = false [[example]] diff --git a/README.md b/README.md index edb25f42..65a2ab9b 100644 --- a/README.md +++ b/README.md @@ -282,7 +282,7 @@ How to choose between all these features? Here are some considerations: } ``` -* `Enum(8|16)` are supported using [serde_repr](https://docs.rs/serde_repr/latest/serde_repr/). +* `Enum(8|16)` are supported using [serde_repr](https://docs.rs/serde_repr/latest/serde_repr/). You could use `#[repr(i8)]` for `Enum8` and `#[repr(i16)]` for `Enum16`.
Example @@ -295,7 +295,7 @@ How to choose between all these features? Here are some considerations: } #[derive(Debug, Serialize_repr, Deserialize_repr)] - #[repr(u8)] + #[repr(i8)] enum Level { Debug = 1, Info = 2, diff --git a/benches/README.md b/benches/README.md index d39bc8ab..a57105ba 100644 --- a/benches/README.md +++ b/benches/README.md @@ -4,31 +4,41 @@ All cases are run with `cargo bench --bench `. ## With a mocked server -These benchmarks are run against a mocked server, which is a simple HTTP server that responds with a fixed response. This is useful to measure the overhead of the client itself: -* `select` checks throughput of `Client::query()`. -* `insert` checks throughput of `Client::insert()` and `Client::inserter()` (if the `inserter` features is enabled). +These benchmarks are run against a mocked server, which is a simple HTTP server that responds with a fixed response. +This is useful to measure the overhead of the client itself. + +### Scenarios + +* [mocked_select](mocked_select.rs) checks throughput of `Client::query()`. +* [mocked_insert](mocked_insert.rs) checks throughput of `Client::insert()` and `Client::inserter()` + (requires `inserter` feature). ### How to collect perf data The crate's code runs on the thread with the name `testee`: + ```bash cargo bench --bench & perf record -p `ps -AT | grep testee | awk '{print $2}'` --call-graph dwarf,65528 --freq 5000 -g -- sleep 5 perf script > perf.script ``` -Then upload the `perf.script` file to [Firefox Profiler](https://profiler.firefox.com). +Then upload the `perf.script` file to [Firefox Profiler]. ## With a running ClickHouse server These benchmarks are run against a real ClickHouse server, so it must be started: + ```bash docker compose up -d cargo bench --bench ``` -Cases: -* `select_numbers` measures time of running a big SELECT query to the `system.numbers_mt` table. +### Scenarios + +* [select_numbers.rs](select_numbers.rs) measures time of running a big SELECT query to the `system.numbers_mt` table. +* [select_nyc_taxi_data.rs](select_nyc_taxi_data.rs) measures time of running a fairly large SELECT query (approximately + 3 million records) to the `nyc_taxi_data` table using the [NYC taxi dataset]. ### How to collect perf data @@ -38,4 +48,10 @@ perf record -p `ps -AT | grep | awk '{print $2}'` --call-graph dwarf,6552 perf script > perf.script ``` -Then upload the `perf.script` file to [Firefox Profiler](https://profiler.firefox.com). +Then upload the `perf.script` file to [Firefox Profiler]. + + + +[Firefox Profiler]: https://profiler.firefox.com + +[NYC taxi dataset]: https://clickhouse.com/docs/getting-started/example-datasets/nyc-taxi#create-the-table-trips \ No newline at end of file diff --git a/benches/insert.rs b/benches/mocked_insert.rs similarity index 100% rename from benches/insert.rs rename to benches/mocked_insert.rs diff --git a/benches/select.rs b/benches/mocked_select.rs similarity index 100% rename from benches/select.rs rename to benches/mocked_select.rs diff --git a/derive/Cargo.toml b/derive/Cargo.toml index 56cb3220..4c17fc85 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -1,14 +1,14 @@ [package] name = "clickhouse-derive" -version = "0.2.0" description = "A macro for deriving clickhouse::Row" -authors = ["ClickHouse Contributors", "Paul Loyd "] -repository = "https://github.com/ClickHouse/clickhouse-rs" -homepage = "https://clickhouse.com" -edition = "2021" -license = "MIT OR Apache-2.0" -# update `Cargo.toml` and CI if changed -rust-version = "1.73.0" +version = "0.2.0" + +authors.workspace = true +repository.workspace = true +homepage.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true [lib] proc-macro = true diff --git a/types/Cargo.toml b/types/Cargo.toml index b9576b54..db89c637 100644 --- a/types/Cargo.toml +++ b/types/Cargo.toml @@ -1,17 +1,14 @@ [package] name = "clickhouse-types" -version = "0.1.0" description = "Data types utils to use with Native and RowBinary(WithNamesAndTypes) formats in ClickHouse" -authors = ["ClickHouse"] -repository = "https://github.com/ClickHouse/clickhouse-rs" -homepage = "https://clickhouse.com" -edition = "2021" -license = "MIT OR Apache-2.0" -# update `Cargo.toml` and CI if changed -rust-version = "1.73.0" +version = "0.1.0" -[lib] -#proc-macro = true +authors.workspace = true +repository.workspace = true +homepage.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true [dependencies] thiserror = "1.0.16" From 926213b483745f7c3b1ff17724795a8f1bbb1a91 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 9 Jun 2025 13:54:02 +0200 Subject: [PATCH 31/54] Fix examples schema mismatch --- examples/data_types_derive_simple.rs | 33 ++++++++++++++++++---------- examples/data_types_variant.rs | 2 +- examples/enums.rs | 8 +++---- 3 files changed, 27 insertions(+), 16 deletions(-) diff --git a/examples/data_types_derive_simple.rs b/examples/data_types_derive_simple.rs index d633a63e..5b464e4e 100644 --- a/examples/data_types_derive_simple.rs +++ b/examples/data_types_derive_simple.rs @@ -53,15 +53,26 @@ async fn main() -> Result<()> { decimal64_18_8 Decimal(18, 8), decimal128_38_12 Decimal(38, 12), -- decimal256_76_20 Decimal(76, 20), - date Date, - date32 Date32, - datetime DateTime, - datetime_tz DateTime('UTC'), - datetime64_0 DateTime64(0), - datetime64_3 DateTime64(3), - datetime64_6 DateTime64(6), - datetime64_9 DateTime64(9), - datetime64_9_tz DateTime64(9, 'UTC') + + time_date Date, + time_date32 Date32, + time_datetime DateTime, + time_datetime_tz DateTime('UTC'), + time_datetime64_0 DateTime64(0), + time_datetime64_3 DateTime64(3), + time_datetime64_6 DateTime64(6), + time_datetime64_9 DateTime64(9), + time_datetime64_9_tz DateTime64(9, 'UTC'), + + chrono_date Date, + chrono_date32 Date32, + chrono_datetime DateTime, + chrono_datetime_tz DateTime('UTC'), + chrono_datetime64_0 DateTime64(0), + chrono_datetime64_3 DateTime64(3), + chrono_datetime64_6 DateTime64(6), + chrono_datetime64_9 DateTime64(9), + chrono_datetime64_9_tz DateTime64(9, 'UTC'), ) ENGINE MergeTree ORDER BY (); ", ) @@ -166,7 +177,7 @@ type Decimal128 = FixedPoint; // Decimal(38, 12) = Decimal128(12) #[derive(Clone, Debug, PartialEq)] #[derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] -#[repr(u8)] +#[repr(i8)] pub enum Enum8 { Foo = 1, Bar = 2, @@ -174,7 +185,7 @@ pub enum Enum8 { #[derive(Clone, Debug, PartialEq)] #[derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] -#[repr(u16)] +#[repr(i16)] pub enum Enum16 { Qaz = 42, Qux = 255, diff --git a/examples/data_types_variant.rs b/examples/data_types_variant.rs index e575464b..35aa7568 100644 --- a/examples/data_types_variant.rs +++ b/examples/data_types_variant.rs @@ -140,7 +140,7 @@ fn get_rows() -> Vec { // This enum represents Variant(Array(UInt16), Bool, Date, FixedString(6), Float32, Float64, Int128, Int16, Int32, Int64, Int8, String, UInt128, UInt16, UInt32, UInt64, UInt8) #[derive(Debug, PartialEq, Serialize, Deserialize)] enum MyRowVariant { - Array(Vec), + Array(Vec), Boolean(bool), // attributes should work in this case, too #[serde(with = "clickhouse::serde::time::date")] diff --git a/examples/enums.rs b/examples/enums.rs index 851ca3fb..d20e5dc6 100644 --- a/examples/enums.rs +++ b/examples/enums.rs @@ -35,14 +35,14 @@ async fn main() -> Result<()> { #[derive(Debug, Serialize, Deserialize, Row)] struct Event { - timestamp: u64, + timestamp: i64, message: String, level: Level, } // How to define enums that map to `Enum8`/`Enum16`. #[derive(Debug, Serialize_repr, Deserialize_repr)] - #[repr(u8)] + #[repr(i8)] enum Level { Debug = 1, Info = 2, @@ -69,9 +69,9 @@ async fn main() -> Result<()> { Ok(()) } -fn now() -> u64 { +fn now() -> i64 { UNIX_EPOCH .elapsed() .expect("invalid system time") - .as_nanos() as u64 + .as_nanos() as i64 } From da08827969d103666a786bb16aaf0031d210f6fa Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 9 Jun 2025 15:19:35 +0200 Subject: [PATCH 32/54] Bring back `Vec<(K, V)>` for maps, more tests, fix clippy --- derive/src/lib.rs | 2 +- src/rowbinary/validation.rs | 165 +++++++++++++++---------------- tests/it/rbwnat.rs | 190 ++++++++++++++++++++++++++++++++++++ types/src/data_types.rs | 155 ++++++++++++++--------------- types/src/lib.rs | 2 +- 5 files changed, 348 insertions(+), 166 deletions(-) diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 5b6ceb92..5e539250 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -50,7 +50,7 @@ pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream { }; // TODO: do something more clever? - let _ = cx.check().expect("derive context error"); + cx.check().expect("derive context error"); let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 9ba47b51..170b33ca 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -191,17 +191,26 @@ impl SchemaValidator for DataTypeValidator<'_> { } } +/// Having a ClickHouse `Map` defined as a `HashMap` in Rust, Serde will call: +/// - `deserialize_map` for `Vec<(K, V)>` +/// - `deserialize_` suitable for `K` +/// - `deserialize_` suitable for `V` #[derive(Debug)] pub(crate) enum MapValidatorState { Key, Value, - Validated, } +/// Having a ClickHouse `Map` defined as `Vec<(K, V)>` in Rust, Serde will call: +/// - `deserialize_seq` for `Vec<(K, V)>` +/// - `deserialize_tuple` for `(K, V)` +/// - `deserialize_` suitable for `K` +/// - `deserialize_` suitable for `V` #[derive(Debug)] -pub(crate) enum ArrayValidatorState { - Pending, - Validated, +pub(crate) enum MapAsSequenceValidatorState { + Tuple, + Key, + Value, } pub(crate) struct InnerDataTypeValidator<'de, 'cursor> { @@ -211,15 +220,13 @@ pub(crate) struct InnerDataTypeValidator<'de, 'cursor> { #[derive(Debug)] pub(crate) enum InnerDataTypeValidatorKind<'cursor> { - Array(&'cursor DataTypeNode, ArrayValidatorState), + Array(&'cursor DataTypeNode), FixedString(usize), - Map( - &'cursor DataTypeNode, - &'cursor DataTypeNode, - MapValidatorState, - ), + Map(&'cursor [Box; 2], MapValidatorState), + /// Allows supporting ClickHouse `Map` defined as `Vec<(K, V)>` in Rust + MapAsSequence(&'cursor [Box; 2], MapAsSequenceValidatorState), Tuple(&'cursor [DataTypeNode]), - /// This is a hack to support deserializing tuples/vectors (and not structs) from fetch calls + /// This is a hack to support deserializing tuples/arrays (and not structs) from fetch calls RootTuple(&'cursor [Column], usize), RootArray(&'cursor DataTypeNode), Enum(&'cursor HashMap), @@ -242,29 +249,41 @@ impl<'de, 'cursor> SchemaValidator for Option Ok(None), Some(inner) => match &mut inner.kind { - InnerDataTypeValidatorKind::Map(key_type, value_type, state) => match state { + InnerDataTypeValidatorKind::Map(kv, state) => match state { MapValidatorState::Key => { - let result = validate_impl(inner.root, key_type, &serde_type, true); + let result = validate_impl(inner.root, &kv[0], &serde_type, true); *state = MapValidatorState::Value; result } MapValidatorState::Value => { - let result = validate_impl(inner.root, value_type, &serde_type, true); - *state = MapValidatorState::Validated; + let result = validate_impl(inner.root, &kv[1], &serde_type, true); + *state = MapValidatorState::Key; result } - MapValidatorState::Validated => Ok(None), }, - InnerDataTypeValidatorKind::Array(inner_type, state) => match state { - ArrayValidatorState::Pending => { - let result = validate_impl(inner.root, inner_type, &serde_type, true); - *state = ArrayValidatorState::Validated; - result + InnerDataTypeValidatorKind::MapAsSequence(kv, state) => { + match state { + // the first state is simply skipped, as the same validator + // will be called again for the Key and then the Value types + MapAsSequenceValidatorState::Tuple => { + *state = MapAsSequenceValidatorState::Key; + Ok(self.take()) + } + MapAsSequenceValidatorState::Key => { + let result = validate_impl(inner.root, &kv[0], &serde_type, true); + *state = MapAsSequenceValidatorState::Value; + result + } + MapAsSequenceValidatorState::Value => { + let result = validate_impl(inner.root, &kv[1], &serde_type, true); + *state = MapAsSequenceValidatorState::Tuple; + result + } } - // TODO: perhaps we can allow to validate the inner type more than once - // avoiding e.g. issues with Array(Nullable(T)) when the first element in NULL - ArrayValidatorState::Validated => Ok(None), - }, + } + InnerDataTypeValidatorKind::Array(inner_type) => { + validate_impl(inner.root, inner_type, &serde_type, true) + } InnerDataTypeValidatorKind::Nullable(inner_type) => { validate_impl(inner.root, inner_type, &serde_type, true) } @@ -315,17 +334,19 @@ impl<'de, 'cursor> SchemaValidator for Option { - todo!() // TODO - check value correctness in the hashmap + unreachable!() } }, } @@ -410,6 +431,9 @@ impl Drop for InnerDataTypeValidator<'_, '_> { } } +// TODO: is there a way to eliminate multiple branches with similar patterns? +// static/const dispatch? +// separate smaller inline functions? #[inline] fn validate_impl<'de, 'cursor>( root: &'de DataTypeValidator<'cursor>, @@ -418,9 +442,6 @@ fn validate_impl<'de, 'cursor>( is_inner: bool, ) -> Result>> { let data_type = column_data_type.remove_low_cardinality(); - // TODO: is there a way to eliminate multiple branches with similar patterns? - // static/const dispatch? - // separate smaller inline functions? match serde_type { SerdeType::Bool if data_type == &DataTypeNode::Bool || data_type == &DataTypeNode::UInt8 => @@ -494,17 +515,10 @@ fn validate_impl<'de, 'cursor>( { Ok(None) } - // TODO: find use cases where this is called instead of `deserialize_tuple` - // SerdeType::Bytes | SerdeType::ByteBuf => { - // if let DataTypeNode::FixedString(n) = data_type { - // Ok(Some(InnerDataTypeValidator::FixedString(*n))) - // } else { - // panic!( - // "Expected FixedString(N) for {} call, but got {}", - // serde_type, data_type - // ) - // } - // } + // allows to work with BLOB strings as well + SerdeType::Bytes(_) | SerdeType::ByteBuf(_) if data_type == &DataTypeNode::String => { + Ok(None) + } SerdeType::Option => { if let DataTypeNode::Nullable(inner_type) = data_type { Ok(Some(InnerDataTypeValidator { @@ -518,42 +532,35 @@ fn validate_impl<'de, 'cursor>( SerdeType::Seq(_) => match data_type { DataTypeNode::Array(inner_type) => Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Array(inner_type, ArrayValidatorState::Pending), + kind: InnerDataTypeValidatorKind::Array(inner_type), })), - DataTypeNode::Ring => Ok(Some(InnerDataTypeValidator { + // A map can be defined as `Vec<(K, V)>` in the struct + DataTypeNode::Map(kv) => Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Array( - &DataTypeNode::Point, - ArrayValidatorState::Pending, + kind: InnerDataTypeValidatorKind::MapAsSequence( + kv, + MapAsSequenceValidatorState::Tuple, ), })), + DataTypeNode::Ring => Ok(Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Point), + })), DataTypeNode::Polygon => Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Array( - &DataTypeNode::Ring, - ArrayValidatorState::Pending, - ), + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Ring), })), DataTypeNode::MultiPolygon => Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Array( - &DataTypeNode::Polygon, - ArrayValidatorState::Pending, - ), + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Polygon), })), DataTypeNode::LineString => Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Array( - &DataTypeNode::Point, - ArrayValidatorState::Pending, - ), + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Point), })), DataTypeNode::MultiLineString => Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Array( - &DataTypeNode::LineString, - ArrayValidatorState::Pending, - ), + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::LineString), })), _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), }, @@ -579,40 +586,27 @@ fn validate_impl<'de, 'cursor>( })), DataTypeNode::Array(inner_type) => Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Array(inner_type, ArrayValidatorState::Pending), + kind: InnerDataTypeValidatorKind::Array(inner_type), })), DataTypeNode::IPv6 => Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Array( - &DataTypeNode::UInt8, - ArrayValidatorState::Pending, - ), + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::UInt8), })), DataTypeNode::UUID => Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Tuple(&[ - DataTypeNode::UInt64, - DataTypeNode::UInt64, - ]), + kind: InnerDataTypeValidatorKind::Tuple(UUID_TUPLE_ELEMENTS), })), DataTypeNode::Point => Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Tuple(&[ - DataTypeNode::Float64, - DataTypeNode::Float64, - ]), + kind: InnerDataTypeValidatorKind::Tuple(POINT_TUPLE_ELEMENTS), })), _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), }, SerdeType::Map(_) => { - if let DataTypeNode::Map(key_type, value_type) = data_type { + if let DataTypeNode::Map(kv) = data_type { Ok(Some(InnerDataTypeValidator { root, - kind: InnerDataTypeValidatorKind::Map( - key_type, - value_type, - MapValidatorState::Key, - ), + kind: InnerDataTypeValidatorKind::Map(kv, MapValidatorState::Key), })) } else { panic!( @@ -747,3 +741,6 @@ impl Display for SerdeType { } } } + +const UUID_TUPLE_ELEMENTS: &[DataTypeNode; 2] = &[DataTypeNode::UInt64, DataTypeNode::UInt64]; +const POINT_TUPLE_ELEMENTS: &[DataTypeNode; 2] = &[DataTypeNode::Float64, DataTypeNode::Float64]; diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 14cef6d7..203d8ee6 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -350,6 +350,96 @@ async fn test_basic_types() { ); } +#[tokio::test] +async fn test_borrowed_data() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data<'a> { + str: &'a str, + array: Vec<&'a str>, + tuple: (&'a str, &'a str), + str_opt: Option<&'a str>, + vec_map_str: Vec<(&'a str, &'a str)>, + vec_map_f32: Vec<(&'a str, f32)>, + vec_map_nested: Vec<(&'a str, Vec<(&'a str, &'a str)>)>, + hash_map_str: HashMap<&'a str, &'a str>, + hash_map_f32: HashMap<&'a str, f32>, + hash_map_nested: HashMap<&'a str, HashMap<&'a str, &'a str>>, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let mut cursor = client + .query( + " + SELECT + 'a' :: String AS str, + ['b', 'c'] :: Array(String) AS array, + ('d', 'e') :: Tuple(String, String) AS tuple, + NULL :: Nullable(String) AS str_opt, + map('key1', 'value1', 'key2', 'value2') :: Map(String, String) AS hash_map_str, + map('key3', 100, 'key4', 200) :: Map(String, Float32) AS hash_map_f32, + map('n1', hash_map_str) :: Map(String, Map(String, String)) AS hash_map_nested, + hash_map_str AS vec_map_str, + hash_map_f32 AS vec_map_f32, + hash_map_nested AS vec_map_nested + UNION ALL + SELECT + 'f' :: String AS str, + ['g', 'h'] :: Array(String) AS array, + ('i', 'j') :: Tuple(String, String) AS tuple, + 'k' :: Nullable(String) AS str_opt, + map('key4', 'value4', 'key5', 'value5') :: Map(String, String) AS hash_map_str, + map('key6', 300, 'key7', 400) :: Map(String, Float32) AS hash_map_f32, + map('n2', hash_map_str) :: Map(String, Map(String, String)) AS hash_map_nested, + hash_map_str AS vec_map_str, + hash_map_f32 AS vec_map_f32, + hash_map_nested AS vec_map_nested + ", + ) + .fetch::>() + .unwrap(); + + let mut result = Vec::new(); + while let Some(row) = cursor.next().await.unwrap() { + result.push(row); + } + + assert_eq!( + result, + vec![ + Data { + str: "a", + array: vec!["b", "c"], + tuple: ("d", "e"), + str_opt: None, + vec_map_str: vec![("key1", "value1"), ("key2", "value2")], + vec_map_f32: vec![("key3", 100.0), ("key4", 200.0)], + vec_map_nested: vec![("n1", vec![("key1", "value1"), ("key2", "value2")])], + hash_map_str: HashMap::from([("key1", "value1"), ("key2", "value2"),]), + hash_map_f32: HashMap::from([("key3", 100.0), ("key4", 200.0),]), + hash_map_nested: HashMap::from([( + "n1", + HashMap::from([("key1", "value1"), ("key2", "value2"),]), + )]), + }, + Data { + str: "f", + array: vec!["g", "h"], + tuple: ("i", "j"), + str_opt: Some("k"), + vec_map_str: vec![("key4", "value4"), ("key5", "value5")], + vec_map_f32: vec![("key6", 300.0), ("key7", 400.0)], + vec_map_nested: vec![("n2", vec![("key4", "value4"), ("key5", "value5")])], + hash_map_str: HashMap::from([("key4", "value4"), ("key5", "value5"),]), + hash_map_f32: HashMap::from([("key6", 300.0), ("key7", 400.0),]), + hash_map_nested: HashMap::from([( + "n2", + HashMap::from([("key4", "value4"), ("key5", "value5"),]), + )]), + }, + ] + ); +} + #[tokio::test] async fn test_several_simple_rows() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] @@ -403,6 +493,28 @@ async fn test_many_numbers() { assert_eq!(sum, (0..2000).sum::()); } +#[tokio::test] +async fn test_blob_string_with_serde_bytes() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + #[serde(with = "serde_bytes")] + blob: Vec, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query("SELECT 'foo' :: String AS blob") + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + blob: "foo".as_bytes().to_vec(), + } + ); +} + #[tokio::test] async fn test_arrays() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] @@ -494,6 +606,84 @@ async fn test_maps() { } ); } + +#[tokio::test] +async fn test_map_as_vec_of_tuples() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + m1: Vec<(i128, String)>, + m2: Vec<(u16, Vec<(String, i32)>)>, + } + + let client = get_client().with_validation_mode(ValidationMode::Each); + let result = client + .query( + " + SELECT + map(100, 'value1', 200, 'value2') :: Map(Int128, String) AS m1, + map(42, map('foo', 100, 'bar', 200), + 144, map('qaz', 300, 'qux', 400)) :: Map(UInt16, Map(String, Int32)) AS m2 + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + m1: vec![(100, "value1".to_string()), (200, "value2".to_string()),], + m2: vec![ + ( + 42, + vec![("foo".to_string(), 100), ("bar".to_string(), 200)] + .into_iter() + .collect() + ), + ( + 144, + vec![("qaz".to_string(), 300), ("qux".to_string(), 400)] + .into_iter() + .collect() + ) + ], + } + ) +} + +#[tokio::test] +async fn test_map_as_vec_of_tuples_schema_mismatch() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + m: Vec<(u16, Vec<(String, i32)>)>, + } + + assert_panic_on_fetch!( + &["Data.m", "Map(Int64, String)", "Int64", "u16"], + "SELECT map(100, 'value1', 200, 'value2') :: Map(Int64, String) AS m" + ); +} + +#[tokio::test] +async fn test_map_as_vec_of_tuples_schema_mismatch_nested() { + type Inner = Vec<(i32, i64)>; // the value should be i128 instead of i64 + + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + m: Vec<(u16, Vec<(String, Inner)>)>, + } + + assert_panic_on_fetch!( + &[ + "Data.m", + "Map(UInt16, Map(String, Map(Int32, Int128)))", + "Int128", + "i64" + ], + "SELECT map(42, map('foo', map(144, 255))) + :: Map(UInt16, Map(String, Map(Int32, Int128))) AS m" + ); +} + #[tokio::test] async fn test_enum() { #[derive(Debug, PartialEq, Serialize_repr, Deserialize_repr)] diff --git a/types/src/data_types.rs b/types/src/data_types.rs index b0f939e5..c412807b 100644 --- a/types/src/data_types.rs +++ b/types/src/data_types.rs @@ -61,9 +61,11 @@ pub enum DataTypeNode { Array(Box), Tuple(Vec), - Map(Box, Box), Enum(EnumType, HashMap), + // key-value pair is defined as an array, so we can also use it as a slice + Map([Box; 2]), + AggregateFunction(String, Vec), Variant(Vec), @@ -142,9 +144,9 @@ impl DataTypeNode { } } -impl Into for DataTypeNode { - fn into(self) -> String { - self.to_string() +impl From for String { + fn from(value: DataTypeNode) -> Self { + value.to_string() } } @@ -181,17 +183,17 @@ impl Display for DataTypeNode { IPv4 => "IPv4".to_string(), IPv6 => "IPv6".to_string(), Bool => "Bool".to_string(), - Nullable(inner) => format!("Nullable({})", inner.to_string()), - Array(inner) => format!("Array({})", inner.to_string()), + Nullable(inner) => format!("Nullable({})", inner), + Array(inner) => format!("Array({})", inner), Tuple(elements) => { let elements_str = data_types_to_string(elements); format!("Tuple({})", elements_str) } - Map(key, value) => { - format!("Map({}, {})", key.to_string(), value.to_string()) + Map([key, value]) => { + format!("Map({}, {})", key, value) } LowCardinality(inner) => { - format!("LowCardinality({})", inner.to_string()) + format!("LowCardinality({})", inner) } Enum(enum_type, values) => { let mut values_vec = values.iter().collect::>(); @@ -402,7 +404,7 @@ fn parse_datetime(input: &str) -> Result { return Ok(DataTypeNode::DateTime(None)); } if input.len() >= 12 { - let timezone = (&input[10..input.len() - 2]).to_string(); + let timezone = input[10..input.len() - 2].to_string(); return Ok(DataTypeNode::DateTime(Some(timezone))); } Err(TypesError::TypeParsingError(format!( @@ -413,7 +415,7 @@ fn parse_datetime(input: &str) -> Result { fn parse_decimal(input: &str) -> Result { if input.len() >= 10 { - let precision_and_scale_str = (&input[8..input.len() - 1]).split(", ").collect::>(); + let precision_and_scale_str = input[8..input.len() - 1].split(", ").collect::>(); if precision_and_scale_str.len() != 2 { return Err(TypesError::TypeParsingError(format!( "Invalid Decimal format, expected Decimal(P, S), got {}", @@ -455,14 +457,14 @@ fn parse_decimal(input: &str) -> Result { fn parse_datetime64(input: &str) -> Result { if input.len() >= 13 { - let mut chars = (&input[11..input.len() - 1]).chars(); + let mut chars = input[11..input.len() - 1].chars(); let precision_char = chars.next().ok_or(TypesError::TypeParsingError(format!( "Invalid DateTime64 precision, expected a positive number. Input: {}", input )))?; let precision = DateTimePrecision::new(precision_char)?; let maybe_tz = match chars.as_str() { - str if str.len() > 2 => Some((&str[3..str.len() - 1]).to_string()), + str if str.len() > 2 => Some(str[3..str.len() - 1].to_string()), _ => None, }; return Ok(DataTypeNode::DateTime64(precision, maybe_tz)); @@ -507,10 +509,10 @@ fn parse_map(input: &str) -> Result { input ))); } - return Ok(DataTypeNode::Map( + return Ok(DataTypeNode::Map([ Box::new(inner_types[0].clone()), Box::new(inner_types[1].clone()), - )); + ])); } Err(TypesError::TypeParsingError(format!( "Invalid Map format, expected Map(KeyType, ValueType), got {}", @@ -572,34 +574,29 @@ fn parse_inner_types(input: &str) -> Result, TypesError> { char_escaped = true; } else if input_bytes[i] == b'\'' { quote_open = !quote_open; // unescaped quote - } else { - if !quote_open { - if input_bytes[i] == b'(' { - open_parens += 1; - } else if input_bytes[i] == b')' { - open_parens -= 1; - } else if input_bytes[i] == b',' { - if open_parens == 0 { - let data_type_str = - String::from_utf8(input_bytes[last_element_index..i].to_vec()) - .map_err(|_| { - TypesError::TypeParsingError(format!( - "Invalid UTF-8 sequence in input for the inner data type: {}", - &input[last_element_index..] - )) - })?; - let data_type = DataTypeNode::new(&data_type_str)?; - inner_types.push(data_type); - // Skip ', ' (comma and space) - if i + 2 <= input_bytes.len() && input_bytes[i + 1] == b' ' { - i += 2; - } else { - i += 1; - } - last_element_index = i; - continue; // Skip the normal increment at the end of the loop - } + } else if !quote_open { + if input_bytes[i] == b'(' { + open_parens += 1; + } else if input_bytes[i] == b')' { + open_parens -= 1; + } else if input_bytes[i] == b',' && open_parens == 0 { + let data_type_str = String::from_utf8(input_bytes[last_element_index..i].to_vec()) + .map_err(|_| { + TypesError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the inner data type: {}", + &input[last_element_index..] + )) + })?; + let data_type = DataTypeNode::new(&data_type_str)?; + inner_types.push(data_type); + // Skip ', ' (comma and space) + if i + 2 <= input_bytes.len() && input_bytes[i + 1] == b' ' { + i += 2; + } else { + i += 1; } + last_element_index = i; + continue; // Skip the normal increment at the end of the loop } } i += 1; @@ -652,31 +649,29 @@ fn parse_enum_values_map(input: &str) -> Result, TypesError if parsing_name { if char_escaped { char_escaped = false; - } else { - if input_bytes[i] == b'\\' { - char_escaped = true; - } else if input_bytes[i] == b'\'' { - // non-escaped closing tick - push the name - let name_bytes = &input_bytes[start_index..i]; - let name = String::from_utf8(name_bytes.to_vec()).map_err(|_| { - TypesError::TypeParsingError(format!( - "Invalid UTF-8 sequence in input for the enum name: {}", - &input[start_index..i] - )) - })?; - names.push(name); - - // Skip ` = ` and the first digit, as it will always have at least one - if i + 4 >= input_bytes.len() { - return Err(TypesError::TypeParsingError(format!( - "Invalid Enum format - expected ` = ` after name, input: {}", - input, - ))); - } - i += 4; - start_index = i; - parsing_name = false; + } else if input_bytes[i] == b'\\' { + char_escaped = true; + } else if input_bytes[i] == b'\'' { + // non-escaped closing tick - push the name + let name_bytes = &input_bytes[start_index..i]; + let name = String::from_utf8(name_bytes.to_vec()).map_err(|_| { + TypesError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the enum name: {}", + &input[start_index..i] + )) + })?; + names.push(name); + + // Skip ` = ` and the first digit, as it will always have at least one + if i + 4 >= input_bytes.len() { + return Err(TypesError::TypeParsingError(format!( + "Invalid Enum format - expected ` = ` after name, input: {}", + input, + ))); } + i += 4; + start_index = i; + parsing_name = false; } } // Parsing the index, skipping next iterations until the first non-digit one @@ -968,29 +963,29 @@ mod tests { fn test_data_type_new_map() { assert_eq!( DataTypeNode::new("Map(UInt8, String)").unwrap(), - DataTypeNode::Map( + DataTypeNode::Map([ Box::new(DataTypeNode::UInt8), Box::new(DataTypeNode::String) - ) + ]) ); assert_eq!( DataTypeNode::new("Map(String, Int32)").unwrap(), - DataTypeNode::Map( + DataTypeNode::Map([ Box::new(DataTypeNode::String), Box::new(DataTypeNode::Int32) - ) + ]) ); assert_eq!( DataTypeNode::new("Map(String, Map(Int32, Array(Nullable(String))))").unwrap(), - DataTypeNode::Map( + DataTypeNode::Map([ Box::new(DataTypeNode::String), - Box::new(DataTypeNode::Map( + Box::new(DataTypeNode::Map([ Box::new(DataTypeNode::Int32), Box::new(DataTypeNode::Array(Box::new(DataTypeNode::Nullable( Box::new(DataTypeNode::String) )))) - )) - ) + ])) + ]) ); assert!(DataTypeNode::new("Map()").is_err()); assert!(DataTypeNode::new("Map").is_err()); @@ -1019,10 +1014,10 @@ mod tests { DataTypeNode::Array(Box::new(DataTypeNode::Nullable(Box::new( DataTypeNode::String )))), - DataTypeNode::Map( + DataTypeNode::Map([ Box::new(DataTypeNode::Int32), Box::new(DataTypeNode::String) - ) + ]) ]) ); assert!(DataTypeNode::new("Variant").is_err()); @@ -1052,13 +1047,13 @@ mod tests { DataTypeNode::Array(Box::new(DataTypeNode::Nullable(Box::new( DataTypeNode::String )))), - DataTypeNode::Map( + DataTypeNode::Map([ Box::new(DataTypeNode::Int32), Box::new(DataTypeNode::Tuple(vec![ DataTypeNode::String, DataTypeNode::Array(Box::new(DataTypeNode::UInt8)) ])) - ) + ]) ]) ); assert_eq!( @@ -1210,10 +1205,10 @@ mod tests { "Tuple(String, UInt32, Float64)" ); assert_eq!( - DataTypeNode::Map( + DataTypeNode::Map([ Box::new(DataTypeNode::String), Box::new(DataTypeNode::UInt32) - ) + ]) .to_string(), "Map(String, UInt32)" ); diff --git a/types/src/lib.rs b/types/src/lib.rs index bed7ccea..22a49b9c 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -49,7 +49,7 @@ pub fn put_rbwnat_columns_header( put_leb128(&mut buffer, column.name.len() as u64); buffer.put_slice(column.name.as_bytes()); } - for column in columns.into_iter() { + for column in columns.iter() { put_leb128(&mut buffer, column.data_type.to_string().len() as u64); buffer.put_slice(column.data_type.to_string().as_bytes()); } From 1b893a8a753960506d3a32794233cdcac7b4f4bb Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 9 Jun 2025 23:03:46 +0200 Subject: [PATCH 33/54] Fix mocked select benchmark --- benches/common.rs | 61 ++++++++++++++++++------ benches/mocked_select.rs | 100 +++++++++++++++++++++++++++------------ 2 files changed, 116 insertions(+), 45 deletions(-) diff --git a/benches/common.rs b/benches/common.rs index 637447ab..7894928b 100644 --- a/benches/common.rs +++ b/benches/common.rs @@ -11,6 +11,7 @@ use std::{ }; use bytes::Bytes; +use clickhouse::error::Result; use futures::stream::StreamExt; use http_body_util::BodyExt; use hyper::{ @@ -25,35 +26,65 @@ use tokio::{ sync::{mpsc, oneshot}, }; -use clickhouse::error::Result; +pub(crate) struct ServerHandle { + handle: Option>, + shutdown_tx: Option>, +} -pub(crate) struct ServerHandle; +impl ServerHandle { + fn shutdown(&mut self) { + if let Some(tx) = self.shutdown_tx.take() { + tx.send(()).unwrap(); + } + if let Some(handle) = self.handle.take() { + handle.join().unwrap(); + } + } +} -pub(crate) fn start_server(addr: SocketAddr, serve: S) -> ServerHandle +impl Drop for ServerHandle { + fn drop(&mut self) { + self.shutdown(); + } +} + +pub(crate) async fn start_server(addr: SocketAddr, serve: S) -> ServerHandle where S: Fn(Request) -> F + Send + Sync + 'static, F: Future> + Send, B: Body + Send + 'static, { + let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); + let (ready_tx, ready_rx) = oneshot::channel::<()>(); + let serving = async move { let listener = TcpListener::bind(addr).await.unwrap(); + ready_tx.send(()).unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); - - let service = - service::service_fn(|request| async { Ok::<_, Infallible>(serve(request).await) }); - - // SELECT benchmark doesn't read the whole body, so ignore possible errors. - let _ = conn::http1::Builder::new() + let server_future = conn::http1::Builder::new() .timer(TokioTimer::new()) - .serve_connection(TokioIo::new(stream), service) - .await; + .serve_connection( + TokioIo::new(stream), + service::service_fn(|request| async { + Ok::<_, Infallible>(serve(request).await) + }), + ); + tokio::select! { + _ = server_future => {} + _ = &mut shutdown_rx => { break; } + } } }; - run_on_st_runtime("server", serving); - ServerHandle + let handle = Some(run_on_st_runtime("server", serving)); + ready_rx.await.unwrap(); + + ServerHandle { + handle, + shutdown_tx: Some(shutdown_tx), + } } pub(crate) async fn skip_incoming(request: Request) { @@ -105,7 +136,7 @@ pub(crate) fn start_runner() -> RunnerHandle { RunnerHandle { tx } } -fn run_on_st_runtime(name: &str, f: impl Future + Send + 'static) { +fn run_on_st_runtime(name: &str, f: impl Future + Send + 'static) -> thread::JoinHandle<()> { let name = name.to_string(); thread::Builder::new() .name(name.clone()) @@ -121,5 +152,5 @@ fn run_on_st_runtime(name: &str, f: impl Future + Send + 'static) { .unwrap() .block_on(f); }) - .unwrap(); + .unwrap() } diff --git a/benches/mocked_select.rs b/benches/mocked_select.rs index 89836baa..744ed7b2 100644 --- a/benches/mocked_select.rs +++ b/benches/mocked_select.rs @@ -1,10 +1,9 @@ -use std::{ - convert::Infallible, - mem, - time::{Duration, Instant}, -}; - use bytes::Bytes; +use clickhouse::{ + error::{Error, Result}, + Client, Compression, Row, +}; +use clickhouse_types::{Column, DataTypeNode}; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use futures::stream::{self, StreamExt as _}; use http_body_util::StreamBody; @@ -13,21 +12,45 @@ use hyper::{ Request, Response, }; use serde::Deserialize; - -use clickhouse::{ - error::{Error, Result}, - Client, Compression, Row, +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::{ + convert::Infallible, + mem, + time::{Duration, Instant}, }; mod common; async fn serve( request: Request, - chunk: Bytes, + compression: Compression, ) -> Response> { common::skip_incoming(request).await; - let stream = stream::repeat(chunk).map(|chunk| Ok(Frame::data(chunk))); + let write_schema = async move { + let schema = vec![ + Column::new("a".to_string(), DataTypeNode::UInt64), + Column::new("b".to_string(), DataTypeNode::Int64), + Column::new("c".to_string(), DataTypeNode::Int32), + Column::new("d".to_string(), DataTypeNode::UInt32), + ]; + + let mut buffer = Vec::new(); + clickhouse_types::put_rbwnat_columns_header(&schema, &mut buffer).unwrap(); + + let buffer = match compression { + Compression::None => Bytes::from(buffer), + #[cfg(feature = "lz4")] + Compression::Lz4 => clickhouse::_priv::lz4_compress(&buffer).unwrap(), + _ => unreachable!(), + }; + + Ok(Frame::data(buffer)) + }; + + let chunk = prepare_chunk(); + let stream = + stream::once(write_schema).chain(stream::repeat(chunk).map(|chunk| Ok(Frame::data(chunk)))); Response::new(StreamBody::new(stream)) } @@ -49,10 +72,13 @@ fn prepare_chunk() -> Bytes { chunk } +const ADDR: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 6523)); + fn select(c: &mut Criterion) { - let addr = "127.0.0.1:6543".parse().unwrap(); - let chunk = prepare_chunk(); - let _server = common::start_server(addr, move |req| serve(req, chunk.clone())); + async fn start_server(compression: Compression) -> common::ServerHandle { + common::start_server(ADDR, move |req| serve(req, compression)).await + } + let runner = common::start_runner(); #[derive(Default, Debug, Row, Deserialize)] @@ -63,7 +89,9 @@ fn select(c: &mut Criterion) { d: u32, } - async fn select_rows(client: Client, iters: u64) -> Result { + async fn select_rows(client: Client, iters: u64, compression: Compression) -> Result { + let _server = start_server(compression).await; + let mut sum = SomeRow::default(); let start = Instant::now(); let mut cursor = client @@ -81,10 +109,18 @@ fn select(c: &mut Criterion) { } black_box(sum); - Ok(start.elapsed()) + + let elapsed = start.elapsed(); + Ok(elapsed) } - async fn select_bytes(client: Client, min_size: u64) -> Result { + async fn select_bytes( + client: Client, + min_size: u64, + compression: Compression, + ) -> Result { + let _server = start_server(compression).await; + let start = Instant::now(); let mut cursor = client .query("SELECT value FROM some") @@ -103,19 +139,21 @@ fn select(c: &mut Criterion) { group.throughput(Throughput::Bytes(mem::size_of::() as u64)); group.bench_function("uncompressed", |b| { b.iter_custom(|iters| { + let compression = Compression::None; let client = Client::default() - .with_url(format!("http://{addr}")) - .with_compression(Compression::None); - runner.run(select_rows(client, iters)) + .with_url(format!("http://{ADDR}")) + .with_compression(compression); + runner.run(select_rows(client, iters, compression)) }) }); #[cfg(feature = "lz4")] group.bench_function("lz4", |b| { b.iter_custom(|iters| { + let compression = Compression::Lz4; let client = Client::default() - .with_url(format!("http://{addr}")) - .with_compression(Compression::Lz4); - runner.run(select_rows(client, iters)) + .with_url(format!("http://{ADDR}")) + .with_compression(compression); + runner.run(select_rows(client, iters, compression)) }) }); group.finish(); @@ -125,19 +163,21 @@ fn select(c: &mut Criterion) { group.throughput(Throughput::Bytes(MIB)); group.bench_function("uncompressed", |b| { b.iter_custom(|iters| { + let compression = Compression::None; let client = Client::default() - .with_url(format!("http://{addr}")) - .with_compression(Compression::None); - runner.run(select_bytes(client, iters * MIB)) + .with_url(format!("http://{ADDR}")) + .with_compression(compression); + runner.run(select_bytes(client, iters * MIB, compression)) }) }); #[cfg(feature = "lz4")] group.bench_function("lz4", |b| { b.iter_custom(|iters| { + let compression = Compression::None; let client = Client::default() - .with_url(format!("http://{addr}")) - .with_compression(Compression::Lz4); - runner.run(select_bytes(client, iters * MIB)) + .with_url(format!("http://{ADDR}")) + .with_compression(compression); + runner.run(select_bytes(client, iters * MIB, compression)) }) }); group.finish(); From 14f8550ad8edc7e02287f98912b59733a485cb52 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 9 Jun 2025 23:29:22 +0200 Subject: [PATCH 34/54] Fix mocked insert benchmark --- benches/mocked_insert.rs | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/benches/mocked_insert.rs b/benches/mocked_insert.rs index f4ebc56f..983c20ab 100644 --- a/benches/mocked_insert.rs +++ b/benches/mocked_insert.rs @@ -1,14 +1,14 @@ -use std::{ - future::Future, - mem, - time::{Duration, Instant}, -}; - use bytes::Bytes; use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; use http_body_util::Empty; use hyper::{body::Incoming, Request, Response}; use serde::Serialize; +use std::net::SocketAddr; +use std::{ + future::Future, + mem, + time::{Duration, Instant}, +}; use clickhouse::{error::Result, Client, Compression, Row}; @@ -46,7 +46,9 @@ impl SomeRow { } } -async fn run_insert(client: Client, iters: u64) -> Result { +async fn run_insert(client: Client, addr: SocketAddr, iters: u64) -> Result { + let _server = common::start_server(addr, serve).await; + let start = Instant::now(); let mut insert = client.insert("table")?; @@ -59,7 +61,13 @@ async fn run_insert(client: Client, iters: u64) -> Result { } #[cfg(feature = "inserter")] -async fn run_inserter(client: Client, iters: u64) -> Result { +async fn run_inserter( + client: Client, + addr: SocketAddr, + iters: u64, +) -> Result { + let _server = common::start_server(addr, serve).await; + let start = Instant::now(); let mut inserter = client.inserter("table")?.with_max_rows(iters); @@ -77,12 +85,11 @@ async fn run_inserter(client: Client, iters: u64) -> Re Ok(start.elapsed()) } -fn run(c: &mut Criterion, name: &str, port: u16, f: impl Fn(Client, u64) -> F) +fn run(c: &mut Criterion, name: &str, port: u16, f: impl Fn(Client, SocketAddr, u64) -> F) where F: Future> + Send + 'static, { - let addr = format!("127.0.0.1:{port}").parse().unwrap(); - let _server = common::start_server(addr, serve); + let addr: SocketAddr = format!("127.0.0.1:{port}").parse().unwrap(); let runner = common::start_runner(); let mut group = c.benchmark_group(name); @@ -92,7 +99,7 @@ where let client = Client::default() .with_url(format!("http://{addr}")) .with_compression(Compression::None); - runner.run((f)(client, iters)) + runner.run((f)(client, addr, iters)) }) }); #[cfg(feature = "lz4")] @@ -101,7 +108,7 @@ where let client = Client::default() .with_url(format!("http://{addr}")) .with_compression(Compression::Lz4); - runner.run((f)(client, iters)) + runner.run((f)(client, addr, iters)) }) }); group.finish(); From 5509b1232bf8127d84c9e3007cb111ac504bc16b Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 9 Jun 2025 23:51:53 +0200 Subject: [PATCH 35/54] Fix the rest of the examples, add a simple sanity check --- examples/async_insert.rs | 6 +++--- examples/clickhouse_cloud.rs | 2 +- tests/it/examples.rs | 29 +++++++++++++++++++++++++++++ tests/it/main.rs | 1 + 4 files changed, 34 insertions(+), 4 deletions(-) create mode 100644 tests/it/examples.rs diff --git a/examples/async_insert.rs b/examples/async_insert.rs index 7a9c18f6..8b567266 100644 --- a/examples/async_insert.rs +++ b/examples/async_insert.rs @@ -10,7 +10,7 @@ use clickhouse::{error::Result, Client, Row}; #[derive(Debug, Serialize, Deserialize, Row)] struct Event { - timestamp: u64, + timestamp: i64, message: String, } @@ -70,9 +70,9 @@ async fn main() -> Result<()> { Ok(()) } -fn now() -> u64 { +fn now() -> i64 { UNIX_EPOCH .elapsed() .expect("invalid system time") - .as_nanos() as u64 + .as_nanos() as i64 } diff --git a/examples/clickhouse_cloud.rs b/examples/clickhouse_cloud.rs index 7002160d..5c84d84b 100644 --- a/examples/clickhouse_cloud.rs +++ b/examples/clickhouse_cloud.rs @@ -66,7 +66,7 @@ async fn main() -> clickhouse::error::Result<()> { #[derive(Debug, Serialize, Deserialize, Row)] struct Data { - id: u32, + id: i32, name: String, } diff --git a/tests/it/examples.rs b/tests/it/examples.rs new file mode 100644 index 00000000..98b03b09 --- /dev/null +++ b/tests/it/examples.rs @@ -0,0 +1,29 @@ +#[test] +fn test_all_examples_exit_zero() { + let entries = std::fs::read_dir("./examples").unwrap(); + for entry in entries { + let entry = entry.unwrap(); + let path = entry.path(); + if path.is_file() && path.extension().map_or(false, |ext| ext == "rs") { + let file_name = path.file_stem().unwrap().to_str().unwrap(); + if !file_name.ends_with("_test.rs") { + println!("-- Running example: {}", file_name); + let output = std::process::Command::new("cargo") + .args(&["run", "--example", file_name, "--all-features"]) + .envs([ + ("CLICKHOUSE_URL", "http://localhost:8123"), + ("CLICKHOUSE_USER", "default"), + ("CLICKHOUSE_PASSWORD", ""), + ]) + .output() + .expect(&format!("Failed to execute example {}", file_name)); + assert!( + output.status.success(), + "Example '{}' failed with stderr: {}", + file_name, + String::from_utf8_lossy(&output.stderr) + ); + } + } + } +} diff --git a/tests/it/main.rs b/tests/it/main.rs index 6bbe41b6..ebe50ed3 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -164,6 +164,7 @@ mod cloud_jwt; mod compression; mod cursor_error; mod cursor_stats; +mod examples; mod fetch_bytes; mod insert; mod inserter; From 38d771d98e67c6b1424bbdabf3980f6dfe6eaa7b Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 9 Jun 2025 23:52:49 +0200 Subject: [PATCH 36/54] Clippy fixes --- tests/it/examples.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/it/examples.rs b/tests/it/examples.rs index 98b03b09..5dc3e10a 100644 --- a/tests/it/examples.rs +++ b/tests/it/examples.rs @@ -4,19 +4,19 @@ fn test_all_examples_exit_zero() { for entry in entries { let entry = entry.unwrap(); let path = entry.path(); - if path.is_file() && path.extension().map_or(false, |ext| ext == "rs") { + if path.is_file() && path.extension().is_some_and(|ext| ext == "rs") { let file_name = path.file_stem().unwrap().to_str().unwrap(); if !file_name.ends_with("_test.rs") { println!("-- Running example: {}", file_name); let output = std::process::Command::new("cargo") - .args(&["run", "--example", file_name, "--all-features"]) + .args(["run", "--example", file_name, "--all-features"]) .envs([ ("CLICKHOUSE_URL", "http://localhost:8123"), ("CLICKHOUSE_USER", "default"), ("CLICKHOUSE_PASSWORD", ""), ]) .output() - .expect(&format!("Failed to execute example {}", file_name)); + .unwrap_or_else(|_| panic!("Failed to execute example {}", file_name)); assert!( output.status.success(), "Example '{}' failed with stderr: {}", From 446eb7c4f159966864fe138def4c6bbe811918ac Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Tue, 10 Jun 2025 00:33:57 +0200 Subject: [PATCH 37/54] Don't use Result as validation always panics --- src/rowbinary/de.rs | 26 ++--- src/rowbinary/validation.rs | 226 +++++++++++++++++++++++------------- 2 files changed, 160 insertions(+), 92 deletions(-) diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 06b52a79..177d1854 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -85,7 +85,7 @@ macro_rules! impl_num { ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr) => { #[inline(always)] fn $deser_method>(self, visitor: V) -> Result { - self.validator.validate($serde_type)?; + self.validator.validate($serde_type); ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; let value = self.input.$reader_method(); visitor.$visitor_method(value) @@ -101,7 +101,7 @@ where #[inline(always)] fn deserialize_i8>(self, visitor: V) -> Result { - let mut maybe_enum_validator = self.validator.validate(SerdeType::I8)?; + let mut maybe_enum_validator = self.validator.validate(SerdeType::I8); ensure_size(&mut self.input, size_of::())?; let value = self.input.get_i8(); maybe_enum_validator.validate_enum8_value(value); @@ -110,7 +110,7 @@ where #[inline(always)] fn deserialize_i16>(self, visitor: V) -> Result { - let mut maybe_enum_validator = self.validator.validate(SerdeType::I16)?; + let mut maybe_enum_validator = self.validator.validate(SerdeType::I16); ensure_size(&mut self.input, size_of::())?; let value = self.input.get_i16_le(); // TODO: is there a better way to validate that the deserialized value matches the schema? @@ -155,7 +155,7 @@ where #[inline(always)] fn deserialize_bool>(self, visitor: V) -> Result { - self.validator.validate(SerdeType::Bool)?; + self.validator.validate(SerdeType::Bool); ensure_size(&mut self.input, 1)?; match self.input.get_u8() { 0 => visitor.visit_bool(false), @@ -166,7 +166,7 @@ where #[inline(always)] fn deserialize_str>(self, visitor: V) -> Result { - self.validator.validate(SerdeType::Str)?; + self.validator.validate(SerdeType::Str); let size = self.read_size()?; let slice = self.read_slice(size)?; let str = str::from_utf8(slice).map_err(Error::from)?; @@ -175,7 +175,7 @@ where #[inline(always)] fn deserialize_string>(self, visitor: V) -> Result { - self.validator.validate(SerdeType::String)?; + self.validator.validate(SerdeType::String); let size = self.read_size()?; let vec = self.read_vec(size)?; let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; @@ -185,7 +185,7 @@ where #[inline(always)] fn deserialize_bytes>(self, visitor: V) -> Result { let size = self.read_size()?; - self.validator.validate(SerdeType::Bytes(size))?; + self.validator.validate(SerdeType::Bytes(size)); let slice = self.read_slice(size)?; visitor.visit_borrowed_bytes(slice) } @@ -193,7 +193,7 @@ where #[inline(always)] fn deserialize_byte_buf>(self, visitor: V) -> Result { let size = self.read_size()?; - self.validator.validate(SerdeType::ByteBuf(size))?; + self.validator.validate(SerdeType::ByteBuf(size)); visitor.visit_byte_buf(self.read_vec(size)?) } @@ -213,7 +213,7 @@ where _variants: &'static [&'static str], visitor: V, ) -> Result { - let validator = self.validator.validate(SerdeType::Enum)?; + let validator = self.validator.validate(SerdeType::Enum); visitor.visit_enum(RowBinaryEnumAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, @@ -224,7 +224,7 @@ where #[inline(always)] fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { - let validator = self.validator.validate(SerdeType::Tuple(len))?; + let validator = self.validator.validate(SerdeType::Tuple(len)); let mut de = RowBinaryDeserializer { input: self.input, validator, @@ -239,7 +239,7 @@ where #[inline(always)] fn deserialize_option>(self, visitor: V) -> Result { ensure_size(&mut self.input, 1)?; - let inner_validator = self.validator.validate(SerdeType::Option)?; + let inner_validator = self.validator.validate(SerdeType::Option); match self.input.get_u8() { 0 => visitor.visit_some(&mut RowBinaryDeserializer { input: self.input, @@ -256,7 +256,7 @@ where visitor.visit_seq(RowBinarySeqAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, - validator: self.validator.validate(SerdeType::Seq(len))?, + validator: self.validator.validate(SerdeType::Seq(len)), }, len, }) @@ -265,7 +265,7 @@ where #[inline(always)] fn deserialize_map>(self, visitor: V) -> Result { let len = self.read_size()?; - let validator = self.validator.validate(SerdeType::Map(len))?; + let validator = self.validator.validate(SerdeType::Map(len)); visitor.visit_map(RowBinaryMapAccess { deserializer: &mut RowBinaryDeserializer { input: self.input, diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 170b33ca..4e2c0ce9 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -1,4 +1,3 @@ -use crate::error::Result; use crate::row_metadata::RowMetadata; use crate::RowKind; use clickhouse_types::data_types::{Column, DataTypeNode, DecimalType, EnumType}; @@ -6,10 +5,7 @@ use std::collections::HashMap; use std::fmt::Display; pub(crate) trait SchemaValidator: Sized { - fn validate( - &'_ mut self, - serde_type: SerdeType, - ) -> Result>>; + fn validate(&'_ mut self, serde_type: SerdeType) -> Option>; fn validate_enum8_value(&mut self, value: i8); fn validate_enum16_value(&mut self, value: i16); fn set_next_variant_value(&mut self, value: u8); @@ -52,7 +48,7 @@ impl<'cursor> DataTypeValidator<'cursor> { data_type: &DataTypeNode, serde_type: &SerdeType, is_inner: bool, - ) -> Result>> { + ) -> Option> { match self.metadata.kind { RowKind::Primitive => { panic!( @@ -99,10 +95,7 @@ impl<'cursor> DataTypeValidator<'cursor> { impl SchemaValidator for DataTypeValidator<'_> { #[inline] - fn validate( - &'_ mut self, - serde_type: SerdeType, - ) -> Result>> { + fn validate(&'_ mut self, serde_type: SerdeType) -> Option> { match self.metadata.kind { // `fetch::` for a "primitive row" type RowKind::Primitive => { @@ -119,10 +112,10 @@ impl SchemaValidator for DataTypeValidator<'_> { // `fetch::<(i16, i32)>` or `fetch::<(T, u64)>` for a "tuple row" type RowKind::Tuple => { match serde_type { - SerdeType::Tuple(_) => Ok(Some(InnerDataTypeValidator { + SerdeType::Tuple(_) => Some(InnerDataTypeValidator { root: self, kind: InnerDataTypeValidatorKind::RootTuple(&self.metadata.columns, 0), - })), + }), _ => { // should be unreachable panic!( @@ -144,7 +137,7 @@ impl SchemaValidator for DataTypeValidator<'_> { self.metadata.columns[0].data_type ), }; - Ok(Some(InnerDataTypeValidator { root: self, kind })) + Some(InnerDataTypeValidator { root: self, kind }) } // `fetch::` for a "struct row" type, which is supposed to be the default flow RowKind::Struct => { @@ -242,12 +235,9 @@ pub(crate) enum VariantValidationState { impl<'de, 'cursor> SchemaValidator for Option> { #[inline] - fn validate( - &mut self, - serde_type: SerdeType, - ) -> Result>> { + fn validate(&mut self, serde_type: SerdeType) -> Option> { match self { - None => Ok(None), + None => None, Some(inner) => match &mut inner.kind { InnerDataTypeValidatorKind::Map(kv, state) => match state { MapValidatorState::Key => { @@ -267,7 +257,7 @@ impl<'de, 'cursor> SchemaValidator for Option { *state = MapAsSequenceValidatorState::Key; - Ok(self.take()) + self.take() } MapAsSequenceValidatorState::Key => { let result = validate_impl(inner.root, &kv[0], &serde_type, true); @@ -305,7 +295,7 @@ impl<'de, 'cursor> SchemaValidator for Option { - Ok(None) // actually unreachable + None // actually unreachable } InnerDataTypeValidatorKind::RootTuple(columns, current_index) => { if *current_index < columns.len() { @@ -431,6 +421,89 @@ impl Drop for InnerDataTypeValidator<'_, '_> { } } +// #[inline] +// fn simple_types_impl<'de, 'cursor>( +// root: &'de DataTypeValidator<'cursor>, +// data_type: &'cursor DataTypeNode, +// serde_type: &SerdeType, +// is_inner: bool, +// ) { +// match serde_type { +// SerdeType::Bool +// if data_type == &DataTypeNode::Bool || data_type == &DataTypeNode::UInt8 => +// { +// None +// } +// SerdeType::I8 => match data_type { +// DataTypeNode::Int8 => None, +// DataTypeNode::Enum(EnumType::Enum8, values_map) => Some(InnerDataTypeValidator { +// root, +// kind: InnerDataTypeValidatorKind::Enum(values_map), +// })), +// _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), +// }, +// SerdeType::I16 => match data_type { +// DataTypeNode::Int16 => None, +// DataTypeNode::Enum(EnumType::Enum16, values_map) => Some(InnerDataTypeValidator { +// root, +// kind: InnerDataTypeValidatorKind::Enum(values_map), +// })), +// _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), +// }, +// SerdeType::I32 +// if data_type == &DataTypeNode::Int32 +// || data_type == &DataTypeNode::Date32 +// || matches!( +// data_type, +// DataTypeNode::Decimal(_, _, DecimalType::Decimal32) +// ) => +// { +// None +// } +// SerdeType::I64 +// if data_type == &DataTypeNode::Int64 +// || matches!(data_type, DataTypeNode::DateTime64(_, _)) +// || matches!( +// data_type, +// DataTypeNode::Decimal(_, _, DecimalType::Decimal64) +// ) => +// { +// None +// } +// SerdeType::I128 +// if data_type == &DataTypeNode::Int128 +// || matches!( +// data_type, +// DataTypeNode::Decimal(_, _, DecimalType::Decimal128) +// ) => +// { +// None +// } +// SerdeType::U8 if data_type == &DataTypeNode::UInt8 => None, +// SerdeType::U16 +// if data_type == &DataTypeNode::UInt16 || data_type == &DataTypeNode::Date => +// { +// None +// } +// SerdeType::U32 +// if data_type == &DataTypeNode::UInt32 +// || matches!(data_type, DataTypeNode::DateTime(_)) +// || data_type == &DataTypeNode::IPv4 => +// { +// None +// } +// SerdeType::U64 if data_type == &DataTypeNode::UInt64 => None, +// SerdeType::U128 if data_type == &DataTypeNode::UInt128 => None, +// SerdeType::F32 if data_type == &DataTypeNode::Float32 => None, +// SerdeType::F64 if data_type == &DataTypeNode::Float64 => None, +// SerdeType::Str | SerdeType::String +// if data_type == &DataTypeNode::String || data_type == &DataTypeNode::JSON => +// { +// None +// } +// } +// } + // TODO: is there a way to eliminate multiple branches with similar patterns? // static/const dispatch? // separate smaller inline functions? @@ -440,28 +513,28 @@ fn validate_impl<'de, 'cursor>( column_data_type: &'cursor DataTypeNode, serde_type: &SerdeType, is_inner: bool, -) -> Result>> { +) -> Option> { let data_type = column_data_type.remove_low_cardinality(); match serde_type { SerdeType::Bool if data_type == &DataTypeNode::Bool || data_type == &DataTypeNode::UInt8 => { - Ok(None) + None } SerdeType::I8 => match data_type { - DataTypeNode::Int8 => Ok(None), - DataTypeNode::Enum(EnumType::Enum8, values_map) => Ok(Some(InnerDataTypeValidator { + DataTypeNode::Int8 => None, + DataTypeNode::Enum(EnumType::Enum8, values_map) => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Enum(values_map), - })), + }), _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), }, SerdeType::I16 => match data_type { - DataTypeNode::Int16 => Ok(None), - DataTypeNode::Enum(EnumType::Enum16, values_map) => Ok(Some(InnerDataTypeValidator { + DataTypeNode::Int16 => None, + DataTypeNode::Enum(EnumType::Enum16, values_map) => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Enum(values_map), - })), + }), _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), }, SerdeType::I32 @@ -472,7 +545,7 @@ fn validate_impl<'de, 'cursor>( DataTypeNode::Decimal(_, _, DecimalType::Decimal32) ) => { - Ok(None) + None } SerdeType::I64 if data_type == &DataTypeNode::Int64 @@ -482,7 +555,7 @@ fn validate_impl<'de, 'cursor>( DataTypeNode::Decimal(_, _, DecimalType::Decimal64) ) => { - Ok(None) + None } SerdeType::I128 if data_type == &DataTypeNode::Int128 @@ -491,86 +564,84 @@ fn validate_impl<'de, 'cursor>( DataTypeNode::Decimal(_, _, DecimalType::Decimal128) ) => { - Ok(None) + None } - SerdeType::U8 if data_type == &DataTypeNode::UInt8 => Ok(None), + SerdeType::U8 if data_type == &DataTypeNode::UInt8 => None, SerdeType::U16 if data_type == &DataTypeNode::UInt16 || data_type == &DataTypeNode::Date => { - Ok(None) + None } SerdeType::U32 if data_type == &DataTypeNode::UInt32 || matches!(data_type, DataTypeNode::DateTime(_)) || data_type == &DataTypeNode::IPv4 => { - Ok(None) + None } - SerdeType::U64 if data_type == &DataTypeNode::UInt64 => Ok(None), - SerdeType::U128 if data_type == &DataTypeNode::UInt128 => Ok(None), - SerdeType::F32 if data_type == &DataTypeNode::Float32 => Ok(None), - SerdeType::F64 if data_type == &DataTypeNode::Float64 => Ok(None), + SerdeType::U64 if data_type == &DataTypeNode::UInt64 => None, + SerdeType::U128 if data_type == &DataTypeNode::UInt128 => None, + SerdeType::F32 if data_type == &DataTypeNode::Float32 => None, + SerdeType::F64 if data_type == &DataTypeNode::Float64 => None, SerdeType::Str | SerdeType::String if data_type == &DataTypeNode::String || data_type == &DataTypeNode::JSON => { - Ok(None) + None } // allows to work with BLOB strings as well - SerdeType::Bytes(_) | SerdeType::ByteBuf(_) if data_type == &DataTypeNode::String => { - Ok(None) - } + SerdeType::Bytes(_) | SerdeType::ByteBuf(_) if data_type == &DataTypeNode::String => None, SerdeType::Option => { if let DataTypeNode::Nullable(inner_type) = data_type { - Ok(Some(InnerDataTypeValidator { + Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Nullable(inner_type), - })) + }) } else { root.panic_on_schema_mismatch(data_type, serde_type, is_inner) } } SerdeType::Seq(_) => match data_type { - DataTypeNode::Array(inner_type) => Ok(Some(InnerDataTypeValidator { + DataTypeNode::Array(inner_type) => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Array(inner_type), - })), + }), // A map can be defined as `Vec<(K, V)>` in the struct - DataTypeNode::Map(kv) => Ok(Some(InnerDataTypeValidator { + DataTypeNode::Map(kv) => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::MapAsSequence( kv, MapAsSequenceValidatorState::Tuple, ), - })), - DataTypeNode::Ring => Ok(Some(InnerDataTypeValidator { + }), + DataTypeNode::Ring => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Point), - })), - DataTypeNode::Polygon => Ok(Some(InnerDataTypeValidator { + }), + DataTypeNode::Polygon => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Ring), - })), - DataTypeNode::MultiPolygon => Ok(Some(InnerDataTypeValidator { + }), + DataTypeNode::MultiPolygon => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Polygon), - })), - DataTypeNode::LineString => Ok(Some(InnerDataTypeValidator { + }), + DataTypeNode::LineString => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Point), - })), - DataTypeNode::MultiLineString => Ok(Some(InnerDataTypeValidator { + }), + DataTypeNode::MultiLineString => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::LineString), - })), + }), _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), }, SerdeType::Tuple(len) => match data_type { DataTypeNode::FixedString(n) => { if n == len { - Ok(Some(InnerDataTypeValidator { + Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::FixedString(*n), - })) + }) } else { let (full_name, full_data_type) = root.get_current_column_name_and_type(); panic!( @@ -580,34 +651,34 @@ fn validate_impl<'de, 'cursor>( ) } } - DataTypeNode::Tuple(elements) => Ok(Some(InnerDataTypeValidator { + DataTypeNode::Tuple(elements) => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Tuple(elements), - })), - DataTypeNode::Array(inner_type) => Ok(Some(InnerDataTypeValidator { + }), + DataTypeNode::Array(inner_type) => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Array(inner_type), - })), - DataTypeNode::IPv6 => Ok(Some(InnerDataTypeValidator { + }), + DataTypeNode::IPv6 => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::UInt8), - })), - DataTypeNode::UUID => Ok(Some(InnerDataTypeValidator { + }), + DataTypeNode::UUID => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Tuple(UUID_TUPLE_ELEMENTS), - })), - DataTypeNode::Point => Ok(Some(InnerDataTypeValidator { + }), + DataTypeNode::Point => Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Tuple(POINT_TUPLE_ELEMENTS), - })), + }), _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), }, SerdeType::Map(_) => { if let DataTypeNode::Map(kv) = data_type { - Ok(Some(InnerDataTypeValidator { + Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Map(kv, MapValidatorState::Key), - })) + }) } else { panic!( "Expected Map for {} call, but got {}", @@ -617,13 +688,13 @@ fn validate_impl<'de, 'cursor>( } SerdeType::Enum => { if let DataTypeNode::Variant(possible_types) = data_type { - Ok(Some(InnerDataTypeValidator { + Some(InnerDataTypeValidator { root, kind: InnerDataTypeValidatorKind::Variant( possible_types, VariantValidationState::Pending, ), - })) + }) } else { panic!( "Expected Variant for {} call, but got {}", @@ -642,11 +713,8 @@ fn validate_impl<'de, 'cursor>( impl SchemaValidator for () { #[inline(always)] - fn validate( - &mut self, - _serde_type: SerdeType, - ) -> Result>> { - Ok(None) + fn validate(&mut self, _serde_type: SerdeType) -> Option> { + None } #[inline(always)] From 19760f31ca88e395041d652208990675438b1eda Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Tue, 10 Jun 2025 00:36:55 +0200 Subject: [PATCH 38/54] Bring back Unsupported error kind --- src/error.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/error.rs b/src/error.rs index d598f18d..438a6b87 100644 --- a/src/error.rs +++ b/src/error.rs @@ -43,6 +43,8 @@ pub enum Error { TimedOut, #[error("error while parsing columns header from the response: {0}")] InvalidColumnsHeader(#[source] BoxedError), + #[error("unsupported: {0}")] + Unsupported(String), #[error("{0}")] Other(BoxedError), } From 5f51dc7b2c39fb286da11927816d9e95904bc18b Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Tue, 10 Jun 2025 00:39:37 +0200 Subject: [PATCH 39/54] Remove examples runner from the `it` directory --- tests/it/examples.rs | 29 ----------------------------- tests/it/main.rs | 1 - 2 files changed, 30 deletions(-) delete mode 100644 tests/it/examples.rs diff --git a/tests/it/examples.rs b/tests/it/examples.rs deleted file mode 100644 index 5dc3e10a..00000000 --- a/tests/it/examples.rs +++ /dev/null @@ -1,29 +0,0 @@ -#[test] -fn test_all_examples_exit_zero() { - let entries = std::fs::read_dir("./examples").unwrap(); - for entry in entries { - let entry = entry.unwrap(); - let path = entry.path(); - if path.is_file() && path.extension().is_some_and(|ext| ext == "rs") { - let file_name = path.file_stem().unwrap().to_str().unwrap(); - if !file_name.ends_with("_test.rs") { - println!("-- Running example: {}", file_name); - let output = std::process::Command::new("cargo") - .args(["run", "--example", file_name, "--all-features"]) - .envs([ - ("CLICKHOUSE_URL", "http://localhost:8123"), - ("CLICKHOUSE_USER", "default"), - ("CLICKHOUSE_PASSWORD", ""), - ]) - .output() - .unwrap_or_else(|_| panic!("Failed to execute example {}", file_name)); - assert!( - output.status.success(), - "Example '{}' failed with stderr: {}", - file_name, - String::from_utf8_lossy(&output.stderr) - ); - } - } - } -} diff --git a/tests/it/main.rs b/tests/it/main.rs index 73d7eb32..f51ccb3e 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -164,7 +164,6 @@ mod cloud_jwt; mod compression; mod cursor_error; mod cursor_stats; -mod examples; mod fetch_bytes; mod https_errors; mod insert; From 8f3f3b28429b183021089c1b7fee610a8152a156 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Tue, 10 Jun 2025 00:55:57 +0200 Subject: [PATCH 40/54] Ignore an odd test --- tests/it/rbwnat.rs | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 203d8ee6..fc46ea8e 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -350,6 +350,8 @@ async fn test_basic_types() { ); } +// FIXME: somehow this test breaks `cargo test`, but works from RustRover +#[ignore] #[tokio::test] async fn test_borrowed_data() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] @@ -370,6 +372,8 @@ async fn test_borrowed_data() { let mut cursor = client .query( " + SELECT * FROM + ( SELECT 'a' :: String AS str, ['b', 'c'] :: Array(String) AS array, @@ -393,6 +397,8 @@ async fn test_borrowed_data() { hash_map_str AS vec_map_str, hash_map_f32 AS vec_map_f32, hash_map_nested AS vec_map_nested + ) + ORDER BY str ", ) .fetch::>() @@ -569,9 +575,9 @@ async fn test_maps() { let result = client .query( " - SELECT + SELECT map('key1', 'value1', 'key2', 'value2') :: Map(String, String) AS m1, - map(42, map('foo', 100, 'bar', 200), + map(42, map('foo', 100, 'bar', 200), 144, map('qaz', 300, 'qux', 400)) :: Map(UInt16, Map(String, Int32)) AS m2 ", ) From d189a78f7f0392739f40cb07785a3d5e98093de5 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Tue, 10 Jun 2025 14:16:57 +0200 Subject: [PATCH 41/54] Add CI workflow dispatch and all PR trigger --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8c546f56..b0960230 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + workflow_dispatch: env: CARGO_TERM_COLOR: always From ccfac33235d978716902884effe3ffd8c0261b3d Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Tue, 10 Jun 2025 21:20:09 +0200 Subject: [PATCH 42/54] Further optimizations, remove validation_mode, remove schema from mocks --- benches/common_select.rs | 21 ++- benches/select_numbers.rs | 15 +- benches/select_nyc_taxi_data.rs | 19 ++- examples/mock.rs | 11 +- src/cursors/row.rs | 61 ++++---- src/lib.rs | 45 +++--- src/query.rs | 16 +- src/row_metadata.rs | 7 +- src/rowbinary/de.rs | 255 ++++++++++++++++++------------- src/rowbinary/mod.rs | 3 +- src/rowbinary/tests.rs | 8 +- src/rowbinary/validation.rs | 260 +++++++++++++------------------- src/test/handlers.rs | 10 +- src/validation_mode.rs | 34 ----- tests/it/main.rs | 2 +- tests/it/mock.rs | 13 +- tests/it/rbwnat.rs | 77 +++++----- tests/it/variant.rs | 3 +- 18 files changed, 403 insertions(+), 457 deletions(-) delete mode 100644 src/validation_mode.rs diff --git a/benches/common_select.rs b/benches/common_select.rs index 54f183d1..a1c3d850 100644 --- a/benches/common_select.rs +++ b/benches/common_select.rs @@ -1,7 +1,6 @@ #![allow(dead_code)] use clickhouse::query::RowCursor; -use clickhouse::validation_mode::ValidationMode; use clickhouse::{Client, Compression, Row}; use criterion::black_box; use serde::Deserialize; @@ -57,7 +56,7 @@ pub(crate) fn print_header(add: Option<&str>) { pub(crate) fn print_results<'a, T: BenchmarkRow<'a>>( stats: &BenchmarkStats, compression: Compression, - validation_mode: ValidationMode, + validation: bool, ) { let BenchmarkStats { throughput_mbytes_sec, @@ -65,11 +64,7 @@ pub(crate) fn print_results<'a, T: BenchmarkRow<'a>>( elapsed, .. } = stats; - let validation_mode = match validation_mode { - ValidationMode::First(n) => format!("First({})", n), - ValidationMode::Each => "Each".to_string(), - _ => panic!("Unexpected validation mode"), - }; + let validation_mode = if validation { "enabled" } else { "disabled" }; let compression = match compression { Compression::None => "none", #[cfg(feature = "lz4")] @@ -87,23 +82,25 @@ pub(crate) fn print_results<'a, T: BenchmarkRow<'a>>( pub(crate) async fn fetch_cursor<'a, T: BenchmarkRow<'a>>( compression: Compression, - validation_mode: ValidationMode, + validation: bool, query: &str, ) -> RowCursor { - let client = Client::default() + let mut client = Client::default() .with_compression(compression) - .with_validation_mode(validation_mode) .with_url("http://localhost:8123"); + if !validation { + client = client.with_disabled_validation(); + } client.query(query).fetch::().unwrap() } pub(crate) async fn do_select_bench<'a, T: BenchmarkRow<'a>>( query: &str, compression: Compression, - validation_mode: ValidationMode, + validation: bool, ) -> BenchmarkStats { let start = Instant::now(); - let mut cursor = fetch_cursor::(compression, validation_mode, query).await; + let mut cursor = fetch_cursor::(compression, validation, query).await; let mut sum = 0; while let Some(row) = cursor.next().await.unwrap() { diff --git a/benches/select_numbers.rs b/benches/select_numbers.rs index 2cc98aab..2adfde86 100644 --- a/benches/select_numbers.rs +++ b/benches/select_numbers.rs @@ -3,7 +3,6 @@ use serde::Deserialize; use crate::common_select::{ do_select_bench, print_header, print_results, BenchmarkRow, WithAccessType, WithId, }; -use clickhouse::validation_mode::ValidationMode; use clickhouse::{Compression, Row}; mod common_select; @@ -15,25 +14,25 @@ struct Data { impl_benchmark_row_no_access_type!(Data, number); -async fn bench(compression: Compression, validation_mode: ValidationMode) { +async fn bench(compression: Compression, validation: bool) { let stats = do_select_bench::( "SELECT number FROM system.numbers_mt LIMIT 500000000", compression, - validation_mode, + validation, ) .await; assert_eq!(stats.result, 124999999750000000); - print_results::(&stats, compression, validation_mode); + print_results::(&stats, compression, validation); } #[tokio::main] async fn main() { print_header(None); - bench(Compression::None, ValidationMode::First(1)).await; - bench(Compression::None, ValidationMode::Each).await; + bench(Compression::None, false).await; + bench(Compression::None, true).await; #[cfg(feature = "lz4")] { - bench(Compression::Lz4, ValidationMode::First(1)).await; - bench(Compression::Lz4, ValidationMode::Each).await; + bench(Compression::Lz4, false).await; + bench(Compression::Lz4, true).await; } } diff --git a/benches/select_nyc_taxi_data.rs b/benches/select_nyc_taxi_data.rs index b6e96baf..618ea8c5 100644 --- a/benches/select_nyc_taxi_data.rs +++ b/benches/select_nyc_taxi_data.rs @@ -3,7 +3,6 @@ use crate::common_select::{ do_select_bench, print_header, print_results, BenchmarkRow, WithAccessType, WithId, }; -use clickhouse::validation_mode::ValidationMode; use clickhouse::{Compression, Row}; use serde::Deserialize; use serde_repr::Deserialize_repr; @@ -75,27 +74,27 @@ struct TripSmallMapAccess { impl_benchmark_row!(TripSmallSeqAccess, trip_id, "seq"); impl_benchmark_row!(TripSmallMapAccess, trip_id, "map"); -async fn bench<'a, T: BenchmarkRow<'a>>(compression: Compression, validation_mode: ValidationMode) { +async fn bench<'a, T: BenchmarkRow<'a>>(compression: Compression, validation: bool) { let stats = do_select_bench::( "SELECT * FROM nyc_taxi.trips_small ORDER BY trip_id DESC", compression, - validation_mode, + validation, ) .await; assert_eq!(stats.result, 3630387815532582); - print_results::(&stats, compression, validation_mode); + print_results::(&stats, compression, validation); } #[tokio::main] async fn main() { print_header(Some(" access")); - bench::(Compression::None, ValidationMode::First(1)).await; - bench::(Compression::None, ValidationMode::Each).await; - bench::(Compression::None, ValidationMode::Each).await; + bench::(Compression::None, false).await; + bench::(Compression::None, true).await; + bench::(Compression::None, true).await; #[cfg(feature = "lz4")] { - bench::(Compression::Lz4, ValidationMode::First(1)).await; - bench::(Compression::Lz4, ValidationMode::Each).await; - bench::(Compression::Lz4, ValidationMode::Each).await; + bench::(Compression::Lz4, false).await; + bench::(Compression::Lz4, true).await; + bench::(Compression::Lz4, true).await; } } diff --git a/examples/mock.rs b/examples/mock.rs index ca961f32..d4e7dadb 100644 --- a/examples/mock.rs +++ b/examples/mock.rs @@ -1,6 +1,4 @@ use clickhouse::{error::Result, test, Client, Row}; -use clickhouse_types::Column; -use clickhouse_types::DataTypeNode::UInt32; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, PartialEq)] @@ -48,7 +46,9 @@ async fn make_watch_only_events(client: &Client) -> Result { #[tokio::main] async fn main() { let mock = test::Mock::new(); - let client = Client::default().with_url(mock.url()); + let client = Client::default() + .with_url(mock.url()) + .with_disabled_validation(); let list = vec![SomeRow { no: 1 }, SomeRow { no: 2 }]; // How to test DDL. @@ -56,11 +56,8 @@ async fn main() { make_create(&client).await.unwrap(); assert!(recording.query().await.contains("CREATE TABLE")); - let metadata = - clickhouse::RowMetadata::new::(vec![Column::new("no".to_string(), UInt32)]); - // How to test SELECT. - mock.add(test::handlers::provide(&metadata, list.clone())); + mock.add(test::handlers::provide(list.clone())); let rows = make_select(&client).await.unwrap(); assert_eq!(rows, list); diff --git a/src/cursors/row.rs b/src/cursors/row.rs index f66269bf..e7ef537d 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -1,5 +1,4 @@ use crate::row_metadata::RowMetadata; -use crate::validation_mode::ValidationMode; use crate::{ bytes_ext::BytesExt, cursors::RawCursor, @@ -17,24 +16,21 @@ use std::marker::PhantomData; pub struct RowCursor { raw: RawCursor, bytes: BytesExt, + validation: bool, /// [`None`] until the first call to [`RowCursor::next()`], /// as [`RowCursor::new`] is not `async`, so it loads lazily. row_metadata: Option, - rows_to_validate: u64, _marker: PhantomData, } impl RowCursor { - pub(crate) fn new(response: Response, validation_mode: ValidationMode) -> Self { + pub(crate) fn new(response: Response, validation: bool) -> Self { Self { _marker: PhantomData, raw: RawCursor::new(response), bytes: BytesExt::default(), row_metadata: None, - rows_to_validate: match validation_mode { - ValidationMode::First(n) => n as u64, - ValidationMode::Each => u64::MAX, - }, + validation, } } @@ -54,7 +50,9 @@ impl RowCursor { return Ok(()); } Ok(_) => { - // TODO: or panic instead? + // This does not panic, as it could be a network issue + // or a malformed response from the server or LB, + // and a simple retry might help in certain cases. return Err(Error::BadResponse( "Expected at least one column in the header".to_string(), )); @@ -68,6 +66,7 @@ impl RowCursor { match self.raw.next().await? { Some(chunk) => self.bytes.extend(chunk), None if self.row_metadata.is_none() => { + // Similar to the other BadResponse branch above return Err(Error::BadResponse( "Could not read columns header".to_string(), )); @@ -91,35 +90,27 @@ impl RowCursor { { loop { if self.bytes.remaining() > 0 { - if self.row_metadata.is_none() { - self.read_columns().await?; - if self.bytes.remaining() == 0 { - continue; - } - } - let mut slice = super::workaround_51132(self.bytes.slice()); - let (result, not_enough_data) = match self.rows_to_validate { - 0 => rowbinary::deserialize_from::(&mut slice, None), - u64::MAX => { - rowbinary::deserialize_from::(&mut slice, self.row_metadata.as_ref()) - } - _ => { - let result = rowbinary::deserialize_from::( - &mut slice, - self.row_metadata.as_ref(), - ); - self.rows_to_validate -= 1; - result + let mut slice: &[u8]; + let result = if self.validation { + if self.row_metadata.is_none() { + self.read_columns().await?; + if self.bytes.remaining() == 0 { + continue; + } } + slice = super::workaround_51132(self.bytes.slice()); + rowbinary::deserialize_rbwnat::(&mut slice, self.row_metadata.as_ref()) + } else { + slice = super::workaround_51132(self.bytes.slice()); + rowbinary::deserialize_row_binary::(&mut slice) }; - if !not_enough_data { - return match result { - Ok(value) => { - self.bytes.set_remaining(slice.len()); - Ok(Some(value)) - } - Err(err) => Err(err), - }; + match result { + Err(Error::NotEnoughData) => {} + Ok(value) => { + self.bytes.set_remaining(slice.len()); + return Ok(Some(value)); + } + Err(err) => return Err(err), } } diff --git a/src/lib.rs b/src/lib.rs index d3760a89..b4d7ebbe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,10 +5,8 @@ #[macro_use] extern crate static_assertions; -#[cfg(feature = "test-util")] -pub use self::row_metadata::RowMetadata; pub use self::{compression::Compression, row::Row, row::RowKind}; -use self::{error::Result, http_client::HttpClient, validation_mode::ValidationMode}; +use self::{error::Result, http_client::HttpClient}; pub use clickhouse_derive::Row; use std::{collections::HashMap, fmt::Display, sync::Arc}; @@ -21,7 +19,6 @@ pub mod serde; pub mod sql; #[cfg(feature = "test-util")] pub mod test; -pub mod validation_mode; #[cfg(feature = "watch")] pub mod watch; @@ -50,7 +47,7 @@ pub struct Client { options: HashMap, headers: HashMap, products_info: Vec, - validation_mode: ValidationMode, + validation: bool, } #[derive(Clone)] @@ -105,7 +102,7 @@ impl Client { options: HashMap::new(), headers: HashMap::new(), products_info: Vec::default(), - validation_mode: ValidationMode::default(), + validation: true, } } @@ -299,15 +296,6 @@ impl Client { self } - /// Specifies the struct validation mode that will be used when calling - /// [`query::Query::fetch`], [`query::Query::fetch_one`], [`query::Query::fetch_all`], - /// and [`query::Query::fetch_optional`] methods. - /// See [`ValidationMode`] for more details. - pub fn with_validation_mode(mut self, mode: ValidationMode) -> Self { - self.validation_mode = mode; - self - } - /// Starts a new INSERT statement. /// /// # Panics @@ -336,6 +324,22 @@ impl Client { watch::Watch::new(self, query) } + /// Disables [`Row`] types validation against the database schema. + /// Validation is enabled by default. + /// + /// # Warning + /// + /// While disabled validation will result in increased performance, + /// this mode is intended to be used for testing purposes only, + /// and only in scenarios where schema mismatch issues are irrelevant. + /// + /// ***DO NOT*** disable validation in your production code or tests + /// unless you are 100% sure why you are doing it. + pub fn with_disabled_validation(mut self) -> Self { + self.validation = false; + self + } + /// Used internally to modify the options map of an _already cloned_ /// [`Client`] instance. pub(crate) fn add_option(&mut self, name: impl Into, value: impl Into) { @@ -355,7 +359,6 @@ pub mod _priv { #[cfg(test)] mod client_tests { - use crate::validation_mode::ValidationMode; use crate::{Authentication, Client}; #[test] @@ -477,10 +480,10 @@ mod client_tests { #[test] fn it_sets_validation_mode() { let client = Client::default(); - assert_eq!(client.validation_mode, ValidationMode::First(1)); - let client = client.with_validation_mode(ValidationMode::Each); - assert_eq!(client.validation_mode, ValidationMode::Each); - let client = client.with_validation_mode(ValidationMode::First(10)); - assert_eq!(client.validation_mode, ValidationMode::First(10)); + assert!(client.validation); + let client = client.with_disabled_validation(); + assert!(!client.validation); + let client = client.with_disabled_validation(); + assert!(!client.validation); } } diff --git a/src/query.rs b/src/query.rs index 2a1036fa..9c1ff04f 100644 --- a/src/query.rs +++ b/src/query.rs @@ -84,13 +84,21 @@ impl Query { /// # Ok(()) } /// ``` pub fn fetch(mut self) -> Result> { - let validation_mode = self.client.validation_mode; - self.sql.bind_fields::(); - self.sql.set_output_format("RowBinaryWithNamesAndTypes"); + + let validation = self.client.validation; + if validation { + self.sql.set_output_format("RowBinaryWithNamesAndTypes"); + } else { + self.sql.set_output_format("RowBinary"); + } let response = self.do_execute(true)?; - Ok(RowCursor::new(response, validation_mode)) + + // #[cfg(feature = "test_util")] + // if response.headers + + Ok(RowCursor::new(response, validation)) } /// Executes the query and returns just a single row. diff --git a/src/row_metadata.rs b/src/row_metadata.rs index 67047fc2..6dc52518 100644 --- a/src/row_metadata.rs +++ b/src/row_metadata.rs @@ -20,7 +20,7 @@ type LockedRowMetadataCache = RwLock>>; static ROW_METADATA_CACHE: OnceCell = OnceCell::const_new(); #[derive(Debug, PartialEq)] -enum AccessType { +pub(crate) enum AccessType { WithSeqAccess, WithMapAccess(Vec), } @@ -28,7 +28,7 @@ enum AccessType { /// [`RowMetadata`] should be owned outside the (de)serializer, /// as it is calculated only once per struct. It does not have lifetimes, /// so it does not introduce a breaking change to [`crate::cursors::RowCursor`]. -pub struct RowMetadata { +pub(crate) struct RowMetadata { /// See [`Row::NAME`] pub(crate) name: &'static str, /// See [`Row::TYPE`] @@ -43,8 +43,7 @@ pub struct RowMetadata { } impl RowMetadata { - // FIXME: perhaps it should not be public? But it is required for mocks/provide. - pub fn new(columns: Vec) -> Self { + pub(crate) fn new(columns: Vec) -> Self { let access_type = match T::KIND { RowKind::Primitive => { if columns.len() != 1 { diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 177d1854..3fc319bd 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -24,24 +24,20 @@ use std::{convert::TryFrom, str}; /// It expects a slice of [`Column`] objects parsed /// from the beginning of `RowBinaryWithNamesAndTypes` data stream. /// After the header, the rows format is the same as `RowBinary`. -pub(crate) fn deserialize_from<'data, 'cursor, T: Deserialize<'data>>( +pub(crate) fn deserialize_row_binary<'data, 'cursor, T: Deserialize<'data>>( + input: &mut &'data [u8], +) -> Result { + let mut deserializer = RowBinaryDeserializer::new(input, ()); + T::deserialize(&mut deserializer) +} + +pub(crate) fn deserialize_rbwnat<'data, 'cursor, T: Deserialize<'data>>( input: &mut &'data [u8], metadata: Option<&'cursor RowMetadata>, -) -> (Result, bool) { - let result = if metadata.is_none() { - let mut deserializer = RowBinaryDeserializer::new(input, ()); - T::deserialize(&mut deserializer) - } else { - let validator = DataTypeValidator::new(metadata.unwrap()); - let mut deserializer = RowBinaryDeserializer::new(input, validator); - T::deserialize(&mut deserializer) - }; - // an explicit hint about NotEnoughData error boosts RowCursor performance ~20% - match result { - Ok(value) => (Ok(value), false), - Err(Error::NotEnoughData) => (Err(Error::NotEnoughData), true), - Err(e) => (Err(e), false), - } +) -> Result { + let validator = DataTypeValidator::new(metadata.unwrap()); + let mut deserializer = RowBinaryDeserializer::new(input, validator); + T::deserialize(&mut deserializer) } /// A deserializer for the `RowBinary(WithNamesAndTypes)` format. @@ -85,7 +81,9 @@ macro_rules! impl_num { ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr) => { #[inline(always)] fn $deser_method>(self, visitor: V) -> Result { - self.validator.validate($serde_type); + if Validator::VALIDATION { + self.validator.validate($serde_type); + } ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; let value = self.input.$reader_method(); visitor.$visitor_method(value) @@ -93,54 +91,49 @@ macro_rules! impl_num { }; } +macro_rules! impl_num_or_enum { + ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr) => { + #[inline(always)] + fn $deser_method>(self, visitor: V) -> Result { + if Validator::VALIDATION { + let mut maybe_enum_validator = self.validator.validate($serde_type); + ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; + let value = self.input.$reader_method(); + maybe_enum_validator.validate_identifier::<$ty>(value); + visitor.$visitor_method(value) + } else { + ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; + visitor.$visitor_method(self.input.$reader_method()) + } + } + }; +} + impl<'data, Validator> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data, Validator> where Validator: SchemaValidator, { type Error = Error; - #[inline(always)] - fn deserialize_i8>(self, visitor: V) -> Result { - let mut maybe_enum_validator = self.validator.validate(SerdeType::I8); - ensure_size(&mut self.input, size_of::())?; - let value = self.input.get_i8(); - maybe_enum_validator.validate_enum8_value(value); - visitor.visit_i8(value) - } - - #[inline(always)] - fn deserialize_i16>(self, visitor: V) -> Result { - let mut maybe_enum_validator = self.validator.validate(SerdeType::I16); - ensure_size(&mut self.input, size_of::())?; - let value = self.input.get_i16_le(); - // TODO: is there a better way to validate that the deserialized value matches the schema? - maybe_enum_validator.validate_enum16_value(value); - visitor.visit_i16(value) - } + impl_num_or_enum!(i8, deserialize_i8, visit_i8, get_i8, SerdeType::I8); + impl_num_or_enum!(i16, deserialize_i16, visit_i16, get_i16_le, SerdeType::I16); impl_num!(i32, deserialize_i32, visit_i32, get_i32_le, SerdeType::I32); impl_num!(i64, deserialize_i64, visit_i64, get_i64_le, SerdeType::I64); - impl_num!( - i128, - deserialize_i128, - visit_i128, - get_i128_le, - SerdeType::I128 - ); + impl_num!(u8, deserialize_u8, visit_u8, get_u8, SerdeType::U8); impl_num!(u16, deserialize_u16, visit_u16, get_u16_le, SerdeType::U16); impl_num!(u32, deserialize_u32, visit_u32, get_u32_le, SerdeType::U32); impl_num!(u64, deserialize_u64, visit_u64, get_u64_le, SerdeType::U64); - impl_num!( - u128, - deserialize_u128, - visit_u128, - get_u128_le, - SerdeType::U128 - ); + impl_num!(f32, deserialize_f32, visit_f32, get_f32_le, SerdeType::F32); impl_num!(f64, deserialize_f64, visit_f64, get_f64_le, SerdeType::F64); + #[rustfmt::skip] + impl_num!(i128, deserialize_i128, visit_i128, get_i128_le, SerdeType::I128); + #[rustfmt::skip] + impl_num!(u128, deserialize_u128, visit_u128, get_u128_le, SerdeType::U128); + #[inline(always)] fn deserialize_any>(self, _: V) -> Result { Err(Error::DeserializeAnyNotSupported) @@ -155,7 +148,9 @@ where #[inline(always)] fn deserialize_bool>(self, visitor: V) -> Result { - self.validator.validate(SerdeType::Bool); + if Validator::VALIDATION { + self.validator.validate(SerdeType::Bool); + } ensure_size(&mut self.input, 1)?; match self.input.get_u8() { 0 => visitor.visit_bool(false), @@ -166,7 +161,9 @@ where #[inline(always)] fn deserialize_str>(self, visitor: V) -> Result { - self.validator.validate(SerdeType::Str); + if Validator::VALIDATION { + self.validator.validate(SerdeType::Str); + } let size = self.read_size()?; let slice = self.read_slice(size)?; let str = str::from_utf8(slice).map_err(Error::from)?; @@ -175,7 +172,9 @@ where #[inline(always)] fn deserialize_string>(self, visitor: V) -> Result { - self.validator.validate(SerdeType::String); + if Validator::VALIDATION { + self.validator.validate(SerdeType::String); + } let size = self.read_size()?; let vec = self.read_vec(size)?; let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; @@ -185,7 +184,9 @@ where #[inline(always)] fn deserialize_bytes>(self, visitor: V) -> Result { let size = self.read_size()?; - self.validator.validate(SerdeType::Bytes(size)); + if Validator::VALIDATION { + self.validator.validate(SerdeType::Bytes(size)); + } let slice = self.read_slice(size)?; visitor.visit_borrowed_bytes(slice) } @@ -193,16 +194,25 @@ where #[inline(always)] fn deserialize_byte_buf>(self, visitor: V) -> Result { let size = self.read_size()?; - self.validator.validate(SerdeType::ByteBuf(size)); + if Validator::VALIDATION { + self.validator.validate(SerdeType::ByteBuf(size)); + } visitor.visit_byte_buf(self.read_vec(size)?) } + /// This is used to deserialize identifiers for either: + /// - `Variant` data type + /// - [`RowBinaryStructAsMapAccess`] field. #[inline(always)] fn deserialize_identifier>(self, visitor: V) -> Result { ensure_size(&mut self.input, size_of::())?; let value = self.input.get_u8(); // TODO: is there a better way to validate that the deserialized value matches the schema? - self.validator.set_next_variant_value(value); + if Validator::VALIDATION { + // TODO: theoretically, we can track if we are currently processing a struct field id, + // and don't call the validator in that case, cause it will never be a `Variant`. + self.validator.validate_identifier::(value); + } visitor.visit_u8(value) } @@ -213,67 +223,98 @@ where _variants: &'static [&'static str], visitor: V, ) -> Result { - let validator = self.validator.validate(SerdeType::Enum); - visitor.visit_enum(RowBinaryEnumAccess { - deserializer: &mut RowBinaryDeserializer { - input: self.input, - validator, - }, - }) + if Validator::VALIDATION { + visitor.visit_enum(RowBinaryEnumAccess { + deserializer: &mut RowBinaryDeserializer { + input: self.input, + validator: self.validator.validate(SerdeType::Enum), + }, + }) + } else { + visitor.visit_enum(RowBinaryEnumAccess { deserializer: self }) + } } #[inline(always)] fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { - let validator = self.validator.validate(SerdeType::Tuple(len)); - let mut de = RowBinaryDeserializer { - input: self.input, - validator, - }; - let access = RowBinarySeqAccess { - deserializer: &mut de, - len, - }; - visitor.visit_seq(access) + if Validator::VALIDATION { + visitor.visit_seq(RowBinarySeqAccess { + deserializer: &mut RowBinaryDeserializer { + input: self.input, + validator: self.validator.validate(SerdeType::Tuple(len)), + }, + len, + }) + } else { + visitor.visit_seq(RowBinarySeqAccess { + deserializer: self, + len, + }) + } } #[inline(always)] fn deserialize_option>(self, visitor: V) -> Result { ensure_size(&mut self.input, 1)?; - let inner_validator = self.validator.validate(SerdeType::Option); - match self.input.get_u8() { - 0 => visitor.visit_some(&mut RowBinaryDeserializer { - input: self.input, - validator: inner_validator, - }), - 1 => visitor.visit_none(), - v => Err(Error::InvalidTagEncoding(v as usize)), + let is_null = self.input.get_u8(); + if Validator::VALIDATION { + let inner_validator = self.validator.validate(SerdeType::Option); + match is_null { + 0 => visitor.visit_some(&mut RowBinaryDeserializer { + input: self.input, + validator: inner_validator, + }), + 1 => visitor.visit_none(), + v => Err(Error::InvalidTagEncoding(v as usize)), + } + } else { + // a bit of copy-paste here, since Deserializer types are not exactly the same + match is_null { + 0 => visitor.visit_some(self), + 1 => visitor.visit_none(), + v => Err(Error::InvalidTagEncoding(v as usize)), + } } } #[inline(always)] fn deserialize_seq>(self, visitor: V) -> Result { let len = self.read_size()?; - visitor.visit_seq(RowBinarySeqAccess { - deserializer: &mut RowBinaryDeserializer { - input: self.input, - validator: self.validator.validate(SerdeType::Seq(len)), - }, - len, - }) + if Validator::VALIDATION { + visitor.visit_seq(RowBinarySeqAccess { + deserializer: &mut RowBinaryDeserializer { + input: self.input, + validator: self.validator.validate(SerdeType::Seq(len)), + }, + len, + }) + } else { + visitor.visit_seq(RowBinarySeqAccess { + deserializer: self, + len, + }) + } } #[inline(always)] fn deserialize_map>(self, visitor: V) -> Result { let len = self.read_size()?; - let validator = self.validator.validate(SerdeType::Map(len)); - visitor.visit_map(RowBinaryMapAccess { - deserializer: &mut RowBinaryDeserializer { - input: self.input, - validator, - }, - entries_visited: 0, - len, - }) + if Validator::VALIDATION { + visitor.visit_map(RowBinaryMapAccess { + deserializer: &mut RowBinaryDeserializer { + input: self.input, + validator: self.validator.validate(SerdeType::Map(len)), + }, + entries_visited: 0, + len, + }) + } else { + visitor.visit_map(RowBinaryMapAccess { + deserializer: self, + entries_visited: 0, + len, + }) + } } #[inline(always)] @@ -283,17 +324,25 @@ where fields: &'static [&'static str], visitor: V, ) -> Result { - if !self.validator.is_field_order_wrong() { + if Validator::VALIDATION { + if !self.validator.is_field_order_wrong() { + visitor.visit_seq(RowBinarySeqAccess { + deserializer: self, + len: fields.len(), + }) + } else { + visitor.visit_map(RowBinaryStructAsMapAccess { + deserializer: self, + current_field_idx: 0, + fields, + }) + } + } else { + // We can't detect incorrect field order with just plain `RowBinary` format visitor.visit_seq(RowBinarySeqAccess { deserializer: self, len: fields.len(), }) - } else { - visitor.visit_map(RowBinaryStructAsMapAccess { - deserializer: self, - current_field_idx: 0, - fields, - }) } } diff --git a/src/rowbinary/mod.rs b/src/rowbinary/mod.rs index a465a2cc..b25147e2 100644 --- a/src/rowbinary/mod.rs +++ b/src/rowbinary/mod.rs @@ -1,4 +1,5 @@ -pub(crate) use de::deserialize_from; +pub(crate) use de::deserialize_rbwnat; +pub(crate) use de::deserialize_row_binary; pub(crate) use ser::serialize_into; pub(crate) mod validation; diff --git a/src/rowbinary/tests.rs b/src/rowbinary/tests.rs index 44fd7d62..f11c0850 100644 --- a/src/rowbinary/tests.rs +++ b/src/rowbinary/tests.rs @@ -122,12 +122,10 @@ fn it_deserializes() { let (mut left, mut right) = input.split_at(i); // It shouldn't panic. - let _: Result, _> = super::deserialize_from(&mut left, None).0; - let _: Result, _> = super::deserialize_from(&mut right, None).0; + let _: Result, _> = super::deserialize_row_binary(&mut left); + let _: Result, _> = super::deserialize_row_binary(&mut right); - let actual: Sample<'_> = super::deserialize_from(&mut input.as_slice(), None) - .0 - .unwrap(); + let actual: Sample<'_> = super::deserialize_row_binary(&mut input.as_slice()).unwrap(); assert_eq!(actual, sample()); } } diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index 4e2c0ce9..a179700f 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -4,13 +4,29 @@ use clickhouse_types::data_types::{Column, DataTypeNode, DecimalType, EnumType}; use std::collections::HashMap; use std::fmt::Display; +/// This trait is used to validate the schema of a [`crate::Row`] against the parsed RBWNAT schema. +/// Note that [`SchemaValidator`] is also implemented for `()`, +/// which is used to skip validation if the user disabled it. pub(crate) trait SchemaValidator: Sized { + /// Ensures that the branching is completely optimized out based on the validation settings. + const VALIDATION: bool; + /// The main entry point. The validation flow based on the [`crate::Row::KIND`]. + /// For container types (nullable, array, map, tuple, variant, etc.), + /// it will return an [`InnerDataTypeValidator`] instance (see [`InnerDataTypeValidatorKind`]), + /// which has its own implementation of this method, allowing recursive validation. fn validate(&'_ mut self, serde_type: SerdeType) -> Option>; - fn validate_enum8_value(&mut self, value: i8); - fn validate_enum16_value(&mut self, value: i16); - fn set_next_variant_value(&mut self, value: u8); - fn get_schema_index(&self, struct_idx: usize) -> usize; + /// Validates that an identifier exists in the values map for enums, + /// or stores the variant identifier for the next serde call. + fn validate_identifier(&mut self, value: T); + /// Having the database schema from RBWNAT, the crate can detect that + /// while the field names and the types are correct, the field order in the struct + /// does not match the column order in the database schema, and we should use + /// `MapAccess` instead of `SeqAccess` to seamlessly deserialize the struct. fn is_field_order_wrong(&self) -> bool; + /// Returns the "restored" index of the schema column for the given struct field index. + /// It is used only if the crate detects that while the field names and the types are correct, + /// the field order in the struct does not match the column order in the database schema. + fn get_schema_index(&self, struct_idx: usize) -> usize; } pub(crate) struct DataTypeValidator<'cursor> { @@ -94,6 +110,8 @@ impl<'cursor> DataTypeValidator<'cursor> { } impl SchemaValidator for DataTypeValidator<'_> { + const VALIDATION: bool = true; + #[inline] fn validate(&'_ mut self, serde_type: SerdeType) -> Option> { match self.metadata.kind { @@ -166,20 +184,7 @@ impl SchemaValidator for DataTypeValidator<'_> { } #[cold] - #[inline(never)] - fn validate_enum8_value(&mut self, _value: i8) { - unreachable!() - } - - #[cold] - #[inline(never)] - fn validate_enum16_value(&mut self, _value: i16) { - unreachable!() - } - - #[cold] - #[inline(never)] - fn set_next_variant_value(&mut self, _value: u8) { + fn validate_identifier(&mut self, _value: T) { unreachable!() } } @@ -234,6 +239,8 @@ pub(crate) enum VariantValidationState { } impl<'de, 'cursor> SchemaValidator for Option> { + const VALIDATION: bool = true; + #[inline] fn validate(&mut self, serde_type: SerdeType) -> Option> { match self { @@ -342,49 +349,37 @@ impl<'de, 'cursor> SchemaValidator for Option(&mut self, value: T) { + use InnerDataTypeValidatorKind::{Enum, Variant}; if let Some(inner) = self { - if let InnerDataTypeValidatorKind::Enum(values_map) = &inner.kind { - if !values_map.contains_key(&value) { - let (full_name, full_data_type) = inner.root.get_current_column_name_and_type(); - panic!( - "While processing column {full_name} defined as {full_data_type}: \ - Enum16 value {value} is not present in the database schema" - ); + match T::IDENTIFIER_TYPE { + IdentifierType::Enum8 | IdentifierType::Enum16 => { + if let Enum(values_map) = &inner.kind { + if !values_map.contains_key(&(value.into_i16())) { + let (full_name, full_data_type) = + inner.root.get_current_column_name_and_type(); + panic!( + "While processing column {full_name} defined as {full_data_type}: \ + Enum8 value {value} is not present in the database schema" + ); + } + } } - } - } - } - - #[inline(always)] - fn set_next_variant_value(&mut self, value: u8) { - if let Some(inner) = self { - if let InnerDataTypeValidatorKind::Variant(possible_types, state) = &mut inner.kind { - if (value as usize) < possible_types.len() { - *state = VariantValidationState::Identifier(value); - } else { - let (full_name, full_data_type) = inner.root.get_current_column_name_and_type(); - panic!( - "While processing column {full_name} defined as {full_data_type}: \ - Variant identifier {value} is out of bounds, max allowed index is {}", - possible_types.len() - 1 - ); + IdentifierType::Variant => { + if let Variant(possible_types, state) = &mut inner.kind { + // ClickHouse guarantees max 255 variants, i.e. the same max value as u8 + if value.into_u8() < (possible_types.len() as u8) { + *state = VariantValidationState::Identifier(value.into_u8()); + } else { + let (full_name, full_data_type) = + inner.root.get_current_column_name_and_type(); + panic!( + "While processing column {full_name} defined as {full_data_type}: \ + Variant identifier {value} is out of bounds, max allowed index is {}", + possible_types.len() - 1 + ); + } + } } } } @@ -395,6 +390,7 @@ impl<'de, 'cursor> SchemaValidator for Option usize { unreachable!() } @@ -421,89 +417,6 @@ impl Drop for InnerDataTypeValidator<'_, '_> { } } -// #[inline] -// fn simple_types_impl<'de, 'cursor>( -// root: &'de DataTypeValidator<'cursor>, -// data_type: &'cursor DataTypeNode, -// serde_type: &SerdeType, -// is_inner: bool, -// ) { -// match serde_type { -// SerdeType::Bool -// if data_type == &DataTypeNode::Bool || data_type == &DataTypeNode::UInt8 => -// { -// None -// } -// SerdeType::I8 => match data_type { -// DataTypeNode::Int8 => None, -// DataTypeNode::Enum(EnumType::Enum8, values_map) => Some(InnerDataTypeValidator { -// root, -// kind: InnerDataTypeValidatorKind::Enum(values_map), -// })), -// _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), -// }, -// SerdeType::I16 => match data_type { -// DataTypeNode::Int16 => None, -// DataTypeNode::Enum(EnumType::Enum16, values_map) => Some(InnerDataTypeValidator { -// root, -// kind: InnerDataTypeValidatorKind::Enum(values_map), -// })), -// _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), -// }, -// SerdeType::I32 -// if data_type == &DataTypeNode::Int32 -// || data_type == &DataTypeNode::Date32 -// || matches!( -// data_type, -// DataTypeNode::Decimal(_, _, DecimalType::Decimal32) -// ) => -// { -// None -// } -// SerdeType::I64 -// if data_type == &DataTypeNode::Int64 -// || matches!(data_type, DataTypeNode::DateTime64(_, _)) -// || matches!( -// data_type, -// DataTypeNode::Decimal(_, _, DecimalType::Decimal64) -// ) => -// { -// None -// } -// SerdeType::I128 -// if data_type == &DataTypeNode::Int128 -// || matches!( -// data_type, -// DataTypeNode::Decimal(_, _, DecimalType::Decimal128) -// ) => -// { -// None -// } -// SerdeType::U8 if data_type == &DataTypeNode::UInt8 => None, -// SerdeType::U16 -// if data_type == &DataTypeNode::UInt16 || data_type == &DataTypeNode::Date => -// { -// None -// } -// SerdeType::U32 -// if data_type == &DataTypeNode::UInt32 -// || matches!(data_type, DataTypeNode::DateTime(_)) -// || data_type == &DataTypeNode::IPv4 => -// { -// None -// } -// SerdeType::U64 if data_type == &DataTypeNode::UInt64 => None, -// SerdeType::U128 if data_type == &DataTypeNode::UInt128 => None, -// SerdeType::F32 if data_type == &DataTypeNode::Float32 => None, -// SerdeType::F64 if data_type == &DataTypeNode::Float64 => None, -// SerdeType::Str | SerdeType::String -// if data_type == &DataTypeNode::String || data_type == &DataTypeNode::JSON => -// { -// None -// } -// } -// } - // TODO: is there a way to eliminate multiple branches with similar patterns? // static/const dispatch? // separate smaller inline functions? @@ -712,25 +625,22 @@ fn validate_impl<'de, 'cursor>( } impl SchemaValidator for () { + const VALIDATION: bool = false; + #[inline(always)] fn validate(&mut self, _serde_type: SerdeType) -> Option> { None } - #[inline(always)] - fn validate_enum8_value(&mut self, _value: i8) {} - - #[inline(always)] - fn validate_enum16_value(&mut self, _value: i16) {} - - #[inline(always)] - fn set_next_variant_value(&mut self, _value: u8) {} - #[inline(always)] fn is_field_order_wrong(&self) -> bool { false } + #[inline(always)] + fn validate_identifier(&mut self, _value: T) {} + + #[cold] fn get_schema_index(&self, _struct_idx: usize) -> usize { unreachable!() } @@ -810,5 +720,53 @@ impl Display for SerdeType { } } +#[derive(Debug)] +pub(crate) enum IdentifierType { + Enum8, + Enum16, + Variant, +} +pub(crate) trait EnumOrVariantIdentifier: Display + Copy { + const IDENTIFIER_TYPE: IdentifierType; + fn into_u8(self) -> u8; + fn into_i16(self) -> i16; +} +impl EnumOrVariantIdentifier for u8 { + const IDENTIFIER_TYPE: IdentifierType = IdentifierType::Variant; + // none of these should be ever called + #[inline(always)] + fn into_u8(self) -> u8 { + self + } + #[inline(always)] + fn into_i16(self) -> i16 { + self as i16 + } +} +impl EnumOrVariantIdentifier for i8 { + const IDENTIFIER_TYPE: IdentifierType = IdentifierType::Enum8; + #[inline(always)] + fn into_i16(self) -> i16 { + self as i16 + } + // we need only i16 for enum values HashMap + #[inline(always)] + fn into_u8(self) -> u8 { + self as u8 + } +} +impl EnumOrVariantIdentifier for i16 { + const IDENTIFIER_TYPE: IdentifierType = IdentifierType::Enum16; + #[inline(always)] + fn into_i16(self) -> i16 { + self + } + // should not be ever called + #[inline(always)] + fn into_u8(self) -> u8 { + self as u8 + } +} + const UUID_TUPLE_ELEMENTS: &[DataTypeNode; 2] = &[DataTypeNode::UInt64, DataTypeNode::UInt64]; const POINT_TUPLE_ELEMENTS: &[DataTypeNode; 2] = &[DataTypeNode::Float64, DataTypeNode::Float64]; diff --git a/src/test/handlers.rs b/src/test/handlers.rs index 1e514110..fd5edc11 100644 --- a/src/test/handlers.rs +++ b/src/test/handlers.rs @@ -1,14 +1,12 @@ use std::marker::PhantomData; use bytes::Bytes; -use clickhouse_types::put_rbwnat_columns_header; use futures::channel::oneshot; use hyper::{Request, Response, StatusCode}; use sealed::sealed; use serde::{Deserialize, Serialize}; use super::{Handler, HandlerFn}; -use crate::row_metadata::RowMetadata; use crate::rowbinary; const BUFFER_INITIAL_CAPACITY: usize = 1024; @@ -42,13 +40,11 @@ pub fn failure(status: StatusCode) -> impl Handler { // === provide === #[track_caller] -pub fn provide(row_metadata: &RowMetadata, rows: impl IntoIterator) -> impl Handler +pub fn provide(rows: impl IntoIterator) -> impl Handler where T: Serialize, { let mut buffer = Vec::with_capacity(BUFFER_INITIAL_CAPACITY); - put_rbwnat_columns_header(&row_metadata.columns, &mut buffer) - .expect("failed to write columns header"); for row in rows { rowbinary::serialize_into(&mut buffer, &row).expect("failed to serialize"); } @@ -97,8 +93,8 @@ where let mut result = C::default(); while !slice.is_empty() { - let (de_result, _) = rowbinary::deserialize_from(slice, None); - let row: T = de_result.expect("failed to deserialize"); + let res = rowbinary::deserialize_row_binary(slice); + let row: T = res.expect("failed to deserialize"); result.extend(std::iter::once(row)); } diff --git a/src/validation_mode.rs b/src/validation_mode.rs deleted file mode 100644 index a76d7d1e..00000000 --- a/src/validation_mode.rs +++ /dev/null @@ -1,34 +0,0 @@ -/// The preferred mode of validation for struct (de)serialization. -/// It also affects which format is used by the client when sending queries. -/// -/// - [`ValidationMode::First`] enables validation _only for the first `N` rows_ -/// emitted by a cursor. For the following rows, validation is skipped. -/// Format: `RowBinaryWithNamesAndTypes`. -/// - [`ValidationMode::Each`] enables validation _for all rows_ emitted by a cursor. -/// This is the slowest mode. Format: `RowBinaryWithNamesAndTypes`. -/// -/// # Default -/// -/// By default, [`ValidationMode::First`] with value `1` is used, -/// meaning that only the first row will be validated against the database schema, -/// which is extracted from the `RowBinaryWithNamesAndTypes` format header. -/// It is done to minimize the performance impact of the validation, -/// while still providing reasonable safety guarantees by default. -/// -/// While it is expected that the default validation mode is sufficient for most use cases, -/// in certain corner case scenarios there still can be schema mismatches after the first rows, -/// e.g., when a field is `Nullable(T)`, and the first value is `NULL`. In that case, -/// consider increasing the number of rows in [`ValidationMode::First`], -/// or even using [`ValidationMode::Each`] instead. -#[non_exhaustive] -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum ValidationMode { - First(usize), - Each, -} - -impl Default for ValidationMode { - fn default() -> Self { - Self::First(1) - } -} diff --git a/tests/it/main.rs b/tests/it/main.rs index f51ccb3e..04153fbd 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -43,7 +43,7 @@ macro_rules! assert_panic_on_fetch_with_client { macro_rules! assert_panic_on_fetch { ($msg_parts:expr, $query:expr) => { use futures::FutureExt; - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let async_panic = std::panic::AssertUnwindSafe(async { client.query($query).fetch_all::().await }); let result = async_panic.catch_unwind().await; diff --git a/tests/it/mock.rs b/tests/it/mock.rs index 2f3a5659..a6c72410 100644 --- a/tests/it/mock.rs +++ b/tests/it/mock.rs @@ -2,21 +2,16 @@ use crate::SimpleRow; use clickhouse::{test, Client}; -use clickhouse_types::data_types::Column; -use clickhouse_types::DataTypeNode; use std::time::Duration; async fn test_provide() { let mock = test::Mock::new(); - let client = Client::default().with_url(mock.url()); + let client = Client::default() + .with_url(mock.url()) + .with_disabled_validation(); let expected = vec![SimpleRow::new(1, "one"), SimpleRow::new(2, "two")]; - let columns = vec![ - Column::new("id".to_string(), DataTypeNode::UInt64), - Column::new("data".to_string(), DataTypeNode::String), - ]; - let metadata = clickhouse::RowMetadata::new::(columns); - mock.add(test::handlers::provide(&metadata, &expected)); + mock.add(test::handlers::provide(&expected)); let actual = crate::fetch_rows::(&client, "doesn't matter").await; assert_eq!(actual, expected); diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index fc46ea8e..21c0081a 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -1,6 +1,5 @@ use crate::{execute_statements, get_client}; use clickhouse::sql::Identifier; -use clickhouse::validation_mode::ValidationMode; use clickhouse_derive::Row; use clickhouse_types::data_types::{Column, DataTypeNode}; use clickhouse_types::parse_rbwnat_columns_header; @@ -105,7 +104,7 @@ async fn test_header_parsing() { #[tokio::test] async fn test_fetch_primitive_row() { - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query("SELECT count() FROM (SELECT * FROM system.numbers LIMIT 3)") .fetch_one::() @@ -124,7 +123,7 @@ async fn test_fetch_primitive_row_schema_mismatch() { #[tokio::test] async fn test_fetch_vector_row() { - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query("SELECT [1, 2, 3] :: Array(UInt32)") .fetch_one::>() @@ -143,7 +142,7 @@ async fn test_fetch_vector_row_schema_mismatch_nested_type() { #[tokio::test] async fn test_fetch_tuple_row() { - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query("SELECT 42 :: UInt32 AS a, 'foo' :: String AS b") .fetch_one::<(u32, String)>() @@ -201,7 +200,7 @@ async fn test_fetch_tuple_row_with_struct() { b: String, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query("SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c") .fetch_one::<(Data, u64)>() @@ -307,7 +306,7 @@ async fn test_basic_types() { string_val: String, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -368,7 +367,7 @@ async fn test_borrowed_data() { hash_map_nested: HashMap<&'a str, HashMap<&'a str, &'a str>>, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let mut cursor = client .query( " @@ -454,7 +453,7 @@ async fn test_several_simple_rows() { str: String, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query("SELECT number AS num, toString(number) AS str FROM system.numbers LIMIT 3") .fetch_all::() @@ -486,7 +485,7 @@ async fn test_many_numbers() { number: u64, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let mut cursor = client .query("SELECT number FROM system.numbers_mt LIMIT 2000") .fetch::() @@ -507,7 +506,7 @@ async fn test_blob_string_with_serde_bytes() { blob: Vec, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query("SELECT 'foo' :: String AS blob") .fetch_one::() @@ -532,7 +531,7 @@ async fn test_arrays() { description: String, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -571,7 +570,7 @@ async fn test_maps() { m2: HashMap>, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -621,7 +620,7 @@ async fn test_map_as_vec_of_tuples() { m2: Vec<(u16, Vec<(String, i32)>)>, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -719,7 +718,7 @@ async fn test_enum() { let table_name = "test_rbwnat_enum"; - let client = prepare_database!().with_validation_mode(ValidationMode::Each); + let client = prepare_database!(); client .query( " @@ -783,7 +782,7 @@ async fn test_nullable() { b: Option, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -830,7 +829,7 @@ async fn test_low_cardinality() { b: Option, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -864,9 +863,7 @@ async fn test_invalid_low_cardinality() { struct Data { a: u32, } - let client = get_client() - .with_validation_mode(ValidationMode::Each) - .with_option("allow_suspicious_low_cardinality_types", "1"); + let client = get_client().with_option("allow_suspicious_low_cardinality_types", "1"); assert_panic_on_fetch_with_client!( client, &["Data.a", "LowCardinality(Int32)", "u32"], @@ -880,9 +877,7 @@ async fn test_invalid_nullable_low_cardinality() { struct Data { a: Option, } - let client = get_client() - .with_validation_mode(ValidationMode::Each) - .with_option("allow_suspicious_low_cardinality_types", "1"); + let client = get_client().with_option("allow_suspicious_low_cardinality_types", "1"); assert_panic_on_fetch_with_client!( client, &["Data.a", "LowCardinality(Nullable(Int32))", "u32"], @@ -925,7 +920,7 @@ async fn test_serde_skip_deserializing() { c: u32, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query("SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS c") .fetch_one::() @@ -966,7 +961,7 @@ async fn test_date_and_time() { date_time64_9: OffsetDateTime, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -1014,7 +1009,7 @@ async fn test_uuid() { uuid: uuid::Uuid, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -1045,7 +1040,7 @@ async fn test_ipv4_ipv6() { ipv6: std::net::Ipv6Addr, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -1076,7 +1071,7 @@ async fn test_fixed_str() { b: [u8; 3], } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query("SELECT '1234' :: FixedString(4) AS a, '777' :: FixedString(3) AS b") .fetch_one::() @@ -1108,7 +1103,7 @@ async fn test_tuple() { b: (i128, HashMap), } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -1223,7 +1218,7 @@ async fn test_geo() { multi_line_string: MultiLineString, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -1325,7 +1320,7 @@ async fn test_issue_109_1() { drone_id: String, call_sign: String, } - let client = prepare_database!().with_validation_mode(ValidationMode::Each); + let client = prepare_database!(); execute_statements( &client, &[ @@ -1390,7 +1385,7 @@ async fn test_issue_113() { b: f64, c: f64, } - let client = prepare_database!().with_validation_mode(ValidationMode::Each); + let client = prepare_database!(); execute_statements(&client, &[ " CREATE TABLE issue_113_1( @@ -1436,7 +1431,7 @@ async fn test_issue_114() { arr: Vec>, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -1471,9 +1466,7 @@ async fn test_issue_173() { ts: time::OffsetDateTime, } - let client = prepare_database!() - .with_validation_mode(ValidationMode::Each) - .with_option("date_time_input_format", "best_effort"); + let client = prepare_database!().with_option("date_time_input_format", "best_effort"); execute_statements(&client, &[ " @@ -1506,7 +1499,7 @@ async fn test_issue_185() { decimal_col: Option, } - let client = prepare_database!().with_validation_mode(ValidationMode::Each); + let client = prepare_database!(); execute_statements( &client, &[ @@ -1537,7 +1530,7 @@ async fn test_issue_218() { max_time: chrono::DateTime, } - let client = prepare_database!().with_validation_mode(ValidationMode::Each); + let client = prepare_database!(); execute_statements( &client, &[" @@ -1573,9 +1566,7 @@ async fn test_variant_wrong_definition() { var: MyVariant, } - let client = get_client() - .with_validation_mode(ValidationMode::Each) - .with_option("allow_experimental_variant_type", "1"); + let client = get_client().with_option("allow_experimental_variant_type", "1"); assert_panic_on_fetch_with_client!( client, @@ -1599,7 +1590,7 @@ async fn test_decimals() { decimal128_38_12: Decimal128, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " @@ -1670,7 +1661,7 @@ async fn test_different_struct_field_order_same_types() { a: String, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query("SELECT 'foo' AS a, 'bar' :: String AS c") .fetch_one::() @@ -1694,7 +1685,7 @@ async fn test_different_struct_field_order_different_types() { c: Vec, } - let client = get_client().with_validation_mode(ValidationMode::Each); + let client = get_client(); let result = client .query( " diff --git a/tests/it/variant.rs b/tests/it/variant.rs index d5f9dae2..1343905e 100644 --- a/tests/it/variant.rs +++ b/tests/it/variant.rs @@ -3,13 +3,12 @@ use serde::{Deserialize, Serialize}; use time::Month::January; -use clickhouse::validation_mode::ValidationMode::Each; use clickhouse::Row; // See also: https://clickhouse.com/docs/en/sql-reference/data-types/variant #[tokio::test] async fn variant_data_type() { - let client = prepare_database!().with_validation_mode(Each); + let client = prepare_database!(); // NB: Inner Variant types are _always_ sorted alphabetically, // and should be defined in _exactly_ the same order in the enum. From 1544b7b9844831e01600a4d83595af2bc8facc8a Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Wed, 11 Jun 2025 20:41:54 +0200 Subject: [PATCH 43/54] Make validation slightly faster again --- Cargo.toml | 1 + src/row_metadata.rs | 12 +--- src/rowbinary/de.rs | 106 +++++++++++++++++++++--------------- src/rowbinary/tests.rs | 29 ++++++++++ src/rowbinary/validation.rs | 52 ++++++++++-------- src/test/handlers.rs | 4 +- 6 files changed, 125 insertions(+), 79 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4ab86d98..019d9b37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -147,6 +147,7 @@ quanta = { version = "0.12", optional = true } replace_with = { version = "0.1.7" } [dev-dependencies] +clickhouse-derive = { version = "0.2.0", path = "derive" } criterion = "0.5.0" serde = { version = "1.0.106", features = ["derive"] } tokio = { version = "1.0.1", features = ["full", "test-util"] } diff --git a/src/row_metadata.rs b/src/row_metadata.rs index 6dc52518..dbf5dbad 100644 --- a/src/row_metadata.rs +++ b/src/row_metadata.rs @@ -29,10 +29,6 @@ pub(crate) enum AccessType { /// as it is calculated only once per struct. It does not have lifetimes, /// so it does not introduce a breaking change to [`crate::cursors::RowCursor`]. pub(crate) struct RowMetadata { - /// See [`Row::NAME`] - pub(crate) name: &'static str, - /// See [`Row::TYPE`] - pub(crate) kind: RowKind, /// Database schema, or columns, are parsed before the first call to (de)serializer. pub(crate) columns: Vec, /// This determines whether we can just use [`crate::rowbinary::de::RowBinarySeqAccess`] @@ -128,8 +124,6 @@ impl RowMetadata { Self { columns, access_type, - kind: T::KIND, - name: T::NAME, } } @@ -140,10 +134,8 @@ impl RowMetadata { if struct_idx < mapping.len() { mapping[struct_idx] } else { - panic!( - "Struct {} has more fields than columns in the database schema", - self.name - ) + // unreachable + panic!("Struct has more fields than columns in the database schema",) } } AccessType::WithSeqAccess => struct_idx, // should be unreachable diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index 3fc319bd..ea18a724 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -3,6 +3,7 @@ use crate::row_metadata::RowMetadata; use crate::rowbinary::utils::{ensure_size, get_unsigned_leb128}; use crate::rowbinary::validation::SerdeType; use crate::rowbinary::validation::{DataTypeValidator, SchemaValidator}; +use crate::Row; use bytes::Buf; use core::mem::size_of; use serde::de::MapAccess; @@ -10,9 +11,11 @@ use serde::{ de::{DeserializeSeed, Deserializer, EnumAccess, SeqAccess, VariantAccess, Visitor}, Deserialize, }; +use std::marker::PhantomData; use std::{convert::TryFrom, str}; -/// Deserializes a value from `input` with a row encoded in `RowBinary(WithNamesAndTypes)`. +/// Deserializes a value from `input` with a row encoded in `RowBinary`, +/// i.e. only when [`crate::Row`] validation is disabled in the client. /// /// It accepts _a reference to_ a byte slice because it somehow leads to a more /// performant generated code than `(&[u8]) -> Result<(T, usize)>` and even @@ -24,39 +27,47 @@ use std::{convert::TryFrom, str}; /// It expects a slice of [`Column`] objects parsed /// from the beginning of `RowBinaryWithNamesAndTypes` data stream. /// After the header, the rows format is the same as `RowBinary`. -pub(crate) fn deserialize_row_binary<'data, 'cursor, T: Deserialize<'data>>( +pub(crate) fn deserialize_row_binary<'data, 'cursor, T: Deserialize<'data> + Row>( input: &mut &'data [u8], ) -> Result { - let mut deserializer = RowBinaryDeserializer::new(input, ()); + let mut deserializer = RowBinaryDeserializer::::new(input, ()); T::deserialize(&mut deserializer) } -pub(crate) fn deserialize_rbwnat<'data, 'cursor, T: Deserialize<'data>>( +/// Similar to [`deserialize_row_binary`], but uses [`RowMetadata`] +/// parsed from `RowBinaryWithNamesAndTypes` header to validate the data types. +/// This is used when [`crate::Row`] validation is enabled in the client (default). +pub(crate) fn deserialize_rbwnat<'data, 'cursor, T: Deserialize<'data> + Row>( input: &mut &'data [u8], metadata: Option<&'cursor RowMetadata>, ) -> Result { let validator = DataTypeValidator::new(metadata.unwrap()); - let mut deserializer = RowBinaryDeserializer::new(input, validator); + let mut deserializer = RowBinaryDeserializer::::new(input, validator); T::deserialize(&mut deserializer) } /// A deserializer for the `RowBinary(WithNamesAndTypes)` format. /// /// See https://clickhouse.com/docs/en/interfaces/formats#rowbinary for details. -struct RowBinaryDeserializer<'cursor, 'data, Validator = ()> +struct RowBinaryDeserializer<'cursor, 'data, R: Row, Validator = ()> where - Validator: SchemaValidator, + Validator: SchemaValidator, { validator: Validator, input: &'cursor mut &'data [u8], + _marker: PhantomData, } -impl<'cursor, 'data, Validator> RowBinaryDeserializer<'cursor, 'data, Validator> +impl<'cursor, 'data, R: Row, Validator> RowBinaryDeserializer<'cursor, 'data, R, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { fn new(input: &'cursor mut &'data [u8], validator: Validator) -> Self { - Self { input, validator } + Self { + input, + validator, + _marker: PhantomData::, + } } fn read_vec(&mut self, size: usize) -> Result> { @@ -109,9 +120,10 @@ macro_rules! impl_num_or_enum { }; } -impl<'data, Validator> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data, Validator> +impl<'data, R: Row, Validator> Deserializer<'data> + for &mut RowBinaryDeserializer<'_, 'data, R, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { type Error = Error; @@ -120,20 +132,19 @@ where impl_num!(i32, deserialize_i32, visit_i32, get_i32_le, SerdeType::I32); impl_num!(i64, deserialize_i64, visit_i64, get_i64_le, SerdeType::I64); + #[rustfmt::skip] + impl_num!(i128, deserialize_i128, visit_i128, get_i128_le, SerdeType::I128); impl_num!(u8, deserialize_u8, visit_u8, get_u8, SerdeType::U8); impl_num!(u16, deserialize_u16, visit_u16, get_u16_le, SerdeType::U16); impl_num!(u32, deserialize_u32, visit_u32, get_u32_le, SerdeType::U32); impl_num!(u64, deserialize_u64, visit_u64, get_u64_le, SerdeType::U64); + #[rustfmt::skip] + impl_num!(u128, deserialize_u128, visit_u128, get_u128_le, SerdeType::U128); impl_num!(f32, deserialize_f32, visit_f32, get_f32_le, SerdeType::F32); impl_num!(f64, deserialize_f64, visit_f64, get_f64_le, SerdeType::F64); - #[rustfmt::skip] - impl_num!(i128, deserialize_i128, visit_i128, get_i128_le, SerdeType::I128); - #[rustfmt::skip] - impl_num!(u128, deserialize_u128, visit_u128, get_u128_le, SerdeType::U128); - #[inline(always)] fn deserialize_any>(self, _: V) -> Result { Err(Error::DeserializeAnyNotSupported) @@ -228,6 +239,7 @@ where deserializer: &mut RowBinaryDeserializer { input: self.input, validator: self.validator.validate(SerdeType::Enum), + _marker: PhantomData::, }, }) } else { @@ -242,6 +254,7 @@ where deserializer: &mut RowBinaryDeserializer { input: self.input, validator: self.validator.validate(SerdeType::Tuple(len)), + _marker: PhantomData::, }, len, }) @@ -263,6 +276,7 @@ where 0 => visitor.visit_some(&mut RowBinaryDeserializer { input: self.input, validator: inner_validator, + _marker: PhantomData::, }), 1 => visitor.visit_none(), v => Err(Error::InvalidTagEncoding(v as usize)), @@ -285,6 +299,7 @@ where deserializer: &mut RowBinaryDeserializer { input: self.input, validator: self.validator.validate(SerdeType::Seq(len)), + _marker: PhantomData::, }, len, }) @@ -304,6 +319,7 @@ where deserializer: &mut RowBinaryDeserializer { input: self.input, validator: self.validator.validate(SerdeType::Map(len)), + _marker: PhantomData::, }, entries_visited: 0, len, @@ -394,17 +410,17 @@ where /// Used in [`Deserializer::deserialize_seq`], [`Deserializer::deserialize_tuple`], /// and it could be used in [`Deserializer::deserialize_struct`], /// if we detect that the field order matches the database schema. -struct RowBinarySeqAccess<'de, 'cursor, 'data, Validator> +struct RowBinarySeqAccess<'de, 'cursor, 'data, R: Row, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, len: usize, } -impl<'data, Validator> SeqAccess<'data> for RowBinarySeqAccess<'_, '_, 'data, Validator> +impl<'data, R: Row, Validator> SeqAccess<'data> for RowBinarySeqAccess<'_, '_, 'data, R, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { type Error = Error; @@ -427,18 +443,18 @@ where } /// Used in [`Deserializer::deserialize_map`]. -struct RowBinaryMapAccess<'de, 'cursor, 'data, Validator> +struct RowBinaryMapAccess<'de, 'cursor, 'data, R: Row, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, entries_visited: usize, len: usize, } -impl<'data, Validator> MapAccess<'data> for RowBinaryMapAccess<'_, '_, 'data, Validator> +impl<'data, R: Row, Validator> MapAccess<'data> for RowBinaryMapAccess<'_, '_, 'data, R, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { type Error = Error; @@ -467,11 +483,11 @@ where /// Used in [`Deserializer::deserialize_struct`] to support wrong struct field order /// as long as the data types and field names are exactly matching the database schema. -struct RowBinaryStructAsMapAccess<'de, 'cursor, 'data, Validator> +struct RowBinaryStructAsMapAccess<'de, 'cursor, 'data, R: Row, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, current_field_idx: usize, fields: &'static [&'static str], } @@ -520,9 +536,10 @@ impl<'de> Deserializer<'de> for StructFieldIdentifier { /// /// If we just use [`RowBinarySeqAccess`] here, `c` will be deserialized into the `a` field, /// and `a` will be deserialized into the `c` field, which is a classic case of data corruption. -impl<'data, Validator> MapAccess<'data> for RowBinaryStructAsMapAccess<'_, '_, 'data, Validator> +impl<'data, R: Row, Validator> MapAccess<'data> + for RowBinaryStructAsMapAccess<'_, '_, 'data, R, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { type Error = Error; @@ -555,23 +572,24 @@ where } /// Used in [`Deserializer::deserialize_enum`]. -struct RowBinaryEnumAccess<'de, 'cursor, 'data, Validator> +struct RowBinaryEnumAccess<'de, 'cursor, 'data, R: Row, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, } -struct VariantDeserializer<'de, 'cursor, 'data, Validator> +struct VariantDeserializer<'de, 'cursor, 'data, R: Row, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, Validator>, + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, } -impl<'data, Validator> VariantAccess<'data> for VariantDeserializer<'_, '_, 'data, Validator> +impl<'data, R: Row, Validator> VariantAccess<'data> + for VariantDeserializer<'_, '_, 'data, R, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { type Error = Error; @@ -601,13 +619,13 @@ where } } -impl<'de, 'cursor, 'data, Validator> EnumAccess<'data> - for RowBinaryEnumAccess<'de, 'cursor, 'data, Validator> +impl<'de, 'cursor, 'data, R: Row, Validator> EnumAccess<'data> + for RowBinaryEnumAccess<'de, 'cursor, 'data, R, Validator> where - Validator: SchemaValidator, + Validator: SchemaValidator, { type Error = Error; - type Variant = VariantDeserializer<'de, 'cursor, 'data, Validator>; + type Variant = VariantDeserializer<'de, 'cursor, 'data, R, Validator>; fn variant_seed(self, seed: T) -> Result<(T::Value, Self::Variant), Self::Error> where diff --git a/src/rowbinary/tests.rs b/src/rowbinary/tests.rs index f11c0850..3e22cde1 100644 --- a/src/rowbinary/tests.rs +++ b/src/rowbinary/tests.rs @@ -1,3 +1,4 @@ +use crate::Row; use serde::{Deserialize, Serialize}; #[derive(Debug, PartialEq, Serialize, Deserialize)] @@ -36,6 +37,34 @@ struct Sample<'a> { boolean: bool, } +// clickhouse_derive is not working here +impl Row for Sample<'_> { + const NAME: &'static str = "Sample"; + const COLUMN_NAMES: &'static [&'static str] = &[ + "int8", + "int32", + "int64", + "uint8", + "uint32", + "uint64", + "float32", + "float64", + "datetime", + "datetime64", + "decimal64", + "decimal128", + "string", + "blob", + "optional_decimal64", + "optional_datetime", + "fixed_string", + "array", + "boolean", + ]; + const COLUMN_COUNT: usize = 19; + const KIND: crate::RowKind = crate::RowKind::Struct; +} + fn sample() -> Sample<'static> { Sample { int8: -42, diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index a179700f..c5b2a56a 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -1,20 +1,21 @@ use crate::row_metadata::RowMetadata; -use crate::RowKind; +use crate::{Row, RowKind}; use clickhouse_types::data_types::{Column, DataTypeNode, DecimalType, EnumType}; use std::collections::HashMap; use std::fmt::Display; +use std::marker::PhantomData; /// This trait is used to validate the schema of a [`crate::Row`] against the parsed RBWNAT schema. /// Note that [`SchemaValidator`] is also implemented for `()`, /// which is used to skip validation if the user disabled it. -pub(crate) trait SchemaValidator: Sized { +pub(crate) trait SchemaValidator: Sized { /// Ensures that the branching is completely optimized out based on the validation settings. const VALIDATION: bool; /// The main entry point. The validation flow based on the [`crate::Row::KIND`]. /// For container types (nullable, array, map, tuple, variant, etc.), /// it will return an [`InnerDataTypeValidator`] instance (see [`InnerDataTypeValidatorKind`]), /// which has its own implementation of this method, allowing recursive validation. - fn validate(&'_ mut self, serde_type: SerdeType) -> Option>; + fn validate(&'_ mut self, serde_type: SerdeType) -> Option>; /// Validates that an identifier exists in the values map for enums, /// or stores the variant identifier for the next serde call. fn validate_identifier(&mut self, value: T); @@ -29,16 +30,18 @@ pub(crate) trait SchemaValidator: Sized { fn get_schema_index(&self, struct_idx: usize) -> usize; } -pub(crate) struct DataTypeValidator<'cursor> { +pub(crate) struct DataTypeValidator<'cursor, R: Row> { metadata: &'cursor RowMetadata, current_column_idx: usize, + _marker: PhantomData, } -impl<'cursor> DataTypeValidator<'cursor> { +impl<'cursor, R: Row> DataTypeValidator<'cursor, R> { pub(crate) fn new(metadata: &'cursor RowMetadata) -> Self { Self { - current_column_idx: 0, metadata, + current_column_idx: 0, + _marker: PhantomData::, } } @@ -54,7 +57,7 @@ impl<'cursor> DataTypeValidator<'cursor> { fn get_current_column_name_and_type(&self) -> (String, &DataTypeNode) { self.get_current_column() - .map(|c| (format!("{}.{}", self.metadata.name, c.name), &c.data_type)) + .map(|c| (format!("{}.{}", R::NAME, c.name), &c.data_type)) // both should be defined at this point .unwrap_or(("Struct".to_string(), &DataTypeNode::Bool)) } @@ -64,8 +67,8 @@ impl<'cursor> DataTypeValidator<'cursor> { data_type: &DataTypeNode, serde_type: &SerdeType, is_inner: bool, - ) -> Option> { - match self.metadata.kind { + ) -> Option> { + match R::KIND { RowKind::Primitive => { panic!( "While processing row as a primitive: attempting to deserialize \ @@ -109,12 +112,12 @@ impl<'cursor> DataTypeValidator<'cursor> { } } -impl SchemaValidator for DataTypeValidator<'_> { +impl SchemaValidator for DataTypeValidator<'_, R> { const VALIDATION: bool = true; #[inline] - fn validate(&'_ mut self, serde_type: SerdeType) -> Option> { - match self.metadata.kind { + fn validate(&'_ mut self, serde_type: SerdeType) -> Option> { + match R::KIND { // `fetch::` for a "primitive row" type RowKind::Primitive => { if self.current_column_idx == 0 && self.metadata.columns.len() == 1 { @@ -166,7 +169,7 @@ impl SchemaValidator for DataTypeValidator<'_> { } else { panic!( "Struct {} has more fields than columns in the database schema", - self.metadata.name + R::NAME ) } } @@ -211,8 +214,8 @@ pub(crate) enum MapAsSequenceValidatorState { Value, } -pub(crate) struct InnerDataTypeValidator<'de, 'cursor> { - root: &'de DataTypeValidator<'cursor>, +pub(crate) struct InnerDataTypeValidator<'de, 'cursor, R: Row> { + root: &'de DataTypeValidator<'cursor, R>, kind: InnerDataTypeValidatorKind<'cursor>, } @@ -238,11 +241,14 @@ pub(crate) enum VariantValidationState { Identifier(u8), } -impl<'de, 'cursor> SchemaValidator for Option> { +impl<'de, 'cursor, R: Row> SchemaValidator for Option> { const VALIDATION: bool = true; #[inline] - fn validate(&mut self, serde_type: SerdeType) -> Option> { + fn validate( + &mut self, + serde_type: SerdeType, + ) -> Option> { match self { None => None, Some(inner) => match &mut inner.kind { @@ -396,7 +402,7 @@ impl<'de, 'cursor> SchemaValidator for Option { +impl Drop for InnerDataTypeValidator<'_, '_, R> { fn drop(&mut self) { if let InnerDataTypeValidatorKind::Tuple(elements_types) = self.kind { if !elements_types.is_empty() { @@ -421,12 +427,12 @@ impl Drop for InnerDataTypeValidator<'_, '_> { // static/const dispatch? // separate smaller inline functions? #[inline] -fn validate_impl<'de, 'cursor>( - root: &'de DataTypeValidator<'cursor>, +fn validate_impl<'de, 'cursor, R: Row>( + root: &'de DataTypeValidator<'cursor, R>, column_data_type: &'cursor DataTypeNode, serde_type: &SerdeType, is_inner: bool, -) -> Option> { +) -> Option> { let data_type = column_data_type.remove_low_cardinality(); match serde_type { SerdeType::Bool @@ -624,11 +630,11 @@ fn validate_impl<'de, 'cursor>( } } -impl SchemaValidator for () { +impl SchemaValidator for () { const VALIDATION: bool = false; #[inline(always)] - fn validate(&mut self, _serde_type: SerdeType) -> Option> { + fn validate(&mut self, _serde_type: SerdeType) -> Option> { None } diff --git a/src/test/handlers.rs b/src/test/handlers.rs index fd5edc11..e5c6d47a 100644 --- a/src/test/handlers.rs +++ b/src/test/handlers.rs @@ -7,7 +7,7 @@ use sealed::sealed; use serde::{Deserialize, Serialize}; use super::{Handler, HandlerFn}; -use crate::rowbinary; +use crate::{rowbinary, Row}; const BUFFER_INITIAL_CAPACITY: usize = 1024; @@ -82,7 +82,7 @@ pub struct RecordControl { impl RecordControl where - T: for<'a> Deserialize<'a>, + T: for<'a> Deserialize<'a> + Row, { pub async fn collect(self) -> C where From b7b45c5f99348b1359e33ebf7f041427cca4d580 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Tue, 17 Jun 2025 16:16:53 +0200 Subject: [PATCH 44/54] Address PR feedback --- CHANGELOG.md | 11 ++++ benches/common_select.rs | 8 +-- examples/mock.rs | 4 +- src/lib.rs | 11 ++-- src/query.rs | 4 -- tests/it/mock.rs | 2 +- tests/it/rbwnat.rs | 124 +++++++++++++++++++-------------------- types/Cargo.toml | 4 +- types/src/data_types.rs | 41 +++++++++++-- types/src/decoders.rs | 10 ++-- types/src/error.rs | 14 ++--- types/src/leb128.rs | 49 +++++++++++++--- types/src/lib.rs | 59 +++++++++++++++---- 13 files changed, 225 insertions(+), 116 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d196400..17bcc8b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - ReleaseDate +### Breaking Changes + +- query: `RowBinaryWithNamesAndTypes` is now used by default for query results. This may cause panics if the row struct definition does not match the database schema. Use `Client::with_validation(false)` to revert to the previous behavior which uses plain `RowBinary` format for fetching rows. ([#221]) +- query: due to `RowBinaryWithNamesAndTypes` format usage, there might be an impact on fetch performance, which largely depends on how the dataset is defined. If you experience performance issues, consider disabling validation by using `Client::with_validation(false)`. + +### Added +- client: added `Client::with_validation` builder method. Validation is enabled by default, meaning that `RowBinaryWithNamesAndTypes` format will be used to fetch rows from the database. If validation is disabled, `RowBinary` format will be used, similarly to the previous versions. ([#221]). +- types: a new crate `clickhouse-types` was added to the project workspace. This crate is required for `RowBinaryWithNamesAndTypes` struct definition validation, as it contains ClickHouse data types AST, as well as functions and utilities to parse the types out of the ClickHouse server response. Note that this crate is not intended for public usage, as it might introduce internal breaking changes not following semver. ([#221]). + +[#221]: https://github.com/ClickHouse/clickhouse-rs/pull/221 + ## [0.13.3] - 2025-05-29 ### Added - client: added `Client::with_access_token` to support JWT authentication ClickHouse Cloud feature ([#215]). diff --git a/benches/common_select.rs b/benches/common_select.rs index a1c3d850..0fb00dd5 100644 --- a/benches/common_select.rs +++ b/benches/common_select.rs @@ -85,12 +85,10 @@ pub(crate) async fn fetch_cursor<'a, T: BenchmarkRow<'a>>( validation: bool, query: &str, ) -> RowCursor { - let mut client = Client::default() + let client = Client::default() .with_compression(compression) - .with_url("http://localhost:8123"); - if !validation { - client = client.with_disabled_validation(); - } + .with_url("http://localhost:8123") + .with_validation(validation); client.query(query).fetch::().unwrap() } diff --git a/examples/mock.rs b/examples/mock.rs index d4e7dadb..e63950e2 100644 --- a/examples/mock.rs +++ b/examples/mock.rs @@ -48,7 +48,9 @@ async fn main() { let mock = test::Mock::new(); let client = Client::default() .with_url(mock.url()) - .with_disabled_validation(); + // disabled schema validation is required for mocks to work; + // it is pointless for mocked tests anyway + .with_validation(false); let list = vec![SomeRow { no: 1 }, SomeRow { no: 2 }]; // How to test DDL. diff --git a/src/lib.rs b/src/lib.rs index b4d7ebbe..f3df63ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -329,14 +329,15 @@ impl Client { /// /// # Warning /// - /// While disabled validation will result in increased performance, + /// While disabled validation will result in increased performance + /// (between 1.1x and 3x, depending on the data), /// this mode is intended to be used for testing purposes only, /// and only in scenarios where schema mismatch issues are irrelevant. /// /// ***DO NOT*** disable validation in your production code or tests /// unless you are 100% sure why you are doing it. - pub fn with_disabled_validation(mut self) -> Self { - self.validation = false; + pub fn with_validation(mut self, enabled: bool) -> Self { + self.validation = enabled; self } @@ -481,9 +482,9 @@ mod client_tests { fn it_sets_validation_mode() { let client = Client::default(); assert!(client.validation); - let client = client.with_disabled_validation(); + let client = client.with_validation(false); assert!(!client.validation); - let client = client.with_disabled_validation(); + let client = client.with_validation(true); assert!(!client.validation); } } diff --git a/src/query.rs b/src/query.rs index 9c1ff04f..095ab7f0 100644 --- a/src/query.rs +++ b/src/query.rs @@ -94,10 +94,6 @@ impl Query { } let response = self.do_execute(true)?; - - // #[cfg(feature = "test_util")] - // if response.headers - Ok(RowCursor::new(response, validation)) } diff --git a/tests/it/mock.rs b/tests/it/mock.rs index a6c72410..5ea5a2e4 100644 --- a/tests/it/mock.rs +++ b/tests/it/mock.rs @@ -8,7 +8,7 @@ async fn test_provide() { let mock = test::Mock::new(); let client = Client::default() .with_url(mock.url()) - .with_disabled_validation(); + .with_validation(false); let expected = vec![SimpleRow::new(1, "one"), SimpleRow::new(2, "two")]; mock.add(test::handlers::provide(&expected)); diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs index 21c0081a..73b5fc1f 100644 --- a/tests/it/rbwnat.rs +++ b/tests/it/rbwnat.rs @@ -11,7 +11,7 @@ use std::collections::HashMap; use std::str::FromStr; #[tokio::test] -async fn test_header_parsing() { +async fn header_parsing() { let client = prepare_database!(); client .query( @@ -103,7 +103,7 @@ async fn test_header_parsing() { } #[tokio::test] -async fn test_fetch_primitive_row() { +async fn fetch_primitive_row() { let client = get_client(); let result = client .query("SELECT count() FROM (SELECT * FROM system.numbers LIMIT 3)") @@ -113,7 +113,7 @@ async fn test_fetch_primitive_row() { } #[tokio::test] -async fn test_fetch_primitive_row_schema_mismatch() { +async fn fetch_primitive_row_schema_mismatch() { type Data = i32; // expected type is UInt64 assert_panic_on_fetch!( &["primitive", "UInt64", "i32"], @@ -122,7 +122,7 @@ async fn test_fetch_primitive_row_schema_mismatch() { } #[tokio::test] -async fn test_fetch_vector_row() { +async fn fetch_vector_row() { let client = get_client(); let result = client .query("SELECT [1, 2, 3] :: Array(UInt32)") @@ -132,7 +132,7 @@ async fn test_fetch_vector_row() { } #[tokio::test] -async fn test_fetch_vector_row_schema_mismatch_nested_type() { +async fn fetch_vector_row_schema_mismatch_nested_type() { type Data = Vec; // expected type for Array(UInt32) is Vec assert_panic_on_fetch!( &["vector", "UInt32", "i128"], @@ -141,7 +141,7 @@ async fn test_fetch_vector_row_schema_mismatch_nested_type() { } #[tokio::test] -async fn test_fetch_tuple_row() { +async fn fetch_tuple_row() { let client = get_client(); let result = client .query("SELECT 42 :: UInt32 AS a, 'foo' :: String AS b") @@ -151,7 +151,7 @@ async fn test_fetch_tuple_row() { } #[tokio::test] -async fn test_fetch_tuple_row_schema_mismatch_first_element() { +async fn fetch_tuple_row_schema_mismatch_first_element() { type Data = (i128, String); // expected u32 instead of i128 assert_panic_on_fetch!( &["tuple", "UInt32", "i128"], @@ -160,7 +160,7 @@ async fn test_fetch_tuple_row_schema_mismatch_first_element() { } #[tokio::test] -async fn test_fetch_tuple_row_schema_mismatch_second_element() { +async fn fetch_tuple_row_schema_mismatch_second_element() { type Data = (u32, i64); // expected String instead of i64 assert_panic_on_fetch!( &["tuple", "String", "i64"], @@ -169,7 +169,7 @@ async fn test_fetch_tuple_row_schema_mismatch_second_element() { } #[tokio::test] -async fn test_fetch_tuple_row_schema_mismatch_missing_element() { +async fn fetch_tuple_row_schema_mismatch_missing_element() { type Data = (u32, String); // expected to have the third element as i64 assert_panic_on_fetch!( &[ @@ -181,7 +181,7 @@ async fn test_fetch_tuple_row_schema_mismatch_missing_element() { } #[tokio::test] -async fn test_fetch_tuple_row_schema_mismatch_too_many_elements() { +async fn fetch_tuple_row_schema_mismatch_too_many_elements() { type Data = (u32, String, i128); // i128 should not be there assert_panic_on_fetch!( &[ @@ -193,7 +193,7 @@ async fn test_fetch_tuple_row_schema_mismatch_too_many_elements() { } #[tokio::test] -async fn test_fetch_tuple_row_with_struct() { +async fn fetch_tuple_row_with_struct() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: u32, @@ -218,7 +218,7 @@ async fn test_fetch_tuple_row_with_struct() { } #[tokio::test] -async fn test_fetch_tuple_row_with_struct_schema_mismatch() { +async fn fetch_tuple_row_with_struct_schema_mismatch() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct _Data { a: u64, // expected type is u32 @@ -232,7 +232,7 @@ async fn test_fetch_tuple_row_with_struct_schema_mismatch() { } #[tokio::test] -async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_many_struct_fields() { +async fn fetch_tuple_row_with_struct_schema_mismatch_too_many_struct_fields() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct _Data { a: u32, @@ -247,7 +247,7 @@ async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_many_struct_fields } #[tokio::test] -async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_many_fields() { +async fn fetch_tuple_row_with_struct_schema_mismatch_too_many_fields() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct _Data { a: u32, @@ -261,7 +261,7 @@ async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_many_fields() { } #[tokio::test] -async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_few_struct_fields() { +async fn fetch_tuple_row_with_struct_schema_mismatch_too_few_struct_fields() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct _Data { a: u32, // the second field is missing now @@ -274,7 +274,7 @@ async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_few_struct_fields( } #[tokio::test] -async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_few_fields() { +async fn fetch_tuple_row_with_struct_schema_mismatch_too_few_fields() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct _Data { a: u32, @@ -288,7 +288,7 @@ async fn test_fetch_tuple_row_with_struct_schema_mismatch_too_few_fields() { } #[tokio::test] -async fn test_basic_types() { +async fn basic_types() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { uint8_val: u8, @@ -352,7 +352,7 @@ async fn test_basic_types() { // FIXME: somehow this test breaks `cargo test`, but works from RustRover #[ignore] #[tokio::test] -async fn test_borrowed_data() { +async fn borrowed_data() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data<'a> { str: &'a str, @@ -446,7 +446,7 @@ async fn test_borrowed_data() { } #[tokio::test] -async fn test_several_simple_rows() { +async fn several_simple_rows() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { num: u64, @@ -479,7 +479,7 @@ async fn test_several_simple_rows() { } #[tokio::test] -async fn test_many_numbers() { +async fn many_numbers() { #[derive(Row, Deserialize)] struct Data { number: u64, @@ -499,7 +499,7 @@ async fn test_many_numbers() { } #[tokio::test] -async fn test_blob_string_with_serde_bytes() { +async fn blob_string_with_serde_bytes() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { #[serde(with = "serde_bytes")] @@ -521,7 +521,7 @@ async fn test_blob_string_with_serde_bytes() { } #[tokio::test] -async fn test_arrays() { +async fn arrays() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { id: u16, @@ -563,7 +563,7 @@ async fn test_arrays() { } #[tokio::test] -async fn test_maps() { +async fn maps() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { m1: HashMap, @@ -613,7 +613,7 @@ async fn test_maps() { } #[tokio::test] -async fn test_map_as_vec_of_tuples() { +async fn map_as_vec_of_tuples() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { m1: Vec<(i128, String)>, @@ -656,7 +656,7 @@ async fn test_map_as_vec_of_tuples() { } #[tokio::test] -async fn test_map_as_vec_of_tuples_schema_mismatch() { +async fn map_as_vec_of_tuples_schema_mismatch() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { m: Vec<(u16, Vec<(String, i32)>)>, @@ -669,7 +669,7 @@ async fn test_map_as_vec_of_tuples_schema_mismatch() { } #[tokio::test] -async fn test_map_as_vec_of_tuples_schema_mismatch_nested() { +async fn map_as_vec_of_tuples_schema_mismatch_nested() { type Inner = Vec<(i32, i64)>; // the value should be i128 instead of i64 #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] @@ -690,7 +690,7 @@ async fn test_map_as_vec_of_tuples_schema_mismatch_nested() { } #[tokio::test] -async fn test_enum() { +async fn enums() { #[derive(Debug, PartialEq, Serialize_repr, Deserialize_repr)] #[repr(i8)] enum MyEnum8 { @@ -775,7 +775,7 @@ async fn test_enum() { } #[tokio::test] -async fn test_nullable() { +async fn nullable() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: u32, @@ -810,7 +810,7 @@ async fn test_nullable() { } #[tokio::test] -async fn test_invalid_nullable() { +async fn invalid_nullable() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { n: Option, @@ -822,7 +822,7 @@ async fn test_invalid_nullable() { } #[tokio::test] -async fn test_low_cardinality() { +async fn low_cardinality() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: u32, @@ -858,7 +858,7 @@ async fn test_low_cardinality() { } #[tokio::test] -async fn test_invalid_low_cardinality() { +async fn invalid_low_cardinality() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: u32, @@ -872,7 +872,7 @@ async fn test_invalid_low_cardinality() { } #[tokio::test] -async fn test_invalid_nullable_low_cardinality() { +async fn invalid_nullable_low_cardinality() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: Option, @@ -887,7 +887,7 @@ async fn test_invalid_nullable_low_cardinality() { #[tokio::test] #[cfg(feature = "time")] -async fn test_invalid_serde_with() { +async fn invalid_serde_with() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { #[serde(with = "clickhouse::serde::time::datetime64::millis")] @@ -897,7 +897,7 @@ async fn test_invalid_serde_with() { } #[tokio::test] -async fn test_too_many_struct_fields() { +async fn too_many_struct_fields() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: u32, @@ -911,7 +911,7 @@ async fn test_too_many_struct_fields() { } #[tokio::test] -async fn test_serde_skip_deserializing() { +async fn serde_skip_deserializing() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: u32, @@ -938,7 +938,7 @@ async fn test_serde_skip_deserializing() { #[tokio::test] #[cfg(feature = "time")] -async fn test_date_and_time() { +async fn date_and_time() { use time::format_description::well_known::Iso8601; use time::Month::{February, January}; use time::OffsetDateTime; @@ -1001,7 +1001,7 @@ async fn test_date_and_time() { #[tokio::test] #[cfg(feature = "uuid")] -async fn test_uuid() { +async fn uuid() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { id: u16, @@ -1031,7 +1031,7 @@ async fn test_uuid() { } #[tokio::test] -async fn test_ipv4_ipv6() { +async fn ipv4_ipv6() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { id: u16, @@ -1064,7 +1064,7 @@ async fn test_ipv4_ipv6() { } #[tokio::test] -async fn test_fixed_str() { +async fn fixed_str() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: [u8; 4], @@ -1083,7 +1083,7 @@ async fn test_fixed_str() { } #[tokio::test] -async fn test_fixed_str_too_long() { +async fn fixed_str_too_long() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: [u8; 4], @@ -1096,7 +1096,7 @@ async fn test_fixed_str_too_long() { } #[tokio::test] -async fn test_tuple() { +async fn tuple() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: (u32, String), @@ -1125,7 +1125,7 @@ async fn test_tuple() { } #[tokio::test] -async fn test_tuple_invalid_definition() { +async fn tuple_invalid_definition() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: (u32, String), @@ -1147,7 +1147,7 @@ async fn test_tuple_invalid_definition() { } #[tokio::test] -async fn test_tuple_too_many_elements_in_the_schema() { +async fn tuple_too_many_elements_in_the_schema() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: (u32, String), @@ -1169,7 +1169,7 @@ async fn test_tuple_too_many_elements_in_the_schema() { } #[tokio::test] -async fn test_tuple_too_many_elements_in_the_struct() { +async fn tuple_too_many_elements_in_the_struct() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: (u32, String, bool), @@ -1187,7 +1187,7 @@ async fn test_tuple_too_many_elements_in_the_struct() { } #[tokio::test] -async fn test_deeply_nested_validation_incorrect_fixed_string() { +async fn deeply_nested_validation_incorrect_fixed_string() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { id: u32, @@ -1205,7 +1205,7 @@ async fn test_deeply_nested_validation_incorrect_fixed_string() { } #[tokio::test] -async fn test_geo() { +async fn geo() { #[derive(Clone, Debug, PartialEq)] #[derive(Row, serde::Serialize, serde::Deserialize)] struct Data { @@ -1254,7 +1254,7 @@ async fn test_geo() { // not easy to assert, same applies to the other Geo types #[ignore] #[tokio::test] -async fn test_geo_invalid_point() { +async fn geo_invalid_point() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { id: u32, @@ -1272,7 +1272,7 @@ async fn test_geo_invalid_point() { #[tokio::test] /// See https://github.com/ClickHouse/clickhouse-rs/issues/100 -async fn test_issue_100() { +async fn issue_100() { { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { @@ -1311,7 +1311,7 @@ async fn test_issue_100() { #[ignore] #[tokio::test] /// See https://github.com/ClickHouse/clickhouse-rs/issues/109#issuecomment-2243197221 -async fn test_issue_109_1() { +async fn issue_109_1() { #[derive(Debug, Serialize, Deserialize, Row)] struct Data { #[serde(skip_deserializing)] @@ -1363,7 +1363,7 @@ async fn test_issue_109_1() { } #[tokio::test] -async fn test_issue_112() { +async fn issue_112() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: bool, @@ -1378,7 +1378,7 @@ async fn test_issue_112() { #[tokio::test] /// See https://github.com/ClickHouse/clickhouse-rs/issues/113 -async fn test_issue_113() { +async fn issue_113() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { a: u64, @@ -1423,7 +1423,7 @@ async fn test_issue_113() { #[tokio::test] #[cfg(feature = "time")] /// See https://github.com/ClickHouse/clickhouse-rs/issues/114 -async fn test_issue_114() { +async fn issue_114() { #[derive(Row, Deserialize, Debug, PartialEq)] struct Data { #[serde(with = "clickhouse::serde::time::date")] @@ -1458,7 +1458,7 @@ async fn test_issue_114() { #[tokio::test] #[cfg(feature = "time")] /// See https://github.com/ClickHouse/clickhouse-rs/issues/173 -async fn test_issue_173() { +async fn issue_173() { #[derive(Debug, Serialize, Deserialize, Row)] struct Data { log_id: String, @@ -1492,7 +1492,7 @@ async fn test_issue_173() { #[tokio::test] /// See https://github.com/ClickHouse/clickhouse-rs/issues/185 -async fn test_issue_185() { +async fn issue_185() { #[derive(Row, Deserialize, Debug, PartialEq)] struct Data { pk: u32, @@ -1524,7 +1524,7 @@ async fn test_issue_185() { #[tokio::test] #[cfg(feature = "chrono")] -async fn test_issue_218() { +async fn issue_218() { #[derive(Row, Serialize, Deserialize, Debug)] struct Data { max_time: chrono::DateTime, @@ -1553,7 +1553,7 @@ async fn test_issue_218() { } #[tokio::test] -async fn test_variant_wrong_definition() { +async fn variant_wrong_definition() { #[derive(Debug, Deserialize, PartialEq)] enum MyVariant { Str(String), @@ -1582,7 +1582,7 @@ async fn test_variant_wrong_definition() { } #[tokio::test] -async fn test_decimals() { +async fn decimals() { #[derive(Row, Deserialize, Debug, PartialEq)] struct Data { decimal32_9_4: Decimal32, @@ -1615,7 +1615,7 @@ async fn test_decimals() { } #[tokio::test] -async fn test_decimal32_wrong_size() { +async fn decimal32_wrong_size() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { decimal32: i16, @@ -1628,7 +1628,7 @@ async fn test_decimal32_wrong_size() { } #[tokio::test] -async fn test_decimal64_wrong_size() { +async fn decimal64_wrong_size() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { decimal64: i32, @@ -1641,7 +1641,7 @@ async fn test_decimal64_wrong_size() { } #[tokio::test] -async fn test_decimal128_wrong_size() { +async fn decimal128_wrong_size() { #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] struct Data { decimal128: i64, @@ -1654,7 +1654,7 @@ async fn test_decimal128_wrong_size() { } #[tokio::test] -async fn test_different_struct_field_order_same_types() { +async fn different_struct_field_order_same_types() { #[derive(Debug, Row, Deserialize, PartialEq)] struct Data { c: String, @@ -1677,7 +1677,7 @@ async fn test_different_struct_field_order_same_types() { } #[tokio::test] -async fn test_different_struct_field_order_different_types() { +async fn different_struct_field_order_different_types() { #[derive(Debug, Row, Deserialize, PartialEq)] struct Data { b: u32, diff --git a/types/Cargo.toml b/types/Cargo.toml index db89c637..a169b67c 100644 --- a/types/Cargo.toml +++ b/types/Cargo.toml @@ -2,7 +2,6 @@ name = "clickhouse-types" description = "Data types utils to use with Native and RowBinary(WithNamesAndTypes) formats in ClickHouse" version = "0.1.0" - authors.workspace = true repository.workspace = true homepage.workspace = true @@ -10,6 +9,9 @@ edition.workspace = true license.workspace = true rust-version.workspace = true +[lints.rust] +missing_docs = "warn" + [dependencies] thiserror = "1.0.16" bytes = "1.10.1" diff --git a/types/src/data_types.rs b/types/src/data_types.rs index c412807b..3182406a 100644 --- a/types/src/data_types.rs +++ b/types/src/data_types.rs @@ -2,13 +2,18 @@ use crate::error::TypesError; use std::collections::HashMap; use std::fmt::{Display, Formatter}; +/// A definition of a column in the result set, +/// taken out of the `RowBinaryWithNamesAndTypes` header. #[derive(Debug, Clone, PartialEq)] pub struct Column { + /// The name of the column. pub name: String, + /// The data type of the column. pub data_type: DataTypeNode, } impl Column { + #[allow(missing_docs)] pub fn new(name: String, data_type: DataTypeNode) -> Self { Self { name, data_type } } @@ -20,8 +25,11 @@ impl Display for Column { } } +/// Represents a data type in ClickHouse. +/// See https://clickhouse.com/docs/sql-reference/data-types #[derive(Debug, Clone, PartialEq)] #[non_exhaustive] +#[allow(missing_docs)] pub enum DataTypeNode { Bool, @@ -42,7 +50,9 @@ pub enum DataTypeNode { Float32, Float64, BFloat16, - Decimal(u8, u8, DecimalType), // Scale, Precision, 32 | 64 | 128 | 256 + + /// Scale, Precision, 32 | 64 | 128 | 256 + Decimal(u8, u8, DecimalType), String, FixedString(usize), @@ -50,8 +60,11 @@ pub enum DataTypeNode { Date, Date32, - DateTime(Option), // Optional timezone - DateTime64(DateTimePrecision, Option), // Precision and optional timezone + + /// Optional timezone + DateTime(Option), + /// Precision and optional timezone + DateTime64(DateTimePrecision, Option), IPv4, IPv6, @@ -63,12 +76,15 @@ pub enum DataTypeNode { Tuple(Vec), Enum(EnumType, HashMap), - // key-value pair is defined as an array, so we can also use it as a slice + /// Key-Value pairs are defined as an array, so it can be used as a slice Map([Box; 2]), + /// Function name and its arguments AggregateFunction(String, Vec), + /// Contains all possible types for this variant Variant(Vec), + Dynamic, JSON, @@ -81,6 +97,9 @@ pub enum DataTypeNode { } impl DataTypeNode { + /// Parses a data type from a string that is received + /// in the `RowBinaryWithNamesAndTypes` and `Native` formats headers. + /// See also: https://clickhouse.com/docs/interfaces/formats/RowBinaryWithNamesAndTypes#description pub fn new(name: &str) -> Result { match name { "UInt8" => Ok(Self::UInt8), @@ -136,6 +155,7 @@ impl DataTypeNode { } } + /// LowCardinality(T) -> T pub fn remove_low_cardinality(&self) -> &DataTypeNode { match self { DataTypeNode::LowCardinality(inner) => inner, @@ -229,9 +249,12 @@ impl Display for DataTypeNode { } } +/// Represents the underlying integer size of an Enum type. #[derive(Debug, Clone, PartialEq)] pub enum EnumType { + /// Stored as an `Int8` Enum8, + /// Stored as an `Int16` Enum16, } @@ -244,7 +267,11 @@ impl Display for EnumType { } } +/// DateTime64 precision. +/// Defined as an enum, as it is valid only in the range from 0 to 9. +/// See also: https://clickhouse.com/docs/sql-reference/data-types/datetime64 #[derive(Debug, Clone, PartialEq)] +#[allow(missing_docs)] pub enum DateTimePrecision { Precision0, Precision1, @@ -279,11 +306,17 @@ impl DateTimePrecision { } } +/// Represents the underlying integer type for a Decimal. +/// See also: https://clickhouse.com/docs/sql-reference/data-types/decimal #[derive(Debug, Clone, PartialEq)] pub enum DecimalType { + /// Stored as an `Int32` Decimal32, + /// Stored as an `Int64` Decimal64, + /// Stored as an `Int128` Decimal128, + /// Stored as an `Int256` Decimal256, } diff --git a/types/src/decoders.rs b/types/src/decoders.rs index 4e9c0865..be4355ea 100644 --- a/types/src/decoders.rs +++ b/types/src/decoders.rs @@ -3,20 +3,18 @@ use crate::leb128::read_leb128; use bytes::Buf; #[inline] -pub(crate) fn read_string(buffer: &mut &[u8]) -> Result { - ensure_size(buffer, 1)?; - let length = read_leb128(buffer)? as usize; +pub(crate) fn read_string(mut buffer: impl Buf) -> Result { + let length = read_leb128(&mut buffer)? as usize; if length == 0 { return Ok("".to_string()); } - ensure_size(buffer, length)?; + ensure_size(&mut buffer, length)?; let result = String::from_utf8_lossy(&buffer.copy_to_bytes(length)).to_string(); Ok(result) } #[inline] -pub(crate) fn ensure_size(buffer: &[u8], size: usize) -> Result<(), TypesError> { - // println!("[ensure_size] buffer remaining: {}, required size: {}", buffer.len(), size); +pub(crate) fn ensure_size(buffer: impl Buf, size: usize) -> Result<(), TypesError> { if buffer.remaining() < size { Err(TypesError::NotEnoughData(format!( "expected at least {} bytes, but only {} bytes remaining", diff --git a/types/src/error.rs b/types/src/error.rs index 83757b02..8418d10e 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -1,15 +1,11 @@ -// FIXME: better errors #[derive(Debug, thiserror::Error)] +#[non_exhaustive] +#[doc(hidden)] pub enum TypesError { - #[error("Not enough data: {0}")] + #[error("not enough data: {0}")] NotEnoughData(String), - - #[error("Header parsing error: {0}")] - HeaderParsingError(String), - - #[error("Type parsing error: {0}")] + #[error("type parsing error: {0}")] TypeParsingError(String), - - #[error("Unexpected empty list of columns")] + #[error("unexpected empty list of columns")] EmptyColumns, } diff --git a/types/src/leb128.rs b/types/src/leb128.rs index 1e650457..af7acc0a 100644 --- a/types/src/leb128.rs +++ b/types/src/leb128.rs @@ -1,9 +1,10 @@ use crate::error::TypesError; -use crate::error::TypesError::NotEnoughData; +use crate::error::TypesError::{NotEnoughData, TypeParsingError}; use bytes::{Buf, BufMut}; #[inline] -pub fn read_leb128(buffer: &mut &[u8]) -> Result { +#[doc(hidden)] +pub fn read_leb128(mut buffer: impl Buf) -> Result { let mut value = 0u64; let mut shift = 0; loop { @@ -19,13 +20,16 @@ pub fn read_leb128(buffer: &mut &[u8]) -> Result { } shift += 7; if shift > 57 { - return Err(NotEnoughData("decoding LEB128, invalid shift".to_string())); + return Err(TypeParsingError( + "decoding LEB128, unexpected shift value".to_string(), + )); } } Ok(value) } #[inline] +#[doc(hidden)] pub fn put_leb128(mut buffer: impl BufMut, mut value: u64) { while { let mut byte = value as u8 & 0x7f; @@ -41,9 +45,12 @@ pub fn put_leb128(mut buffer: impl BufMut, mut value: u64) { } {} } +#[cfg(test)] mod tests { + use super::*; + #[test] - fn test_read_leb128() { + fn read() { let test_cases = vec![ // (input bytes, expected value) (vec![0], 0), @@ -56,13 +63,39 @@ mod tests { ]; for (input, expected) in test_cases { - let result = super::read_leb128(&mut input.as_slice()).unwrap(); + let result = read_leb128(&mut input.as_slice()).unwrap(); assert_eq!(result, expected, "Failed decoding {:?}", input); } } #[test] - fn test_put_and_read_leb128() { + fn read_errors() { + let test_cases = vec![ + // (input bytes, expected error message) + (vec![], "decoding LEB128, 0 bytes remaining"), + ( + vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01], + "decoding LEB128, unexpected shift value", + ), + ]; + + for (input, expected_error) in test_cases { + let result = read_leb128(&mut input.as_slice()); + assert!(result.is_err(), "Expected error for input {:?}", input); + if let Err(e) = result { + assert!( + e.to_string().contains(expected_error), + "Error message mismatch for `{:?}`; error was: `{}`, should contain: `{}`", + input, + e, + expected_error + ); + } + } + } + + #[test] + fn put_and_read() { let test_cases: Vec<(u64, Vec)> = vec![ // (value, expected encoding) (0u64, vec![0x00]), @@ -80,7 +113,7 @@ mod tests { for (value, expected_encoding) in test_cases { // Test encoding let mut encoded = Vec::new(); - super::put_leb128(&mut encoded, value); + put_leb128(&mut encoded, value); assert_eq!( encoded, expected_encoding, "Incorrect encoding for {}", @@ -88,7 +121,7 @@ mod tests { ); // Test round-trip - let decoded = super::read_leb128(&mut encoded.as_slice()).unwrap(); + let decoded = read_leb128(&mut encoded.as_slice()).unwrap(); assert_eq!( decoded, value, "Failed round trip for {}: encoded as {:?}, decoded as {}", diff --git a/types/src/lib.rs b/types/src/lib.rs index 22a49b9c..9271783b 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -1,31 +1,47 @@ +//! # clickhouse-types +//! +//! This crate is required for `RowBinaryWithNamesAndTypes` struct definition validation, +//! as it contains ClickHouse data types AST, as well as functions and utilities +//! to parse the types out of the ClickHouse server response. +//! +//! Note that this crate is not intended for public usage, +//! as it might introduce internal breaking changes not following semver. + pub use crate::data_types::{Column, DataTypeNode}; -use crate::decoders::{ensure_size, read_string}; +use crate::decoders::read_string; use crate::error::TypesError; +use bytes::{Buf, BufMut}; + +/// Exported for internal usage only. +/// Do not use it directly in your code. pub use crate::leb128::put_leb128; pub use crate::leb128::read_leb128; -use bytes::BufMut; +/// ClickHouse data types AST and utilities to parse it from strings. pub mod data_types; +/// Required decoders to parse the columns definitions from the header of the response. pub mod decoders; +/// Error types for this crate. pub mod error; +/// Utils for working with LEB128 encoding and decoding. pub mod leb128; -pub fn parse_rbwnat_columns_header(buffer: &mut &[u8]) -> Result, TypesError> { - ensure_size(buffer, 1)?; - let num_columns = read_leb128(buffer)?; +/// Parses the columns definitions from the response in `RowBinaryWithNamesAndTypes` format. +/// This is a mandatory step for this format, as it enables client-side data types validation. +#[doc(hidden)] +pub fn parse_rbwnat_columns_header(mut buffer: impl Buf) -> Result, TypesError> { + let num_columns = read_leb128(&mut buffer)?; if num_columns == 0 { - return Err(TypesError::HeaderParsingError( - "Expected at least one column in the header".to_string(), - )); + return Err(TypesError::EmptyColumns); } let mut columns_names: Vec = Vec::with_capacity(num_columns as usize); for _ in 0..num_columns { - let column_name = read_string(buffer)?; + let column_name = read_string(&mut buffer)?; columns_names.push(column_name); } let mut column_data_types: Vec = Vec::with_capacity(num_columns as usize); for _ in 0..num_columns { - let column_type = read_string(buffer)?; + let column_type = read_string(&mut buffer)?; let data_type = DataTypeNode::new(&column_type)?; column_data_types.push(data_type); } @@ -37,6 +53,10 @@ pub fn parse_rbwnat_columns_header(buffer: &mut &[u8]) -> Result, Ty Ok(columns) } +/// Having a table definition as a slice of [`Column`], +/// encodes it into the `RowBinary` format, and puts it into the provided buffer. +/// This is required to insert the data in `RowBinaryWithNamesAndTypes` format. +#[doc(hidden)] pub fn put_rbwnat_columns_header( columns: &[Column], mut buffer: impl BufMut, @@ -55,3 +75,22 @@ pub fn put_rbwnat_columns_header( } Ok(()) } + +#[cfg(test)] +mod test { + use super::*; + use crate::data_types::DataTypeNode; + use bytes::BytesMut; + + #[test] + fn test_rbwnat_header_round_trip() { + let mut buffer = BytesMut::new(); + let columns = vec![ + Column::new("id".to_string(), DataTypeNode::Int32), + Column::new("name".to_string(), DataTypeNode::String), + ]; + put_rbwnat_columns_header(&columns, &mut buffer).unwrap(); + let parsed_columns = parse_rbwnat_columns_header(&mut buffer).unwrap(); + assert_eq!(parsed_columns, columns); + } +} From bcc1e461d936d7e37ec1e057babd06ae1a1d85c3 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 19 Jun 2025 15:15:27 +0200 Subject: [PATCH 45/54] Resolve merge conflicts --- benches/common_select.rs | 3 +-- benches/mocked_select.rs | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/benches/common_select.rs b/benches/common_select.rs index 0fb00dd5..2b1ebaf1 100644 --- a/benches/common_select.rs +++ b/benches/common_select.rs @@ -2,7 +2,6 @@ use clickhouse::query::RowCursor; use clickhouse::{Client, Compression, Row}; -use criterion::black_box; use serde::Deserialize; use std::time::{Duration, Instant}; @@ -103,7 +102,7 @@ pub(crate) async fn do_select_bench<'a, T: BenchmarkRow<'a>>( let mut sum = 0; while let Some(row) = cursor.next().await.unwrap() { sum += row.id(); - black_box(&row); + std::hint::black_box(&row); } BenchmarkStats::new(&cursor, &start, sum) diff --git a/benches/mocked_select.rs b/benches/mocked_select.rs index c72db61f..e9fcef75 100644 --- a/benches/mocked_select.rs +++ b/benches/mocked_select.rs @@ -4,7 +4,7 @@ use clickhouse::{ Client, Compression, Row, }; use clickhouse_types::{Column, DataTypeNode}; -use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use futures::stream::{self, StreamExt as _}; use http_body_util::StreamBody; use hyper::{ @@ -108,7 +108,7 @@ fn select(c: &mut Criterion) { sum.d = sum.d.wrapping_add(row.d); } - black_box(sum); + std::hint::black_box(sum); let elapsed = start.elapsed(); Ok(elapsed) @@ -128,7 +128,7 @@ fn select(c: &mut Criterion) { let mut size = 0; while size < min_size { - let buf = black_box(cursor.next().await?); + let buf = std::hint::black_box(cursor.next().await?); size += buf.unwrap().len() as u64; } From a879945ce3f32f248f4fa13c1044ac4bc34ab129 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 19 Jun 2025 15:16:32 +0200 Subject: [PATCH 46/54] fix cargo fmt --- benches/mocked_insert.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benches/mocked_insert.rs b/benches/mocked_insert.rs index 1e2d1339..cdbef62c 100644 --- a/benches/mocked_insert.rs +++ b/benches/mocked_insert.rs @@ -3,13 +3,13 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use http_body_util::Empty; use hyper::{body::Incoming, Request, Response}; use serde::Serialize; +use std::hint::black_box; use std::net::SocketAddr; use std::{ future::Future, mem, time::{Duration, Instant}, }; -use std::hint::black_box; use clickhouse::{error::Result, Client, Compression, Row}; From b094dd08b593c15377921147952bfa3275e37ada Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Thu, 19 Jun 2025 19:50:14 +0200 Subject: [PATCH 47/54] Fix docs, tests --- src/lib.rs | 2 +- types/src/data_types.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index f3df63ec..286d0fca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -485,6 +485,6 @@ mod client_tests { let client = client.with_validation(false); assert!(!client.validation); let client = client.with_validation(true); - assert!(!client.validation); + assert!(client.validation); } } diff --git a/types/src/data_types.rs b/types/src/data_types.rs index 3182406a..d61e1979 100644 --- a/types/src/data_types.rs +++ b/types/src/data_types.rs @@ -26,7 +26,7 @@ impl Display for Column { } /// Represents a data type in ClickHouse. -/// See https://clickhouse.com/docs/sql-reference/data-types +/// See #[derive(Debug, Clone, PartialEq)] #[non_exhaustive] #[allow(missing_docs)] @@ -99,7 +99,7 @@ pub enum DataTypeNode { impl DataTypeNode { /// Parses a data type from a string that is received /// in the `RowBinaryWithNamesAndTypes` and `Native` formats headers. - /// See also: https://clickhouse.com/docs/interfaces/formats/RowBinaryWithNamesAndTypes#description + /// See also: pub fn new(name: &str) -> Result { match name { "UInt8" => Ok(Self::UInt8), @@ -269,7 +269,7 @@ impl Display for EnumType { /// DateTime64 precision. /// Defined as an enum, as it is valid only in the range from 0 to 9. -/// See also: https://clickhouse.com/docs/sql-reference/data-types/datetime64 +/// See also: #[derive(Debug, Clone, PartialEq)] #[allow(missing_docs)] pub enum DateTimePrecision { @@ -307,7 +307,7 @@ impl DateTimePrecision { } /// Represents the underlying integer type for a Decimal. -/// See also: https://clickhouse.com/docs/sql-reference/data-types/decimal +/// See also: #[derive(Debug, Clone, PartialEq)] pub enum DecimalType { /// Stored as an `Int32` From c449ee2cd1be2c2f7ff4ebc7732f7c6bb1336e6e Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Fri, 20 Jun 2025 18:15:33 +0200 Subject: [PATCH 48/54] Update CHANGELOG.md, README.md --- CHANGELOG.md | 2 +- README.md | 158 +++++++++++++++++++++++++++++++++++---------------- 2 files changed, 110 insertions(+), 50 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 17bcc8b5..565583d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - client: added `Client::with_validation` builder method. Validation is enabled by default, meaning that `RowBinaryWithNamesAndTypes` format will be used to fetch rows from the database. If validation is disabled, `RowBinary` format will be used, similarly to the previous versions. ([#221]). -- types: a new crate `clickhouse-types` was added to the project workspace. This crate is required for `RowBinaryWithNamesAndTypes` struct definition validation, as it contains ClickHouse data types AST, as well as functions and utilities to parse the types out of the ClickHouse server response. Note that this crate is not intended for public usage, as it might introduce internal breaking changes not following semver. ([#221]). +- types: a new crate `clickhouse-types` was added to the project workspace. This crate is required for `RowBinaryWithNamesAndTypes` struct definition validation, as it contains ClickHouse data types AST, as well as functions and utilities to parse the types out of the ClickHouse server response. ([#221]). [#221]: https://github.com/ClickHouse/clickhouse-rs/pull/221 diff --git a/README.md b/README.md index 65a2ab9b..7caeec71 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,8 @@ Official pure Rust typed client for ClickHouse DB. * Uses `serde` for encoding/decoding rows. * Supports `serde` attributes: `skip_serializing`, `skip_deserializing`, `rename`. -* Uses `RowBinary` encoding over HTTP transport. +* Uses `RowBinaryWithNamesAndTypes` format over HTTP transport with struct definition validation against the database schema. + * It is possible to use `RowBinary` instead, disabling the validation, which can potentially lead to increased performance ([see below](#validation)). * There are plans to switch to `Native` over TCP. * Supports TLS (see `native-tls` and `rustls-tls` features below). * Supports compression and decompression (LZ4 and LZ4HC). @@ -30,9 +31,24 @@ Official pure Rust typed client for ClickHouse DB. Note: [ch2rs](https://github.com/ClickHouse/ch2rs) is useful to generate a row type from ClickHouse. +## Validation + +Starting from 0.14.0, the crate supports validation of the row types against the ClickHouse schema, as +`RowBinaryWithNamesAndTypes` format is used by default instead of `RowBinary`. Additionally, with enabled validation, +the crate supports structs with correct field names and matching types, but incorrect order of the fields, +with a slight (5-10%) performance penalty. + +If you want to disable validation entirely, essentially reverting the client behavior to pre-0.14.0, you can use +`Client::with_validation(false)`, which will switch the fetch format to `RowBinary` instead. + +Depending on the dataset, disabling validation can yield from x1.1 to x3 performance improvement, +but it is not recommended to use it in production, as it can lead to unclear runtime errors +if the row types do not match the ClickHouse schema. + ## Usage To use the crate, add this to your `Cargo.toml`: + ```toml [dependencies] clickhouse = "0.13.3" @@ -44,16 +60,6 @@ clickhouse = { version = "0.13.3", features = ["test-util"] }
-### Note about ClickHouse prior to v22.6 - - - -CH server older than v22.6 (2022-06-16) handles `RowBinary` [incorrectly](https://github.com/ClickHouse/ClickHouse/issues/37420) in some rare cases. Use 0.11 and enable `wa-37420` feature to solve this problem. Don't use it for newer versions. - -
-
- - ### Create a client @@ -102,7 +108,10 @@ while let Some(row) = cursor.next().await? { .. } * Convenient `fetch_one::()` and `fetch_all::()` can be used to get a first row or all rows correspondingly. * `sql::Identifier` can be used to bind table names. -Note that cursors can return an error even after producing some rows. To avoid this, use `client.with_option("wait_end_of_query", "1")` in order to enable buffering on the server-side. [More details](https://clickhouse.com/docs/en/interfaces/http/#response-buffering). The `buffer_size` option can be useful too. +Note that cursors can return an error even after producing some rows. To avoid this, use +`client.with_option("wait_end_of_query", "1")` in order to enable buffering on the +server-side. [More details](https://clickhouse.com/docs/en/interfaces/http/#response-buffering). The `buffer_size` +option can be useful too.
@@ -130,7 +139,8 @@ insert.end().await?; * If `end()` isn't called, the `INSERT` is aborted. * Rows are being sent progressively to spread network load. -* ClickHouse inserts batches atomically only if all rows fit in the same partition and their number is less [`max_insert_block_size`](https://clickhouse.com/docs/en/operations/settings/settings#max_insert_block_size). +* ClickHouse inserts batches atomically only if all rows fit in the same partition and their number is less [ + `max_insert_block_size`](https://clickhouse.com/docs/en/operations/settings/settings#max_insert_block_size).
@@ -160,14 +170,19 @@ if stats.rows > 0 { } ``` -Please, read [examples](https://github.com/ClickHouse/clickhouse-rs/tree/main/examples/inserter.rs) to understand how to use it properly in different real-world cases. +Please, read [examples](https://github.com/ClickHouse/clickhouse-rs/tree/main/examples/inserter.rs) to understand how to +use it properly in different real-world cases. * `Inserter` ends an active insert in `commit()` if thresholds (`max_bytes`, `max_rows`, `period`) are reached. -* The interval between ending active `INSERT`s can be biased by using `with_period_bias` to avoid load spikes by parallel inserters. -* `Inserter::time_left()` can be used to detect when the current period ends. Call `Inserter::commit()` again to check limits if your stream emits items rarely. -* Time thresholds implemented by using [quanta](https://docs.rs/quanta) crate to speed the inserter up. Not used if `test-util` is enabled (thus, time can be managed by `tokio::time::advance()` in custom tests). +* The interval between ending active `INSERT`s can be biased by using `with_period_bias` to avoid load spikes by + parallel inserters. +* `Inserter::time_left()` can be used to detect when the current period ends. Call `Inserter::commit()` again to check + limits if your stream emits items rarely. +* Time thresholds implemented by using [quanta](https://docs.rs/quanta) crate to speed the inserter up. Not used if + `test-util` is enabled (thus, time can be managed by `tokio::time::advance()` in custom tests). * All rows between `commit()` calls are inserted in the same `INSERT` statement. * Do not forget to flush if you want to terminate inserting: + ```rust,ignore inserter.end().await?; ``` @@ -208,9 +223,11 @@ println!("live view updated: version={:?}", cursor.next().await?); ``` * Use [carefully](https://github.com/ClickHouse/ClickHouse/issues/28309#issuecomment-908666042). -* This code uses or creates if not exists a temporary live view named `lv_{sha1(query)}` to reuse the same live view by parallel watchers. +* This code uses or creates if not exists a temporary live view named `lv_{sha1(query)}` to reuse the same live view by + parallel watchers. * You can specify a name instead of a query. -* This API uses `JSONEachRowWithProgress` under the hood because of [the issue](https://github.com/ClickHouse/ClickHouse/issues/22996). +* This API uses `JSONEachRowWithProgress` under the hood because + of [the issue](https://github.com/ClickHouse/ClickHouse/issues/22996). * Only struct rows can be used. Avoid `fetch::()` and other without specified names.
@@ -218,16 +235,21 @@ println!("live view updated: version={:?}", cursor.next().await?); See [examples](https://github.com/ClickHouse/clickhouse-rs/tree/main/examples). ## Feature Flags -* `lz4` (enabled by default) — enables `Compression::Lz4`. If enabled, `Compression::Lz4` is used by default for all queries except for `WATCH`. + +* `lz4` (enabled by default) — enables `Compression::Lz4`. If enabled, `Compression::Lz4` is used by default for all + queries except for `WATCH`. * `inserter` — enables `client.inserter()`. -* `test-util` — adds mocks. See [the example](https://github.com/ClickHouse/clickhouse-rs/tree/main/examples/mock.rs). Use it only in `dev-dependencies`. +* `test-util` — adds mocks. See [the example](https://github.com/ClickHouse/clickhouse-rs/tree/main/examples/mock.rs). + Use it only in `dev-dependencies`. * `watch` — enables `client.watch` functionality. See the corresponding section for details. * `uuid` — adds `serde::uuid` to work with [uuid](https://docs.rs/uuid) crate. * `time` — adds `serde::time` to work with [time](https://docs.rs/time) crate. * `chrono` — adds `serde::chrono` to work with [chrono](https://docs.rs/chrono) crate. ### TLS + By default, TLS is disabled and one or more following features must be enabled to use HTTPS urls: + * `native-tls` — uses [native-tls], utilizing dynamic linking (e.g. against OpenSSL). * `rustls-tls` — enables `rustls-tls-aws-lc` and `rustls-tls-webpki-roots` features. * `rustls-tls-aws-lc` — uses [rustls] with the `aws-lc` cryptography implementation. @@ -236,26 +258,37 @@ By default, TLS is disabled and one or more following features must be enabled t * `rustls-tls-native-roots` — uses [rustls] with certificates provided by the [rustls-native-certs] crate. If multiple features are enabled, the following priority is applied: + * `native-tls` > `rustls-tls-aws-lc` > `rustls-tls-ring` * `rustls-tls-native-roots` > `rustls-tls-webpki-roots` How to choose between all these features? Here are some considerations: + * A good starting point is `rustls-tls`, e.g. if you use ClickHouse Cloud. * To be more environment-agnostic, prefer `rustls-tls` over `native-tls`. * Enable `rustls-tls-native-roots` or `native-tls` if you want to use self-signed certificates. [native-tls]: https://docs.rs/native-tls + [rustls]: https://docs.rs/rustls + [webpki-roots]: https://docs.rs/webpki-roots + [rustls-native-certs]: https://docs.rs/rustls-native-certs ## Data Types + * `(U)Int(8|16|32|64|128)` maps to/from corresponding `(u|i)(8|16|32|64|128)` types or newtypes around them. -* `(U)Int256` aren't supported directly, but there is [a workaround for it](https://github.com/ClickHouse/clickhouse-rs/issues/48). +* `(U)Int256` aren't supported directly, but there + is [a workaround for it](https://github.com/ClickHouse/clickhouse-rs/issues/48). * `Float(32|64)` maps to/from corresponding `f(32|64)` or newtypes around them. -* `Decimal(32|64|128)` maps to/from corresponding `i(32|64|128)` or newtypes around them. It's more convenient to use [fixnum](https://github.com/loyd/fixnum) or another implementation of signed fixed-point numbers. +* `Decimal(32|64|128)` maps to/from corresponding `i(32|64|128)` or newtypes around them. It's more convenient to + use [fixnum](https://github.com/loyd/fixnum) or another implementation of signed fixed-point numbers. * `Boolean` maps to/from `bool` or newtypes around it. -* `String` maps to/from any string or bytes types, e.g. `&str`, `&[u8]`, `String`, `Vec` or [`SmartString`](https://docs.rs/smartstring/latest/smartstring/struct.SmartString.html). Newtypes are also supported. To store bytes, consider using [serde_bytes](https://docs.rs/serde_bytes/latest/serde_bytes/), because it's more efficient. +* `String` maps to/from any string or bytes types, e.g. `&str`, `&[u8]`, `String`, `Vec` or [ + `SmartString`](https://docs.rs/smartstring/latest/smartstring/struct.SmartString.html). Newtypes are also supported. + To store bytes, consider using [serde_bytes](https://docs.rs/serde_bytes/latest/serde_bytes/), because it's more + efficient.
Example @@ -274,7 +307,7 @@ How to choose between all these features? Here are some considerations: * `FixedString(N)` is supported as an array of bytes, e.g. `[u8; N]`.
Example - + ```rust,ignore #[derive(Row, Debug, Serialize, Deserialize)] struct MyRow { @@ -282,7 +315,8 @@ How to choose between all these features? Here are some considerations: } ```
-* `Enum(8|16)` are supported using [serde_repr](https://docs.rs/serde_repr/latest/serde_repr/). You could use `#[repr(i8)]` for `Enum8` and `#[repr(i16)]` for `Enum16`. +* `Enum(8|16)` are supported using [serde_repr](https://docs.rs/serde_repr/latest/serde_repr/). You could use + `#[repr(i8)]` for `Enum8` and `#[repr(i16)]` for `Enum16`.
Example @@ -304,7 +338,8 @@ How to choose between all these features? Here are some considerations: } ```
-* `UUID` maps to/from [`uuid::Uuid`](https://docs.rs/uuid/latest/uuid/struct.Uuid.html) by using `serde::uuid`. Requires the `uuid` feature. +* `UUID` maps to/from [`uuid::Uuid`](https://docs.rs/uuid/latest/uuid/struct.Uuid.html) by using `serde::uuid`. Requires + the `uuid` feature.
Example @@ -317,7 +352,8 @@ How to choose between all these features? Here are some considerations: ```
* `IPv6` maps to/from [`std::net::Ipv6Addr`](https://doc.rust-lang.org/stable/std/net/struct.Ipv6Addr.html). -* `IPv4` maps to/from [`std::net::Ipv4Addr`](https://doc.rust-lang.org/stable/std/net/struct.Ipv4Addr.html) by using `serde::ipv4`. +* `IPv4` maps to/from [`std::net::Ipv4Addr`](https://doc.rust-lang.org/stable/std/net/struct.Ipv4Addr.html) by using + `serde::ipv4`.
Example @@ -329,9 +365,12 @@ How to choose between all these features? Here are some considerations: } ```
-* `Date` maps to/from `u16` or a newtype around it and represents a number of days elapsed since `1970-01-01`. The following external types are supported: - * [`time::Date`](https://docs.rs/time/latest/time/struct.Date.html) is supported by using `serde::time::date`, requiring the `time` feature. - * [`chrono::NaiveDate`](https://docs.rs/chrono/latest/chrono/struct.NaiveDate.html) is supported by using `serde::chrono::date`, requiring the `chrono` feature. +* `Date` maps to/from `u16` or a newtype around it and represents a number of days elapsed since `1970-01-01`. The + following external types are supported: + * [`time::Date`](https://docs.rs/time/latest/time/struct.Date.html) is supported by using `serde::time::date`, + requiring the `time` feature. + * [`chrono::NaiveDate`](https://docs.rs/chrono/latest/chrono/struct.NaiveDate.html) is supported by using + `serde::chrono::date`, requiring the `chrono` feature.
Example @@ -348,9 +387,12 @@ How to choose between all these features? Here are some considerations: ```
-* `Date32` maps to/from `i32` or a newtype around it and represents a number of days elapsed since `1970-01-01`. The following external types are supported: - * [`time::Date`](https://docs.rs/time/latest/time/struct.Date.html) is supported by using `serde::time::date32`, requiring the `time` feature. - * [`chrono::NaiveDate`](https://docs.rs/chrono/latest/chrono/struct.NaiveDate.html) is supported by using `serde::chrono::date32`, requiring the `chrono` feature. +* `Date32` maps to/from `i32` or a newtype around it and represents a number of days elapsed since `1970-01-01`. The + following external types are supported: + * [`time::Date`](https://docs.rs/time/latest/time/struct.Date.html) is supported by using `serde::time::date32`, + requiring the `time` feature. + * [`chrono::NaiveDate`](https://docs.rs/chrono/latest/chrono/struct.NaiveDate.html) is supported by using + `serde::chrono::date32`, requiring the `chrono` feature.
Example @@ -368,9 +410,12 @@ How to choose between all these features? Here are some considerations: ```
-* `DateTime` maps to/from `u32` or a newtype around it and represents a number of seconds elapsed since UNIX epoch. The following external types are supported: - * [`time::OffsetDateTime`](https://docs.rs/time/latest/time/struct.OffsetDateTime.html) is supported by using `serde::time::datetime`, requiring the `time` feature. - * [`chrono::DateTime`](https://docs.rs/chrono/latest/chrono/struct.DateTime.html) is supported by using `serde::chrono::datetime`, requiring the `chrono` feature. +* `DateTime` maps to/from `u32` or a newtype around it and represents a number of seconds elapsed since UNIX epoch. The + following external types are supported: + * [`time::OffsetDateTime`](https://docs.rs/time/latest/time/struct.OffsetDateTime.html) is supported by using + `serde::time::datetime`, requiring the `time` feature. + * [`chrono::DateTime`](https://docs.rs/chrono/latest/chrono/struct.DateTime.html) is supported by using + `serde::chrono::datetime`, requiring the `chrono` feature.
Example @@ -386,9 +431,12 @@ How to choose between all these features? Here are some considerations: } ```
-* `DateTime64(_)` maps to/from `i64` or a newtype around it and represents a time elapsed since UNIX epoch. The following external types are supported: - * [`time::OffsetDateTime`](https://docs.rs/time/latest/time/struct.OffsetDateTime.html) is supported by using `serde::time::datetime64::*`, requiring the `time` feature. - * [`chrono::DateTime`](https://docs.rs/chrono/latest/chrono/struct.DateTime.html) is supported by using `serde::chrono::datetime64::*`, requiring the `chrono` feature. +* `DateTime64(_)` maps to/from `i64` or a newtype around it and represents a time elapsed since UNIX epoch. The + following external types are supported: + * [`time::OffsetDateTime`](https://docs.rs/time/latest/time/struct.OffsetDateTime.html) is supported by using + `serde::time::datetime64::*`, requiring the `time` feature. + * [`chrono::DateTime`](https://docs.rs/chrono/latest/chrono/struct.DateTime.html) is supported by using + `serde::chrono::datetime64::*`, requiring the `chrono` feature.
Example @@ -420,7 +468,7 @@ How to choose between all these features? Here are some considerations:
* `Tuple(A, B, ...)` maps to/from `(A, B, ...)` or a newtype around it. * `Array(_)` maps to/from any slice, e.g. `Vec<_>`, `&[_]`. Newtypes are also supported. -* `Map(K, V)` behaves like `Array((K, V))`. +* `Map(K, V)` can be deserialized as `HashMap` or `Vec<(K, V)>`. * `LowCardinality(_)` is supported seamlessly. * `Nullable(_)` maps to/from `Option<_>`. For `clickhouse::serde::*` helpers add `::option`.
@@ -449,7 +497,8 @@ How to choose between all these features? Here are some considerations: } ```
-* `Geo` types are supported. `Point` behaves like a tuple `(f64, f64)`, and the rest of the types are just slices of points. +* `Geo` types are supported. `Point` behaves like a tuple `(f64, f64)`, and the rest of the types are just slices of + points.
Example @@ -472,10 +521,13 @@ How to choose between all these features? Here are some considerations: } ```
-* `Variant` data type is supported as a Rust enum. As the inner Variant types are _always_ sorted alphabetically, Rust enum variants should be defined in the _exactly_ same order as it is in the data type; their names are irrelevant, only the order of the types matters. This following example has a column defined as `Variant(Array(UInt16), Bool, Date, String, UInt32)`: +* `Variant` data type is supported as a Rust enum. As the inner Variant types are _always_ sorted alphabetically, Rust + enum variants should be defined in the _exactly_ same order as it is in the data type; their names are irrelevant, + only the order of the types matters. This following example has a column defined as + `Variant(Array(UInt16), Bool, Date, String, UInt32)`:
Example - + ```rust,ignore #[derive(Serialize, Deserialize)] enum MyRowVariant { @@ -494,16 +546,24 @@ How to choose between all these features? Here are some considerations: } ```
-* [New `JSON` data type](https://clickhouse.com/docs/en/sql-reference/data-types/newjson) is currently supported as a string when using ClickHouse 24.10+. See [this example](examples/data_types_new_json.rs) for more details. +* [New `JSON` data type](https://clickhouse.com/docs/en/sql-reference/data-types/newjson) is currently supported as a + string when using ClickHouse 24.10+. See [this example](examples/data_types_new_json.rs) for more details. * `Dynamic` data type is not supported for now. -See also the additional examples: +### See also + +- Examples of deriving ClickHouse data types: -* [Simpler ClickHouse data types](examples/data_types_derive_simple.rs) -* [Container-like ClickHouse data types](examples/data_types_derive_containers.rs) -* [Variant data type](examples/data_types_variant.rs) + * [Simpler ClickHouse data types](examples/data_types_derive_simple.rs) + * [Container-like ClickHouse data types](examples/data_types_derive_containers.rs) + * [Variant data type](examples/data_types_variant.rs) + +- Integration tests that cover most of the data types: + + * [RowBinaryWithNamesAndTypes](tests/it/rbwnat.rs) ## Mocking + The crate provides utils for mocking CH server and testing DDL, `SELECT`, `INSERT` and `WATCH` queries. The functionality can be enabled with the `test-util` feature. Use it **only** in dev-dependencies. From e1706f4556243dc54e8fabaa4c20ee0656a2eef2 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Fri, 20 Jun 2025 18:40:23 +0200 Subject: [PATCH 49/54] Update client usage with mocks --- CHANGELOG.md | 21 +++++++++++++++++---- examples/mock.rs | 8 +++----- src/lib.rs | 30 ++++++++++++++++++++++++++++++ src/query.rs | 2 +- tests/it/mock.rs | 4 +--- 5 files changed, 52 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 565583d9..ed531a9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,12 +10,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Breaking Changes -- query: `RowBinaryWithNamesAndTypes` is now used by default for query results. This may cause panics if the row struct definition does not match the database schema. Use `Client::with_validation(false)` to revert to the previous behavior which uses plain `RowBinary` format for fetching rows. ([#221]) -- query: due to `RowBinaryWithNamesAndTypes` format usage, there might be an impact on fetch performance, which largely depends on how the dataset is defined. If you experience performance issues, consider disabling validation by using `Client::with_validation(false)`. +- query: `RowBinaryWithNamesAndTypes` is now used by default for query results. This may cause panics if the row struct + definition does not match the database schema. Use `Client::with_validation(false)` to revert to the previous behavior + which uses plain `RowBinary` format for fetching rows. ([#221]) +- query: due to `RowBinaryWithNamesAndTypes` format usage, there might be an impact on fetch performance, which largely + depends on how the dataset is defined. If you experience performance issues, consider disabling validation by using + `Client::with_validation(false)`. +- mock: when using `test-util` feature, it is now required to use `Client::with_mock(&mock)` to set up the mock server, + so it properly handles the response format and automatically disables parsing `RowBinaryWithNamesAndTypes` header + parsing and validation. Additionally, it is not required to call `with_url` explicitly. + See the [updated example](./examples/mock.rs). ### Added -- client: added `Client::with_validation` builder method. Validation is enabled by default, meaning that `RowBinaryWithNamesAndTypes` format will be used to fetch rows from the database. If validation is disabled, `RowBinary` format will be used, similarly to the previous versions. ([#221]). -- types: a new crate `clickhouse-types` was added to the project workspace. This crate is required for `RowBinaryWithNamesAndTypes` struct definition validation, as it contains ClickHouse data types AST, as well as functions and utilities to parse the types out of the ClickHouse server response. ([#221]). + +- client: added `Client::with_validation` builder method. Validation is enabled by default, meaning that + `RowBinaryWithNamesAndTypes` format will be used to fetch rows from the database. If validation is disabled, + `RowBinary` format will be used, similarly to the previous versions. ([#221]). +- types: a new crate `clickhouse-types` was added to the project workspace. This crate is required for + `RowBinaryWithNamesAndTypes` struct definition validation, as it contains ClickHouse data types AST, as well as + functions and utilities to parse the types out of the ClickHouse server response. ([#221]). [#221]: https://github.com/ClickHouse/clickhouse-rs/pull/221 diff --git a/examples/mock.rs b/examples/mock.rs index e63950e2..c37d1e94 100644 --- a/examples/mock.rs +++ b/examples/mock.rs @@ -46,11 +46,9 @@ async fn make_watch_only_events(client: &Client) -> Result { #[tokio::main] async fn main() { let mock = test::Mock::new(); - let client = Client::default() - .with_url(mock.url()) - // disabled schema validation is required for mocks to work; - // it is pointless for mocked tests anyway - .with_validation(false); + // Note that an explicit `with_url` call is not required, + // it will be set automatically to the mock server URL. + let client = Client::default().with_mock(&mock); let list = vec![SomeRow { no: 1 }, SomeRow { no: 2 }]; // How to test DDL. diff --git a/src/lib.rs b/src/lib.rs index 286d0fca..27c69174 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,6 +48,9 @@ pub struct Client { headers: HashMap, products_info: Vec, validation: bool, + + #[cfg(feature = "test-util")] + mocked: bool, } #[derive(Clone)] @@ -103,6 +106,8 @@ impl Client { headers: HashMap::new(), products_info: Vec::default(), validation: true, + #[cfg(feature = "test-util")] + mocked: false, } } @@ -341,11 +346,36 @@ impl Client { self } + /// Used internally to check if the validation mode is enabled, + /// as it takes into account the `test-util` feature flag. + #[inline] + pub(crate) fn get_validation(&self) -> bool { + #[cfg(feature = "test-util")] + if self.mocked { + return false; + } + self.validation + } + /// Used internally to modify the options map of an _already cloned_ /// [`Client`] instance. pub(crate) fn add_option(&mut self, name: impl Into, value: impl Into) { self.options.insert(name.into(), value.into()); } + + /// Use a mock server for testing purposes. + /// + /// # Note + /// + /// The client will always use `RowBinary` format instead of `RowBinaryWithNamesAndTypes`, + /// as otherwise it'd be required to provide RBWNAT header in the mocks, + /// which is pointless in that kind of tests. + #[cfg(feature = "test-util")] + pub fn with_mock(mut self, mock: &test::Mock) -> Self { + self.url = mock.url().to_string(); + self.mocked = true; + self + } } /// This is a private API exported only for internal purposes. diff --git a/src/query.rs b/src/query.rs index 095ab7f0..346836c6 100644 --- a/src/query.rs +++ b/src/query.rs @@ -86,7 +86,7 @@ impl Query { pub fn fetch(mut self) -> Result> { self.sql.bind_fields::(); - let validation = self.client.validation; + let validation = self.client.get_validation(); if validation { self.sql.set_output_format("RowBinaryWithNamesAndTypes"); } else { diff --git a/tests/it/mock.rs b/tests/it/mock.rs index 5ea5a2e4..3cc92481 100644 --- a/tests/it/mock.rs +++ b/tests/it/mock.rs @@ -6,9 +6,7 @@ use std::time::Duration; async fn test_provide() { let mock = test::Mock::new(); - let client = Client::default() - .with_url(mock.url()) - .with_validation(false); + let client = Client::default().with_mock(&mock); let expected = vec![SimpleRow::new(1, "one"), SimpleRow::new(2, "two")]; mock.add(test::handlers::provide(&expected)); From 3c08c771d8e4bc8e9db75f672aafa449d2b9dcf7 Mon Sep 17 00:00:00 2001 From: Paul Loyd Date: Sat, 21 Jun 2025 14:51:54 +0400 Subject: [PATCH 50/54] chore: stop using nightly-only features of rustfmt --- rustfmt.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/rustfmt.toml b/rustfmt.toml index 5f62e976..ef4162c2 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,6 +1,2 @@ edition = "2021" merge_derives = false -imports_granularity = "Crate" -normalize_comments = true -reorder_impl_items = true -wrap_comments = true From 244d5872a2339e6a2e2e1f9ea7b5a8ff9e2f084a Mon Sep 17 00:00:00 2001 From: Paul Loyd Date: Sat, 21 Jun 2025 17:57:39 +0400 Subject: [PATCH 51/54] refactor(rowbinary/de): dedup code --- src/rowbinary/de.rs | 187 +++++++++++------------------------- src/rowbinary/validation.rs | 36 ++++--- 2 files changed, 76 insertions(+), 147 deletions(-) diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index ea18a724..e6d84193 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -49,24 +49,35 @@ pub(crate) fn deserialize_rbwnat<'data, 'cursor, T: Deserialize<'data> + Row>( /// A deserializer for the `RowBinary(WithNamesAndTypes)` format. /// /// See https://clickhouse.com/docs/en/interfaces/formats#rowbinary for details. -struct RowBinaryDeserializer<'cursor, 'data, R: Row, Validator = ()> +struct RowBinaryDeserializer<'cursor, 'data, R: Row, V = ()> where - Validator: SchemaValidator, + V: SchemaValidator, { - validator: Validator, input: &'cursor mut &'data [u8], + validator: V, _marker: PhantomData, } -impl<'cursor, 'data, R: Row, Validator> RowBinaryDeserializer<'cursor, 'data, R, Validator> +impl<'cursor, 'data, R: Row, V> RowBinaryDeserializer<'cursor, 'data, R, V> where - Validator: SchemaValidator, + V: SchemaValidator, { - fn new(input: &'cursor mut &'data [u8], validator: Validator) -> Self { + fn new(input: &'cursor mut &'data [u8], validator: V) -> Self { Self { input, validator, - _marker: PhantomData::, + _marker: PhantomData, + } + } + + fn inner( + &mut self, + serde_type: SerdeType, + ) -> RowBinaryDeserializer<'_, 'data, R, V::Inner<'_>> { + RowBinaryDeserializer { + input: self.input, + validator: self.validator.validate(serde_type), + _marker: PhantomData, } } @@ -92,9 +103,7 @@ macro_rules! impl_num { ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr) => { #[inline(always)] fn $deser_method>(self, visitor: V) -> Result { - if Validator::VALIDATION { - self.validator.validate($serde_type); - } + self.validator.validate($serde_type); ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; let value = self.input.$reader_method(); visitor.$visitor_method(value) @@ -106,16 +115,11 @@ macro_rules! impl_num_or_enum { ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr) => { #[inline(always)] fn $deser_method>(self, visitor: V) -> Result { - if Validator::VALIDATION { - let mut maybe_enum_validator = self.validator.validate($serde_type); - ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; - let value = self.input.$reader_method(); - maybe_enum_validator.validate_identifier::<$ty>(value); - visitor.$visitor_method(value) - } else { - ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; - visitor.$visitor_method(self.input.$reader_method()) - } + let mut maybe_enum_validator = self.validator.validate($serde_type); + ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; + let value = self.input.$reader_method(); + maybe_enum_validator.validate_identifier::<$ty>(value); + visitor.$visitor_method(value) } }; } @@ -159,9 +163,7 @@ where #[inline(always)] fn deserialize_bool>(self, visitor: V) -> Result { - if Validator::VALIDATION { - self.validator.validate(SerdeType::Bool); - } + self.validator.validate(SerdeType::Bool); ensure_size(&mut self.input, 1)?; match self.input.get_u8() { 0 => visitor.visit_bool(false), @@ -172,9 +174,7 @@ where #[inline(always)] fn deserialize_str>(self, visitor: V) -> Result { - if Validator::VALIDATION { - self.validator.validate(SerdeType::Str); - } + self.validator.validate(SerdeType::Str); let size = self.read_size()?; let slice = self.read_slice(size)?; let str = str::from_utf8(slice).map_err(Error::from)?; @@ -183,9 +183,7 @@ where #[inline(always)] fn deserialize_string>(self, visitor: V) -> Result { - if Validator::VALIDATION { - self.validator.validate(SerdeType::String); - } + self.validator.validate(SerdeType::String); let size = self.read_size()?; let vec = self.read_vec(size)?; let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; @@ -195,9 +193,7 @@ where #[inline(always)] fn deserialize_bytes>(self, visitor: V) -> Result { let size = self.read_size()?; - if Validator::VALIDATION { - self.validator.validate(SerdeType::Bytes(size)); - } + self.validator.validate(SerdeType::Bytes(size)); let slice = self.read_slice(size)?; visitor.visit_borrowed_bytes(slice) } @@ -205,9 +201,7 @@ where #[inline(always)] fn deserialize_byte_buf>(self, visitor: V) -> Result { let size = self.read_size()?; - if Validator::VALIDATION { - self.validator.validate(SerdeType::ByteBuf(size)); - } + self.validator.validate(SerdeType::ByteBuf(size)); visitor.visit_byte_buf(self.read_vec(size)?) } @@ -219,11 +213,9 @@ where ensure_size(&mut self.input, size_of::())?; let value = self.input.get_u8(); // TODO: is there a better way to validate that the deserialized value matches the schema? - if Validator::VALIDATION { - // TODO: theoretically, we can track if we are currently processing a struct field id, - // and don't call the validator in that case, cause it will never be a `Variant`. - self.validator.validate_identifier::(value); - } + // TODO: theoretically, we can track if we are currently processing a struct field id, + // and don't call the validator in that case, cause it will never be a `Variant`. + self.validator.validate_identifier::(value); visitor.visit_u8(value) } @@ -234,103 +226,44 @@ where _variants: &'static [&'static str], visitor: V, ) -> Result { - if Validator::VALIDATION { - visitor.visit_enum(RowBinaryEnumAccess { - deserializer: &mut RowBinaryDeserializer { - input: self.input, - validator: self.validator.validate(SerdeType::Enum), - _marker: PhantomData::, - }, - }) - } else { - visitor.visit_enum(RowBinaryEnumAccess { deserializer: self }) - } + let deserializer = &mut self.inner(SerdeType::Enum); + visitor.visit_enum(RowBinaryEnumAccess { deserializer }) } #[inline(always)] fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { - if Validator::VALIDATION { - visitor.visit_seq(RowBinarySeqAccess { - deserializer: &mut RowBinaryDeserializer { - input: self.input, - validator: self.validator.validate(SerdeType::Tuple(len)), - _marker: PhantomData::, - }, - len, - }) - } else { - visitor.visit_seq(RowBinarySeqAccess { - deserializer: self, - len, - }) - } + let deserializer = &mut self.inner(SerdeType::Tuple(len)); + visitor.visit_seq(RowBinarySeqAccess { deserializer, len }) } #[inline(always)] fn deserialize_option>(self, visitor: V) -> Result { ensure_size(&mut self.input, 1)?; let is_null = self.input.get_u8(); - if Validator::VALIDATION { - let inner_validator = self.validator.validate(SerdeType::Option); - match is_null { - 0 => visitor.visit_some(&mut RowBinaryDeserializer { - input: self.input, - validator: inner_validator, - _marker: PhantomData::, - }), - 1 => visitor.visit_none(), - v => Err(Error::InvalidTagEncoding(v as usize)), - } - } else { - // a bit of copy-paste here, since Deserializer types are not exactly the same - match is_null { - 0 => visitor.visit_some(self), - 1 => visitor.visit_none(), - v => Err(Error::InvalidTagEncoding(v as usize)), - } + let deserializer = &mut self.inner(SerdeType::Option); + match is_null { + 0 => visitor.visit_some(deserializer), + 1 => visitor.visit_none(), + v => Err(Error::InvalidTagEncoding(v as usize)), } } #[inline(always)] fn deserialize_seq>(self, visitor: V) -> Result { let len = self.read_size()?; - if Validator::VALIDATION { - visitor.visit_seq(RowBinarySeqAccess { - deserializer: &mut RowBinaryDeserializer { - input: self.input, - validator: self.validator.validate(SerdeType::Seq(len)), - _marker: PhantomData::, - }, - len, - }) - } else { - visitor.visit_seq(RowBinarySeqAccess { - deserializer: self, - len, - }) - } + let deserializer = &mut self.inner(SerdeType::Seq(len)); + visitor.visit_seq(RowBinarySeqAccess { deserializer, len }) } #[inline(always)] fn deserialize_map>(self, visitor: V) -> Result { let len = self.read_size()?; - if Validator::VALIDATION { - visitor.visit_map(RowBinaryMapAccess { - deserializer: &mut RowBinaryDeserializer { - input: self.input, - validator: self.validator.validate(SerdeType::Map(len)), - _marker: PhantomData::, - }, - entries_visited: 0, - len, - }) - } else { - visitor.visit_map(RowBinaryMapAccess { - deserializer: self, - entries_visited: 0, - len, - }) - } + let deserializer = &mut self.inner(SerdeType::Map(len)); + visitor.visit_map(RowBinaryMapAccess { + deserializer, + entries_visited: 0, + len, + }) } #[inline(always)] @@ -340,25 +273,17 @@ where fields: &'static [&'static str], visitor: V, ) -> Result { - if Validator::VALIDATION { - if !self.validator.is_field_order_wrong() { - visitor.visit_seq(RowBinarySeqAccess { - deserializer: self, - len: fields.len(), - }) - } else { - visitor.visit_map(RowBinaryStructAsMapAccess { - deserializer: self, - current_field_idx: 0, - fields, - }) - } - } else { - // We can't detect incorrect field order with just plain `RowBinary` format + if !self.validator.is_field_order_wrong() { visitor.visit_seq(RowBinarySeqAccess { deserializer: self, len: fields.len(), }) + } else { + visitor.visit_map(RowBinaryStructAsMapAccess { + deserializer: self, + current_field_idx: 0, + fields, + }) } } diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index c5b2a56a..be1b3ec3 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -9,13 +9,15 @@ use std::marker::PhantomData; /// Note that [`SchemaValidator`] is also implemented for `()`, /// which is used to skip validation if the user disabled it. pub(crate) trait SchemaValidator: Sized { - /// Ensures that the branching is completely optimized out based on the validation settings. - const VALIDATION: bool; + type Inner<'de>: SchemaValidator + where + Self: 'de; + /// The main entry point. The validation flow based on the [`crate::Row::KIND`]. /// For container types (nullable, array, map, tuple, variant, etc.), /// it will return an [`InnerDataTypeValidator`] instance (see [`InnerDataTypeValidatorKind`]), /// which has its own implementation of this method, allowing recursive validation. - fn validate(&'_ mut self, serde_type: SerdeType) -> Option>; + fn validate(&mut self, serde_type: SerdeType) -> Self::Inner<'_>; /// Validates that an identifier exists in the values map for enums, /// or stores the variant identifier for the next serde call. fn validate_identifier(&mut self, value: T); @@ -112,11 +114,14 @@ impl<'cursor, R: Row> DataTypeValidator<'cursor, R> { } } -impl SchemaValidator for DataTypeValidator<'_, R> { - const VALIDATION: bool = true; +impl<'cursor, R: Row> SchemaValidator for DataTypeValidator<'cursor, R> { + type Inner<'de> + = Option> + where + Self: 'de; #[inline] - fn validate(&'_ mut self, serde_type: SerdeType) -> Option> { + fn validate(&'_ mut self, serde_type: SerdeType) -> Self::Inner<'_> { match R::KIND { // `fetch::` for a "primitive row" type RowKind::Primitive => { @@ -241,14 +246,14 @@ pub(crate) enum VariantValidationState { Identifier(u8), } -impl<'de, 'cursor, R: Row> SchemaValidator for Option> { - const VALIDATION: bool = true; +impl<'cursor, R: Row> SchemaValidator for Option> { + type Inner<'de> + = Self + where + Self: 'de; #[inline] - fn validate( - &mut self, - serde_type: SerdeType, - ) -> Option> { + fn validate(&mut self, serde_type: SerdeType) -> Self { match self { None => None, Some(inner) => match &mut inner.kind { @@ -631,15 +636,14 @@ fn validate_impl<'de, 'cursor, R: Row>( } impl SchemaValidator for () { - const VALIDATION: bool = false; + type Inner<'de> = (); #[inline(always)] - fn validate(&mut self, _serde_type: SerdeType) -> Option> { - None - } + fn validate(&mut self, _serde_type: SerdeType) {} #[inline(always)] fn is_field_order_wrong(&self) -> bool { + // We can't detect incorrect field order with just plain `RowBinary` format false } From d5af0b8fb9f926de22d1fca61437e27cc16fb109 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 23 Jun 2025 16:16:30 +0200 Subject: [PATCH 52/54] Address PR feedback --- README.md | 31 ++++--- src/cursors/row.rs | 8 +- src/lib.rs | 24 ++++-- src/rowbinary/de.rs | 34 ++++---- src/rowbinary/mod.rs | 4 +- src/rowbinary/tests.rs | 6 +- src/rowbinary/validation.rs | 158 ++++++++++++++++++------------------ src/test/handlers.rs | 2 +- 8 files changed, 137 insertions(+), 130 deletions(-) diff --git a/README.md b/README.md index 7caeec71..632bab3e 100644 --- a/README.md +++ b/README.md @@ -18,9 +18,10 @@ Official pure Rust typed client for ClickHouse DB. * Uses `serde` for encoding/decoding rows. * Supports `serde` attributes: `skip_serializing`, `skip_deserializing`, `rename`. -* Uses `RowBinaryWithNamesAndTypes` format over HTTP transport with struct definition validation against the database schema. - * It is possible to use `RowBinary` instead, disabling the validation, which can potentially lead to increased performance ([see below](#validation)). - * There are plans to switch to `Native` over TCP. +* Uses `RowBinaryWithNamesAndTypes` or `RowBinary` formats over HTTP transport. + * By default, `RowBinaryWithNamesAndTypes` with database schema validation is used. + * It is possible to switch to `RowBinary`, which can potentially lead to increased performance ([see below](#validation)). + * There are plans to implement `Native` format over TCP. * Supports TLS (see `native-tls` and `rustls-tls` features below). * Supports compression and decompression (LZ4 and LZ4HC). * Provides API for selecting. @@ -33,17 +34,23 @@ Note: [ch2rs](https://github.com/ClickHouse/ch2rs) is useful to generate a row t ## Validation -Starting from 0.14.0, the crate supports validation of the row types against the ClickHouse schema, as -`RowBinaryWithNamesAndTypes` format is used by default instead of `RowBinary`. Additionally, with enabled validation, -the crate supports structs with correct field names and matching types, but incorrect order of the fields, -with a slight (5-10%) performance penalty. +Starting from 0.14.0, the crate uses `RowBinaryWithNamesAndTypes` format by default, which allows row types validation +against the ClickHouse schema. This enables clearer error messages in case of schema mismatch at the cost of +performance. Additionally, with enabled validation, the crate supports structs with correct field names and matching +types, but incorrect order of the fields, with an additional slight (5-10%) performance penalty. -If you want to disable validation entirely, essentially reverting the client behavior to pre-0.14.0, you can use -`Client::with_validation(false)`, which will switch the fetch format to `RowBinary` instead. +If you are looking to maximize performance, you could disable validation using `Client::with_validation(false)`. When +validation is disabled, the client switches to `RowBinary` format usage instead. -Depending on the dataset, disabling validation can yield from x1.1 to x3 performance improvement, -but it is not recommended to use it in production, as it can lead to unclear runtime errors -if the row types do not match the ClickHouse schema. +The downside with plain `RowBinary` is that instead of clearer error messages, a mismatch between `Row` and database +schema will result in a `NotEnoughData` error without specific details. + +However, depending on the dataset, there might be x1.1 to x3 performance improvement, but that highly depends on the +shape and volume of the dataset. + +It is always recommended to measure the performance impact of validation in your specific use case. Additionally, +writing smoke tests to ensure that the row types match the ClickHouse schema is highly recommended, if you plan to +disable validation in your application. ## Usage diff --git a/src/cursors/row.rs b/src/cursors/row.rs index e7ef537d..9538e17a 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -99,10 +99,14 @@ impl RowCursor { } } slice = super::workaround_51132(self.bytes.slice()); - rowbinary::deserialize_rbwnat::(&mut slice, self.row_metadata.as_ref()) + rowbinary::deserialize_row_with_validation::( + &mut slice, + // handled above + self.row_metadata.as_ref().unwrap(), + ) } else { slice = super::workaround_51132(self.bytes.slice()); - rowbinary::deserialize_row_binary::(&mut slice) + rowbinary::deserialize_row::(&mut slice) }; match result { Err(Error::NotEnoughData) => {} diff --git a/src/lib.rs b/src/lib.rs index 27c69174..f16062d4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -329,18 +329,24 @@ impl Client { watch::Watch::new(self, query) } - /// Disables [`Row`] types validation against the database schema. - /// Validation is enabled by default. + /// Enables or disables [`Row`] data types validation against the database schema + /// at the cost of performance. Validation is enabled by default, and in this mode, + /// the client will use `RowBinaryWithNamesAndTypes` format. /// - /// # Warning + /// If you are looking to maximize performance, you could disable validation using this method. + /// When validation is disabled, the client switches to `RowBinary` format usage instead. /// - /// While disabled validation will result in increased performance - /// (between 1.1x and 3x, depending on the data), - /// this mode is intended to be used for testing purposes only, - /// and only in scenarios where schema mismatch issues are irrelevant. + /// The downside with plain `RowBinary` is that instead of clearer error messages, + /// a mismatch between [`Row`] and database schema will result + /// in a [`error::Error::NotEnoughData`] error without specific details. /// - /// ***DO NOT*** disable validation in your production code or tests - /// unless you are 100% sure why you are doing it. + /// However, depending on the dataset, there might be x1.1 to x3 performance improvement, + /// but that highly depends on the shape and volume of the dataset. + /// + /// It is always recommended to measure the performance impact of validation + /// in your specific use case. Additionally, writing smoke tests to ensure that + /// the row types match the ClickHouse schema is highly recommended, + /// if you plan to disable validation in your application. pub fn with_validation(mut self, enabled: bool) -> Self { self.validation = enabled; self diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index e6d84193..a1e97854 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -20,28 +20,25 @@ use std::{convert::TryFrom, str}; /// It accepts _a reference to_ a byte slice because it somehow leads to a more /// performant generated code than `(&[u8]) -> Result<(T, usize)>` and even /// `(&[u8], &mut Option) -> Result`. -/// -/// Additionally, having a single function speeds up [`crate::cursors::RowCursor::next`] x2. -/// A hint about the [`Error::NotEnoughData`] gives another 20% performance boost. -/// -/// It expects a slice of [`Column`] objects parsed -/// from the beginning of `RowBinaryWithNamesAndTypes` data stream. -/// After the header, the rows format is the same as `RowBinary`. -pub(crate) fn deserialize_row_binary<'data, 'cursor, T: Deserialize<'data> + Row>( +pub(crate) fn deserialize_row<'data, 'cursor, T: Deserialize<'data> + Row>( input: &mut &'data [u8], ) -> Result { let mut deserializer = RowBinaryDeserializer::::new(input, ()); T::deserialize(&mut deserializer) } -/// Similar to [`deserialize_row_binary`], but uses [`RowMetadata`] +/// Similar to [`deserialize_row`], but uses [`RowMetadata`] /// parsed from `RowBinaryWithNamesAndTypes` header to validate the data types. /// This is used when [`crate::Row`] validation is enabled in the client (default). -pub(crate) fn deserialize_rbwnat<'data, 'cursor, T: Deserialize<'data> + Row>( +/// +/// It expects a slice of [`Column`] objects parsed from the beginning +/// of `RowBinaryWithNamesAndTypes` data stream. After the header, +/// the rows format is the same as `RowBinary`. +pub(crate) fn deserialize_row_with_validation<'data, 'cursor, T: Deserialize<'data> + Row>( input: &mut &'data [u8], - metadata: Option<&'cursor RowMetadata>, + metadata: &'cursor RowMetadata, ) -> Result { - let validator = DataTypeValidator::new(metadata.unwrap()); + let validator = DataTypeValidator::new(metadata); let mut deserializer = RowBinaryDeserializer::::new(input, validator); T::deserialize(&mut deserializer) } @@ -261,8 +258,7 @@ where let deserializer = &mut self.inner(SerdeType::Map(len)); visitor.visit_map(RowBinaryMapAccess { deserializer, - entries_visited: 0, - len, + remaining: len, }) } @@ -293,7 +289,6 @@ where _name: &str, visitor: V, ) -> Result { - // TODO - skip validation? visitor.visit_newtype_struct(self) } @@ -373,8 +368,7 @@ where Validator: SchemaValidator, { deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, - entries_visited: usize, - len: usize, + remaining: usize, } impl<'data, R: Row, Validator> MapAccess<'data> for RowBinaryMapAccess<'_, '_, 'data, R, Validator> @@ -387,10 +381,10 @@ where where K: DeserializeSeed<'data>, { - if self.entries_visited >= self.len { + if self.remaining == 0 { return Ok(None); } - self.entries_visited += 1; + self.remaining -= 1; seed.deserialize(&mut *self.deserializer).map(Some) } @@ -402,7 +396,7 @@ where } fn size_hint(&self) -> Option { - Some(self.len) + Some(self.remaining) } } diff --git a/src/rowbinary/mod.rs b/src/rowbinary/mod.rs index b25147e2..9c96bca9 100644 --- a/src/rowbinary/mod.rs +++ b/src/rowbinary/mod.rs @@ -1,5 +1,5 @@ -pub(crate) use de::deserialize_rbwnat; -pub(crate) use de::deserialize_row_binary; +pub(crate) use de::deserialize_row; +pub(crate) use de::deserialize_row_with_validation; pub(crate) use ser::serialize_into; pub(crate) mod validation; diff --git a/src/rowbinary/tests.rs b/src/rowbinary/tests.rs index 3e22cde1..ac097d6c 100644 --- a/src/rowbinary/tests.rs +++ b/src/rowbinary/tests.rs @@ -151,10 +151,10 @@ fn it_deserializes() { let (mut left, mut right) = input.split_at(i); // It shouldn't panic. - let _: Result, _> = super::deserialize_row_binary(&mut left); - let _: Result, _> = super::deserialize_row_binary(&mut right); + let _: Result, _> = super::deserialize_row(&mut left); + let _: Result, _> = super::deserialize_row(&mut right); - let actual: Sample<'_> = super::deserialize_row_binary(&mut input.as_slice()).unwrap(); + let actual: Sample<'_> = super::deserialize_row(&mut input.as_slice()).unwrap(); assert_eq!(actual, sample()); } } diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs index be1b3ec3..e8b8ad0c 100644 --- a/src/rowbinary/validation.rs +++ b/src/rowbinary/validation.rs @@ -12,7 +12,6 @@ pub(crate) trait SchemaValidator: Sized { type Inner<'de>: SchemaValidator where Self: 'de; - /// The main entry point. The validation flow based on the [`crate::Row::KIND`]. /// For container types (nullable, array, map, tuple, variant, etc.), /// it will return an [`InnerDataTypeValidator`] instance (see [`InnerDataTypeValidatorKind`]), @@ -254,109 +253,106 @@ impl<'cursor, R: Row> SchemaValidator for Option Self { - match self { - None => None, - Some(inner) => match &mut inner.kind { - InnerDataTypeValidatorKind::Map(kv, state) => match state { - MapValidatorState::Key => { + let inner = self.as_mut()?; + match &mut inner.kind { + InnerDataTypeValidatorKind::Map(kv, state) => match state { + MapValidatorState::Key => { + let result = validate_impl(inner.root, &kv[0], &serde_type, true); + *state = MapValidatorState::Value; + result + } + MapValidatorState::Value => { + let result = validate_impl(inner.root, &kv[1], &serde_type, true); + *state = MapValidatorState::Key; + result + } + }, + InnerDataTypeValidatorKind::MapAsSequence(kv, state) => { + match state { + // the first state is simply skipped, as the same validator + // will be called again for the Key and then the Value types + MapAsSequenceValidatorState::Tuple => { + *state = MapAsSequenceValidatorState::Key; + self.take() + } + MapAsSequenceValidatorState::Key => { let result = validate_impl(inner.root, &kv[0], &serde_type, true); - *state = MapValidatorState::Value; + *state = MapAsSequenceValidatorState::Value; result } - MapValidatorState::Value => { + MapAsSequenceValidatorState::Value => { let result = validate_impl(inner.root, &kv[1], &serde_type, true); - *state = MapValidatorState::Key; + *state = MapAsSequenceValidatorState::Tuple; result } - }, - InnerDataTypeValidatorKind::MapAsSequence(kv, state) => { - match state { - // the first state is simply skipped, as the same validator - // will be called again for the Key and then the Value types - MapAsSequenceValidatorState::Tuple => { - *state = MapAsSequenceValidatorState::Key; - self.take() - } - MapAsSequenceValidatorState::Key => { - let result = validate_impl(inner.root, &kv[0], &serde_type, true); - *state = MapAsSequenceValidatorState::Value; - result - } - MapAsSequenceValidatorState::Value => { - let result = validate_impl(inner.root, &kv[1], &serde_type, true); - *state = MapAsSequenceValidatorState::Tuple; - result - } - } - } - InnerDataTypeValidatorKind::Array(inner_type) => { - validate_impl(inner.root, inner_type, &serde_type, true) - } - InnerDataTypeValidatorKind::Nullable(inner_type) => { - validate_impl(inner.root, inner_type, &serde_type, true) } - InnerDataTypeValidatorKind::Tuple(elements_types) => { - match elements_types.split_first() { - Some((first, rest)) => { - *elements_types = rest; - validate_impl(inner.root, first, &serde_type, true) - } - None => { - let (full_name, full_data_type) = - inner.root.get_current_column_name_and_type(); - panic!( - "While processing column {} defined as {}: \ - attempting to deserialize {} while no more elements are allowed", - full_name, full_data_type, serde_type - ) - } + } + InnerDataTypeValidatorKind::Array(inner_type) => { + validate_impl(inner.root, inner_type, &serde_type, true) + } + InnerDataTypeValidatorKind::Nullable(inner_type) => { + validate_impl(inner.root, inner_type, &serde_type, true) + } + InnerDataTypeValidatorKind::Tuple(elements_types) => { + match elements_types.split_first() { + Some((first, rest)) => { + *elements_types = rest; + validate_impl(inner.root, first, &serde_type, true) } - } - InnerDataTypeValidatorKind::FixedString(_len) => { - None // actually unreachable - } - InnerDataTypeValidatorKind::RootTuple(columns, current_index) => { - if *current_index < columns.len() { - let data_type = &columns[*current_index].data_type; - *current_index += 1; - validate_impl(inner.root, data_type, &serde_type, true) - } else { + None => { let (full_name, full_data_type) = inner.root.get_current_column_name_and_type(); panic!( - "While processing root tuple element {} defined as {}: \ - attempting to deserialize {} while no more elements are allowed", + "While processing column {} defined as {}: \ + attempting to deserialize {} while no more elements are allowed", full_name, full_data_type, serde_type ) } } - InnerDataTypeValidatorKind::RootArray(inner_data_type) => { - validate_impl(inner.root, inner_data_type, &serde_type, true) + } + InnerDataTypeValidatorKind::FixedString(_len) => { + None // actually unreachable + } + InnerDataTypeValidatorKind::RootTuple(columns, current_index) => { + if *current_index < columns.len() { + let data_type = &columns[*current_index].data_type; + *current_index += 1; + validate_impl(inner.root, data_type, &serde_type, true) + } else { + let (full_name, full_data_type) = inner.root.get_current_column_name_and_type(); + panic!( + "While processing root tuple element {} defined as {}: \ + attempting to deserialize {} while no more elements are allowed", + full_name, full_data_type, serde_type + ) } - InnerDataTypeValidatorKind::Variant(possible_types, state) => match state { - VariantValidationState::Pending => { - unreachable!() - } - VariantValidationState::Identifier(value) => { - if *value as usize >= possible_types.len() { - let (full_name, full_data_type) = - inner.root.get_current_column_name_and_type(); - panic!( + } + InnerDataTypeValidatorKind::RootArray(inner_data_type) => { + validate_impl(inner.root, inner_data_type, &serde_type, true) + } + InnerDataTypeValidatorKind::Variant(possible_types, state) => match state { + VariantValidationState::Pending => { + unreachable!() + } + VariantValidationState::Identifier(value) => { + if *value as usize >= possible_types.len() { + let (full_name, full_data_type) = + inner.root.get_current_column_name_and_type(); + panic!( "While processing column {full_name} defined as {full_data_type}: \ Variant identifier {value} is out of bounds, max allowed index is {}", possible_types.len() - 1 ); - } - let data_type = &possible_types[*value as usize]; - validate_impl(inner.root, data_type, &serde_type, true) } - }, - // TODO - check enum string value correctness in the hashmap? - // is this even possible? - InnerDataTypeValidatorKind::Enum(_values_map) => { - unreachable!() + let data_type = &possible_types[*value as usize]; + validate_impl(inner.root, data_type, &serde_type, true) } }, + // TODO - check enum string value correctness in the hashmap? + // is this even possible? + InnerDataTypeValidatorKind::Enum(_values_map) => { + unreachable!() + } } } diff --git a/src/test/handlers.rs b/src/test/handlers.rs index e5c6d47a..6a003f07 100644 --- a/src/test/handlers.rs +++ b/src/test/handlers.rs @@ -93,7 +93,7 @@ where let mut result = C::default(); while !slice.is_empty() { - let res = rowbinary::deserialize_row_binary(slice); + let res = rowbinary::deserialize_row(slice); let row: T = res.expect("failed to deserialize"); result.extend(std::iter::once(row)); } From 6d0e77179bd2e008a3c6695eba50490432c0708d Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 23 Jun 2025 16:33:51 +0200 Subject: [PATCH 53/54] Update CHANGELOG.md --- CHANGELOG.md | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0151444b..2bf5d300 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,21 +9,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - ReleaseDate ### Removed + - **BREAKING** watch: `Client::watch()` API is removed ([#245]). - **BREAKING** mock: `watch()` and `watch_only_events()` are removed ([#245]). ### Changed -- **BREAKING** query: `RowBinaryWithNamesAndTypes` is now used by default for query results. This may cause panics if the row struct - definition does not match the database schema. Use `Client::with_validation(false)` to revert to the previous behavior - which uses plain `RowBinary` format for fetching rows. ([#221]) +- **BREAKING** query: `RowBinaryWithNamesAndTypes` is now used by default for query results. This may cause panics if + the row struct definition does not match the database schema. Use `Client::with_validation(false)` to revert to the + previous behavior which uses plain `RowBinary` format for fetching rows. ([#221]) +- **BREAKING** mock: when using `test-util` feature, it is now required to use `Client::with_mock(&mock)` to set up the + mock server, so it properly handles the response format and automatically disables parsing + `RowBinaryWithNamesAndTypes` header parsing and validation. Additionally, it is not required to call `with_url` + explicitly. See the [updated example](./examples/mock.rs). - query: due to `RowBinaryWithNamesAndTypes` format usage, there might be an impact on fetch performance, which largely - depends on how the dataset is defined. If you experience performance issues, consider disabling validation by using + depends on how the dataset is defined. If you notice decreased performance, consider disabling validation by using `Client::with_validation(false)`. -- **BREAKING** mock: when using `test-util` feature, it is now required to use `Client::with_mock(&mock)` to set up the mock server, - so it properly handles the response format and automatically disables parsing `RowBinaryWithNamesAndTypes` header - parsing and validation. Additionally, it is not required to call `with_url` explicitly. - See the [updated example](./examples/mock.rs). +- serde: it is now possible to deserialize Map ClickHouse type into `HashMap` (or `BTreeMap`, `IndexMap`, + `DashMap`, etc.). ### Added From 9f495e238daae89f3e04ea7efbde4467e87897a1 Mon Sep 17 00:00:00 2001 From: slvrtrn Date: Mon, 23 Jun 2025 16:37:37 +0200 Subject: [PATCH 54/54] Add missing env variables to docker compose --- docker-compose.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docker-compose.yml b/docker-compose.yml index 3ca667e6..cc309127 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,6 +3,8 @@ services: clickhouse: image: 'clickhouse/clickhouse-server:${CLICKHOUSE_VERSION-latest-alpine}' container_name: 'clickhouse-rs-clickhouse-server' + environment: + CLICKHOUSE_SKIP_USER_SETUP: 1 ports: - '8123:8123' - '9000:9000'