diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs index 206fd8b308c..7ce6ff1e607 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstSimplifier.cs @@ -14,6 +14,7 @@ */ using System; +using System.Collections.Generic; using System.Linq; using MongoDB.Bson; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; @@ -454,6 +455,79 @@ public override AstNode VisitMapExpression(AstMapExpression node) } } + // { $map : { input : { $map : { input : , as : "inner", in : { A : , B : , ... } } }, as: "outer", in : { F : '$$outer.A', G : "$$outer.B", ... } } } + // => { $map : { input : , as: "inner", in : { F : , G : , ... } } } + if (node.Input is AstMapExpression innerMapExpression && + node.As is var outerVar && + node.In is AstComputedDocumentExpression outerComputedDocumentExpression && + innerMapExpression.Input is var innerInput && + innerMapExpression.As is var innerVar && + innerMapExpression.In is AstComputedDocumentExpression innerComputedDocumentExpression && + outerComputedDocumentExpression.Fields.All(outerField => + outerField.Value is AstGetFieldExpression outerGetFieldExpression && + outerGetFieldExpression.Input == outerVar && + outerGetFieldExpression.FieldName is AstConstantExpression { Value : BsonString { Value : var matchingFieldName } } && + innerComputedDocumentExpression.Fields.Any(innerField => innerField.Path == matchingFieldName))) + { + var rewrittenOuterFields = new List(); + foreach (var outerField in outerComputedDocumentExpression.Fields) + { + var outerGetFieldExpression = (AstGetFieldExpression)outerField.Value; + var matchingFieldName = ((AstConstantExpression)outerGetFieldExpression.FieldName).Value.AsString; + var matchingInnerField = innerComputedDocumentExpression.Fields.Single(innerField => innerField.Path == matchingFieldName); + var rewrittenOuterField = AstExpression.ComputedField(outerField.Path, matchingInnerField.Value); + rewrittenOuterFields.Add(rewrittenOuterField); + } + + var simplified = AstExpression.Map( + input: innerInput, + @as: innerVar, + @in: AstExpression.ComputedDocument(rewrittenOuterFields)); + + return Visit(simplified); + } + + // { $map : { input : [{ A : , B : , ... }, { A : , B : , ... }, ...], as : "item", in: { F : "$$item.A", G : "$$item.B", ... } } } + // => [{ F : , G : ", ... }, { F : , G : , ... }, ...] + if (node.Input is AstComputedArrayExpression inputComputedArray && + inputComputedArray.Items.Count >= 1 && + inputComputedArray.Items[0] is AstComputedDocumentExpression firstComputedDocument && + firstComputedDocument.Fields.Select(inputField => inputField.Path).ToArray() is var inputFieldNames && + inputComputedArray.Items.Skip(1).All(otherItem => + otherItem is AstComputedDocumentExpression otherComputedDocument && + otherComputedDocument.Fields.Select(otherField => otherField.Path).SequenceEqual(inputFieldNames)) && + node.As is var itemVar && + node.In is AstComputedDocumentExpression mappedDocument && + mappedDocument.Fields.All(mappedField => + mappedField.Value is AstGetFieldExpression mappedGetField && + mappedGetField.Input == itemVar && + mappedGetField.FieldName is AstConstantExpression { Value : BsonString { Value : var matchingFieldName } } && + inputFieldNames.Contains(matchingFieldName))) + { + var rewrittenItems = new List(); + foreach (var inputItem in inputComputedArray.Items) + { + var inputDocument = (AstComputedDocumentExpression)inputItem; + + var rewrittenFields = new List(); + foreach (var mappedField in mappedDocument.Fields) + { + var mappedGetField = (AstGetFieldExpression)mappedField.Value; + var matchingFieldName = ((AstConstantExpression)mappedGetField.FieldName).Value.AsString; + var matchingInputField = inputDocument.Fields.Single(inputField => inputField.Path == matchingFieldName); + var rewrittenField = AstExpression.ComputedField(mappedField.Path, matchingInputField.Value); + rewrittenFields.Add(rewrittenField); + } + + var rewrittenItem = AstExpression.ComputedDocument(rewrittenFields); + rewrittenItems.Add(rewrittenItem); + } + + var simplified = AstExpression.ComputedArray(rewrittenItems); + + return Visit(simplified); + } + return base.VisitMapExpression(node); static AstExpression UltimateGetFieldInput(AstGetFieldExpression getField) @@ -574,7 +648,32 @@ arg is AstBinaryExpression argBinaryExpression && return AstExpression.Binary(oppositeComparisonOperator, argBinaryExpression.Arg1, argBinaryExpression.Arg2); } + // { $arrayToObject : [[{ k : 'A', v : }, { k : 'B', v : }, ...]] } => { A : , B : , ... } + if (node.Operator == AstUnaryOperator.ArrayToObject && + arg is AstComputedArrayExpression computedArrayExpression && + computedArrayExpression.Items.All( + item => + item is AstComputedDocumentExpression computedDocumentExpression && + computedDocumentExpression.Fields.Count == 2 && + computedDocumentExpression.Fields[0].Path == "k" && + computedDocumentExpression.Fields[1].Path == "v" && + computedDocumentExpression.Fields[0].Value is AstConstantExpression { Value : { IsString : true } })) + { + var computedFields = computedArrayExpression.Items.Select(KeyValuePairDocumentToComputedField); + return AstExpression.ComputedDocument(computedFields); + } + return node.Update(arg); + + static AstComputedField KeyValuePairDocumentToComputedField(AstExpression expression) + { + // caller has verified that expression is of the form: { k : , v : } + var keyValuePairDocumentExpression = (AstComputedDocumentExpression)expression; + var keyConstantExpression = (AstConstantExpression)keyValuePairDocumentExpression.Fields[0].Value; + var valueExpression = keyValuePairDocumentExpression.Fields[1].Value; + + return AstExpression.ComputedField(keyConstantExpression.Value.AsString, valueExpression); + } } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs new file mode 100644 index 00000000000..ca0ccc27664 --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Reflection/DictionaryConstructor.cs @@ -0,0 +1,37 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System.Collections.Generic; +using System.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Reflection +{ + internal static class DictionaryConstructor + { + public static bool IsWithIEnumerableKeyValuePairConstructor(ConstructorInfo constructor) + { + var declaringType = constructor.DeclaringType; + var parameters = constructor.GetParameters(); + return + declaringType.IsConstructedGenericType && + declaringType.GetGenericTypeDefinition() == typeof(Dictionary<,>) && + parameters.Length == 1 && + parameters[0].ParameterType.ImplementsIEnumerable(out var enumerableType) && + enumerableType.IsConstructedGenericType && + enumerableType.GetGenericTypeDefinition() == typeof(KeyValuePair<,>); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs new file mode 100644 index 00000000000..aee174ac38d --- /dev/null +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslator.cs @@ -0,0 +1,98 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Options; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; +using MongoDB.Driver.Linq.Linq3Implementation.Reflection; + +namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators +{ + internal static class NewDictionaryExpressionToAggregationExpressionTranslator + { + public static bool CanTranslate(NewExpression expression) + => DictionaryConstructor.IsWithIEnumerableKeyValuePairConstructor(expression.Constructor); + + public static TranslatedExpression Translate(TranslationContext context, NewExpression expression) + { + var arguments = expression.Arguments; + + var collectionExpression = arguments[0]; + var collectionTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, collectionExpression); + var itemSerializer = ArraySerializerHelper.GetItemSerializer(collectionTranslation.Serializer); + + IBsonSerializer keySerializer; + IBsonSerializer valueSerializer; + AstExpression collectionTranslationAst; + + if (itemSerializer is IBsonDocumentSerializer itemDocumentSerializer) + { + if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Key", out var keyMemberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Key member"); + } + keySerializer = keyMemberSerializationInfo.Serializer; + + if (!itemDocumentSerializer.TryGetMemberSerializationInfo("Value", out var valueMemberSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {itemSerializer.GetType()} does not have a Value member"); + } + valueSerializer = valueMemberSerializationInfo.Serializer; + + if (keyMemberSerializationInfo.ElementName == "k" && valueMemberSerializationInfo.ElementName == "v") + { + collectionTranslationAst = collectionTranslation.Ast; + } + else + { + var pairVar = AstExpression.Var("pair"); + var computedDocumentAst = AstExpression.ComputedDocument([ + AstExpression.ComputedField("k", AstExpression.GetField(pairVar, keyMemberSerializationInfo.ElementName)), + AstExpression.ComputedField("v", AstExpression.GetField(pairVar, valueMemberSerializationInfo.ElementName)) + ]); + + collectionTranslationAst = AstExpression.Map(collectionTranslation.Ast, pairVar, computedDocumentAst); + } + } + else + { + throw new ExpressionNotSupportedException(expression); + } + + if (keySerializer is not IRepresentationConfigurable { Representation: BsonType.String }) + { + throw new ExpressionNotSupportedException(expression, because: "key does not serialize as a string"); + } + + var ast = AstExpression.Unary(AstUnaryOperator.ArrayToObject, collectionTranslationAst); + var resultSerializer = CreateResultSerializer(keySerializer, valueSerializer); + return new TranslatedExpression(expression, ast, resultSerializer); + } + + private static IBsonSerializer CreateResultSerializer(IBsonSerializer keySerializer, IBsonSerializer valueSerializer) + { + var dictionaryType = typeof(Dictionary<,>).MakeGenericType(keySerializer.ValueType, valueSerializer.ValueType); + var serializerType = typeof(DictionaryInterfaceImplementerSerializer<,,>).MakeGenericType(dictionaryType, keySerializer.ValueType, valueSerializer.ValueType); + + return (IBsonSerializer)Activator.CreateInstance(serializerType, DictionaryRepresentation.Document, keySerializer, valueSerializer); + } + } +} diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs index b54f431e516..af521d05658 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs @@ -50,6 +50,10 @@ public static TranslatedExpression Translate(TranslationContext context, NewExpr { return NewKeyValuePairExpressionToAggregationExpressionTranslator.Translate(context, expression); } + if (NewDictionaryExpressionToAggregationExpressionTranslator.CanTranslate(expression)) + { + return NewDictionaryExpressionToAggregationExpressionTranslator.Translate(context, expression); + } return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty()); } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs new file mode 100644 index 00000000000..59931e8f199 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewDictionaryExpressionToAggregationExpressionTranslatorTests.cs @@ -0,0 +1,155 @@ +/* Copyright 2010-present MongoDB Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#if NET6_0_OR_GREATER || NETCOREAPP3_1_OR_GREATER + +using System; +using System.Collections.Generic; +using System.Linq; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Driver.Linq; +using MongoDB.Driver.TestHelpers; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators +{ + public class NewDictionaryExpressionToAggregationExpressionTranslatorTests : LinqIntegrationTest + { + public NewDictionaryExpressionToAggregationExpressionTranslatorTests(ClassFixture fixture) + : base(fixture) + { + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_should_translate() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { new KeyValuePair("A", d.A), new KeyValuePair("B", d.B) })); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { A : '$A', B: '$B' }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ ["A"] = "a", ["B"] = "b" }); + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_Create_should_translate() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { KeyValuePair.Create("A", d.A), KeyValuePair.Create("B", d.B) })); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { A : '$A', B: '$B' }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ ["A"] = "a", ["B"] = "b" }); + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_should_translate_Guid_as_string_key() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { new KeyValuePair(d.GuidAsString, d.A) })); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { $arrayToObject : [[{ k : '$GuidAsString', v : '$A' }]] }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ [Guid.Parse("3E9AE467-9705-4C17-9655-EE7730BCC2EE")] = "a" }); + } + + + [Fact] + public void NewDictionary_with_KeyValuePairs_should_translate_dynamic_array() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + d.Items.Select(i => new KeyValuePair(i.P, i.W)))); + + var stages = Translate(collection, queryable); + + AssertStages(stages, "{ $project : { _v : { $arrayToObject : { $map: { input: '$Items', as: 'i', in: { k: '$$i.P', v: '$$i.W' } } } }, _id : 0 } }"); + + var result = queryable.Single(); + result.Should().Equal(new Dictionary{ ["x"] = "y" }); + } + + [Fact] + public void NewDictionary_with_KeyValuePairs_throws_on_non_string_key() + { + var collection = Fixture.Collection; + + var queryable = collection.AsQueryable() + .Select(d => new Dictionary( + new[] { new KeyValuePair(42, d.A) })); + + var exception = Record.Exception(() => queryable.ToList()); + + exception.Should().NotBeNull(); + exception.Should().BeOfType(); + } + + public class C + { + public string A { get; set; } + + public string B { get; set; } + + [BsonRepresentation(BsonType.String)] + public Guid GuidAsString { get; set; } + + public Item[] Items { get; set; } + } + + public class Item + { + public string P { get; set; } + + public string W { get; set; } + } + + public sealed class ClassFixture : MongoCollectionFixture + { + protected override IEnumerable InitialData => + [ + new C + { + A = "a", + B = "b", + GuidAsString = Guid.Parse("3E9AE467-9705-4C17-9655-EE7730BCC2EE"), + Items = [ new Item { P = "x", W = "y" } ] + }, + ]; + } + } +} +#endif