diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs index 4d62eaea95c..40e41bbd51c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/GroupingWithOutputExpressionStageDefinitions.cs @@ -63,7 +63,7 @@ private AstStage RenderProjectStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer outputSerializer) { - var partiallyEvaluatedOutput = (Expression>)PartialEvaluator.EvaluatePartially(_output); + var partiallyEvaluatedOutput = (Expression>)LinqExpressionPreprocessor.Preprocess(_output); var context = TranslationContext.Create(translationOptions); var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true); var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation); @@ -105,7 +105,7 @@ protected override AstStage RenderGroupingStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer> groupingOutputSerializer) { - var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); + var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); @@ -149,7 +149,7 @@ protected override AstStage RenderGroupingStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer, TInput>> groupingOutputSerializer) { - var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); + var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); @@ -187,7 +187,7 @@ protected override AstStage RenderGroupingStage( ExpressionTranslationOptions translationOptions, out IBsonSerializer> groupingOutputSerializer) { - var partiallyEvaluatedGroupBy = (Expression>)PartialEvaluator.EvaluatePartially(_groupBy); + var partiallyEvaluatedGroupBy = (Expression>)LinqExpressionPreprocessor.Preprocess(_groupBy); var context = TranslationContext.Create(translationOptions); var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true); var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs new file mode 100644 index 00000000000..62983df83a7 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/ClrCompatExpressionRewriter.cs @@ -0,0 +1,150 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +/// +/// This visitor rewrites expressions where new features of .NET CLR or +/// C# compiler interfere with LINQ expression tree translation. +/// +internal class ClrCompatExpressionRewriter : ExpressionVisitor +{ + private static readonly ClrCompatExpressionRewriter __instance = new(); + + public static Expression Rewrite(Expression expression) + => __instance.Visit(expression); + + /// + protected override Expression VisitMethodCall(MethodCallExpression node) + { + node = (MethodCallExpression)base.VisitMethodCall(node); + + var method = node.Method; + var arguments = node.Arguments; + + return method.Name switch + { + "Contains" => VisitContainsMethod(node, method, arguments), + "SequenceEqual" => VisitSequenceEqualMethod(node, method, arguments), + _ => node + }; + + static Expression VisitContainsMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection arguments) + { + if (method.IsOneOf(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValue, MemoryExtensionsMethod.ContainsWithSpanAndValue)) + { + var itemType = method.GetGenericArguments().Single(); + var span = arguments[0]; + var value = arguments[1]; + + if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) && + unwrappedSpan.Type.ImplementsIEnumerableOf(itemType)) + { + return + Expression.Call( + EnumerableMethod.Contains.MakeGenericMethod(itemType), + [unwrappedSpan, value]); + } + } + else if (method.Is(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValueAndComparer)) + { + var itemType = method.GetGenericArguments().Single(); + var span = arguments[0]; + var value = arguments[1]; + var comparer = arguments[2]; + + if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) && + unwrappedSpan.Type.ImplementsIEnumerableOf(itemType)) + { + return + Expression.Call( + EnumerableMethod.ContainsWithComparer.MakeGenericMethod(itemType), + [unwrappedSpan, value, comparer]); + } + } + + return node; + } + + static Expression VisitSequenceEqualMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection arguments) + { + if (method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpan, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpan)) + { + var itemType = method.GetGenericArguments().Single(); + var span = arguments[0]; + var other = arguments[1]; + + if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) && + TryUnwrapSpanImplicitCast(other, out var unwrappedOther) && + unwrappedSpan.Type.ImplementsIEnumerableOf(itemType) && + unwrappedOther.Type.ImplementsIEnumerableOf(itemType)) + { + return + Expression.Call( + EnumerableMethod.SequenceEqual.MakeGenericMethod(itemType), + [unwrappedSpan, unwrappedOther]); + } + } + else if (method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpanAndComparer)) + { + var itemType = method.GetGenericArguments().Single(); + var span = arguments[0]; + var other = arguments[1]; + var comparer = arguments[2]; + + if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) && + TryUnwrapSpanImplicitCast(other, out var unwrappedOther) && + unwrappedSpan.Type.ImplementsIEnumerableOf(itemType) && + unwrappedOther.Type.ImplementsIEnumerableOf(itemType)) + { + return + Expression.Call( + EnumerableMethod.SequenceEqualWithComparer.MakeGenericMethod(itemType), + [unwrappedSpan, unwrappedOther, comparer]); + } + } + + return node; + } + + // Erase implicit casts to ReadOnlySpan and Span + static bool TryUnwrapSpanImplicitCast(Expression expression, out Expression result) + { + if (expression is MethodCallExpression + { + Method: + { + Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType + } + } methodCallExpression + && implicitCastDeclaringType.GetGenericTypeDefinition() is var genericTypeDefinition + && (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>))) + { + result = methodCallExpression.Arguments[0]; + return true; + } + + result = null; + return false; + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs new file mode 100644 index 00000000000..d302dd05cb4 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/LinqExpressionPreprocessor.cs @@ -0,0 +1,33 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Linq.Expressions; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Misc; + +/// +/// This class is called before we process any LINQ expression trees +/// to perform any necessary pre-processing such as CLR compatibility +/// and partial evaluation. +/// +internal static class LinqExpressionPreprocessor +{ + public static Expression Preprocess(Expression expression) + { + expression = ClrCompatExpressionRewriter.Rewrite(expression); + expression = PartialEvaluator.EvaluatePartially(expression); + return expression; + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs index d7a61ab07d5..54ef243728b 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/MethodInfoExtensions.cs @@ -13,12 +13,59 @@ * limitations under the License. */ +using System; using System.Reflection; namespace MongoDB.Driver.Linq.Linq3Implementation.Misc { internal static class MethodInfoExtensions { + public static bool Has1GenericArgument(this MethodInfo method, out Type genericArgument) + { + if (method.IsGenericMethod && + method.GetGenericArguments() is var genericArguments && + genericArguments.Length == 1) + { + genericArgument = genericArguments[0]; + return true; + } + + genericArgument = null; + return false; + } + + public static bool Has2Parameters(this MethodInfo method, out ParameterInfo parameter1, out ParameterInfo parameter2) + { + if (method.GetParameters() is var parameters && + parameters.Length == 2) + { + parameter1 = parameters[0]; + parameter2 = parameters[1]; + return true; + } + + parameter1 = null; + parameter2 = null; + return false; + } + + public static bool Has3Parameters(this MethodInfo method, out ParameterInfo parameter1, out ParameterInfo parameter2, out ParameterInfo parameter3) + { + if (method.GetParameters() is var parameters && + parameters.Length == 3) + { + parameter1 = parameters[0]; + parameter2 = parameters[1]; + parameter3 = parameters[2]; + return true; + } + + parameter1 = null; + parameter2 = null; + parameter3 = null; + return false; + } + public static bool Is(this MethodInfo method, MethodInfo comparand) { if (comparand != null) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs index ccb8f699740..f6d2f758940 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Misc/TypeExtensions.cs @@ -248,11 +248,27 @@ public static bool IsNullableOf(this Type type, Type valueType) return type.IsNullable(out var nullableValueType) && nullableValueType == valueType; } + public static bool IsReadOnlySpanOf(this Type type, Type itemType) + { + return + type.IsGenericType && + type.GetGenericTypeDefinition() == typeof(ReadOnlySpan<>) && + type.GetGenericArguments()[0] == itemType; + } + public static bool IsSameAsOrNullableOf(this Type type, Type valueType) { return type == valueType || type.IsNullableOf(valueType); } + public static bool IsSpanOf(this Type type, Type itemType) + { + return + type.IsGenericType && + type.GetGenericTypeDefinition() == typeof(Span<>) && + type.GetGenericArguments()[0] == itemType; + } + public static bool IsSubclassOfOrImplements(this Type type, Type baseTypeOrInterface) { return diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs index 0ae3e99ca4a..a10f2a67531 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/EnumerableMethod.cs @@ -59,6 +59,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __cast; private static readonly MethodInfo __concat; private static readonly MethodInfo __contains; + private static readonly MethodInfo __containsWithComparer; private static readonly MethodInfo __count; private static readonly MethodInfo __countWithPredicate; private static readonly MethodInfo __defaultIfEmpty; @@ -150,6 +151,7 @@ internal static class EnumerableMethod private static readonly MethodInfo __selectManyWithSelectorTakingIndex; private static readonly MethodInfo __selectWithSelectorTakingIndex; private static readonly MethodInfo __sequenceEqual; + private static readonly MethodInfo __sequenceEqualWithComparer; private static readonly MethodInfo __single; private static readonly MethodInfo __singleOrDefault; private static readonly MethodInfo __singleOrDefaultWithPredicate; @@ -226,6 +228,7 @@ static EnumerableMethod() __cast = ReflectionInfo.Method((IEnumerable source) => source.Cast()); __concat = ReflectionInfo.Method((IEnumerable first, IEnumerable second) => first.Concat(second)); __contains = ReflectionInfo.Method((IEnumerable source, object value) => source.Contains(value)); + __containsWithComparer = ReflectionInfo.Method((IEnumerable source, object value, IEqualityComparer comparer) => source.Contains(value, comparer)); __count = ReflectionInfo.Method((IEnumerable source) => source.Count()); __countWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.Count(predicate)); __defaultIfEmpty = ReflectionInfo.Method((IEnumerable source) => source.DefaultIfEmpty()); @@ -317,6 +320,7 @@ static EnumerableMethod() __selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable source, Func> selector) => source.SelectMany(selector)); __selectWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable source, Func selector) => source.Select(selector)); __sequenceEqual = ReflectionInfo.Method((IEnumerable first, IEnumerable second) => first.SequenceEqual(second)); + __sequenceEqualWithComparer = ReflectionInfo.Method((IEnumerable first, IEnumerable second, IEqualityComparer comparer) => first.SequenceEqual(second, comparer)); __single = ReflectionInfo.Method((IEnumerable source) => source.Single()); __singleOrDefault = ReflectionInfo.Method((IEnumerable source) => source.SingleOrDefault()); __singleOrDefaultWithPredicate = ReflectionInfo.Method((IEnumerable source, Func predicate) => source.SingleOrDefault(predicate)); @@ -392,6 +396,7 @@ static EnumerableMethod() public static MethodInfo Cast => __cast; public static MethodInfo Concat => __concat; public static MethodInfo Contains => __contains; + public static MethodInfo ContainsWithComparer => __containsWithComparer; public static MethodInfo Count => __count; public static MethodInfo CountWithPredicate => __countWithPredicate; public static MethodInfo DefaultIfEmpty => __defaultIfEmpty; @@ -483,6 +488,7 @@ static EnumerableMethod() public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex; public static MethodInfo SelectWithSelectorTakingIndex => __selectWithSelectorTakingIndex; public static MethodInfo SequenceEqual => __sequenceEqual; + public static MethodInfo SequenceEqualWithComparer => __sequenceEqualWithComparer; public static MethodInfo Single => __single; public static MethodInfo SingleOrDefault => __singleOrDefault; public static MethodInfo SingleOrDefaultWithPredicate => __singleOrDefaultWithPredicate; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MemoryExtensionsMethod.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MemoryExtensionsMethod.cs new file mode 100644 index 00000000000..59aa2432c5e --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/MemoryExtensionsMethod.cs @@ -0,0 +1,150 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +internal static class MemoryExtensionsMethod +{ + // private static fields + private static readonly MethodInfo __containsWithReadOnlySpanAndValue; + private static readonly MethodInfo __containsWithReadOnlySpanAndValueAndComparer; + private static readonly MethodInfo __containsWithSpanAndValue; + private static readonly MethodInfo __sequenceEqualWithReadOnlySpanAndReadOnlySpan; + private static readonly MethodInfo __sequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer; + private static readonly MethodInfo __sequenceEqualWithSpanAndReadOnlySpan; + private static readonly MethodInfo __sequenceEqualWithSpanAndReadOnlySpanAndComparer; + + // static constructor + static MemoryExtensionsMethod() + { + __containsWithReadOnlySpanAndValue = GetContainsWithReadOnlySpanAndValueMethod(); + __containsWithReadOnlySpanAndValueAndComparer = GetContainsWithReadOnlySpanAndValueAndComparerMethod(); + __containsWithSpanAndValue = GetContainsWithSpanAndValueMethod(); + __sequenceEqualWithReadOnlySpanAndReadOnlySpan = GetSequenceEqualWithReadOnlySpanAndReadOnlySpan(); + __sequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer = GetSequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer(); + __sequenceEqualWithSpanAndReadOnlySpan = GetSequenceEqualWithSpanAndReadOnlySpan(); + __sequenceEqualWithSpanAndReadOnlySpanAndComparer = GetSequenceEqualWithSpanAndReadOnlySpanAndComparer(); + } + + // public static properties + public static MethodInfo ContainsWithReadOnlySpanAndValue => __containsWithReadOnlySpanAndValue; + public static MethodInfo ContainsWithReadOnlySpanAndValueAndComparer => __containsWithReadOnlySpanAndValueAndComparer; + public static MethodInfo ContainsWithSpanAndValue => __containsWithSpanAndValue; + public static MethodInfo SequenceEqualWithReadOnlySpanAndReadOnlySpan => __sequenceEqualWithReadOnlySpanAndReadOnlySpan; + public static MethodInfo SequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer => __sequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer; + public static MethodInfo SequenceEqualWithSpanAndReadOnlySpan => __sequenceEqualWithSpanAndReadOnlySpan; + public static MethodInfo SequenceEqualWithSpanAndReadOnlySpanAndComparer => __sequenceEqualWithSpanAndReadOnlySpanAndComparer; + + // private static methods + private static MethodInfo GetContainsWithReadOnlySpanAndValueMethod() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "Contains" && + m.Has1GenericArgument(out var itemType) && + m.Has2Parameters(out var spanParameter, out var valueParameter) && + spanParameter.ParameterType.IsReadOnlySpanOf(itemType) && + valueParameter.ParameterType == itemType); + } + + private static MethodInfo GetContainsWithReadOnlySpanAndValueAndComparerMethod() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "Contains" && + m.Has1GenericArgument(out var itemType) && + m.Has3Parameters(out var spanParameter, out var valueParameter, out var comparerParameter) && + spanParameter.ParameterType.IsReadOnlySpanOf(itemType) && + valueParameter.ParameterType == itemType && + comparerParameter.ParameterType == typeof(IEqualityComparer<>).MakeGenericType(itemType)); + } + + private static MethodInfo GetContainsWithSpanAndValueMethod() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "Contains" && + m.Has1GenericArgument(out var itemType) && + m.Has2Parameters(out var spanParameter, out var valueParameter) && + spanParameter.ParameterType.IsSpanOf(itemType) && + valueParameter.ParameterType == itemType); + } + + private static MethodInfo GetSequenceEqualWithReadOnlySpanAndReadOnlySpan() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "SequenceEqual" && + m.Has1GenericArgument(out var itemType) && + m.Has2Parameters(out var spanParameter, out var otherParameter) && + spanParameter.ParameterType.IsReadOnlySpanOf(itemType) && + otherParameter.ParameterType.IsReadOnlySpanOf(itemType)); + } + + private static MethodInfo GetSequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "SequenceEqual" && + m.Has1GenericArgument(out var itemType) && + m.Has3Parameters(out var spanParameter, out var otherParameter, out var comparerParameter) && + spanParameter.ParameterType.IsReadOnlySpanOf(itemType) && + otherParameter.ParameterType.IsReadOnlySpanOf(itemType) && + comparerParameter.ParameterType == typeof(IEqualityComparer<>).MakeGenericType(itemType)); + } + + private static MethodInfo GetSequenceEqualWithSpanAndReadOnlySpan() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "SequenceEqual" && + m.Has1GenericArgument(out var itemType) && + m.Has2Parameters(out var spanParameter, out var otherParameter) && + spanParameter.ParameterType.IsSpanOf(itemType) && + otherParameter.ParameterType.IsReadOnlySpanOf(itemType)); + } + + private static MethodInfo GetSequenceEqualWithSpanAndReadOnlySpanAndComparer() + { + return + typeof(MemoryExtensions) + .GetMethods(BindingFlags.Public | BindingFlags.Static) + .FirstOrDefault(m => + m.Name == "SequenceEqual" && + m.Has1GenericArgument(out var itemType) && + m.Has3Parameters(out var spanParameter, out var otherParameter, out var comparerParameter) && + spanParameter.ParameterType.IsSpanOf(itemType) && + otherParameter.ParameterType.IsReadOnlySpanOf(itemType) && + comparerParameter.ParameterType == typeof(IEqualityComparer<>).MakeGenericType(itemType)); + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs index b96a193e323..a6a89b7639f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToExecutableQueryTranslators/ExpressionToExecutableQueryTranslator.cs @@ -29,7 +29,7 @@ public static ExecutableQuery> Translate TranslateScalar(provider); body = ExpressionReplacer.Replace(body, queryableParameter, Expression.Constant(queryable)); - body = PartialEvaluator.EvaluatePartially(body); + body = LinqExpressionPreprocessor.Preprocess(body); return ExpressionToPipelineTranslator.Translate(context, body); } @@ -338,7 +338,7 @@ private static TranslatedPipeline TranslateLookupPipelineAgainstQueryable>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var context = TranslationContext.Create(translationOptions, contextData); var translation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, expression, sourceSerializer, asRoot: true); var simplifiedAst = AstSimplifier.Simplify(translation.Ast); @@ -74,7 +74,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - expression = (LambdaExpression)PartialEvaluator.EvaluatePartially(expression); + expression = (LambdaExpression)LinqExpressionPreprocessor.Preprocess(expression); var parameter = expression.Parameters.Single(); var context = TranslationContext.Create(translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); @@ -104,7 +104,7 @@ internal static RenderedFieldDefinition TranslateExpressionToField>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var parameter = expression.Parameters.Single(); var context = TranslationContext.Create(translationOptions); var symbol = context.CreateSymbol(parameter, documentSerializer, isCurrent: true); @@ -124,7 +124,7 @@ internal static BsonDocument TranslateExpressionToElemMatchFilter( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var context = TranslationContext.Create(translationOptions); var parameter = expression.Parameters.Single(); var symbol = context.CreateSymbol(parameter, "@", elementSerializer); // @ represents the implied element @@ -141,7 +141,7 @@ internal static BsonDocument TranslateExpressionToFilter( IBsonSerializerRegistry serializerRegistry, ExpressionTranslationOptions translationOptions) { - expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var context = TranslationContext.Create(translationOptions); var filter = ExpressionToFilterTranslator.TranslateLambda(context, expression, documentSerializer, asRoot: true); filter = AstSimplifier.SimplifyAndConvert(filter); @@ -175,7 +175,7 @@ private static RenderedProjectionDefinition TranslateExpressionToProjec return new RenderedProjectionDefinition(null, (IBsonSerializer)inputSerializer); } - expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var context = TranslationContext.Create(translationOptions); var simplifier = forFind ? new AstFindProjectionSimplifier() : new AstSimplifier(); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs new file mode 100644 index 00000000000..470f8828777 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5749Tests.cs @@ -0,0 +1,257 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.TestHelpers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira; + +#if NET6_0_OR_GREATER || NETCOREAPP3_1_OR_GREATER +public class CSharp5749Tests : LinqIntegrationTest +{ + public CSharp5749Tests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void MemoryExtensions_Contains_in_Where_should_work() + { + var collection = Fixture.Collection; + var names = new[] { "Two", "Three" }; + + var queryable = collection.AsQueryable().Where(Rewrite((C x) => names.Contains(x.Name))); + + var results = queryable.ToArray(); + results.Select(x => x.Id).Should().Equal(2, 3); + } + + [Fact] + public void MemoryExtensions_Contains_in_Single_should_work() + { + var collection = Fixture.Collection; + var names = new[] { "Two" }; + + var result = collection.AsQueryable().Single(Rewrite((C x) => names.Contains(x.Name))); + + result.Id.Should().Be(2); + } + + [Fact] + public void MemoryExtensions_Contains_in_Any_should_work() + { + var collection = Fixture.Collection; + var ids = new[] { 2 }; + + var result = collection.AsQueryable().Any(Rewrite((C x) => ids.Contains(x.Id))); + + result.Should().BeTrue(); + } + + [Fact] + public void MemoryExtensions_Contains_in_Count_should_work() + { + var collection = Fixture.Collection; + var ids = new[] { 2 }; + + var result = collection.AsQueryable().Count(Rewrite((C x) => ids.Contains(x.Id))); + + result.Should().Be(1); + } + + [Fact] + public void MemoryExtensions_SequenceEqual_in_Where_should_work() + { + var collection = Fixture.Collection; + var ratings = new[] { 1, 9, 6 }; + + var queryable = collection.AsQueryable().Where(Rewrite((C x) => ratings.SequenceEqual(x.Ratings))); + + var results = queryable.ToArray(); + results.Select(x => x.Id).Should().Equal(3); + } + + [Fact] + public void MemoryExtensions_SequenceEqual_in_Single_should_work() + { + var collection = Fixture.Collection; + var ratings = new[] { 1, 9, 6 }; + + var result = collection.AsQueryable().Single(Rewrite((C x) => ratings.SequenceEqual(x.Ratings))); + + result.Id.Should().Be(3); + } + + [Fact] + public void MemoryExtensions_SequenceEqual_in_Any_should_work() + { + var collection = Fixture.Collection; + var ratings = new[] { 1, 2, 3, 4, 5 }; + + var result = collection.AsQueryable().Any(Rewrite((C x) => ratings.SequenceEqual(x.Ratings))); + + result.Should().BeTrue(); + } + + [Fact] + public void MemoryExtensions_SequenceEqual_in_Count_should_work() + { + var collection = Fixture.Collection; + var ratings = new[] { 3, 4, 5, 6, 7 }; + + var result = collection.AsQueryable().Count(Rewrite((C x) => ratings.SequenceEqual(x.Ratings))); + + result.Should().Be(1); + } + + public class C + { + public int Id { get; set; } + public string Name { get; set; } + public int[] Ratings { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + BsonDocument.Parse("{ _id : 1, Name : \"One\", Ratings : [1, 2, 3, 4, 5] }"), + BsonDocument.Parse("{ _id : 2, Name : \"Two\", Ratings : [3, 4, 5, 6, 7] }"), + BsonDocument.Parse("{ _id : 3, Name : \"Three\", Ratings : [1, 9, 6] }") + ]; + } + + private Expression> Rewrite(Expression> predicate) + { + return (Expression>)new EnumerableToMemoryExtensionsRewriter().Visit(predicate); + } + + public class EnumerableToMemoryExtensionsRewriter : ExpressionVisitor + { + protected override Expression VisitMethodCall(MethodCallExpression node) + { + node = (MethodCallExpression)base.VisitMethodCall(node); + + var method = node.Method; + var arguments = node.Arguments; + + return method.Name switch + { + "Contains" => VisitContainsMethod(node, method, arguments), + "SequenceEqual" => VisitSequenceEqualMethod(node, method, arguments), + _ => node + }; + + static Expression VisitContainsMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection arguments) + { + if (method.Is(EnumerableMethod.Contains)) + { + var itemType = method.GetGenericArguments().Single(); + var source = arguments[0]; + var value = arguments[1]; + + if (source.Type.IsArray) + { + var readOnlySpan = ImplicitCastArrayToSpan(source, typeof(ReadOnlySpan<>), itemType); + return + Expression.Call( + MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValue.MakeGenericMethod(itemType), + [readOnlySpan, value]); + } + } + else if (method.Is(EnumerableMethod.ContainsWithComparer)) + { + var itemType = method.GetGenericArguments().Single(); + var source = arguments[0]; + var value = arguments[1]; + var comparer = arguments[2]; + + if (source.Type.IsArray) + { + var readOnlySpan = ImplicitCastArrayToSpan(source, typeof(ReadOnlySpan<>), itemType); + return + Expression.Call( + MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValueAndComparer.MakeGenericMethod(itemType), + [readOnlySpan, value, comparer]); + } + } + + + return node; + } + + static Expression VisitSequenceEqualMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection arguments) + { + if (method.Is(EnumerableMethod.SequenceEqual)) + { + var itemType = method.GetGenericArguments().Single(); + var first = arguments[0]; + var second = arguments[1]; + + if (first.Type.IsArray && second.Type.IsArray) + { + var firstReadOnlySpan = ImplicitCastArrayToSpan(first, typeof(ReadOnlySpan<>), itemType); + var secondReadOnlySpan = ImplicitCastArrayToSpan(second, typeof(ReadOnlySpan<>), itemType); + return + Expression.Call( + MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpan.MakeGenericMethod(itemType), + [firstReadOnlySpan, secondReadOnlySpan]); + } + } + else if (method.Is(EnumerableMethod.SequenceEqualWithComparer)) + { + var itemType = method.GetGenericArguments().Single(); + var first = arguments[0]; + var second = arguments[1]; + var comparer = arguments[2]; + + if (first.Type.IsArray && second.Type.IsArray) + { + var firstReadOnlySpan = ImplicitCastArrayToSpan(first, typeof(ReadOnlySpan<>), itemType); + var secondReadOnlySpan = ImplicitCastArrayToSpan(second, typeof(ReadOnlySpan<>), itemType); + return + Expression.Call( + MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpan.MakeGenericMethod(itemType), + [firstReadOnlySpan, secondReadOnlySpan, comparer]); + } + } + + return node; + } + + static Expression ImplicitCastArrayToSpan(Expression value, Type spanType, Type itemType) + { + var opImplicitMethod = spanType.MakeGenericType(itemType).GetMethod( + "op_Implicit", + BindingFlags.Public | BindingFlags.Static, + null, + [itemType.MakeArrayType()], + null); + return Expression.Call(opImplicitMethod, value); + } + } + } +} +#endif diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs index cd49af1955f..fa01543be13 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/LegacyPredicateTranslatorTests.cs @@ -1180,7 +1180,7 @@ public void TestWhereXNotEquals1Not() private void Assert(Expression> expression, int expectedCount, string expectedFilter) { - expression = (Expression>)PartialEvaluator.EvaluatePartially(expression); + expression = (Expression>)LinqExpressionPreprocessor.Preprocess(expression); var parameter = expression.Parameters.Single(); var serializer = BsonSerializer.LookupSerializer(); diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs index 0869d70822e..c96a2a96b9a 100644 --- a/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3ImplementationWithLinq2Tests/Translators/PredicateTranslatorTests.cs @@ -1150,7 +1150,7 @@ private void Assert(IMongoCollection collection, Expressio public List Assert(IMongoCollection collection, Expression> filter, int expectedCount, BsonDocument expectedFilter) { - filter = (Expression>)PartialEvaluator.EvaluatePartially(filter); + filter = (Expression>)LinqExpressionPreprocessor.Preprocess(filter); var serializer = BsonSerializer.SerializerRegistry.GetSerializer(); var parameter = filter.Parameters.Single();