Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ private static readonly MethodInfo _thenIncludeAfterEnumerableMethodInfo
&& mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IIncludableQueryable<,>)
&& mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>));

private readonly record struct CacheKey(Type EntityType, Type PropertyType, Type? PreviousReturnType);
private readonly record struct CacheKey(int IncludeType, Type EntityType, Type PropertyType, Type? PreviousReturnType);
private static readonly ConcurrentDictionary<CacheKey, Func<IQueryable, LambdaExpression, IQueryable>> _cache = new();


Expand All @@ -53,14 +53,15 @@ public IQueryable<T> Evaluate<T>(IQueryable<T> source, Specification<T> specific
{
if (item.Bag == (int)IncludeType.Include)
{
var key = new CacheKey(typeof(T), expr.ReturnType, null);
var key = new CacheKey(item.Bag, typeof(T), expr.ReturnType, null);
previousReturnType = expr.ReturnType;
var include = _cache.GetOrAdd(key, CreateIncludeDelegate);
source = (IQueryable<T>)include(source, expr);
}
else if (item.Bag == (int)IncludeType.ThenInclude)
else if (item.Bag == (int)IncludeType.ThenIncludeAfterReference
|| item.Bag == (int)IncludeType.ThenIncludeAfterCollection)
{
var key = new CacheKey(typeof(T), expr.ReturnType, previousReturnType);
var key = new CacheKey(item.Bag, typeof(T), expr.ReturnType, previousReturnType);
previousReturnType = expr.ReturnType;
var include = _cache.GetOrAdd(key, CreateThenIncludeDelegate);
source = (IQueryable<T>)include(source, expr);
Expand Down Expand Up @@ -90,33 +91,21 @@ private static Func<IQueryable, LambdaExpression, IQueryable> CreateThenIncludeD
{
Debug.Assert(cacheKey.PreviousReturnType is not null);

var thenIncludeInfo = IsGenericEnumerable(cacheKey.PreviousReturnType, out var previousPropertyType)
? _thenIncludeAfterEnumerableMethodInfo
: _thenIncludeAfterReferenceMethodInfo;
var (thenIncludeMethod, previousPropertyType) = cacheKey.IncludeType == (int)IncludeType.ThenIncludeAfterReference
? (_thenIncludeAfterReferenceMethodInfo, cacheKey.PreviousReturnType)
: (_thenIncludeAfterEnumerableMethodInfo, cacheKey.PreviousReturnType.GenericTypeArguments[0]);

var thenIncludeMethod = thenIncludeInfo.MakeGenericMethod(cacheKey.EntityType, previousPropertyType, cacheKey.PropertyType);
var thenIncludeMethodGeneric = thenIncludeMethod.MakeGenericMethod(cacheKey.EntityType, previousPropertyType, cacheKey.PropertyType);
var sourceParameter = Expression.Parameter(typeof(IQueryable));
var selectorParameter = Expression.Parameter(typeof(LambdaExpression));

var call = Expression.Call(
thenIncludeMethod,
thenIncludeMethodGeneric,
// We must pass cacheKey.PreviousReturnType. It must be exact type, not the generic type argument
Expression.Convert(sourceParameter, typeof(IIncludableQueryable<,>).MakeGenericType(cacheKey.EntityType, cacheKey.PreviousReturnType)),
Expression.Convert(selectorParameter, typeof(Expression<>).MakeGenericType(typeof(Func<,>).MakeGenericType(previousPropertyType, cacheKey.PropertyType))));

var lambda = Expression.Lambda<Func<IQueryable, LambdaExpression, IQueryable>>(call, sourceParameter, selectorParameter);
return lambda.Compile();
}

private static bool IsGenericEnumerable(Type type, out Type propertyType)
{
if (type.IsGenericType && typeof(IEnumerable).IsAssignableFrom(type))
{
propertyType = type.GenericTypeArguments[0];
return true;
}

propertyType = type;
return false;
}
}
8 changes: 4 additions & 4 deletions src/QuerySpecification/Builders/Builder_Include.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ public static IIncludableSpecificationBuilder<TEntity, TResult, TProperty> ThenI
{
if (condition && !Specification<TEntity, TResult>.IsChainDiscarded)
{
builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenInclude);
builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenIncludeAfterReference);
}
else
{
Expand Down Expand Up @@ -223,7 +223,7 @@ public static IIncludableSpecificationBuilder<TEntity, TProperty> ThenInclude<TE
{
if (condition && !Specification<TEntity>.IsChainDiscarded)
{
builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenInclude);
builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenIncludeAfterReference);
}
else
{
Expand Down Expand Up @@ -269,7 +269,7 @@ public static IIncludableSpecificationBuilder<TEntity, TResult, TProperty> ThenI
{
if (condition && !Specification<TEntity, TResult>.IsChainDiscarded)
{
builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenInclude);
builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenIncludeAfterCollection);
}
else
{
Expand Down Expand Up @@ -313,7 +313,7 @@ public static IIncludableSpecificationBuilder<TEntity, TProperty> ThenInclude<TE
{
if (condition && !Specification<TEntity>.IsChainDiscarded)
{
builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenInclude);
builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenIncludeAfterCollection);
}
else
{
Expand Down
11 changes: 8 additions & 3 deletions src/QuerySpecification/Expressions/IncludeType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@
public enum IncludeType
{
/// <summary>
/// Represents an include operation.
/// Represents an Include operation.
/// </summary>
Include = 1,

/// <summary>
/// Represents a then include operation.
/// Represents a ThenInclude operation after reference include.
/// </summary>
ThenInclude = 2
ThenIncludeAfterReference = 2,

/// <summary>
/// Represents a ThenInclude operation after collection include.
/// </summary>
ThenIncludeAfterCollection = 3
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
using Microsoft.EntityFrameworkCore.Query;
using System.Collections;
using System.Diagnostics;
using System.Linq.Expressions;
using System.Reflection;
using System.Linq.Expressions;

namespace QuerySpecification.Benchmarks;

Expand Down Expand Up @@ -41,124 +37,10 @@ public object EFCore()
}

[Benchmark]
public object Spec_MethodInvoke()
{
var evaluator = IncludeEvaluatorMethodInvoke.Instance;
var result = evaluator.Evaluate(_queryable, _spec);

return result;
}

[Benchmark]
public object Spec_v11()
public object Spec()
{
var evaluator = IncludeEvaluator.Instance;
var result = evaluator.Evaluate(_queryable, _spec);
return result;
}

private sealed class IncludeEvaluatorMethodInvoke : IEvaluator
{
private static readonly MethodInfo _includeMethodInfo = typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.Include))
.Single(mi => mi.IsPublic && mi.GetGenericArguments().Length == 2
&& mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IQueryable<>)
&& mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>));

private static readonly MethodInfo _thenIncludeAfterReferenceMethodInfo
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.ThenInclude))
.Single(mi => mi.IsPublic && mi.GetGenericArguments().Length == 3
&& mi.GetParameters()[0].ParameterType.GenericTypeArguments[1].IsGenericParameter
&& mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IIncludableQueryable<,>)
&& mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>));

private static readonly MethodInfo _thenIncludeAfterEnumerableMethodInfo
= typeof(EntityFrameworkQueryableExtensions)
.GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.ThenInclude))
.Single(mi => mi.IsPublic && mi.GetGenericArguments().Length == 3
&& !mi.GetParameters()[0].ParameterType.GenericTypeArguments[1].IsGenericParameter
&& mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IIncludableQueryable<,>)
&& mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>));

private IncludeEvaluatorMethodInvoke() { }
public static IncludeEvaluatorMethodInvoke Instance = new();

public IQueryable<T> Evaluate<T>(IQueryable<T> source, Specification<T> specification) where T : class
{
if (specification.IsEmpty) return source;

foreach (var item in specification.Items)
{
if (item.Type == ItemType.IncludeString && item.Reference is string includeString)
{
source = source.Include(includeString);
}
}

var isPreviousPropertyCollection = false;

foreach (var item in specification.Items)
{
if (item.Type == ItemType.Include && item.Reference is LambdaExpression expr)
{
if (item.Bag == (int)IncludeType.Include)
{
source = BuildInclude<T>(source, expr);
isPreviousPropertyCollection = IsCollection(expr.ReturnType);
}
else if (item.Bag == (int)IncludeType.ThenInclude)
{
source = BuildThenInclude<T>(source, expr, isPreviousPropertyCollection);
isPreviousPropertyCollection = IsCollection(expr.ReturnType);
}
}
}

return source;
}

private static IQueryable<T> BuildInclude<T>(IQueryable source, LambdaExpression includeExpression)

{
Debug.Assert(includeExpression is not null);

var result = _includeMethodInfo
.MakeGenericMethod(typeof(T), includeExpression.ReturnType)
.Invoke(null, [source, includeExpression]);

Debug.Assert(result is not null);

return (IQueryable<T>)result;
}


private static IQueryable<T> BuildThenInclude<T>(IQueryable source, LambdaExpression includeExpression, bool isPreviousPropertyCollection)
{
Debug.Assert(includeExpression is not null);

var previousPropertyType = includeExpression.Parameters[0].Type;

var mi = isPreviousPropertyCollection
? _thenIncludeAfterEnumerableMethodInfo.MakeGenericMethod(typeof(T), previousPropertyType, includeExpression.ReturnType)
: _thenIncludeAfterReferenceMethodInfo.MakeGenericMethod(typeof(T), previousPropertyType, includeExpression.ReturnType);

var result = mi.Invoke(null, [source, includeExpression]);

Debug.Assert(result is not null);

return (IQueryable<T>)result;
}

public static bool IsCollection(Type type)
{
// Exclude string, which implements IEnumerable but is not considered a collection
if (type == typeof(string))
{
return false;
}

return typeof(IEnumerable).IsAssignableFrom(type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,36 @@ public void QueriesMatch_GivenIncludeExpressions()
actual.Should().Be(expected);
}

[Fact]
public void QueriesMatch_GivenInheritanceModel()
{
var spec = new Specification<Bar>();
spec.Query
.Include(x => x.BarChildren)
.ThenInclude(x => (x as BarDerived)!.BarDerivedInfo);

var actual = _evaluator
.Evaluate(DbContext.Bars, spec)
.ToQueryString();

var spec2 = new Specification<Bar>();
spec2.Query
.Include(x => x.BarChildren)
.ThenInclude<Bar, BarChild, BarDerivedInfo>(x => (x as BarDerived)!.BarDerivedInfo);

var actual2 = _evaluator
.Evaluate(DbContext.Bars, spec2)
.ToQueryString();

var expected = DbContext.Bars
.Include(x => x.BarChildren)
.ThenInclude(x => (x as BarDerived)!.BarDerivedInfo)
.ToQueryString();

actual.Should().Be(expected);
actual2.Should().Be(expected);
}

[Fact]
public void QueriesMatch_GivenThenIncludeWithVariousNavigationCollectionTypes()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
namespace Tests.Fixture;

public class Bar
{
public int Id { get; set; }
public string? Dummy { get; set; }

private readonly List<BarChild> _barChildren = [];
public IReadOnlyCollection<BarChild> BarChildren => _barChildren.AsReadOnly();
}

public class BarChild
{
public int Id { get; set; }
public string? Dummy { get; set; }

public int BarId { get; set; }
public Bar Bar { get; set; } = default!;
}

public class BarDerived : BarChild
{
public int BarDerivedInfoId { get; set; }
public BarDerivedInfo BarDerivedInfo { get; set; } = default!;
}

public class BarDerivedInfo
{
public int Id { get; set; }
public string? Name { get; set; }
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

public class TestDbContext(DbContextOptions options) : DbContext(options)
{
public DbSet<Bar> Bars => Set<Bar>();
public DbSet<Foo> Foos => Set<Foo>();
public DbSet<Country> Countries => Set<Country>();
public DbSet<Company> Companies => Set<Company>();
Expand All @@ -23,5 +24,8 @@ protected override void OnModelCreating(ModelBuilder modelBuilder)

modelBuilder.Entity<Country>()
.HasQueryFilter(x => !x.IsDeleted);

modelBuilder.Entity<BarDerived>()
.HasBaseType<BarChild>();
}
}
Loading