Skip to content
43 changes: 42 additions & 1 deletion src/FastExpressionCompiler/FastExpressionCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,15 @@ public static Result TryCollectInfo(ref ClosureInfo closure, Expression expr,
}
case ExpressionType.Conditional:
var condExpr = (ConditionalExpression)expr;
// Try structural branch elimination - skip collecting dead branch info
{
var reducedCond = Tools.TryReduceConditional(condExpr);
if (!ReferenceEquals(reducedCond, condExpr))
{
expr = reducedCond;
continue;
}
}
if ((r = TryCollectInfo(ref closure, condExpr.Test, paramExprs, nestedLambda, ref rootNestedLambdas, flags)) != Result.OK ||
(r = TryCollectInfo(ref closure, condExpr.IfFalse, paramExprs, nestedLambda, ref rootNestedLambdas, flags)) != Result.OK)
return r;
Expand Down Expand Up @@ -2247,6 +2256,15 @@ public static bool TryEmit(Expression expr,
expr = testIsTrue ? condExpr.IfTrue : condExpr.IfFalse;
continue; // no recursion, just continue with the left or right side of condition
}
// Try structural branch elimination (e.g., null == Default(X) → always true/false)
{
var reducedCond = Tools.TryReduceConditional(condExpr);
if (!ReferenceEquals(reducedCond, condExpr))
{
expr = reducedCond;
continue;
}
}
return TryEmitConditional(testExpr, condExpr.IfTrue, condExpr.IfFalse, paramExprs, il, ref closure, setup, parent);

case ExpressionType.PostIncrementAssign:
Expand Down Expand Up @@ -8591,7 +8609,9 @@ public static Expression TryReduceConditional(ConditionalExpression condExpr)
var testExpr = TryReduceConditionalTest(condExpr.Test);
if (testExpr is BinaryExpression bi && (bi.NodeType == ExpressionType.Equal || bi.NodeType == ExpressionType.NotEqual))
{
if (bi.Left is ConstantExpression lc && bi.Right is ConstantExpression rc)
var left = bi.Left;
var right = bi.Right;
if (left is ConstantExpression lc && right is ConstantExpression rc)
{
#if INTERPRETATION_DIAGNOSTICS
Console.WriteLine("//Reduced Conditional in Interpretation: " + condExpr);
Expand All @@ -8601,12 +8621,33 @@ public static Expression TryReduceConditional(ConditionalExpression condExpr)
? (equals ? condExpr.IfTrue : condExpr.IfFalse)
: (equals ? condExpr.IfFalse: condExpr.IfTrue);
}

// Handle compile-time branch elimination for null/default equality:
// e.g. Constant(null) == Default(typeof(X)) or Default(typeof(X)) == Constant(null)
// where X is a reference, interface, or nullable type - both represent null, so they are always equal
var leftIsNull = left is ConstantExpression lnc && lnc.Value == null ||
left is DefaultExpression lde && IsNullDefault(lde.Type);
var rightIsNull = right is ConstantExpression rnc && rnc.Value == null ||
right is DefaultExpression rde && IsNullDefault(rde.Type);
if (leftIsNull && rightIsNull)
{
#if INTERPRETATION_DIAGNOSTICS
Console.WriteLine("//Reduced Conditional (null/default equality) in Interpretation: " + condExpr);
#endif
// both sides represent null, so they are equal
return bi.NodeType == ExpressionType.Equal ? condExpr.IfTrue : condExpr.IfFalse;
}
}

return testExpr is ConstantExpression constExpr && constExpr.Value is bool testBool
? (testBool ? condExpr.IfTrue : condExpr.IfFalse)
: condExpr;
}

// Returns true if the type's default value is null (reference types, interfaces, and Nullable<T>)
[MethodImpl((MethodImplOptions)256)]
internal static bool IsNullDefault(Type type) =>
type.IsClass || type.IsInterface || Nullable.GetUnderlyingType(type) != null;
}

[RequiresUnreferencedCode(Trimming.Message)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ public void Run(TestRun t)
{
Logical_expression_started_with_not_Without_Interpreter_due_param_use(t);
Logical_expression_started_with_not(t);
Condition_with_null_constant_equal_to_default_of_class_type_is_eliminated(t);
Condition_with_default_class_type_equal_to_null_constant_is_eliminated(t);
Condition_with_two_defaults_of_class_type_is_eliminated(t);
Condition_with_not_equal_null_and_default_of_class_type_is_eliminated(t);
Condition_with_nullable_default_equal_to_null_is_eliminated(t);
Condition_with_null_constant_equal_to_non_null_constant_is_not_eliminated(t);
}

public void Logical_expression_started_with_not(TestContext t)
Expand Down Expand Up @@ -58,4 +64,127 @@ public void Logical_expression_started_with_not_Without_Interpreter_due_param_us
t.IsFalse(ff(true));
t.IsTrue(ff(false));
}

// Branch elimination: Constant(null) == Default(typeof(X)) where X is a class → always true
// Models the AutoMapper pattern: after inlining a null argument into a null-check lambda
public void Condition_with_null_constant_equal_to_default_of_class_type_is_eliminated(TestContext t)
{
// Condition(Equal(Constant(null), Default(typeof(string))), Constant("trueBranch"), Constant("falseBranch"))
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this comment if it repeats the actual expression below? I would rather see the reasult of PrintCSharp output for the more compact/evident C# reprrsentation.

// Since null == default(string) is always true, this should reduce to "trueBranch"
var expr = Lambda<Func<string>>(
Condition(
Equal(Constant(null, typeof(string)), Default(typeof(string))),
Constant("trueBranch"),
Constant("falseBranch")));

expr.PrintCSharp();

var fs = expr.CompileSys();
fs.PrintIL();
t.AreEqual("trueBranch", fs());

var ff = expr.CompileFast(false);
ff.PrintIL();
t.AreEqual("trueBranch", ff());
}

// Branch elimination: Default(typeof(X)) == Constant(null) where X is a class → always true (symmetric)
public void Condition_with_default_class_type_equal_to_null_constant_is_eliminated(TestContext t)
{
var expr = Lambda<Func<string>>(
Condition(
Equal(Default(typeof(string)), Constant(null, typeof(string))),
Constant("trueBranch"),
Constant("falseBranch")));

expr.PrintCSharp();

var fs = expr.CompileSys();
fs.PrintIL();
t.AreEqual("trueBranch", fs());

var ff = expr.CompileFast(false);
ff.PrintIL();
t.AreEqual("trueBranch", ff());
}

// Branch elimination: Default(typeof(X)) == Default(typeof(X)) where X is a class → always true
public void Condition_with_two_defaults_of_class_type_is_eliminated(TestContext t)
{
var expr = Lambda<Func<string>>(
Condition(
Equal(Default(typeof(string)), Default(typeof(string))),
Constant("trueBranch"),
Constant("falseBranch")));

expr.PrintCSharp();

var fs = expr.CompileSys();
fs.PrintIL();
t.AreEqual("trueBranch", fs());

var ff = expr.CompileFast(false);
ff.PrintIL();
t.AreEqual("trueBranch", ff());
}

// Branch elimination: Constant(null) != Default(typeof(X)) where X is a class → always false
public void Condition_with_not_equal_null_and_default_of_class_type_is_eliminated(TestContext t)
{
var expr = Lambda<Func<string>>(
Condition(
NotEqual(Constant(null, typeof(string)), Default(typeof(string))),
Constant("trueBranch"),
Constant("falseBranch")));

expr.PrintCSharp();

var fs = expr.CompileSys();
fs.PrintIL();
t.AreEqual("falseBranch", fs());

var ff = expr.CompileFast(false);
ff.PrintIL();
t.AreEqual("falseBranch", ff());
}

// Branch elimination: Constant(null) == Default(typeof(int?)) → always true (null == default(int?) is null == null)
public void Condition_with_nullable_default_equal_to_null_is_eliminated(TestContext t)
{
var expr = Lambda<Func<int?>>(
Condition(
Equal(Constant(null, typeof(int?)), Default(typeof(int?))),
Constant(42, typeof(int?)),
Constant(0, typeof(int?))));

expr.PrintCSharp();

var fs = expr.CompileSys();
fs.PrintIL();
t.AreEqual(42, fs());

var ff = expr.CompileFast(false);
ff.PrintIL();
t.AreEqual(42, ff());
}

// Sanity check: Constant(null) == Constant("hello") should NOT be eliminated (false, not a null-null case)
public void Condition_with_null_constant_equal_to_non_null_constant_is_not_eliminated(TestContext t)
{
var expr = Lambda<Func<string>>(
Condition(
Equal(Constant(null, typeof(string)), Constant("hello", typeof(string))),
Constant("trueBranch"),
Constant("falseBranch")));

expr.PrintCSharp();

var fs = expr.CompileSys();
fs.PrintIL();
t.AreEqual("falseBranch", fs());

var ff = expr.CompileFast(false);
ff.PrintIL();
t.AreEqual("falseBranch", ff());
}
}
Loading