Update KAddressArbiter implementation to 11.x kernel (#1851)

* Update KAddressArbiter implementation to 11.x kernel

* InsertSortedByPriority is no longer needed
This commit is contained in:
gdkchan 2021-01-01 14:59:26 -03:00 committed by GitHub
parent 0a55657bd2
commit 532b8cad13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,6 @@
using Ryujinx.HLE.HOS.Kernel.Common; using Ryujinx.HLE.HOS.Kernel.Common;
using Ryujinx.HLE.HOS.Kernel.Process; using Ryujinx.HLE.HOS.Kernel.Process;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading; using System.Threading;
@ -83,7 +84,14 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
KThread currentThread = KernelStatic.GetCurrentThread(); KThread currentThread = KernelStatic.GetCurrentThread();
(KernelResult result, KThread newOwnerThread) = MutexUnlock(currentThread, mutexAddress); (int mutexValue, KThread newOwnerThread) = MutexUnlock(currentThread, mutexAddress);
KernelResult result = KernelResult.Success;
if (!KernelTransfer.KernelToUserInt32(_context, mutexAddress, mutexValue))
{
result = KernelResult.InvalidMemState;
}
if (result != KernelResult.Success && newOwnerThread != null) if (result != KernelResult.Success && newOwnerThread != null)
{ {
@ -96,11 +104,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
return result; return result;
} }
public KernelResult WaitProcessWideKeyAtomic( public KernelResult WaitProcessWideKeyAtomic(ulong mutexAddress, ulong condVarAddress, int threadHandle, long timeout)
ulong mutexAddress,
ulong condVarAddress,
int threadHandle,
long timeout)
{ {
_context.CriticalSection.Enter(); _context.CriticalSection.Enter();
@ -117,13 +121,15 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
return KernelResult.ThreadTerminating; return KernelResult.ThreadTerminating;
} }
(KernelResult result, _) = MutexUnlock(currentThread, mutexAddress); (int mutexValue, _) = MutexUnlock(currentThread, mutexAddress);
if (result != KernelResult.Success) KernelTransfer.KernelToUserInt32(_context, condVarAddress, 1);
if (!KernelTransfer.KernelToUserInt32(_context, mutexAddress, mutexValue))
{ {
_context.CriticalSection.Leave(); _context.CriticalSection.Leave();
return result; return KernelResult.InvalidMemState;
} }
currentThread.MutexAddress = mutexAddress; currentThread.MutexAddress = mutexAddress;
@ -163,7 +169,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
return currentThread.ObjSyncResult; return currentThread.ObjSyncResult;
} }
private (KernelResult, KThread) MutexUnlock(KThread currentThread, ulong mutexAddress) private (int, KThread) MutexUnlock(KThread currentThread, ulong mutexAddress)
{ {
KThread newOwnerThread = currentThread.RelinquishMutex(mutexAddress, out int count); KThread newOwnerThread = currentThread.RelinquishMutex(mutexAddress, out int count);
@ -184,46 +190,24 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
newOwnerThread.ReleaseAndResume(); newOwnerThread.ReleaseAndResume();
} }
KernelResult result = KernelResult.Success; return (mutexValue, newOwnerThread);
if (!KernelTransfer.KernelToUserInt32(_context, mutexAddress, mutexValue))
{
result = KernelResult.InvalidMemState;
}
return (result, newOwnerThread);
} }
public void SignalProcessWideKey(ulong address, int count) public void SignalProcessWideKey(ulong address, int count)
{ {
Queue<KThread> signaledThreads = new Queue<KThread>();
_context.CriticalSection.Enter(); _context.CriticalSection.Enter();
IOrderedEnumerable<KThread> sortedThreads = _condVarThreads.OrderBy(x => x.DynamicPriority); WakeThreads(_condVarThreads, count, TryAcquireMutex, x => x.CondVarAddress == address);
foreach (KThread thread in sortedThreads.Where(x => x.CondVarAddress == address)) if (!_condVarThreads.Any(x => x.CondVarAddress == address))
{ {
TryAcquireMutex(thread); KernelTransfer.KernelToUserInt32(_context, address, 0);
signaledThreads.Enqueue(thread);
// If the count is <= 0, we should signal all threads waiting.
if (count >= 1 && --count == 0)
{
break;
}
}
while (signaledThreads.TryDequeue(out KThread thread))
{
_condVarThreads.Remove(thread);
} }
_context.CriticalSection.Leave(); _context.CriticalSection.Leave();
} }
private KThread TryAcquireMutex(KThread requester) private static void TryAcquireMutex(KThread requester)
{ {
ulong address = requester.MutexAddress; ulong address = requester.MutexAddress;
@ -235,7 +219,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
requester.SignaledObj = null; requester.SignaledObj = null;
requester.ObjSyncResult = KernelResult.InvalidMemState; requester.ObjSyncResult = KernelResult.InvalidMemState;
return null; return;
} }
ref int mutexRef = ref currentProcess.CpuMemory.GetRef<int>(address); ref int mutexRef = ref currentProcess.CpuMemory.GetRef<int>(address);
@ -267,7 +251,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
requester.ReleaseAndResume(); requester.ReleaseAndResume();
return null; return;
} }
mutexValue &= ~HasListenersMask; mutexValue &= ~HasListenersMask;
@ -287,8 +271,6 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
requester.ReleaseAndResume(); requester.ReleaseAndResume();
} }
return mutexOwner;
} }
public KernelResult WaitForAddressIfEqual(ulong address, int value, long timeout) public KernelResult WaitForAddressIfEqual(ulong address, int value, long timeout)
@ -327,7 +309,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
currentThread.MutexAddress = address; currentThread.MutexAddress = address;
currentThread.WaitingInArbitration = true; currentThread.WaitingInArbitration = true;
InsertSortedByPriority(_arbiterThreads, currentThread); _arbiterThreads.Add(currentThread);
currentThread.Reschedule(ThreadSchedState.Paused); currentThread.Reschedule(ThreadSchedState.Paused);
@ -362,11 +344,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
return KernelResult.InvalidState; return KernelResult.InvalidState;
} }
public KernelResult WaitForAddressIfLessThan( public KernelResult WaitForAddressIfLessThan(ulong address, int value, bool shouldDecrement, long timeout)
ulong address,
int value,
bool shouldDecrement,
long timeout)
{ {
KThread currentThread = KernelStatic.GetCurrentThread(); KThread currentThread = KernelStatic.GetCurrentThread();
@ -409,7 +387,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
currentThread.MutexAddress = address; currentThread.MutexAddress = address;
currentThread.WaitingInArbitration = true; currentThread.WaitingInArbitration = true;
InsertSortedByPriority(_arbiterThreads, currentThread); _arbiterThreads.Add(currentThread);
currentThread.Reschedule(ThreadSchedState.Paused); currentThread.Reschedule(ThreadSchedState.Paused);
@ -444,30 +422,6 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
return KernelResult.InvalidState; return KernelResult.InvalidState;
} }
private void InsertSortedByPriority(List<KThread> threads, KThread thread)
{
int nextIndex = -1;
for (int index = 0; index < threads.Count; index++)
{
if (threads[index].DynamicPriority > thread.DynamicPriority)
{
nextIndex = index;
break;
}
}
if (nextIndex != -1)
{
threads.Insert(nextIndex, thread);
}
else
{
threads.Add(thread);
}
}
public KernelResult Signal(ulong address, int count) public KernelResult Signal(ulong address, int count)
{ {
_context.CriticalSection.Enter(); _context.CriticalSection.Enter();
@ -520,7 +474,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
{ {
_context.CriticalSection.Enter(); _context.CriticalSection.Enter();
int offset; int addend;
// The value is decremented if the number of threads waiting is less // The value is decremented if the number of threads waiting is less
// or equal to the Count of threads to be signaled, or Count is zero // or equal to the Count of threads to be signaled, or Count is zero
@ -529,7 +483,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
foreach (KThread thread in _arbiterThreads.Where(x => x.MutexAddress == address)) foreach (KThread thread in _arbiterThreads.Where(x => x.MutexAddress == address))
{ {
if (++waitingCount > count) if (++waitingCount >= count)
{ {
break; break;
} }
@ -537,11 +491,22 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
if (waitingCount > 0) if (waitingCount > 0)
{ {
offset = waitingCount <= count || count <= 0 ? -1 : 0; if (count <= 0)
{
addend = -2;
}
else if (waitingCount < count)
{
addend = -1;
}
else
{
addend = 0;
}
} }
else else
{ {
offset = 1; addend = 1;
} }
KProcess currentProcess = KernelStatic.GetCurrentProcess(); KProcess currentProcess = KernelStatic.GetCurrentProcess();
@ -568,7 +533,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
return KernelResult.InvalidState; return KernelResult.InvalidState;
} }
} }
while (Interlocked.CompareExchange(ref valueRef, currentValue + offset, currentValue) != currentValue); while (Interlocked.CompareExchange(ref valueRef, currentValue + addend, currentValue) != currentValue);
WakeArbiterThreads(address, count); WakeArbiterThreads(address, count);
@ -579,20 +544,7 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
private void WakeArbiterThreads(ulong address, int count) private void WakeArbiterThreads(ulong address, int count)
{ {
Queue<KThread> signaledThreads = new Queue<KThread>(); static void RemoveArbiterThread(KThread thread)
foreach (KThread thread in _arbiterThreads.Where(x => x.MutexAddress == address))
{
signaledThreads.Enqueue(thread);
// If the count is <= 0, we should signal all threads waiting.
if (count >= 1 && --count == 0)
{
break;
}
}
while (signaledThreads.TryDequeue(out KThread thread))
{ {
thread.SignaledObj = null; thread.SignaledObj = null;
thread.ObjSyncResult = KernelResult.Success; thread.ObjSyncResult = KernelResult.Success;
@ -600,8 +552,24 @@ namespace Ryujinx.HLE.HOS.Kernel.Threading
thread.ReleaseAndResume(); thread.ReleaseAndResume();
thread.WaitingInArbitration = false; thread.WaitingInArbitration = false;
}
_arbiterThreads.Remove(thread); WakeThreads(_arbiterThreads, count, RemoveArbiterThread, x => x.MutexAddress == address);
}
private static void WakeThreads(
List<KThread> threads,
int count,
Action<KThread> removeCallback,
Func<KThread, bool> predicate)
{
var candidates = threads.Where(predicate).OrderBy(x => x.DynamicPriority);
var toSignal = (count > 0 ? candidates.Take(count) : candidates).ToArray();
foreach (KThread thread in toSignal)
{
removeCallback(thread);
threads.Remove(thread);
} }
} }
} }