diff --git a/src/Microsoft.Data.Analysis/DataFrame.IO.cs b/src/Microsoft.Data.Analysis/DataFrame.IO.cs index 7fb56833da..3714fb18c2 100644 --- a/src/Microsoft.Data.Analysis/DataFrame.IO.cs +++ b/src/Microsoft.Data.Analysis/DataFrame.IO.cs @@ -20,51 +20,51 @@ public partial class DataFrame { private const int DefaultStreamReaderBufferSize = 1024; - private static Type GuessKind(int col, List read) + private static Type DefaultGuessTypeFunction(IEnumerable columnValues) { - Type res = typeof(string); + Type result = typeof(string); int nbline = 0; - foreach (var line in read) - { - if (col >= line.Length) - throw new FormatException(string.Format(Strings.LessColumnsThatExpected, nbline + 1)); - - string val = line[col]; - if (string.Equals(val, "null", StringComparison.OrdinalIgnoreCase)) + foreach (var columnValue in columnValues) + { + if (string.Equals(columnValue, "null", StringComparison.OrdinalIgnoreCase)) { continue; } - if (!string.IsNullOrEmpty(val)) + if (!string.IsNullOrEmpty(columnValue)) { - bool boolParse = bool.TryParse(val, out bool boolResult); - if (boolParse) + if (bool.TryParse(columnValue, out bool boolResult)) { - res = DetermineType(nbline == 0, typeof(bool), res); - ++nbline; - continue; + result = DetermineType(nbline == 0, typeof(bool), result); } - bool floatParse = float.TryParse(val, out float floatResult); - if (floatParse) + else if (float.TryParse(columnValue, out float floatResult)) { - res = DetermineType(nbline == 0, typeof(float), res); - ++nbline; - continue; + result = DetermineType(nbline == 0, typeof(float), result); } - bool dateParse = DateTime.TryParse(val, out DateTime dateResult); - if (dateParse) + else if (DateTime.TryParse(columnValue, out DateTime dateResult)) { - res = DetermineType(nbline == 0, typeof(DateTime), res); - ++nbline; - continue; + result = DetermineType(nbline == 0, typeof(DateTime), result); + } + else + { + result = DetermineType(nbline == 0, typeof(string), result); } - res = DetermineType(nbline == 0, typeof(string), res); - ++nbline; + nbline++; } } - return res; + + return result; + } + + private static Type GuessKind(int col, List<(long LineNumber, string[] Line)> read, Func, Type> guessTypeFunction) + { + IEnumerable lines = read.Select(line => col < line.Line.Length ? line.Line[col] : throw new FormatException(string.Format(Strings.LessColumnsThatExpected, line.LineNumber + 1))); + + return guessTypeFunction != null + ? guessTypeFunction.Invoke(lines) + : DefaultGuessTypeFunction(lines); } private static Type DetermineType(bool first, Type suggested, Type previous) @@ -357,7 +357,7 @@ private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringRe string[] columnNames = null, Type[] dataTypes = null, long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false, bool renameDuplicatedColumns = false, - CultureInfo cultureInfo = null) + CultureInfo cultureInfo = null, Func, Type> guessTypeFunction = null) { if (cultureInfo == null) { @@ -376,7 +376,7 @@ private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringRe TextFieldParser parser = new TextFieldParser(textReader); parser.SetDelimiters(separator.ToString()); - var linesForGuessType = new List(); + var linesForGuessType = new List<(long LineNumber, string[] Line)>(); long rowline = 0; int numberOfColumns = dataTypes?.Length ?? 0; @@ -420,7 +420,7 @@ private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringRe } else { - linesForGuessType.Add(fields); + linesForGuessType.Add((rowline, fields)); numberOfColumns = Math.Max(numberOfColumns, fields.Length); } } @@ -441,7 +441,7 @@ private static DataFrame ReadCsvLinesIntoDataFrame(WrappedStreamReaderOrStringRe // Guesses types or looks up dataTypes and adds columns. for (int i = 0; i < numberOfColumns; ++i) { - Type kind = dataTypes == null ? GuessKind(i, linesForGuessType) : dataTypes[i]; + Type kind = dataTypes == null ? GuessKind(i, linesForGuessType, guessTypeFunction) : dataTypes[i]; columns.Add(CreateColumn(kind, columnNames, i)); } } @@ -534,16 +534,17 @@ public TextReader GetTextReader() /// add one column with the row index /// If set to true, columns with repeated names are auto-renamed. /// culture info for formatting values + /// function used to guess the type of a column based on its values /// public static DataFrame LoadCsvFromString(string csvString, char separator = ',', bool header = true, string[] columnNames = null, Type[] dataTypes = null, long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false, bool renameDuplicatedColumns = false, - CultureInfo cultureInfo = null) + CultureInfo cultureInfo = null, Func, Type> guessTypeFunction = null) { WrappedStreamReaderOrStringReader wrappedStreamReaderOrStringReader = new WrappedStreamReaderOrStringReader(csvString); - return ReadCsvLinesIntoDataFrame(wrappedStreamReaderOrStringReader, separator, header, columnNames, dataTypes, numberOfRowsToRead, guessRows, addIndexColumn, renameDuplicatedColumns, cultureInfo); + return ReadCsvLinesIntoDataFrame(wrappedStreamReaderOrStringReader, separator, header, columnNames, dataTypes, numberOfRowsToRead, guessRows, addIndexColumn, renameDuplicatedColumns, cultureInfo, guessTypeFunction); } /// @@ -560,12 +561,14 @@ public static DataFrame LoadCsvFromString(string csvString, /// The character encoding. Defaults to UTF8 if not specified /// If set to true, columns with repeated names are auto-renamed. /// culture info for formatting values + /// function used to guess the type of a column based on its values /// public static DataFrame LoadCsv(Stream csvStream, char separator = ',', bool header = true, string[] columnNames = null, Type[] dataTypes = null, long numberOfRowsToRead = -1, int guessRows = 10, bool addIndexColumn = false, - Encoding encoding = null, bool renameDuplicatedColumns = false, CultureInfo cultureInfo = null) + Encoding encoding = null, bool renameDuplicatedColumns = false, CultureInfo cultureInfo = null, + Func, Type> guessTypeFunction = null) { if (!csvStream.CanSeek) { @@ -578,7 +581,7 @@ public static DataFrame LoadCsv(Stream csvStream, } WrappedStreamReaderOrStringReader wrappedStreamReaderOrStringReader = new WrappedStreamReaderOrStringReader(csvStream, encoding ?? Encoding.UTF8); - return ReadCsvLinesIntoDataFrame(wrappedStreamReaderOrStringReader, separator, header, columnNames, dataTypes, numberOfRowsToRead, guessRows, addIndexColumn, renameDuplicatedColumns, cultureInfo); + return ReadCsvLinesIntoDataFrame(wrappedStreamReaderOrStringReader, separator, header, columnNames, dataTypes, numberOfRowsToRead, guessRows, addIndexColumn, renameDuplicatedColumns, cultureInfo, guessTypeFunction); } /// diff --git a/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs b/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs index 5c1eb7aaec..46db3c177d 100644 --- a/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs +++ b/test/Microsoft.Data.Analysis.Tests/DataFrame.IOTests.cs @@ -1512,5 +1512,161 @@ public void TestSaveCsvWithTextQualifiers(string data, char separator, Type[] da DataFrame df2 = DataFrame.LoadCsv(csvStream, dataTypes: dataTypes, separator: separator); helper.VerifyLoadCsv(df2); } + + [Fact] + public void TestLoadCsvWithGuessTypes() + { + string csvString = """ + Name,Age,Description,UpdatedOn,Weight,LargeNumber,NullColumn + Paul,34,"Paul lives in Vermont, VA.",2024-01-23T05:06:15.028,195.48,123,null + Victor,29,"Victor: Funny guy",2023-11-04T17:27:59.167,175.3,2147483648,null + Clara,,,,,,null + Ellie,null,null,null,null,null,null + Maria,31,,2024-03-31T07:20:47.250,126,456,null + """; + + var defaultResultVerifyingHelper = new LoadCsvVerifyingHelper( + 7, + 5, + new string[] { "Name", "Age", "Description", "UpdatedOn", "Weight", "LargeNumber", "NullColumn" }, + new Type[] { typeof(string), typeof(float), typeof(string), typeof(DateTime), typeof(float), typeof(float), typeof(string) }, + new object[][] + { + new object[] { "Paul", 34f, "Paul lives in Vermont, VA.", DateTime.Parse("2024-01-23T05:06:15.028"), 195.48f, 123f, null }, + new object[] { "Victor", 29f, "Victor: Funny guy", DateTime.Parse("2023-11-04T17:27:59.167"), 175.3f, 2147483648f, null }, + new object[] { "Clara", null, "", null, null, null, null }, + new object[] { "Ellie", null, null, null, null, null, null }, + new object[] { "Maria", 31f, "", DateTime.Parse("2024-03-31T07:20:47.250"), 126f, 456f, null } + } + ); + + var customResultVerifyingHelper = new LoadCsvVerifyingHelper( + 7, + 5, + new string[] { "Name", "Age", "Description", "UpdatedOn", "Weight", "LargeNumber", "NullColumn" }, + new Type[] { typeof(string), typeof(int), typeof(string), typeof(DateTime), typeof(double), typeof(long), typeof(string) }, + new object[][] + { + new object[] { "Paul", 34, "Paul lives in Vermont, VA.", DateTime.Parse("2024-01-23T05:06:15.028"), 195.48, 123L, null }, + new object[] { "Victor", 29, "Victor: Funny guy", DateTime.Parse("2023-11-04T17:27:59.167"), 175.3, 2147483648L, null }, + new object[] { "Clara", null, "", null, null, null, null }, + new object[] { "Ellie", null, null, null, null, null, null }, + new object[] { "Maria", 31, "", DateTime.Parse("2024-03-31T07:20:47.250"), 126.0, 456L, null } + } + ); + + Type CustomGuessTypeFunction(IEnumerable columnValues) + { + List types = [ + typeof(bool), + typeof(int), + typeof(long), + typeof(double), + typeof(DateTime) + ]; + + bool allNullData = true; + + HashSet possibleTypes = new HashSet(types); + + foreach (var item in columnValues) + { + if (string.IsNullOrEmpty(item) || string.Equals(item, "null", StringComparison.OrdinalIgnoreCase)) + { + continue; + } + else + { + allNullData = false; + } + + List typesToRemove = new List(possibleTypes.Count); + + foreach (var type in possibleTypes) + { + if (type == typeof(bool)) + { + if (!bool.TryParse(item, out bool result)) + { + typesToRemove.Add(type); + } + } + else if (type == typeof(int)) + { + if (!int.TryParse(item, out int result)) + { + typesToRemove.Add(type); + } + } + else if (type == typeof(long)) + { + if (!long.TryParse(item, out long result)) + { + typesToRemove.Add(type); + } + } + else if (type == typeof(double)) + { + if (!double.TryParse(item, out double result)) + { + typesToRemove.Add(type); + } + } + else if (type == typeof(DateTime)) + { + if (!DateTime.TryParse(item, out DateTime result)) + { + typesToRemove.Add(type); + } + } + } + + foreach (var type in typesToRemove) + { + possibleTypes.Remove(type); + } + } + + if (allNullData) + { + // Could not determine type since all data was null + return typeof(string); + } + + foreach (var type in types) + { + if (possibleTypes.Contains(type)) + { + return type; + } + } + + return typeof(string); + } + + DataFrame defaultDf = DataFrame.LoadCsvFromString(csvString); + + defaultResultVerifyingHelper.VerifyLoadCsv(defaultDf); + + DataFrame customDf = DataFrame.LoadCsvFromString(csvString, guessTypeFunction: CustomGuessTypeFunction); + + customResultVerifyingHelper.VerifyLoadCsv(customDf); + } + + [Fact] + public void TestLoadCsvWithMismatchedNumberOfColumnsInDataRows() + { + // Victor line is missing the "LargeNumber" row + string csvString = """ + Name,Age,Description,UpdatedOn,Weight,LargeNumber + Paul,34,"Paul lives in Vermont, VA.",2024-01-23T05:06:15.028,195.48,123 + Victor,29,"Victor: Funny guy",2023-11-04T17:27:59.167,175.3 + Clara,,,,, + Ellie,null,null,null,null,null + Maria,31,,2024-03-31T07:20:47.250,126,456 + """; + + Assert.Throws(() => DataFrame.LoadCsvFromString(csvString)); + } } }