Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ private AstStage RenderProjectStage(
ExpressionTranslationOptions translationOptions,
out IBsonSerializer<TOutput> outputSerializer)
{
var partiallyEvaluatedOutput = (Expression<Func<TGrouping, TOutput>>)PartialEvaluator.EvaluatePartially(_output);
var partiallyEvaluatedOutput = (Expression<Func<TGrouping, TOutput>>)LinqExpressionPreprocessor.Preprocess(_output);
var context = TranslationContext.Create(translationOptions);
var outputTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedOutput, inputSerializer, asRoot: true);
var (projectStage, projectSerializer) = ProjectionHelper.CreateProjectStage(outputTranslation);
Expand Down Expand Up @@ -105,7 +105,7 @@ protected override AstStage RenderGroupingStage(
ExpressionTranslationOptions translationOptions,
out IBsonSerializer<IGrouping<TValue, TInput>> groupingOutputSerializer)
{
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)LinqExpressionPreprocessor.Preprocess(_groupBy);
var context = TranslationContext.Create(translationOptions);
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);

Expand Down Expand Up @@ -149,7 +149,7 @@ protected override AstStage RenderGroupingStage(
ExpressionTranslationOptions translationOptions,
out IBsonSerializer<IGrouping<AggregateBucketAutoResultId<TValue>, TInput>> groupingOutputSerializer)
{
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)LinqExpressionPreprocessor.Preprocess(_groupBy);
var context = TranslationContext.Create(translationOptions);
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);

Expand Down Expand Up @@ -187,7 +187,7 @@ protected override AstStage RenderGroupingStage(
ExpressionTranslationOptions translationOptions,
out IBsonSerializer<IGrouping<TValue, TInput>> groupingOutputSerializer)
{
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)PartialEvaluator.EvaluatePartially(_groupBy);
var partiallyEvaluatedGroupBy = (Expression<Func<TInput, TValue>>)LinqExpressionPreprocessor.Preprocess(_groupBy);
var context = TranslationContext.Create(translationOptions);
var groupByTranslation = ExpressionToAggregationExpressionTranslator.TranslateLambdaBody(context, partiallyEvaluatedGroupBy, inputSerializer, asRoot: true);
var pushElements = AstExpression.AccumulatorField("_elements", AstUnaryAccumulatorOperator.Push, AstExpression.RootVar);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System;
using System.Collections.ObjectModel;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;

namespace MongoDB.Driver.Linq.Linq3Implementation.Misc;

/// <summary>
/// This visitor rewrites expressions where new features of .NET CLR or
/// C# compiler interfere with LINQ expression tree translation.
/// </summary>
internal class ClrCompatExpressionRewriter : ExpressionVisitor
{
private static readonly ClrCompatExpressionRewriter __instance = new();

public static Expression Rewrite(Expression expression)
=> __instance.Visit(expression);

/// <inheritdoc />
protected override Expression VisitMethodCall(MethodCallExpression node)
{
node = (MethodCallExpression)base.VisitMethodCall(node);

var method = node.Method;
var arguments = node.Arguments;

return method.Name switch
{
"Contains" => VisitContainsMethod(node, method, arguments),
"SequenceEqual" => VisitSequenceEqualMethod(node, method, arguments),
_ => node
};

static Expression VisitContainsMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection<Expression> arguments)
{
if (method.IsOneOf(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValue, MemoryExtensionsMethod.ContainsWithSpanAndValue))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one line accurately checks what method we are rewriting.

The many lines of code that it used to take to check this are now encapsulated in the MemoryExtensionsMethod class.

{
var itemType = method.GetGenericArguments().Single();
var span = arguments[0];
var value = arguments[1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few intermediate variables go a long way to make the code understandable.


if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) &&
unwrappedSpan.Type.ImplementsIEnumerableOf(itemType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively we could check that unwrappedSpan is an array, but technically all that is required for this rewriting to be OK is that it implements IEnumerable<TItem>.

{
return
Expression.Call(
EnumerableMethod.Contains.MakeGenericMethod(itemType),
[unwrappedSpan, value]);
}
}
else if (method.Is(MemoryExtensionsMethod.ContainsWithReadOnlySpanAndValueAndComparer))
{
var itemType = method.GetGenericArguments().Single();
var span = arguments[0];
var value = arguments[1];
var comparer = arguments[2];

if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) &&
unwrappedSpan.Type.ImplementsIEnumerableOf(itemType))
{
return
Expression.Call(
EnumerableMethod.ContainsWithComparer.MakeGenericMethod(itemType),
[unwrappedSpan, value, comparer]);
}
}

return node;
}

static Expression VisitSequenceEqualMethod(MethodCallExpression node, MethodInfo method, ReadOnlyCollection<Expression> arguments)
{
if (method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpan, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpan))
{
var itemType = method.GetGenericArguments().Single();
var span = arguments[0];
var other = arguments[1];

if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) &&
TryUnwrapSpanImplicitCast(other, out var unwrappedOther) &&
unwrappedSpan.Type.ImplementsIEnumerableOf(itemType) &&
unwrappedOther.Type.ImplementsIEnumerableOf(itemType))
{
return
Expression.Call(
EnumerableMethod.SequenceEqual.MakeGenericMethod(itemType),
[unwrappedSpan, unwrappedOther]);
}
}
else if (method.IsOneOf(MemoryExtensionsMethod.SequenceEqualWithReadOnlySpanAndReadOnlySpanAndComparer, MemoryExtensionsMethod.SequenceEqualWithSpanAndReadOnlySpanAndComparer))
{
var itemType = method.GetGenericArguments().Single();
var span = arguments[0];
var other = arguments[1];
var comparer = arguments[2];

if (TryUnwrapSpanImplicitCast(span, out var unwrappedSpan) &&
TryUnwrapSpanImplicitCast(other, out var unwrappedOther) &&
unwrappedSpan.Type.ImplementsIEnumerableOf(itemType) &&
unwrappedOther.Type.ImplementsIEnumerableOf(itemType))
{
return
Expression.Call(
EnumerableMethod.SequenceEqualWithComparer.MakeGenericMethod(itemType),
[unwrappedSpan, unwrappedOther, comparer]);
}
}

return node;
}

// Erase implicit casts to ReadOnlySpan<T> and Span<T>
static bool TryUnwrapSpanImplicitCast(Expression expression, out Expression result)
{
if (expression is MethodCallExpression
{
Method:
{
Name: "op_Implicit", DeclaringType: { IsGenericType: true } implicitCastDeclaringType
}
} methodCallExpression
&& implicitCastDeclaringType.GetGenericTypeDefinition() is var genericTypeDefinition
&& (genericTypeDefinition == typeof(Span<>) || genericTypeDefinition == typeof(ReadOnlySpan<>)))
{
result = methodCallExpression.Arguments[0];
return true;
}

result = null;
return false;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright 2010-present MongoDB Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

using System.Linq.Expressions;

namespace MongoDB.Driver.Linq.Linq3Implementation.Misc;

/// <summary>
/// This class is called before we process any LINQ expression trees
/// to perform any necessary pre-processing such as CLR compatibility
/// and partial evaluation.
/// </summary>
internal static class LinqExpressionPreprocessor
{
public static Expression Preprocess(Expression expression)
{
expression = ClrCompatExpressionRewriter.Rewrite(expression);
expression = PartialEvaluator.EvaluatePartially(expression);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get exceptions if I try to call EvaluatePartially first.

return expression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,59 @@
* limitations under the License.
*/

using System;
using System.Reflection;

namespace MongoDB.Driver.Linq.Linq3Implementation.Misc
{
internal static class MethodInfoExtensions
{
public static bool Has1GenericArgument(this MethodInfo method, out Type genericArgument)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Helper methods that make reading if statements easier by reducing multiple lines of tests to one line with clear intent.

{
if (method.IsGenericMethod &&
method.GetGenericArguments() is var genericArguments &&
genericArguments.Length == 1)
{
genericArgument = genericArguments[0];
return true;
}

genericArgument = null;
return false;
}

public static bool Has2Parameters(this MethodInfo method, out ParameterInfo parameter1, out ParameterInfo parameter2)
{
if (method.GetParameters() is var parameters &&
parameters.Length == 2)
{
parameter1 = parameters[0];
parameter2 = parameters[1];
return true;
}

parameter1 = null;
parameter2 = null;
return false;
}

public static bool Has3Parameters(this MethodInfo method, out ParameterInfo parameter1, out ParameterInfo parameter2, out ParameterInfo parameter3)
{
if (method.GetParameters() is var parameters &&
parameters.Length == 3)
{
parameter1 = parameters[0];
parameter2 = parameters[1];
parameter3 = parameters[2];
return true;
}

parameter1 = null;
parameter2 = null;
parameter3 = null;
return false;
}

public static bool Is(this MethodInfo method, MethodInfo comparand)
{
if (comparand != null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,27 @@ public static bool IsNullableOf(this Type type, Type valueType)
return type.IsNullable(out var nullableValueType) && nullableValueType == valueType;
}

public static bool IsReadOnlySpanOf(this Type type, Type itemType)
{
return
type.IsGenericType &&
type.GetGenericTypeDefinition() == typeof(ReadOnlySpan<>) &&
type.GetGenericArguments()[0] == itemType;
}

public static bool IsSameAsOrNullableOf(this Type type, Type valueType)
{
return type == valueType || type.IsNullableOf(valueType);
}

public static bool IsSpanOf(this Type type, Type itemType)
{
return
type.IsGenericType &&
type.GetGenericTypeDefinition() == typeof(Span<>) &&
type.GetGenericArguments()[0] == itemType;
}

public static bool IsSubclassOfOrImplements(this Type type, Type baseTypeOrInterface)
{
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ internal static class EnumerableMethod
private static readonly MethodInfo __cast;
private static readonly MethodInfo __concat;
private static readonly MethodInfo __contains;
private static readonly MethodInfo __containsWithComparer;
private static readonly MethodInfo __count;
private static readonly MethodInfo __countWithPredicate;
private static readonly MethodInfo __defaultIfEmpty;
Expand Down Expand Up @@ -150,6 +151,7 @@ internal static class EnumerableMethod
private static readonly MethodInfo __selectManyWithSelectorTakingIndex;
private static readonly MethodInfo __selectWithSelectorTakingIndex;
private static readonly MethodInfo __sequenceEqual;
private static readonly MethodInfo __sequenceEqualWithComparer;
private static readonly MethodInfo __single;
private static readonly MethodInfo __singleOrDefault;
private static readonly MethodInfo __singleOrDefaultWithPredicate;
Expand Down Expand Up @@ -226,6 +228,7 @@ static EnumerableMethod()
__cast = ReflectionInfo.Method((IEnumerable source) => source.Cast<object>());
__concat = ReflectionInfo.Method((IEnumerable<object> first, IEnumerable<object> second) => first.Concat(second));
__contains = ReflectionInfo.Method((IEnumerable<object> source, object value) => source.Contains(value));
__containsWithComparer = ReflectionInfo.Method((IEnumerable<object> source, object value, IEqualityComparer<object> comparer) => source.Contains(value, comparer));
__count = ReflectionInfo.Method((IEnumerable<object> source) => source.Count());
__countWithPredicate = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.Count(predicate));
__defaultIfEmpty = ReflectionInfo.Method((IEnumerable<object> source) => source.DefaultIfEmpty());
Expand Down Expand Up @@ -317,6 +320,7 @@ static EnumerableMethod()
__selectManyWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable<object> source, Func<object, int, IEnumerable<object>> selector) => source.SelectMany(selector));
__selectWithSelectorTakingIndex = ReflectionInfo.Method((IEnumerable<object> source, Func<object, int, object> selector) => source.Select(selector));
__sequenceEqual = ReflectionInfo.Method((IEnumerable<object> first, IEnumerable<object> second) => first.SequenceEqual(second));
__sequenceEqualWithComparer = ReflectionInfo.Method((IEnumerable<object> first, IEnumerable<object> second, IEqualityComparer<object> comparer) => first.SequenceEqual(second, comparer));
__single = ReflectionInfo.Method((IEnumerable<object> source) => source.Single());
__singleOrDefault = ReflectionInfo.Method((IEnumerable<object> source) => source.SingleOrDefault());
__singleOrDefaultWithPredicate = ReflectionInfo.Method((IEnumerable<object> source, Func<object, bool> predicate) => source.SingleOrDefault(predicate));
Expand Down Expand Up @@ -392,6 +396,7 @@ static EnumerableMethod()
public static MethodInfo Cast => __cast;
public static MethodInfo Concat => __concat;
public static MethodInfo Contains => __contains;
public static MethodInfo ContainsWithComparer => __containsWithComparer;
public static MethodInfo Count => __count;
public static MethodInfo CountWithPredicate => __countWithPredicate;
public static MethodInfo DefaultIfEmpty => __defaultIfEmpty;
Expand Down Expand Up @@ -483,6 +488,7 @@ static EnumerableMethod()
public static MethodInfo SelectManyWithSelectorTakingIndex => __selectManyWithSelectorTakingIndex;
public static MethodInfo SelectWithSelectorTakingIndex => __selectWithSelectorTakingIndex;
public static MethodInfo SequenceEqual => __sequenceEqual;
public static MethodInfo SequenceEqualWithComparer => __sequenceEqualWithComparer;
public static MethodInfo Single => __single;
public static MethodInfo SingleOrDefault => __singleOrDefault;
public static MethodInfo SingleOrDefaultWithPredicate => __singleOrDefaultWithPredicate;
Expand Down
Loading
Loading