Skip to content

Commit dd77756

Browse files
committed
[fix](function) Result type of regexp_extract_all should be Array
1 parent 85b7bee commit dd77756

File tree

9 files changed

+277
-174
lines changed

9 files changed

+277
-174
lines changed

be/src/vec/functions/function_regexp.cpp

Lines changed: 109 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include <re2/stringpiece.h>
2121
#include <stddef.h>
2222

23-
#include <algorithm>
2423
#include <memory>
2524
#include <string>
2625
#include <string_view>
@@ -32,18 +31,22 @@
3231
#include "udf/udf.h"
3332
#include "vec/aggregate_functions/aggregate_function.h"
3433
#include "vec/columns/column.h"
34+
#include "vec/columns/column_array.h"
3535
#include "vec/columns/column_const.h"
3636
#include "vec/columns/column_nullable.h"
3737
#include "vec/columns/column_string.h"
3838
#include "vec/columns/column_vector.h"
3939
#include "vec/columns/columns_number.h"
40+
#include "vec/common/assert_cast.h"
4041
#include "vec/common/string_ref.h"
4142
#include "vec/core/block.h"
4243
#include "vec/core/column_numbers.h"
4344
#include "vec/core/column_with_type_and_name.h"
4445
#include "vec/core/types.h"
4546
#include "vec/data_types/data_type.h"
47+
#include "vec/data_types/data_type_array.h"
4648
#include "vec/data_types/data_type_nullable.h"
49+
#include "vec/data_types/data_type_number.h"
4750
#include "vec/data_types/data_type_string.h"
4851
#include "vec/functions/function.h"
4952
#include "vec/functions/simple_function_factory.h"
@@ -286,100 +289,106 @@ struct RegexpExtractImpl {
286289
struct RegexpExtractAllImpl {
287290
static constexpr auto name = "regexp_extract_all";
288291

289-
size_t get_number_of_arguments() const { return 2; }
290-
291-
static void execute_impl(FunctionContext* context, ColumnPtr argument_columns[],
292-
size_t input_rows_count, ColumnString::Chars& result_data,
293-
ColumnString::Offsets& result_offset, NullMap& null_map) {
292+
template <bool first_const, bool second_const, bool third_const>
293+
static void execute_impl(FunctionContext* context, const ColumnPtr* argument_columns,
294+
size_t input_rows_count, ColumnArray::MutablePtr& result_column) {
294295
const auto* str_col = check_and_get_column<ColumnString>(argument_columns[0].get());
295296
const auto* pattern_col = check_and_get_column<ColumnString>(argument_columns[1].get());
296-
for (int i = 0; i < input_rows_count; ++i) {
297-
if (null_map[i]) {
298-
StringOP::push_null_string(i, result_data, result_offset, null_map);
299-
continue;
297+
const auto* group_idx_col = check_and_get_column<ColumnInt32>(argument_columns[2].get());
298+
299+
auto& result_array_col = assert_cast<ColumnArray&>(*result_column);
300+
if constexpr (second_const && third_const) {
301+
auto* re = reinterpret_cast<re2::RE2*>(
302+
context->get_function_state(FunctionContext::THREAD_LOCAL));
303+
if (re != nullptr) {
304+
auto group_idx = group_idx_col->get_int(0);
305+
306+
if (re->NumberOfCapturingGroups() < group_idx) {
307+
result_array_col.insert_many_defaults(input_rows_count);
308+
return;
309+
}
300310
}
301-
_execute_inner_loop<false>(context, str_col, pattern_col, result_data, result_offset,
302-
null_map, i);
303311
}
304-
}
305312

306-
static void execute_impl_const_args(FunctionContext* context, ColumnPtr argument_columns[],
307-
size_t input_rows_count, ColumnString::Chars& result_data,
308-
ColumnString::Offsets& result_offset, NullMap& null_map) {
309-
const auto* str_col = check_and_get_column<ColumnString>(argument_columns[0].get());
310-
const auto* pattern_col = check_and_get_column<ColumnString>(argument_columns[1].get());
313+
auto& column_nullable = assert_cast<ColumnNullable&>(result_array_col.get_data());
314+
auto& null_map = column_nullable.get_null_map_data();
315+
auto& column_string = assert_cast<ColumnString&>(column_nullable.get_nested_column());
316+
auto& offsets = result_array_col.get_offsets();
317+
311318
for (int i = 0; i < input_rows_count; ++i) {
312-
if (null_map[i]) {
313-
StringOP::push_null_string(i, result_data, result_offset, null_map);
314-
continue;
315-
}
316-
_execute_inner_loop<true>(context, str_col, pattern_col, result_data, result_offset,
317-
null_map, i);
319+
_execute_inner_loop<first_const, second_const, third_const>(
320+
context, str_col, pattern_col, group_idx_col, i, column_string, null_map,
321+
offsets);
318322
}
319323
}
320-
template <bool Const>
324+
325+
template <bool first_const, bool second_const, bool third_const>
321326
static void _execute_inner_loop(FunctionContext* context, const ColumnString* str_col,
322327
const ColumnString* pattern_col,
323-
ColumnString::Chars& result_data,
324-
ColumnString::Offsets& result_offset, NullMap& null_map,
325-
const size_t index_now) {
326-
re2::RE2* re = reinterpret_cast<re2::RE2*>(
328+
const ColumnInt32* group_idx_col, const size_t index_now,
329+
ColumnString& result_string_column, NullMap& null_map,
330+
ColumnArray::Offsets64& result_offsets) {
331+
auto* re = reinterpret_cast<re2::RE2*>(
327332
context->get_function_state(FunctionContext::THREAD_LOCAL));
328333
std::unique_ptr<re2::RE2> scoped_re;
334+
329335
if (re == nullptr) {
330336
std::string error_str;
331-
const auto& pattern = pattern_col->get_data_at(index_check_const(index_now, Const));
337+
const auto& pattern =
338+
pattern_col->get_data_at(index_check_const(index_now, second_const));
332339
bool st = StringFunctions::compile_regex(pattern, &error_str, StringRef(), scoped_re);
333340
if (!st) {
334341
context->add_warning(error_str.c_str());
335-
StringOP::push_null_string(index_now, result_data, result_offset, null_map);
342+
null_map.push_back(1);
343+
result_string_column.insert_default();
344+
result_offsets.emplace_back(result_offsets.back() + 1);
336345
return;
337346
}
338347
re = scoped_re.get();
339348
}
340-
if (re->NumberOfCapturingGroups() == 0) {
341-
StringOP::push_empty_string(index_now, result_data, result_offset);
349+
350+
auto group_idx = group_idx_col->get_element(index_check_const(index_now, third_const));
351+
352+
if (re->NumberOfCapturingGroups() < group_idx || group_idx < 0) {
353+
result_offsets.emplace_back(result_offsets.back());
342354
return;
343355
}
344-
const auto& str = str_col->get_data_at(index_now);
345-
int max_matches = 1 + re->NumberOfCapturingGroups();
356+
357+
const auto& str = str_col->get_data_at(index_check_const(index_now, first_const));
358+
int max_matches = 1 + group_idx;
346359
std::vector<re2::StringPiece> res_matches;
347360
size_t pos = 0;
348361
while (pos < str.size) {
349-
auto str_pos = str.data + pos;
362+
const auto* str_pos = str.data + pos;
350363
auto str_size = str.size - pos;
351364
re2::StringPiece str_sp = re2::StringPiece(str_pos, str_size);
352365
std::vector<re2::StringPiece> matches(max_matches);
353-
bool success =
354-
re->Match(str_sp, 0, str_size, re2::RE2::UNANCHORED, &matches[0], max_matches);
366+
bool success = re->Match(str_sp, 0, str_size, re2::RE2::UNANCHORED, matches.data(),
367+
max_matches);
355368
if (!success) {
356-
StringOP::push_empty_string(index_now, result_data, result_offset);
357369
break;
358370
}
371+
359372
if (matches[0].empty()) {
360-
StringOP::push_empty_string(index_now, result_data, result_offset);
361373
pos += 1;
362374
continue;
363375
}
364-
res_matches.push_back(matches[1]);
376+
377+
res_matches.push_back(matches[group_idx]);
365378
auto offset = std::string(str_pos, str_size).find(std::string(matches[0].as_string()));
366379
pos += offset + matches[0].size();
367380
}
368381

369382
if (res_matches.empty()) {
370-
StringOP::push_empty_string(index_now, result_data, result_offset);
383+
result_offsets.emplace_back(result_offsets.back());
371384
return;
372385
}
373386

374-
std::string res = "[";
375-
for (int j = 0; j < res_matches.size(); ++j) {
376-
res += "'" + res_matches[j].as_string() + "'";
377-
if (j < res_matches.size() - 1) {
378-
res += ",";
379-
}
387+
for (auto res_match : res_matches) {
388+
result_string_column.insert_data(res_match.data(), res_match.size());
389+
null_map.push_back(0);
380390
}
381-
res += "]";
382-
StringOP::push_value_string(std::string_view(res), index_now, result_data, result_offset);
391+
result_offsets.emplace_back(result_offsets.back() + res_matches.size());
383392
}
384393
};
385394

@@ -395,13 +404,21 @@ class FunctionRegexpFunctionality : public IFunction {
395404

396405
size_t get_number_of_arguments() const override {
397406
if constexpr (std::is_same_v<Impl, RegexpExtractAllImpl>) {
398-
return 2;
407+
return 0;
408+
} else {
409+
return 3;
399410
}
400-
return 3;
401411
}
402412

413+
bool is_variadic() const override { return std::is_same_v<Impl, RegexpExtractAllImpl>; }
414+
403415
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
404-
return make_nullable(std::make_shared<DataTypeString>());
416+
if constexpr (std::is_same_v<Impl, RegexpExtractAllImpl>) {
417+
return std::make_shared<DataTypeArray>(
418+
std::make_shared<DataTypeNullable>(std::make_shared<DataTypeString>()));
419+
} else {
420+
return make_nullable(std::make_shared<DataTypeString>());
421+
}
405422
}
406423

407424
Status open(FunctionContext* context, FunctionContext::FunctionStateScope scope) override {
@@ -433,12 +450,6 @@ class FunctionRegexpFunctionality : public IFunction {
433450
uint32_t result, size_t input_rows_count) const override {
434451
size_t argument_size = arguments.size();
435452

436-
auto result_null_map = ColumnUInt8::create(input_rows_count, 0);
437-
auto result_data_column = ColumnString::create();
438-
auto& result_data = result_data_column->get_chars();
439-
auto& result_offset = result_data_column->get_offsets();
440-
result_offset.resize(input_rows_count);
441-
442453
bool col_const[3];
443454
ColumnPtr argument_columns[3];
444455
for (int i = 0; i < argument_size; ++i) {
@@ -449,23 +460,45 @@ class FunctionRegexpFunctionality : public IFunction {
449460
.convert_to_full_column()
450461
: block.get_by_position(arguments[0]).column;
451462
if constexpr (std::is_same_v<Impl, RegexpExtractAllImpl>) {
452-
default_preprocess_parameter_columns(argument_columns, col_const, {1}, block,
453-
arguments);
463+
DCHECK_LE(argument_size, 3);
464+
DCHECK_GE(argument_size, 2);
465+
auto actual_arguments = arguments;
466+
if (argument_size == 2) {
467+
// Default value of is 1.
468+
auto third_arg_column =
469+
ColumnConst::create(ColumnInt32::create(1, 1), input_rows_count);
470+
auto third_arg_idx = block.columns();
471+
block.insert({std::move(third_arg_column), std::make_shared<DataTypeInt32>(),
472+
"group_idx"});
473+
col_const[2] = true;
474+
actual_arguments.emplace_back(third_arg_idx);
475+
}
476+
477+
default_preprocess_parameter_columns(argument_columns, col_const, {1, 2}, block,
478+
actual_arguments);
479+
480+
auto result_column = ColumnArray::create(
481+
ColumnNullable::create(ColumnString::create(), ColumnUInt8::create()),
482+
ColumnArray::ColumnOffsets::create());
483+
484+
std::visit(
485+
[&](auto first_const, auto second_const, auto third_const) {
486+
Impl::template execute_impl<first_const, second_const, third_const>(
487+
context, argument_columns, input_rows_count, result_column);
488+
},
489+
make_bool_variant(col_const[0]), make_bool_variant(col_const[1]),
490+
make_bool_variant(col_const[2]));
491+
block.get_by_position(result).column = std::move(result_column);
492+
return Status::OK();
454493
} else {
455494
default_preprocess_parameter_columns(argument_columns, col_const, {1, 2}, block,
456495
arguments);
457-
}
458496

459-
if constexpr (std::is_same_v<Impl, RegexpExtractAllImpl>) {
460-
if (col_const[1]) {
461-
Impl::execute_impl_const_args(context, argument_columns, input_rows_count,
462-
result_data, result_offset,
463-
result_null_map->get_data());
464-
} else {
465-
Impl::execute_impl(context, argument_columns, input_rows_count, result_data,
466-
result_offset, result_null_map->get_data());
467-
}
468-
} else {
497+
auto result_null_map = ColumnUInt8::create(input_rows_count, 0);
498+
auto result_data_column = ColumnString::create();
499+
auto& result_data = result_data_column->get_chars();
500+
auto& result_offset = result_data_column->get_offsets();
501+
result_offset.resize(input_rows_count);
469502
if (col_const[1] && col_const[2]) {
470503
Impl::execute_impl_const_args(context, argument_columns, input_rows_count,
471504
result_data, result_offset,
@@ -474,11 +507,11 @@ class FunctionRegexpFunctionality : public IFunction {
474507
Impl::execute_impl(context, argument_columns, input_rows_count, result_data,
475508
result_offset, result_null_map->get_data());
476509
}
477-
}
478510

479-
block.get_by_position(result).column =
480-
ColumnNullable::create(std::move(result_data_column), std::move(result_null_map));
481-
return Status::OK();
511+
block.get_by_position(result).column = ColumnNullable::create(
512+
std::move(result_data_column), std::move(result_null_map));
513+
return Status::OK();
514+
}
482515
}
483516

484517
Status close(FunctionContext* context, FunctionContext::FunctionStateScope scope) override {

be/test/vec/function/function_like_test.cpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
#include <gtest/gtest.h>
19+
1820
#include <cstdint>
1921
#include <string>
2022

2123
#include "function_test_util.h"
2224
#include "gtest/gtest_pred_impl.h"
2325
#include "testutil/any_type.h"
26+
#include "vec/core/field.h"
2427
#include "vec/core/types.h"
28+
#include "vec/data_types/data_type_array.h"
2529
#include "vec/data_types/data_type_nullable.h"
2630
#include "vec/data_types/data_type_number.h"
2731
#include "vec/data_types/data_type_string.h"
@@ -196,22 +200,21 @@ TEST(FunctionLikeTest, regexp_extract_or_null) {
196200

197201
TEST(FunctionLikeTest, regexp_extract_all) {
198202
std::string func_name = "regexp_extract_all";
199-
200203
DataSet data_set = {
201204
{{std::string("x=a3&x=18abc&x=2&y=3&x=4&x=17bcd"), std::string("x=([0-9]+)([a-z]+)")},
202-
std::string("['18','17']")},
205+
TestArray {(std::string("18")), (std::string("17"))}},
203206
{{std::string("x=a3&x=18abc&x=2&y=3&x=4"), std::string("^x=([a-z]+)([0-9]+)")},
204-
std::string("['a']")},
207+
TestArray {(string("a"))}},
205208
{{std::string("http://a.m.baidu.com/i41915173660.htm"), std::string("i([0-9]+)")},
206-
std::string("['41915173660']")},
209+
TestArray {(std::string("41915173660"))}},
207210
{{std::string("http://a.m.baidu.com/i41915i73660.htm"), std::string("i([0-9]+)")},
208-
std::string("['41915','73660']")},
209-
210-
{{std::string("hitdecisiondlist"), std::string("(i)(.*?)(e)")}, std::string("['i']")},
211+
TestArray {(std::string("41915")), (std::string("73660"))}},
212+
{{std::string("hitdecisiondlist"), std::string("(i)(.*?)(e)")},
213+
TestArray {(std::string("i"))}},
211214
{{std::string("hitdecisioendlist"), std::string("(i)(.*?)(e)")},
212-
std::string("['i','i']")},
215+
TestArray {(std::string("i")), (std::string("i"))}},
213216
{{std::string("hitdecisioendliset"), std::string("(i)(.*?)(e)")},
214-
std::string("['i','i','i']")},
217+
TestArray {(std::string("i")), (std::string("i")), (std::string("i"))}},
215218
// null
216219
{{std::string("abc"), Null()}, Null()},
217220
{{Null(), std::string("i([0-9]+)")}, Null()}};
@@ -220,8 +223,8 @@ TEST(FunctionLikeTest, regexp_extract_all) {
220223
InputTypeSet const_pattern_input_types = {TypeIndex::String, Consted {TypeIndex::String}};
221224
for (const auto& line : data_set) {
222225
DataSet const_pattern_dataset = {line};
223-
static_cast<void>(check_function<DataTypeString, true>(func_name, const_pattern_input_types,
224-
const_pattern_dataset));
226+
static_cast<void>(check_function<DataTypeArray, true, -1, -1, DataTypeString>(
227+
func_name, const_pattern_input_types, const_pattern_dataset));
225228
}
226229
}
227230

0 commit comments

Comments
 (0)