Skip to content

[fix](nereids) fix fold constant return wrong scale of datetime type #50142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.GlobalVariable;
import org.apache.doris.thrift.TUniqueId;
Expand Down Expand Up @@ -539,6 +540,7 @@ public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, Expre

@Override
public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) {
CaseWhen originCaseWhen = caseWhen;
caseWhen = rewriteChildren(caseWhen, context);
Expression newDefault = null;
boolean foundNewDefault = false;
Expand All @@ -564,29 +566,35 @@ public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext cont
defaultResult = newDefault;
}
if (whenClauses.isEmpty()) {
return defaultResult == null ? new NullLiteral(caseWhen.getDataType()) : defaultResult;
return TypeCoercionUtils.ensureSameResultType(
originCaseWhen, defaultResult == null ? new NullLiteral(caseWhen.getDataType()) : defaultResult,
context
);
}
if (defaultResult == null) {
if (caseWhen.getDataType().isNullType()) {
// if caseWhen's type is NULL_TYPE, means all possible return values are nulls
// it's safe to return null literal here
return new NullLiteral();
} else {
return new CaseWhen(whenClauses);
return TypeCoercionUtils.ensureSameResultType(originCaseWhen, new CaseWhen(whenClauses), context);
}
}
return new CaseWhen(whenClauses, defaultResult);
return TypeCoercionUtils.ensureSameResultType(
originCaseWhen, new CaseWhen(whenClauses, defaultResult), context
);
}

@Override
public Expression visitIf(If ifExpr, ExpressionRewriteContext context) {
If originIf = ifExpr;
ifExpr = rewriteChildren(ifExpr, context);
if (ifExpr.child(0) instanceof NullLiteral || ifExpr.child(0).equals(BooleanLiteral.FALSE)) {
return ifExpr.child(2);
return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.child(2), context);
} else if (ifExpr.child(0).equals(BooleanLiteral.TRUE)) {
return ifExpr.child(1);
return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr.child(1), context);
}
return ifExpr;
return TypeCoercionUtils.ensureSameResultType(originIf, ifExpr, context);
}

@Override
Expand Down Expand Up @@ -684,17 +692,20 @@ public Expression visitVersion(Version version, ExpressionRewriteContext context

@Override
public Expression visitNvl(Nvl nvl, ExpressionRewriteContext context) {
Nvl originNvl = nvl;
nvl = rewriteChildren(nvl, context);

for (Expression expr : nvl.children()) {
if (expr.isLiteral()) {
if (!expr.isNullLiteral()) {
return expr;
return TypeCoercionUtils.ensureSameResultType(originNvl, expr, context);
}
} else {
return nvl;
return TypeCoercionUtils.ensureSameResultType(originNvl, nvl, context);
}
}
// all nulls
return nvl.child(0);
return TypeCoercionUtils.ensureSameResultType(originNvl, nvl.child(0), context);
}

private <E extends Expression> E rewriteChildren(E expr, ExpressionRewriteContext context) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionMatchingContext;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
Expand All @@ -26,6 +27,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.ImmutableList;

Expand All @@ -38,11 +40,11 @@ public class SimplifyConditionalFunction implements ExpressionPatternRuleFactory
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(Coalesce.class).then(SimplifyConditionalFunction::rewriteCoalesce)
matchesType(Coalesce.class).thenApply(SimplifyConditionalFunction::rewriteCoalesce)
.toRule(ExpressionRuleType.SIMPLIFY_CONDITIONAL_FUNCTION),
matchesType(Nvl.class).then(SimplifyConditionalFunction::rewriteNvl)
matchesType(Nvl.class).thenApply(SimplifyConditionalFunction::rewriteNvl)
.toRule(ExpressionRuleType.SIMPLIFY_CONDITIONAL_FUNCTION),
matchesType(NullIf.class).then(SimplifyConditionalFunction::rewriteNullIf)
matchesType(NullIf.class).thenApply(SimplifyConditionalFunction::rewriteNullIf)
.toRule(ExpressionRuleType.SIMPLIFY_CONDITIONAL_FUNCTION)
);
}
Expand All @@ -53,46 +55,52 @@ public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
* coalesce(null,null) => null
* coalesce(expr1) => expr1
* */
private static Expression rewriteCoalesce(Coalesce expression) {
if (1 == expression.arity()) {
return expression.child(0);
private static Expression rewriteCoalesce(ExpressionMatchingContext<Coalesce> ctx) {
Coalesce coalesce = ctx.expr;
if (1 == coalesce.arity()) {
return TypeCoercionUtils.ensureSameResultType(coalesce, coalesce.child(0), ctx.rewriteContext);
}
if (!(expression.child(0) instanceof NullLiteral) && expression.child(0).nullable()) {
return expression;
if (!(coalesce.child(0) instanceof NullLiteral) && coalesce.child(0).nullable()) {
return TypeCoercionUtils.ensureSameResultType(coalesce, coalesce, ctx.rewriteContext);
}
ImmutableList.Builder<Expression> childBuilder = ImmutableList.builder();
for (int i = 0; i < expression.arity(); i++) {
Expression child = expression.children().get(i);
for (int i = 0; i < coalesce.arity(); i++) {
Expression child = coalesce.children().get(i);
if (child instanceof NullLiteral) {
continue;
}
if (!child.nullable()) {
return child;
return TypeCoercionUtils.ensureSameResultType(coalesce, child, ctx.rewriteContext);
} else {
for (int j = i; j < expression.arity(); j++) {
childBuilder.add(expression.children().get(j));
for (int j = i; j < coalesce.arity(); j++) {
childBuilder.add(coalesce.children().get(j));
}
break;
}
}
List<Expression> newChildren = childBuilder.build();
if (newChildren.isEmpty()) {
return new NullLiteral(expression.getDataType());
return TypeCoercionUtils.ensureSameResultType(
coalesce, new NullLiteral(coalesce.getDataType()), ctx.rewriteContext
);
} else {
return expression.withChildren(newChildren);
return TypeCoercionUtils.ensureSameResultType(
coalesce, coalesce.withChildren(newChildren), ctx.rewriteContext
);
}
}

/*
* nvl(null,R) => R
* nvl(L(not-nullable ),R) => L
* */
private static Expression rewriteNvl(Nvl nvl) {
private static Expression rewriteNvl(ExpressionMatchingContext<Nvl> ctx) {
Nvl nvl = ctx.expr;
if (nvl.child(0) instanceof NullLiteral) {
return nvl.child(1);
return TypeCoercionUtils.ensureSameResultType(nvl, nvl.child(1), ctx.rewriteContext);
}
if (!nvl.child(0).nullable()) {
return nvl.child(0);
return TypeCoercionUtils.ensureSameResultType(nvl, nvl.child(0), ctx.rewriteContext);
}
return nvl;
}
Expand All @@ -101,9 +109,12 @@ private static Expression rewriteNvl(Nvl nvl) {
* nullif(null, R) => Null
* nullif(L, null) => Null
*/
private static Expression rewriteNullIf(NullIf nullIf) {
private static Expression rewriteNullIf(ExpressionMatchingContext<NullIf> ctx) {
NullIf nullIf = ctx.expr;
if (nullIf.child(0) instanceof NullLiteral || nullIf.child(1) instanceof NullLiteral) {
return new Nullable(nullIf.child(0));
return TypeCoercionUtils.ensureSameResultType(
nullIf, new Nullable(nullIf.child(0)), ctx.rewriteContext
);
} else {
return nullIf;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
Expand Down Expand Up @@ -153,6 +155,26 @@ public class TypeCoercionUtils {

private static final Logger LOG = LogManager.getLogger(TypeCoercionUtils.class);

/**
* ensure the result's data type equals to the originExpr's dataType,
* ATTN: this method usually used in fold constant rule
*/
public static Expression ensureSameResultType(
Expression originExpr, Expression result, ExpressionRewriteContext context) {
DataType originDataType = originExpr.getDataType();
DataType newDataType = result.getDataType();
if (originDataType.equals(newDataType)) {
return result;
}
// backend can direct use all string like type without cast
if (originDataType.isStringLikeType() && newDataType.isStringLikeType()) {
return result;
}
return FoldConstantRuleOnFE.PATTERN_MATCH_INSTANCE.visitCast(
new Cast(result, originDataType), context
);
}

/**
* Return Optional.empty() if we cannot do implicit cast.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.rules.expression.rules.SimplifyConditionalFunction;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
Expand Down Expand Up @@ -122,6 +124,15 @@ void testCaseWhenFold() {
assertRewriteAfterTypeCoercion("case when null = 2 then 1 else 4 end", "4");
assertRewriteAfterTypeCoercion("case when null = 2 then 1 end", "null");
assertRewriteAfterTypeCoercion("case when TA = TB then 1 when TC is null then 2 end", "CASE WHEN (TA = TB) THEN 1 WHEN TC IS NULL THEN 2 END");

// make sure the case when return datetime(6)
Expression analyzedCaseWhen = ExpressionAnalyzer.analyzeFunction(null, null, PARSER.parseExpression(
"case when true then cast('2025-04-17' as datetime(0)) else cast('2025-04-18 01:02:03.123456' as datetime(6)) end"));
Assertions.assertEquals(DateTimeV2Type.of(6), analyzedCaseWhen.getDataType());
Assertions.assertEquals(DateTimeV2Type.of(6), ((CaseWhen) analyzedCaseWhen).getWhenClauses().get(0).getResult().getDataType());
Assertions.assertEquals(DateTimeV2Type.of(6), ((CaseWhen) analyzedCaseWhen).getDefaultValue().get().getDataType());
Expression foldCaseWhen = executor.rewrite(analyzedCaseWhen, context);
Assertions.assertEquals(new DateTimeV2Literal(DateTimeV2Type.of(6), "2025-04-17"), foldCaseWhen);
}

@Test
Expand Down Expand Up @@ -1175,14 +1186,21 @@ void testFoldNvl() {
executor = new ExpressionRuleExecutor(ImmutableList.of(
ExpressionAnalyzer.FUNCTION_ANALYZER_RULE,
bottomUp(
FoldConstantRule.INSTANCE
FoldConstantRule.INSTANCE,
SimplifyConditionalFunction.INSTANCE
)
));

assertRewriteExpression("nvl(NULL, 1)", "1");
assertRewriteExpression("nvl(NULL, NULL)", "NULL");
assertRewriteAfterTypeCoercion("nvl(IA, NULL)", "ifnull(IA, NULL)");
assertRewriteAfterTypeCoercion("nvl(IA, 1)", "ifnull(IA, 1)");

Expression foldNvl = executor.rewrite(
PARSER.parseExpression("nvl(cast('2025-04-17' as datetime(0)), cast('2025-04-18 01:02:03.123456' as datetime(6)))"),
context
);
Assertions.assertEquals(new DateTimeV2Literal(DateTimeV2Type.of(6), "2025-04-17"), foldNvl);
}

private void assertRewriteExpression(String actualExpression, String expectedExpression) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Coalesce;
import org.apache.doris.nereids.trees.expressions.functions.scalar.NullIf;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;

Expand Down Expand Up @@ -71,6 +73,18 @@ public void testCoalesce() {

// coalesce(null, nullable_slot, literal) -> coalesce(nullable_slot, slot, literal)
assertRewrite(new Coalesce(slot, nonNullableSlot), new Coalesce(slot, nonNullableSlot));

SlotReference datetimeSlot = new SlotReference("dt", DateTimeV2Type.of(0), false);
// coalesce(null_datetime(0), non-nullable_slot_datetime(6))
assertRewrite(
new Coalesce(new NullLiteral(DateTimeV2Type.of(6)), datetimeSlot),
new Cast(datetimeSlot, DateTimeV2Type.of(6))
);
// coalesce(non-nullable_slot_datetime(6), null_datetime(0))
assertRewrite(
new Coalesce(datetimeSlot, new NullLiteral(DateTimeV2Type.of(6))),
new Cast(datetimeSlot, DateTimeV2Type.of(6))
);
}

@Test
Expand All @@ -92,6 +106,18 @@ public void testNvl() {

// nvl(null, null) -> null
assertRewrite(new Nvl(NullLiteral.INSTANCE, NullLiteral.INSTANCE), new NullLiteral(BooleanType.INSTANCE));

SlotReference datetimeSlot = new SlotReference("dt", DateTimeV2Type.of(0), false);
// nvl(null_datetime(0), non-nullable_slot_datetime(6))
assertRewrite(
new Nvl(new NullLiteral(DateTimeV2Type.of(6)), datetimeSlot),
new Cast(datetimeSlot, DateTimeV2Type.of(6))
);
// nvl(non-nullable_slot_datetime(6), null_datetime(0))
assertRewrite(
new Nvl(datetimeSlot, new NullLiteral(DateTimeV2Type.of(6))),
new Cast(datetimeSlot, DateTimeV2Type.of(6))
);
}

@Test
Expand All @@ -108,6 +134,15 @@ public void testNullIf() {

// nullif(non-nullable_slot, null) -> non-nullable_slot
assertRewrite(new NullIf(nonNullableSlot, NullLiteral.INSTANCE), new Nullable(nonNullableSlot));

// nullif(null_datetime(0), null_datetime(6)) -> null_datetime(6)
assertRewrite(
new NullIf(
new NullLiteral(DateTimeV2Type.of(0)),
new NullLiteral(DateTimeV2Type.of(6))
),
new Cast(new Nullable(new NullLiteral(DateTimeV2Type.of(0))), DateTimeV2Type.of(6))
);
}

}
Loading