Skip to content
191 changes: 170 additions & 21 deletions src/FastExpressionCompiler/FastExpressionCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,15 @@ public static Result TryCollectInfo(ref ClosureInfo closure, Expression expr,
}
case ExpressionType.Conditional:
var condExpr = (ConditionalExpression)expr;
// Try structural branch elimination - skip collecting dead branch info
{
var reducedCond = Tools.TryReduceConditional(condExpr);
if (!ReferenceEquals(reducedCond, condExpr))
{
expr = reducedCond;
continue;
}
}
if ((r = TryCollectInfo(ref closure, condExpr.Test, paramExprs, nestedLambda, ref rootNestedLambdas, flags)) != Result.OK ||
(r = TryCollectInfo(ref closure, condExpr.IfFalse, paramExprs, nestedLambda, ref rootNestedLambdas, flags)) != Result.OK)
return r;
Expand Down Expand Up @@ -1604,6 +1613,16 @@ public static Result TryCollectInfo(ref ClosureInfo closure, Expression expr,

case ExpressionType.Switch:
var switchExpr = ((SwitchExpression)expr);
// Compile-time switch branch elimination (#489): if switch value is interpretable, collect only the matching branch
if (Interpreter.TryFindSwitchBranch(switchExpr, flags, out var switchMatchedBody))
{
if (switchMatchedBody != null)
{
expr = switchMatchedBody;
continue;
}
return r; // no matched body and no default → nothing to collect
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove special llm symbos like unicode arrows, long -, replace with regular ->, and -

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in commit f21739e — replaced all with -> and with - in the comments I added across FastExpressionCompiler.cs, Issue472_*.cs, and Issue489_Switch_BranchElimination.cs.

}
if ((r = TryCollectInfo(ref closure, switchExpr.SwitchValue, paramExprs, nestedLambda, ref rootNestedLambdas, flags)) != Result.OK ||
switchExpr.DefaultBody != null && // todo: @check is the order of collection affects the result?
(r = TryCollectInfo(ref closure, switchExpr.DefaultBody, paramExprs, nestedLambda, ref rootNestedLambdas, flags)) != Result.OK)
Expand Down Expand Up @@ -2247,6 +2266,15 @@ public static bool TryEmit(Expression expr,
expr = testIsTrue ? condExpr.IfTrue : condExpr.IfFalse;
continue; // no recursion, just continue with the left or right side of condition
}
// Try structural branch elimination (e.g., null == Default(X) → always true/false)
{
var reducedCond = Tools.TryReduceConditional(condExpr);
if (!ReferenceEquals(reducedCond, condExpr))
{
expr = reducedCond;
continue;
}
}
return TryEmitConditional(testExpr, condExpr.IfTrue, condExpr.IfFalse, paramExprs, il, ref closure, setup, parent);

case ExpressionType.PostIncrementAssign:
Expand Down Expand Up @@ -5379,25 +5407,8 @@ private struct TestValueAndMultiTestCaseIndex
public int MultiTestValCaseBodyIdxPlusOne; // 0 means not multi-test case, otherwise index+1
}

private static long ConvertValueObjectToLong(object valObj)
{
Debug.Assert(valObj != null);
var type = valObj.GetType();
type = type.IsEnum ? Enum.GetUnderlyingType(type) : type;
return Type.GetTypeCode(type) switch
{
TypeCode.Char => (long)(char)valObj,
TypeCode.SByte => (long)(sbyte)valObj,
TypeCode.Byte => (long)(byte)valObj,
TypeCode.Int16 => (long)(short)valObj,
TypeCode.UInt16 => (long)(ushort)valObj,
TypeCode.Int32 => (long)(int)valObj,
TypeCode.UInt32 => (long)(uint)valObj,
TypeCode.Int64 => (long)valObj,
TypeCode.UInt64 => (long)(ulong)valObj,
_ => 0 // unreachable
};
}
private static long ConvertValueObjectToLong(object valObj) =>
Interpreter.ConvertValueObjectToLong(valObj);

#if LIGHT_EXPRESSION
private static bool TryEmitSwitch(SwitchExpression expr, IParameterProvider paramExprs, ILGenerator il, ref ClosureInfo closure,
Expand All @@ -5413,6 +5424,10 @@ private static bool TryEmitSwitch(SwitchExpression expr, IReadOnlyList<PE> param
var caseCount = cases.Count;
var defaultBody = expr.DefaultBody;

// Compile-time switch branch elimination (#489): if the switch value is interpretable, select the matching branch
if (Interpreter.TryFindSwitchBranch(expr, setup, out var matchedBody))
return matchedBody == null || TryEmit(matchedBody, paramExprs, il, ref closure, setup, parent);

// Optimization for the single case
if (caseCount == 1 & defaultBody != null)
{
Expand Down Expand Up @@ -7213,6 +7228,28 @@ internal static bool TryUnboxToPrimitiveValue(ref PValue value, object boxedValu
_ => UnreachableCase(code, (object)null)
};

/// <summary>Converts an integer/enum/char boxed value to <c>long</c> for uniform comparison.</summary>
[MethodImpl((MethodImplOptions)256)]
internal static long ConvertValueObjectToLong(object valObj)
{
Debug.Assert(valObj != null);
var type = valObj.GetType();
type = type.IsEnum ? Enum.GetUnderlyingType(type) : type;
return Type.GetTypeCode(type) switch
{
TypeCode.Char => (long)(char)valObj,
TypeCode.SByte => (long)(sbyte)valObj,
TypeCode.Byte => (long)(byte)valObj,
TypeCode.Int16 => (long)(short)valObj,
TypeCode.UInt16 => (long)(ushort)valObj,
TypeCode.Int32 => (long)(int)valObj,
TypeCode.UInt32 => (long)(uint)valObj,
TypeCode.Int64 => (long)valObj,
TypeCode.UInt64 => (long)(ulong)valObj,
_ => 0 // unreachable
};
}

internal static bool ComparePrimitiveValues(ref PValue left, ref PValue right, TypeCode code, ExpressionType nodeType)
{
switch (nodeType)
Expand Down Expand Up @@ -7545,7 +7582,7 @@ public static bool TryInterpretBool(out bool result, Expression expr, CompilerFl
{
var exprType = expr.Type;
Debug.Assert(exprType.IsPrimitive, // todo: @feat nullables are not supported yet // || Nullable.GetUnderlyingType(exprType)?.IsPrimitive == true,
"Can only reduce the boolean for the expressions of primitive types but found " + expr.Type);
"Can only reduce the boolean for the expressions of primitive type but found " + expr.Type);
result = false;
if ((flags & CompilerFlags.DisableInterpreter) != 0)
return false;
Expand All @@ -7564,6 +7601,95 @@ public static bool TryInterpretBool(out bool result, Expression expr, CompilerFl
}
}

/// <summary>
/// Tries to determine at compile time which branch a switch expression will take.
/// Works for integer/enum and string switch values with no custom equality method.
/// Returns true when the switch value is deterministic; <paramref name="matchedBody"/> is set to
/// the branch body to emit (null means use default body which may itself be null).
/// </summary>
public static bool TryFindSwitchBranch(SwitchExpression switchExpr, CompilerFlags flags, out Expression matchedBody)
{
matchedBody = null;
if (switchExpr.Comparison != null) return false; // custom equality: can't interpret statically
if ((flags & CompilerFlags.DisableInterpreter) != 0) return false;
var switchValueExpr = switchExpr.SwitchValue;
var switchValueType = switchValueExpr.Type;
var cases = switchExpr.Cases;
try
{
// String switch: only constant switch values supported
if (switchValueType == typeof(string))
{
if (switchValueExpr is not ConstantExpression ce) return false;
var switchStr = ce.Value;
for (var i = 0; i < cases.Count; i++)
{
var testValues = cases[i].TestValues;
for (var j = 0; j < testValues.Count; j++)
{
if (testValues[j] is not ConstantExpression testConst) return false;
if (Equals(switchStr, testConst.Value)) { matchedBody = cases[i].Body; return true; }
}
}
matchedBody = switchExpr.DefaultBody;
return true;
}

// Integer / enum / char switch
var effectiveType = switchValueType.IsEnum ? Enum.GetUnderlyingType(switchValueType) : switchValueType;
var typeCode = Type.GetTypeCode(effectiveType);
if (typeCode < TypeCode.Char || typeCode > TypeCode.UInt64) return false; // non-integral (e.g. float, decimal)

long switchValLong;
if (switchValueExpr is ConstantExpression switchConst && switchConst.Value != null)
switchValLong = ConvertValueObjectToLong(switchConst.Value);
else if (typeCode == TypeCode.Int32)
{
var intVal = 0;
if (!TryInterpretInt(ref intVal, switchValueExpr, switchValueExpr.NodeType)) return false;
switchValLong = intVal;
}
else
{
PValue pv = default;
if (!TryInterpretPrimitiveValue(ref pv, switchValueExpr, typeCode, switchValueExpr.NodeType)) return false;
switchValLong = PValueToLong(ref pv, typeCode);
}

for (var i = 0; i < cases.Count; i++)
{
var testValues = cases[i].TestValues;
for (var j = 0; j < testValues.Count; j++)
{
if (testValues[j] is not ConstantExpression testConst || testConst.Value == null) continue;
if (switchValLong == ConvertValueObjectToLong(testConst.Value)) { matchedBody = cases[i].Body; return true; }
}
}
matchedBody = switchExpr.DefaultBody;
return true;
}
catch
{
return false;
}
}

/// <summary>Converts a <see cref="PValue"/> union to a <c>long</c> for integer/char comparison.</summary>
[MethodImpl((MethodImplOptions)256)]
internal static long PValueToLong(ref PValue value, TypeCode code) => code switch
{
TypeCode.Char => (long)value.CharValue,
TypeCode.SByte => (long)value.SByteValue,
TypeCode.Byte => (long)value.ByteValue,
TypeCode.Int16 => (long)value.Int16Value,
TypeCode.UInt16 => (long)value.UInt16Value,
TypeCode.Int32 => (long)value.Int32Value,
TypeCode.UInt32 => (long)value.UInt32Value,
TypeCode.Int64 => value.Int64Value,
TypeCode.UInt64 => (long)value.UInt64Value,
_ => 0L,
};

// todo: @perf try split to `TryInterpretBinary` overload to streamline the calls for TryEmitConditional and similar
/// <summary>Tries to interpret the expression of the Primitive type of Constant, Convert, Logical, Comparison, Arithmetic.</summary>
internal static bool TryInterpretBool(ref bool resultBool, Expression expr, ExpressionType nodeType)
Expand Down Expand Up @@ -8591,7 +8717,9 @@ public static Expression TryReduceConditional(ConditionalExpression condExpr)
var testExpr = TryReduceConditionalTest(condExpr.Test);
if (testExpr is BinaryExpression bi && (bi.NodeType == ExpressionType.Equal || bi.NodeType == ExpressionType.NotEqual))
{
if (bi.Left is ConstantExpression lc && bi.Right is ConstantExpression rc)
var left = bi.Left;
var right = bi.Right;
if (left is ConstantExpression lc && right is ConstantExpression rc)
{
#if INTERPRETATION_DIAGNOSTICS
Console.WriteLine("//Reduced Conditional in Interpretation: " + condExpr);
Expand All @@ -8601,12 +8729,33 @@ public static Expression TryReduceConditional(ConditionalExpression condExpr)
? (equals ? condExpr.IfTrue : condExpr.IfFalse)
: (equals ? condExpr.IfFalse: condExpr.IfTrue);
}

// Handle compile-time branch elimination for null/default equality:
// e.g. Constant(null) == Default(typeof(X)) or Default(typeof(X)) == Constant(null)
// where X is a reference, interface, or nullable type - both represent null, so they are always equal
var leftIsNull = left is ConstantExpression lnc && lnc.Value == null ||
left is DefaultExpression lde && IsNullDefault(lde.Type);
var rightIsNull = right is ConstantExpression rnc && rnc.Value == null ||
right is DefaultExpression rde && IsNullDefault(rde.Type);
if (leftIsNull && rightIsNull)
{
#if INTERPRETATION_DIAGNOSTICS
Console.WriteLine("//Reduced Conditional (null/default equality) in Interpretation: " + condExpr);
#endif
// both sides represent null, so they are equal
return bi.NodeType == ExpressionType.Equal ? condExpr.IfTrue : condExpr.IfFalse;
}
}

return testExpr is ConstantExpression constExpr && constExpr.Value is bool testBool
? (testBool ? condExpr.IfTrue : condExpr.IfFalse)
: condExpr;
}

// Returns true if the type's default value is null (reference types, interfaces, and Nullable<T>)
[MethodImpl((MethodImplOptions)256)]
internal static bool IsNullDefault(Type type) =>
type.IsClass || type.IsInterface || Nullable.GetUnderlyingType(type) != null;
}

[RequiresUnreferencedCode(Trimming.Message)]
Expand Down
Loading
Loading