diff --git a/src/QuerySpecification.EntityFrameworkCore/Evaluators/IncludeEvaluator.cs b/src/QuerySpecification.EntityFrameworkCore/Evaluators/IncludeEvaluator.cs index 5e8fce3e..5d8598c2 100644 --- a/src/QuerySpecification.EntityFrameworkCore/Evaluators/IncludeEvaluator.cs +++ b/src/QuerySpecification.EntityFrameworkCore/Evaluators/IncludeEvaluator.cs @@ -33,7 +33,7 @@ private static readonly MethodInfo _thenIncludeAfterEnumerableMethodInfo && mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IIncludableQueryable<,>) && mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>)); - private readonly record struct CacheKey(Type EntityType, Type PropertyType, Type? PreviousReturnType); + private readonly record struct CacheKey(int IncludeType, Type EntityType, Type PropertyType, Type? PreviousReturnType); private static readonly ConcurrentDictionary> _cache = new(); @@ -53,14 +53,15 @@ public IQueryable Evaluate(IQueryable source, Specification specific { if (item.Bag == (int)IncludeType.Include) { - var key = new CacheKey(typeof(T), expr.ReturnType, null); + var key = new CacheKey(item.Bag, typeof(T), expr.ReturnType, null); previousReturnType = expr.ReturnType; var include = _cache.GetOrAdd(key, CreateIncludeDelegate); source = (IQueryable)include(source, expr); } - else if (item.Bag == (int)IncludeType.ThenInclude) + else if (item.Bag == (int)IncludeType.ThenIncludeAfterReference + || item.Bag == (int)IncludeType.ThenIncludeAfterCollection) { - var key = new CacheKey(typeof(T), expr.ReturnType, previousReturnType); + var key = new CacheKey(item.Bag, typeof(T), expr.ReturnType, previousReturnType); previousReturnType = expr.ReturnType; var include = _cache.GetOrAdd(key, CreateThenIncludeDelegate); source = (IQueryable)include(source, expr); @@ -90,16 +91,16 @@ private static Func CreateThenIncludeD { Debug.Assert(cacheKey.PreviousReturnType is not null); - var thenIncludeInfo = IsGenericEnumerable(cacheKey.PreviousReturnType, out var previousPropertyType) - ? _thenIncludeAfterEnumerableMethodInfo - : _thenIncludeAfterReferenceMethodInfo; + var (thenIncludeMethod, previousPropertyType) = cacheKey.IncludeType == (int)IncludeType.ThenIncludeAfterReference + ? (_thenIncludeAfterReferenceMethodInfo, cacheKey.PreviousReturnType) + : (_thenIncludeAfterEnumerableMethodInfo, cacheKey.PreviousReturnType.GenericTypeArguments[0]); - var thenIncludeMethod = thenIncludeInfo.MakeGenericMethod(cacheKey.EntityType, previousPropertyType, cacheKey.PropertyType); + var thenIncludeMethodGeneric = thenIncludeMethod.MakeGenericMethod(cacheKey.EntityType, previousPropertyType, cacheKey.PropertyType); var sourceParameter = Expression.Parameter(typeof(IQueryable)); var selectorParameter = Expression.Parameter(typeof(LambdaExpression)); var call = Expression.Call( - thenIncludeMethod, + thenIncludeMethodGeneric, // We must pass cacheKey.PreviousReturnType. It must be exact type, not the generic type argument Expression.Convert(sourceParameter, typeof(IIncludableQueryable<,>).MakeGenericType(cacheKey.EntityType, cacheKey.PreviousReturnType)), Expression.Convert(selectorParameter, typeof(Expression<>).MakeGenericType(typeof(Func<,>).MakeGenericType(previousPropertyType, cacheKey.PropertyType)))); @@ -107,16 +108,4 @@ private static Func CreateThenIncludeD var lambda = Expression.Lambda>(call, sourceParameter, selectorParameter); return lambda.Compile(); } - - private static bool IsGenericEnumerable(Type type, out Type propertyType) - { - if (type.IsGenericType && typeof(IEnumerable).IsAssignableFrom(type)) - { - propertyType = type.GenericTypeArguments[0]; - return true; - } - - propertyType = type; - return false; - } } diff --git a/src/QuerySpecification/Builders/Builder_Include.cs b/src/QuerySpecification/Builders/Builder_Include.cs index a61179db..227ff8d2 100644 --- a/src/QuerySpecification/Builders/Builder_Include.cs +++ b/src/QuerySpecification/Builders/Builder_Include.cs @@ -179,7 +179,7 @@ public static IIncludableSpecificationBuilder ThenI { if (condition && !Specification.IsChainDiscarded) { - builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenInclude); + builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenIncludeAfterReference); } else { @@ -223,7 +223,7 @@ public static IIncludableSpecificationBuilder ThenInclude.IsChainDiscarded) { - builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenInclude); + builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenIncludeAfterReference); } else { @@ -269,7 +269,7 @@ public static IIncludableSpecificationBuilder ThenI { if (condition && !Specification.IsChainDiscarded) { - builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenInclude); + builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenIncludeAfterCollection); } else { @@ -313,7 +313,7 @@ public static IIncludableSpecificationBuilder ThenInclude.IsChainDiscarded) { - builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenInclude); + builder.Specification.AddInternal(ItemType.Include, navigationSelector, (int)IncludeType.ThenIncludeAfterCollection); } else { diff --git a/src/QuerySpecification/Expressions/IncludeType.cs b/src/QuerySpecification/Expressions/IncludeType.cs index 8b8c8ac5..d26612da 100644 --- a/src/QuerySpecification/Expressions/IncludeType.cs +++ b/src/QuerySpecification/Expressions/IncludeType.cs @@ -6,12 +6,17 @@ public enum IncludeType { /// - /// Represents an include operation. + /// Represents an Include operation. /// Include = 1, /// - /// Represents a then include operation. + /// Represents a ThenInclude operation after reference include. /// - ThenInclude = 2 + ThenIncludeAfterReference = 2, + + /// + /// Represents a ThenInclude operation after collection include. + /// + ThenIncludeAfterCollection = 3 } diff --git a/tests/QuerySpecification.Benchmarks/Benchmarks/Benchmark6_IncludeEvaluator.cs b/tests/QuerySpecification.Benchmarks/Benchmarks/Benchmark6_IncludeEvaluator.cs index 2c471b9f..34c9d1ad 100644 --- a/tests/QuerySpecification.Benchmarks/Benchmarks/Benchmark6_IncludeEvaluator.cs +++ b/tests/QuerySpecification.Benchmarks/Benchmarks/Benchmark6_IncludeEvaluator.cs @@ -1,8 +1,4 @@ -using Microsoft.EntityFrameworkCore.Query; -using System.Collections; -using System.Diagnostics; -using System.Linq.Expressions; -using System.Reflection; +using System.Linq.Expressions; namespace QuerySpecification.Benchmarks; @@ -41,124 +37,10 @@ public object EFCore() } [Benchmark] - public object Spec_MethodInvoke() - { - var evaluator = IncludeEvaluatorMethodInvoke.Instance; - var result = evaluator.Evaluate(_queryable, _spec); - - return result; - } - - [Benchmark] - public object Spec_v11() + public object Spec() { var evaluator = IncludeEvaluator.Instance; var result = evaluator.Evaluate(_queryable, _spec); return result; } - - private sealed class IncludeEvaluatorMethodInvoke : IEvaluator - { - private static readonly MethodInfo _includeMethodInfo = typeof(EntityFrameworkQueryableExtensions) - .GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.Include)) - .Single(mi => mi.IsPublic && mi.GetGenericArguments().Length == 2 - && mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IQueryable<>) - && mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>)); - - private static readonly MethodInfo _thenIncludeAfterReferenceMethodInfo - = typeof(EntityFrameworkQueryableExtensions) - .GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.ThenInclude)) - .Single(mi => mi.IsPublic && mi.GetGenericArguments().Length == 3 - && mi.GetParameters()[0].ParameterType.GenericTypeArguments[1].IsGenericParameter - && mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IIncludableQueryable<,>) - && mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>)); - - private static readonly MethodInfo _thenIncludeAfterEnumerableMethodInfo - = typeof(EntityFrameworkQueryableExtensions) - .GetTypeInfo().GetDeclaredMethods(nameof(EntityFrameworkQueryableExtensions.ThenInclude)) - .Single(mi => mi.IsPublic && mi.GetGenericArguments().Length == 3 - && !mi.GetParameters()[0].ParameterType.GenericTypeArguments[1].IsGenericParameter - && mi.GetParameters()[0].ParameterType.GetGenericTypeDefinition() == typeof(IIncludableQueryable<,>) - && mi.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Expression<>)); - - private IncludeEvaluatorMethodInvoke() { } - public static IncludeEvaluatorMethodInvoke Instance = new(); - - public IQueryable Evaluate(IQueryable source, Specification specification) where T : class - { - if (specification.IsEmpty) return source; - - foreach (var item in specification.Items) - { - if (item.Type == ItemType.IncludeString && item.Reference is string includeString) - { - source = source.Include(includeString); - } - } - - var isPreviousPropertyCollection = false; - - foreach (var item in specification.Items) - { - if (item.Type == ItemType.Include && item.Reference is LambdaExpression expr) - { - if (item.Bag == (int)IncludeType.Include) - { - source = BuildInclude(source, expr); - isPreviousPropertyCollection = IsCollection(expr.ReturnType); - } - else if (item.Bag == (int)IncludeType.ThenInclude) - { - source = BuildThenInclude(source, expr, isPreviousPropertyCollection); - isPreviousPropertyCollection = IsCollection(expr.ReturnType); - } - } - } - - return source; - } - - private static IQueryable BuildInclude(IQueryable source, LambdaExpression includeExpression) - - { - Debug.Assert(includeExpression is not null); - - var result = _includeMethodInfo - .MakeGenericMethod(typeof(T), includeExpression.ReturnType) - .Invoke(null, [source, includeExpression]); - - Debug.Assert(result is not null); - - return (IQueryable)result; - } - - - private static IQueryable BuildThenInclude(IQueryable source, LambdaExpression includeExpression, bool isPreviousPropertyCollection) - { - Debug.Assert(includeExpression is not null); - - var previousPropertyType = includeExpression.Parameters[0].Type; - - var mi = isPreviousPropertyCollection - ? _thenIncludeAfterEnumerableMethodInfo.MakeGenericMethod(typeof(T), previousPropertyType, includeExpression.ReturnType) - : _thenIncludeAfterReferenceMethodInfo.MakeGenericMethod(typeof(T), previousPropertyType, includeExpression.ReturnType); - - var result = mi.Invoke(null, [source, includeExpression]); - - Debug.Assert(result is not null); - - return (IQueryable)result; - } - - public static bool IsCollection(Type type) - { - // Exclude string, which implements IEnumerable but is not considered a collection - if (type == typeof(string)) - { - return false; - } - - return typeof(IEnumerable).IsAssignableFrom(type); - } - } } diff --git a/tests/QuerySpecification.EntityFrameworkCore.Tests/Evaluators/IncludeEvaluatorTests.cs b/tests/QuerySpecification.EntityFrameworkCore.Tests/Evaluators/IncludeEvaluatorTests.cs index b114e84a..b175c160 100644 --- a/tests/QuerySpecification.EntityFrameworkCore.Tests/Evaluators/IncludeEvaluatorTests.cs +++ b/tests/QuerySpecification.EntityFrameworkCore.Tests/Evaluators/IncludeEvaluatorTests.cs @@ -29,6 +29,36 @@ public void QueriesMatch_GivenIncludeExpressions() actual.Should().Be(expected); } + [Fact] + public void QueriesMatch_GivenInheritanceModel() + { + var spec = new Specification(); + spec.Query + .Include(x => x.BarChildren) + .ThenInclude(x => (x as BarDerived)!.BarDerivedInfo); + + var actual = _evaluator + .Evaluate(DbContext.Bars, spec) + .ToQueryString(); + + var spec2 = new Specification(); + spec2.Query + .Include(x => x.BarChildren) + .ThenInclude(x => (x as BarDerived)!.BarDerivedInfo); + + var actual2 = _evaluator + .Evaluate(DbContext.Bars, spec2) + .ToQueryString(); + + var expected = DbContext.Bars + .Include(x => x.BarChildren) + .ThenInclude(x => (x as BarDerived)!.BarDerivedInfo) + .ToQueryString(); + + actual.Should().Be(expected); + actual2.Should().Be(expected); + } + [Fact] public void QueriesMatch_GivenThenIncludeWithVariousNavigationCollectionTypes() { diff --git a/tests/QuerySpecification.EntityFrameworkCore.Tests/Fixture/Data/Bar.cs b/tests/QuerySpecification.EntityFrameworkCore.Tests/Fixture/Data/Bar.cs new file mode 100644 index 00000000..2aba8c99 --- /dev/null +++ b/tests/QuerySpecification.EntityFrameworkCore.Tests/Fixture/Data/Bar.cs @@ -0,0 +1,31 @@ +namespace Tests.Fixture; + +public class Bar +{ + public int Id { get; set; } + public string? Dummy { get; set; } + + private readonly List _barChildren = []; + public IReadOnlyCollection BarChildren => _barChildren.AsReadOnly(); +} + +public class BarChild +{ + public int Id { get; set; } + public string? Dummy { get; set; } + + public int BarId { get; set; } + public Bar Bar { get; set; } = default!; +} + +public class BarDerived : BarChild +{ + public int BarDerivedInfoId { get; set; } + public BarDerivedInfo BarDerivedInfo { get; set; } = default!; +} + +public class BarDerivedInfo +{ + public int Id { get; set; } + public string? Name { get; set; } +} diff --git a/tests/QuerySpecification.EntityFrameworkCore.Tests/Fixture/TestDbContext.cs b/tests/QuerySpecification.EntityFrameworkCore.Tests/Fixture/TestDbContext.cs index 7595188c..1313beb8 100644 --- a/tests/QuerySpecification.EntityFrameworkCore.Tests/Fixture/TestDbContext.cs +++ b/tests/QuerySpecification.EntityFrameworkCore.Tests/Fixture/TestDbContext.cs @@ -2,6 +2,7 @@ public class TestDbContext(DbContextOptions options) : DbContext(options) { + public DbSet Bars => Set(); public DbSet Foos => Set(); public DbSet Countries => Set(); public DbSet Companies => Set(); @@ -23,5 +24,8 @@ protected override void OnModelCreating(ModelBuilder modelBuilder) modelBuilder.Entity() .HasQueryFilter(x => !x.IsDeleted); + + modelBuilder.Entity() + .HasBaseType(); } } diff --git a/tests/QuerySpecification.Tests/Builders/Builder_ThenInclude.cs b/tests/QuerySpecification.Tests/Builders/Builder_ThenInclude.cs index be3c4955..79eb0de4 100644 --- a/tests/QuerySpecification.Tests/Builders/Builder_ThenInclude.cs +++ b/tests/QuerySpecification.Tests/Builders/Builder_ThenInclude.cs @@ -115,7 +115,7 @@ public void DoesNothing_GivenIncludeThenWithDiscardedNestedChain() } [Fact] - public void AddsIncludeThen_GivenIncludeThen() + public void AddsThenInclude_GivenThenIncludeAfterReference() { Expression> expr = x => x.Contact; @@ -132,11 +132,36 @@ public void AddsIncludeThen_GivenIncludeThen() spec1.IncludeExpressions.Should().HaveCount(2); spec1.IncludeExpressions.Last().LambdaExpression.Should().BeSameAs(expr); spec1.IncludeExpressions.First().Type.Should().Be(IncludeType.Include); - spec1.IncludeExpressions.Last().Type.Should().Be(IncludeType.ThenInclude); + spec1.IncludeExpressions.Last().Type.Should().Be(IncludeType.ThenIncludeAfterReference); spec2.IncludeExpressions.Should().HaveCount(2); spec2.IncludeExpressions.Last().LambdaExpression.Should().BeSameAs(expr); spec2.IncludeExpressions.First().Type.Should().Be(IncludeType.Include); - spec2.IncludeExpressions.Last().Type.Should().Be(IncludeType.ThenInclude); + spec2.IncludeExpressions.Last().Type.Should().Be(IncludeType.ThenIncludeAfterReference); + } + + [Fact] + public void AddsThenInclude_GivenThenIncludeAfterCollection() + { + Expression> expr = x => x.Contact; + + var spec1 = new Specification(); + spec1.Query + .Include(x => x.Addresses) + .ThenInclude(expr); + + var spec2 = new Specification(); + spec2.Query + .Include(x => x.Addresses) + .ThenInclude(expr); + + spec1.IncludeExpressions.Should().HaveCount(2); + spec1.IncludeExpressions.Last().LambdaExpression.Should().BeSameAs(expr); + spec1.IncludeExpressions.First().Type.Should().Be(IncludeType.Include); + spec1.IncludeExpressions.Last().Type.Should().Be(IncludeType.ThenIncludeAfterCollection); + spec2.IncludeExpressions.Should().HaveCount(2); + spec2.IncludeExpressions.Last().LambdaExpression.Should().BeSameAs(expr); + spec2.IncludeExpressions.First().Type.Should().Be(IncludeType.Include); + spec2.IncludeExpressions.Last().Type.Should().Be(IncludeType.ThenIncludeAfterCollection); } [Fact] @@ -174,9 +199,11 @@ public void AddsIncludeThen_GivenMultipleIncludeThen() spec1.IncludeExpressions.Should().HaveCount(12); spec1.IncludeExpressions.OrderBy(x => x.Type).Take(4).Should().AllSatisfy(x => x.Type.Should().Be(IncludeType.Include)); - spec1.IncludeExpressions.OrderBy(x => x.Type).Skip(4).Should().AllSatisfy(x => x.Type.Should().Be(IncludeType.ThenInclude)); + spec1.IncludeExpressions.OrderBy(x => x.Type).Skip(4).Take(4).Should().AllSatisfy(x => x.Type.Should().Be(IncludeType.ThenIncludeAfterReference)); + spec1.IncludeExpressions.OrderBy(x => x.Type).Skip(8).Should().AllSatisfy(x => x.Type.Should().Be(IncludeType.ThenIncludeAfterCollection)); spec2.IncludeExpressions.Should().HaveCount(12); spec2.IncludeExpressions.OrderBy(x => x.Type).Take(4).Should().AllSatisfy(x => x.Type.Should().Be(IncludeType.Include)); - spec2.IncludeExpressions.OrderBy(x => x.Type).Skip(4).Should().AllSatisfy(x => x.Type.Should().Be(IncludeType.ThenInclude)); + spec2.IncludeExpressions.OrderBy(x => x.Type).Skip(4).Take(4).Should().AllSatisfy(x => x.Type.Should().Be(IncludeType.ThenIncludeAfterReference)); + spec2.IncludeExpressions.OrderBy(x => x.Type).Skip(8).Should().AllSatisfy(x => x.Type.Should().Be(IncludeType.ThenIncludeAfterCollection)); } } diff --git a/tests/QuerySpecification.Tests/SpecificationTests.cs b/tests/QuerySpecification.Tests/SpecificationTests.cs index 87684c3f..de2dcabe 100644 --- a/tests/QuerySpecification.Tests/SpecificationTests.cs +++ b/tests/QuerySpecification.Tests/SpecificationTests.cs @@ -5,7 +5,7 @@ namespace Tests; public class SpecificationTests { private static readonly SpecItem _emptySpecItem = new(); - public record Customer(int Id, string Name, Address Address); + public record Customer(int Id, string Name, Address Address, List
Addresses); public record Address(int Id, City City); public record City(int Id, string Name); @@ -112,7 +112,7 @@ public void LikeExpressions() } [Fact] - public void IncludeExpressions() + public void IncludeExpressions_GivenThenIncludeAfterReference() { Expression> include = x => x.Address; Expression> thenInclude = x => x.City; @@ -127,7 +127,26 @@ public void IncludeExpressions() expressions[0].LambdaExpression.Should().BeSameAs(include); expressions[0].Type.Should().Be(IncludeType.Include); expressions[1].LambdaExpression.Should().BeSameAs(thenInclude); - expressions[1].Type.Should().Be(IncludeType.ThenInclude); + expressions[1].Type.Should().Be(IncludeType.ThenIncludeAfterReference); + } + + [Fact] + public void IncludeExpressions_GivenThenIncludeAfterCollection() + { + Expression>> include = x => x.Addresses; + Expression> thenInclude = x => x.City; + var spec = new Specification(); + spec.Query + .Include(include) + .ThenInclude(thenInclude); + + var expressions = spec.IncludeExpressions.ToList(); + + expressions.Should().HaveCount(2); + expressions[0].LambdaExpression.Should().BeSameAs(include); + expressions[0].Type.Should().Be(IncludeType.Include); + expressions[1].LambdaExpression.Should().BeSameAs(thenInclude); + expressions[1].Type.Should().Be(IncludeType.ThenIncludeAfterCollection); } [Fact]