Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ See https://github.com/pola-rs/polars/issues/22149 for more information."
// @TAG: 2.0
// @HACK: `is_in` does supertype casting between primitive numerics, which
// honestly makes very little sense. To stay backwards compatible we keep this,
// but please in 2.0 remove this.
// but please in 2.0 remove this. FirstArgLossless might be a good alternative,
// as used by index_of(), or build on index_of().

let super_type =
polars_core::utils::try_get_supertype(&type_left, type_other_inner)?;
Expand Down
152 changes: 133 additions & 19 deletions crates/polars-plan/src/plans/conversion/type_coercion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ mod is_in;
use binary::process_binary;
use polars_core::chunked_array::cast::CastOptions;
use polars_core::prelude::*;
use polars_core::utils::{get_supertype, get_supertype_with_options, materialize_dyn_int};
use polars_core::utils::{
get_numeric_upcast_supertype_lossless, get_supertype, get_supertype_with_options,
materialize_dyn_int,
};
use polars_utils::format_list;
use polars_utils::itertools::Itertools;

Expand Down Expand Up @@ -399,24 +402,9 @@ impl OptimizationRule for TypeCoercionRule {
}
},
CastingRules::FirstArgLossless => {
if super_type.is_integer() {
for other in &input[1..] {
let other = other.dtype(schema, expr_arena)?;
if other.is_float() {
polars_bail!(InvalidOperation: "cannot cast lossless between {} and {}", super_type, other)
}
}
}
if super_type.is_categorical() || super_type.is_enum() {
for other in &input[1..] {
let other = other.dtype(schema, expr_arena)?;
if !(other.is_string()
|| other.is_null()
|| *other == super_type)
{
polars_bail!(InvalidOperation: "cannot cast lossless between {} and {}", super_type, other)
}
}
for other in &input[1..] {
let other = other.dtype(schema, expr_arena)?;
can_cast_to_lossless(&super_type, other)?;
}
},
}
Expand Down Expand Up @@ -1012,3 +1000,129 @@ fn inline_implode(expr: Node, expr_arena: &mut Arena<AExpr>) -> PolarsResult<Opt

Ok(out)
}

/// Can we cast the `from` dtype to the `to` dtype without losing information?
fn can_cast_to_lossless(to: &DataType, from: &DataType) -> PolarsResult<()> {
let can_cast = match (to, from) {
(a, b) if a == b => true,
(_, DataType::Null) => true,
// Here we know the exact value, so we can report it to the user if it
// doesn't fit:
(to, DataType::Unknown(UnknownKind::Int(value))) => match to {
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(to_precision, to_scale)
if {
let max =
10i128.pow((to_precision.unwrap_or(38) - to_scale.unwrap_or(0)) as u32);
let min = -max;
*value < max && *value > min
} =>
{
true
},
to if to.is_integer() && to.value_within_range(AnyValue::Int128(*value)) => true,
// For floats, make sure it's in range where all integers convert
// losslessly; this isn't quite every possible value that can be
// converted losslessly, but it's good enough:
DataType::Float32 if (*value < 2i128.pow(24)) && (*value > -2i128.pow(24)) => true,
DataType::Float64 if (*value < 2i128.pow(53)) && (*value > -2i128.pow(53)) => true,
// Make sure we have error message that reports the value:
_ => polars_bail!(InvalidOperation: "cannot cast {} losslessly to {}", value, to),
},
(
DataType::Float32,
DataType::UInt8 | DataType::UInt16 | DataType::Int8 | DataType::Int16,
) => true,
(
DataType::Float64,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::Int8
| DataType::Int16
| DataType::Int32,
) => true,
// When casting unknown float to Float32 we can't tell if the value will
// fit, so can't do anything. When casting to Float64 we can assume
// it'll work since presumably it's no larger than a f64 in practice.
(DataType::Float64, DataType::Unknown(UnknownKind::Float)) => true,
// Handles both String and UnknownKind::Str:
(DataType::String, from) => from.is_string(),
(to, from) if to.is_primitive_numeric() && from.is_primitive_numeric() => {
if let Some(upcast) = get_numeric_upcast_supertype_lossless(to, from) {
&upcast == to
} else {
false
}
},
#[cfg(feature = "dtype-categorical")]
(DataType::Enum(_, _), from) => from.is_string(),
#[cfg(feature = "dtype-categorical")]
(DataType::Categorical(_, _), from) => from.is_string(),
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(p_to, s_to), DataType::Decimal(p_from, s_from)) => {
// Match numbers in DataType::try_to_arrow():
let (p_to, s_to, p_from, s_from) = (
p_to.unwrap_or(38),
s_to.unwrap_or(0),
p_from.unwrap_or(38),
s_from.unwrap_or(0),
);
// 1. The numbers in `from` should fit in `to`'s range.
// 2. The scale in `from` should fit in `to`'s scale.
((p_to - s_to) >= (p_from - s_from)) && (s_to >= s_from)
},
#[cfg(feature = "dtype-decimal")]
(DataType::Decimal(p_to, _), dt) if dt.is_primitive_numeric() => {
// Match numbers in DataType::try_to_arrow():
let max_value = 10i128.pow(p_to.unwrap_or(38) as u32) - 1;
let min_value = max_value - 1;
max_value >= dt.max().unwrap().value().extract::<i128>().unwrap()
&& min_value <= dt.min().unwrap().value().extract::<i128>().unwrap()
},
// Can't check for more granular time_unit in less-granular time_unit
// data, or we'll cast away valid/necessary precision (eg: nanosecs to
// millisecs):
(DataType::Datetime(to_unit, _), DataType::Datetime(from_unit, _)) => to_unit <= from_unit,
(DataType::Duration(to_unit), DataType::Duration(from_unit)) => to_unit <= from_unit,
(DataType::List(to), DataType::List(from)) => return can_cast_to_lossless(to, from),
#[cfg(feature = "dtype-array")]
(DataType::List(to), DataType::Array(from, _)) => return can_cast_to_lossless(to, from),
// If list doesn't fit array size it'll get handled when casting
// actually happens.
#[cfg(feature = "dtype-array")]
(DataType::Array(to, _), DataType::List(from)) => return can_cast_to_lossless(to, from),
#[cfg(feature = "dtype-array")]
(DataType::Array(to, to_count), DataType::Array(from, from_count)) => {
if from_count != to_count {
false
} else {
return can_cast_to_lossless(to, from);
}
},
#[cfg(feature = "dtype-struct")]
(DataType::Struct(to_fields), DataType::Struct(from_fields)) => {
if to_fields.len() != from_fields.len() {
false
} else {
return to_fields.iter().zip(from_fields.iter()).try_for_each(
|(to_field, from_field)| {
polars_ensure!(
to_field.name == from_field.name,
InvalidOperation:
"cannot cast losslessly from {} to {}",
from,
to
);
can_cast_to_lossless(&to_field.dtype, &from_field.dtype)
},
);
}
},
_ => false,
};
if !can_cast {
polars_bail!(InvalidOperation: "cannot cast losslessly from {} to {}", from, to)
}
Ok(())
}
147 changes: 137 additions & 10 deletions py-polars/tests/unit/operations/test_index_of.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from polars.testing.parametric import series

if TYPE_CHECKING:
from polars._typing import IntoExpr
from polars._typing import IntoExpr, PolarsDataType
from polars.datatypes import IntegerType


Expand Down Expand Up @@ -61,29 +61,33 @@ def assert_index_of(
@pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64])
def test_float(dtype: pl.DataType) -> None:
values = [1.5, np.nan, np.inf, 3.0, None, -np.inf, 0.0, -0.0, -np.nan]
if dtype == pl.Float32:
# Can't pass Python literals to index_of() for Float32
values = [(None if v is None else np.float32(v)) for v in values] # type: ignore[misc]

series = pl.Series(values, dtype=dtype)
sorted_series_asc = series.sort(descending=False)
sorted_series_desc = series.sort(descending=True)
chunked_series = pl.concat([pl.Series([1, 7], dtype=dtype), series], rechunk=False)

extra_values = [
np.int8(3),
np.int64(2**42),
np.float64(1.5),
np.float32(1.5),
np.float32(2**37),
np.float64(2**100),
np.float32(2**10),
]
if dtype == pl.Float64:
extra_values.extend([np.int32(2**10), np.float64(2**10), np.float64(1.5)])
for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]:
for value in values:
assert_index_of(s, value, convert_to_literal=True)
assert_index_of(s, value, convert_to_literal=False)
for value in extra_values: # type: ignore[assignment]
assert_index_of(s, value)

# Explicitly check some extra-tricky edge cases:
assert series.index_of(-np.nan) == 1 # -np.nan should match np.nan
assert series.index_of(-0.0) == 6 # -0.0 should match 0.0
# -np.nan should match np.nan:
assert series.index_of(-np.float32("nan")) == 1 # type: ignore[arg-type]
# -0.0 should match 0.0:
assert series.index_of(-np.float32(0.0)) == 6 # type: ignore[arg-type]


def test_null() -> None:
Expand Down Expand Up @@ -148,10 +152,16 @@ def test_integer(dtype: IntegerType) -> None:

# Can't cast floats:
for f in [np.float32(3.1), np.float64(3.1), 50.9]:
with pytest.raises(InvalidOperationError, match="cannot cast lossless"):
with pytest.raises(InvalidOperationError, match="cannot cast.*"):
s.index_of(f) # type: ignore[arg-type]


def test_integer_upcast() -> None:
series = pl.Series([0, 123, 456, 789], dtype=pl.Int64)
for should_work in [pl.Int8, pl.UInt8, pl.Int16, pl.UInt16, pl.Int32, pl.UInt32]:
assert series.index_of(pl.lit(123, dtype=should_work)) == 1


def test_groupby() -> None:
df = pl.DataFrame(
{"label": ["a", "b", "a", "b", "a", "b"], "value": [10, 3, 20, 2, 40, 20]}
Expand Down Expand Up @@ -350,7 +360,7 @@ def test_categorical(convert_to_literal: bool) -> None:
@pytest.mark.parametrize("value", [0, 0.1])
def test_categorical_wrong_type_keys_dont_work(value: int | float) -> None:
series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical)
msg = "cannot cast lossless"
msg = "cannot cast.*losslessly.*"
with pytest.raises(InvalidOperationError, match=msg):
series.index_of(value)
df = pl.DataFrame({"s": series})
Expand All @@ -367,3 +377,120 @@ def test_index_of_null_parametric(s: pl.Series) -> None:
assert idx_null is None
elif s.null_count() == len(s):
assert idx_null == 0


def test_out_of_range_integers() -> None:
series = pl.Series([0, 100, None, 1, 2], dtype=pl.Int8)
with pytest.raises(InvalidOperationError, match="cannot cast 128 losslessly to i8"):
assert series.index_of(128)
with pytest.raises(
InvalidOperationError, match="cannot cast -200 losslessly to i8"
):
assert series.index_of(-200)


def test_out_of_range_decimal() -> None:
# Up to 34 digits of integers:
series = pl.Series([1, None], dtype=pl.Decimal(36, 2))
assert series.index_of(10**34 - 1) is None
assert series.index_of(-(10**34 - 1)) is None
out_of_range = 10**34
with pytest.raises(
InvalidOperationError, match=f"cannot cast {out_of_range} losslessly"
):
assert series.index_of(out_of_range)
with pytest.raises(
InvalidOperationError, match=f"cannot cast {-out_of_range} losslessly"
):
assert series.index_of(-out_of_range)


def test_out_of_range_float64() -> None:
series = pl.Series([0, 255, None], dtype=pl.Float64)
# Small numbers are fine:
assert series.index_of(1_000_000) is None
assert series.index_of(-1_000_000) is None
with pytest.raises(
InvalidOperationError, match=f"cannot cast {2**53} losslessly to f64"
):
assert series.index_of(2**53)
with pytest.raises(
InvalidOperationError, match=f"cannot cast {-(2**53)} losslessly to f64"
):
assert series.index_of(-(2**53))


def test_out_of_range_float32() -> None:
series = pl.Series([0, 255, None], dtype=pl.Float32)
# Small numbers are fine:
assert series.index_of(1_000_000) is None
assert series.index_of(-1_000_000) is None
with pytest.raises(
InvalidOperationError, match=f"cannot cast {2**24} losslessly to f32"
):
assert series.index_of(2**24)
with pytest.raises(
InvalidOperationError, match=f"cannot cast {-(2**24)} losslessly to f32"
):
assert series.index_of(-(2**24))


def assert_lossy_cast_rejected(
series_dtype: PolarsDataType, value: Any, value_dtype: PolarsDataType
) -> None:
# We create a Series with a null because previously lossless casts would
# sometimes get turned into nulls and you'd get an answer.
series = pl.Series([None], dtype=series_dtype)
with pytest.raises(InvalidOperationError, match="cannot cast losslessly"):
series.index_of(pl.lit(value, dtype=value_dtype))


@pytest.mark.parametrize(
("series_dtype", "value", "value_dtype"),
[
# Completely incompatible:
(pl.String, 1, pl.UInt8),
(pl.UInt8, "1", pl.String),
# Larger integer doesn't fit in smaller integer:
(pl.UInt8, 17, pl.UInt16),
# Can't find negative numbers in unsigned integers:
(pl.UInt16, -1, pl.Int8),
# Values after the decimal point that can't be represented:
(pl.Decimal(3, 1), 1, pl.Decimal(4, 2)),
# Can't fit in Decimal:
(pl.Decimal(3, 0), 1, pl.Decimal(4, 0)),
(pl.Decimal(5, 2), 1, pl.Decimal(5, 1)),
(pl.Decimal(5, 2), 1, pl.UInt16),
# Can't fit nanoseconds in milliseconds:
(pl.Duration("ms"), 1, pl.Duration("ns")),
# Arrays that are the wrong length:
(pl.Array(pl.Int64, 2), [1], pl.Array(pl.Int64, 1)),
# Struct with wrong number of fields:
(
pl.Struct({"a": pl.Int64, "b": pl.Int64}),
{"a": 1},
pl.Struct({"a": pl.Int64}),
),
# Struct with different field name:
(pl.Struct({"a": pl.Int64}), {"x": 1}, pl.Struct({"x": pl.Int64})),
],
)
def test_lossy_casts_are_rejected(
series_dtype: PolarsDataType, value: Any, value_dtype: PolarsDataType
) -> None:
assert_lossy_cast_rejected(series_dtype, value, value_dtype)


def test_lossy_casts_are_rejected_nested_dtypes() -> None:
# Make sure casting rules are applied recursively for Lists, Arrays,
# Struct:
series_dtype, value, value_dtype = pl.UInt8, 17, pl.UInt16
assert_lossy_cast_rejected(pl.List(series_dtype), [value], pl.List(value_dtype))
assert_lossy_cast_rejected(
pl.Array(series_dtype, 1), [value], pl.Array(value_dtype, 1)
)
assert_lossy_cast_rejected(
pl.Struct({"key": series_dtype}),
{"key": value},
pl.Struct({"key": value_dtype}),
)
Loading