diff --git a/src/GraphQL.EntityFramework/Where/ArgumentReader.cs b/src/GraphQL.EntityFramework/Where/ArgumentReader.cs index 1b75e792..c6db1df6 100644 --- a/src/GraphQL.EntityFramework/Where/ArgumentReader.cs +++ b/src/GraphQL.EntityFramework/Where/ArgumentReader.cs @@ -9,17 +9,8 @@ public static bool TryReadWhere(IResolveFieldContext context, out IReadOnlyColle public static IReadOnlyCollection ReadOrderBy(IResolveFieldContext context) => ReadList(context, "orderBy"); - public static bool TryReadIds(IResolveFieldContext context, [NotNullWhen(true)] out string[]? idValues) + public static bool TryReadIds(IResolveFieldContext context, [NotNullWhen(true)] out object[]? idValues) { - static string ArgumentToExpression(object argument) => - argument switch - { - long l => l.ToString(CultureInfo.InvariantCulture), - int i => i.ToString(CultureInfo.InvariantCulture), - string s => s, - _ => throw new($"TryReadId got an 'id' argument of type '{argument.GetType().FullName}' which is not supported.") - }; - var arguments = context.Arguments; if (arguments == null) { @@ -43,7 +34,7 @@ static string ArgumentToExpression(object argument) => return false; } - var expressions = new List(); + var expressions = new List(); if (id.Source != ArgumentSource.FieldDefault) { @@ -53,7 +44,7 @@ static string ArgumentToExpression(object argument) => throw new("Null 'id' is not supported."); } - expressions.Add(ArgumentToExpression(idValue)); + expressions.Add(idValue); } if (ids.Source != ArgumentSource.FieldDefault) @@ -63,7 +54,7 @@ static string ArgumentToExpression(object argument) => throw new($"TryReadIds got an 'ids' argument of type '{ids.Value!.GetType().FullName}' which is not supported."); } - expressions.AddRange(objCollection.Select(ArgumentToExpression)); + expressions.AddRange(objCollection); } idValues = expressions.ToArray(); diff --git a/src/GraphQL.EntityFramework/Where/ExpressionBuilder.cs b/src/GraphQL.EntityFramework/Where/ExpressionBuilder.cs index 4f8d70dc..859f7bc3 100644 --- a/src/GraphQL.EntityFramework/Where/ExpressionBuilder.cs +++ b/src/GraphQL.EntityFramework/Where/ExpressionBuilder.cs @@ -109,7 +109,7 @@ static Expression MakePredicateBody(string path, Comparison comparison, string?[ /// /// Create a single predicate for the single set of supplied conditional arguments /// - public static Expression> BuildIdPredicate(string path, string[] values) + public static Expression> BuildIdPredicate(string path, object[] values) { var expressionBody = MakeIdPredicateBody(path, values); var param = PropertyCache.SourceParameter; @@ -117,11 +117,22 @@ public static Expression> BuildIdPredicate(string path, string[] v return Expression.Lambda>(expressionBody, param); } - static Expression MakeIdPredicateBody(string path, string[] values) + static Expression MakeIdPredicateBody(string path, object[] values) { try { - return GetExpression(path, Comparison.In, values); + var property = PropertyCache.GetProperty(path); + + if (property.PropertyType == typeof(string)) + { + return MakeStringListInComparison(values.Cast().ToArray(), property); + } + + var objects = TypeConverter.ConvertIdObjects(values, property.Info); + // Make the object values a constant expression + var constant = Expression.Constant(objects); + // Build and return the expression body + return Expression.Call(constant, property.SafeListContains, property.Left); } catch (Exception exception) { diff --git a/src/GraphQL.EntityFramework/Where/TypeConverter.cs b/src/GraphQL.EntityFramework/Where/TypeConverter.cs index e1b76ac4..6abbd389 100644 --- a/src/GraphQL.EntityFramework/Where/TypeConverter.cs +++ b/src/GraphQL.EntityFramework/Where/TypeConverter.cs @@ -162,8 +162,81 @@ static IList ConvertStringsToListInternal(IEnumerable values, Type type) throw new($"Could not convert strings to {type.FullName}."); } + public static IList ConvertIdObjects(IEnumerable values, MemberInfo member) + { + member.GetNullabilityInfo() + + if (!property.IsNullable() && hasNull) + { + throw new($"Null passed to In expression for non nullable type '{type.FullName}'."); + } + if (type == typeof(Guid)) + { + return values.Cast().ToList(); + } + + if (type == typeof(int)) + { + return values.Cast().ToList(); + } + + if (type == typeof(short)) + { + return values.Cast().ToList(); + } + + if (type == typeof(long)) + { + return values.Cast().ToList(); + } + + if (type == typeof(uint)) + { + return values.Cast().ToList(); + } + + if (type == typeof(ushort)) + { + return values.Cast().ToList(); + } + + if (type == typeof(ulong)) + { + return values.Cast().ToList(); + } + + if (type == typeof(DateTime)) + { + return values.Cast().ToList(); + } + + if (type == typeof(Time)) + { + return values.Cast