diff --git a/src/ARMeilleure/Instructions/InstEmitException.cs b/src/ARMeilleure/Instructions/InstEmitException.cs index d30fb2fbd..a91716c64 100644 --- a/src/ARMeilleure/Instructions/InstEmitException.cs +++ b/src/ARMeilleure/Instructions/InstEmitException.cs @@ -19,7 +19,7 @@ namespace ARMeilleure.Instructions context.LoadFromContext(); - context.Return(Const(op.Address)); + InstEmitFlowHelper.EmitReturn(context, Const(op.Address)); } public static void Svc(ArmEmitterContext context) @@ -49,7 +49,7 @@ namespace ARMeilleure.Instructions context.LoadFromContext(); - context.Return(Const(op.Address)); + InstEmitFlowHelper.EmitReturn(context, Const(op.Address)); } } } diff --git a/src/ARMeilleure/Instructions/InstEmitException32.cs b/src/ARMeilleure/Instructions/InstEmitException32.cs index 57af1522b..e5bad56ef 100644 --- a/src/ARMeilleure/Instructions/InstEmitException32.cs +++ b/src/ARMeilleure/Instructions/InstEmitException32.cs @@ -33,7 +33,7 @@ namespace ARMeilleure.Instructions context.LoadFromContext(); - context.Return(Const(context.CurrOp.Address)); + InstEmitFlowHelper.EmitReturn(context, Const(context.CurrOp.Address)); } } } diff --git a/src/ARMeilleure/Instructions/InstEmitFlow.cs b/src/ARMeilleure/Instructions/InstEmitFlow.cs index a986bf66f..cb214d3d5 100644 --- a/src/ARMeilleure/Instructions/InstEmitFlow.cs +++ b/src/ARMeilleure/Instructions/InstEmitFlow.cs @@ -66,7 +66,7 @@ namespace ARMeilleure.Instructions { OpCodeBReg op = (OpCodeBReg)context.CurrOp; - context.Return(GetIntOrZR(context, op.Rn)); + EmitReturn(context, GetIntOrZR(context, op.Rn)); } public static void Tbnz(ArmEmitterContext context) => EmitTb(context, onNotZero: true); diff --git a/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs b/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs index a602ea49e..55947c3b0 100644 --- a/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs +++ b/src/ARMeilleure/Instructions/InstEmitFlowHelper.cs @@ -12,6 +12,10 @@ namespace ARMeilleure.Instructions { static class InstEmitFlowHelper { + // How many calls we can have in our call stack before we give up and return to the dispatcher. + // This prevents stack overflows caused by deep recursive calls. + private const int MaxCallDepth = 200; + public static void EmitCondBranch(ArmEmitterContext context, Operand target, Condition cond) { if (cond != Condition.Al) @@ -163,12 +167,7 @@ namespace ARMeilleure.Instructions { if (isReturn) { - if (target.Type == OperandType.I32) - { - target = context.ZeroExtend32(OperandType.I64, target); - } - - context.Return(target); + EmitReturn(context, target); } else { @@ -176,6 +175,19 @@ namespace ARMeilleure.Instructions } } + public static void EmitReturn(ArmEmitterContext context, Operand target) + { + Operand nativeContext = context.LoadArgument(OperandType.I64, 0); + DecreaseCallDepth(context, nativeContext); + + if (target.Type == OperandType.I32) + { + target = context.ZeroExtend32(OperandType.I64, target); + } + + context.Return(target); + } + private static void EmitTableBranch(ArmEmitterContext context, Operand guestAddress, bool isJump) { context.StoreToContext(); @@ -238,6 +250,8 @@ namespace ARMeilleure.Instructions if (isJump) { + DecreaseCallDepth(context, nativeContext); + context.Tailcall(hostAddress, nativeContext); } else @@ -259,8 +273,42 @@ namespace ARMeilleure.Instructions Operand lblContinue = context.GetLabel(nextAddr.Value); context.BranchIf(lblContinue, returnAddress, nextAddr, Comparison.Equal, BasicBlockFrequency.Cold); + DecreaseCallDepth(context, nativeContext); + context.Return(returnAddress); } } + + public static void EmitCallDepthCheckAndIncrement(EmitterContext context, Operand guestAddress) + { + if (!Optimizations.EnableDeepCallRecursionProtection) + { + return; + } + + Operand nativeContext = context.LoadArgument(OperandType.I64, 0); + Operand callDepthAddr = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset())); + Operand currentCallDepth = context.Load(OperandType.I32, callDepthAddr); + Operand lblDoCall = Label(); + + context.BranchIf(lblDoCall, currentCallDepth, Const(MaxCallDepth), Comparison.LessUI); + context.Store(callDepthAddr, context.Subtract(currentCallDepth, Const(1))); + context.Return(guestAddress); + + context.MarkLabel(lblDoCall); + context.Store(callDepthAddr, context.Add(currentCallDepth, Const(1))); + } + + private static void DecreaseCallDepth(EmitterContext context, Operand nativeContext) + { + if (!Optimizations.EnableDeepCallRecursionProtection) + { + return; + } + + Operand callDepthAddr = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset())); + Operand currentCallDepth = context.Load(OperandType.I32, callDepthAddr); + context.Store(callDepthAddr, context.Subtract(currentCallDepth, Const(1))); + } } } diff --git a/src/ARMeilleure/Optimizations.cs b/src/ARMeilleure/Optimizations.cs index 231274e41..eaa09df27 100644 --- a/src/ARMeilleure/Optimizations.cs +++ b/src/ARMeilleure/Optimizations.cs @@ -15,6 +15,7 @@ namespace ARMeilleure public static bool AllowLcqInFunctionTable { get; set; } = true; public static bool UseUnmanagedDispatchLoop { get; set; } = true; + public static bool EnableDeepCallRecursionProtection { get; set; } = true; public static bool UseAdvSimdIfAvailable { get; set; } = true; public static bool UseArm64AesIfAvailable { get; set; } = true; diff --git a/src/ARMeilleure/State/ExecutionContext.cs b/src/ARMeilleure/State/ExecutionContext.cs index ce10a591c..7c89dd4be 100644 --- a/src/ARMeilleure/State/ExecutionContext.cs +++ b/src/ARMeilleure/State/ExecutionContext.cs @@ -128,6 +128,11 @@ namespace ARMeilleure.State public bool GetFPstateFlag(FPState flag) => _nativeContext.GetFPStateFlag(flag); public void SetFPstateFlag(FPState flag, bool value) => _nativeContext.SetFPStateFlag(flag, value); + internal void ResetCallDepth() + { + _nativeContext.ResetCallDepth(); + } + internal void CheckInterrupt() { if (_interrupted) diff --git a/src/ARMeilleure/State/NativeContext.cs b/src/ARMeilleure/State/NativeContext.cs index f84cb5080..2cf0de530 100644 --- a/src/ARMeilleure/State/NativeContext.cs +++ b/src/ARMeilleure/State/NativeContext.cs @@ -21,6 +21,7 @@ namespace ARMeilleure.State public ulong ExclusiveValueLow; public ulong ExclusiveValueHigh; public int Running; + public int CallDepth; public long Tpidr2El0; } @@ -186,6 +187,8 @@ namespace ARMeilleure.State public bool GetRunning() => GetStorage().Running != 0; public void SetRunning(bool value) => GetStorage().Running = value ? 1 : 0; + public void ResetCallDepth() => GetStorage().CallDepth = 0; + public unsafe static int GetRegisterOffset(Register reg) { if (reg.Type == RegisterType.Integer) @@ -266,6 +269,11 @@ namespace ARMeilleure.State return StorageOffset(ref _dummyStorage, ref _dummyStorage.Running); } + public static int GetCallDepthOffset() + { + return StorageOffset(ref _dummyStorage, ref _dummyStorage.CallDepth); + } + private static int StorageOffset(ref NativeCtxStorage storage, ref T target) { return (int)Unsafe.ByteOffset(ref Unsafe.As(ref storage), ref target); diff --git a/src/ARMeilleure/Translation/PTC/Ptc.cs b/src/ARMeilleure/Translation/PTC/Ptc.cs index c3993f66e..91c065aed 100644 --- a/src/ARMeilleure/Translation/PTC/Ptc.cs +++ b/src/ARMeilleure/Translation/PTC/Ptc.cs @@ -30,7 +30,7 @@ namespace ARMeilleure.Translation.PTC private const string OuterHeaderMagicString = "PTCohd\0\0"; private const string InnerHeaderMagicString = "PTCihd\0\0"; - private const uint InternalVersion = 7008; //! To be incremented manually for each change to the ARMeilleure project. + private const uint InternalVersion = 7010; //! To be incremented manually for each change to the ARMeilleure project. private const string ActualDir = "0"; private const string BackupDir = "1"; diff --git a/src/ARMeilleure/Translation/Translator.cs b/src/ARMeilleure/Translation/Translator.cs index 4eb4dd69a..094630ed3 100644 --- a/src/ARMeilleure/Translation/Translator.cs +++ b/src/ARMeilleure/Translation/Translator.cs @@ -168,6 +168,7 @@ namespace ARMeilleure.Translation Statistics.StartTimer(); + context.ResetCallDepth(); ulong nextAddr = func.Execute(Stubs.ContextWrapper, context); Statistics.StopTimer(address); @@ -239,6 +240,7 @@ namespace ARMeilleure.Translation Logger.StartPass(PassName.Translation); + InstEmitFlowHelper.EmitCallDepthCheckAndIncrement(context, Const(address)); EmitSynchronization(context); if (blocks[0].Address != address) diff --git a/src/ARMeilleure/Translation/TranslatorStubs.cs b/src/ARMeilleure/Translation/TranslatorStubs.cs index 594ccd575..f3754d7f0 100644 --- a/src/ARMeilleure/Translation/TranslatorStubs.cs +++ b/src/ARMeilleure/Translation/TranslatorStubs.cs @@ -262,10 +262,18 @@ namespace ARMeilleure.Translation Operand runningAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetRunningOffset())); Operand dispatchAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetDispatchAddressOffset())); + Operand callDepthAddress = context.Add(nativeContext, Const((ulong)NativeContext.GetCallDepthOffset())); EmitSyncFpContext(context, nativeContext, true); context.MarkLabel(beginLbl); + + if (Optimizations.EnableDeepCallRecursionProtection) + { + // Reset the call depth counter, since this is our first guest function call. + context.Store(callDepthAddress, Const(0)); + } + context.Store(dispatchAddress, guestAddress); context.Copy(guestAddress, context.Call(Const((ulong)DispatchStub), OperandType.I64, nativeContext)); context.BranchIfFalse(endLbl, guestAddress);