Implement the synchronization primitives like the Horizon kernel does (#97)

* Started to work in improving the sync primitives

* Some fixes

* Check that the mutex address matches before waking a waiting thread

* Add MutexOwner field to keep track of the thread owning the mutex, update wait list when priority changes, other tweaks

* Add new priority information to the log

* SvcSetThreadPriority should update just the WantedPriority
This commit is contained in:
gdkchan 2018-04-21 16:07:16 -03:00 committed by GitHub
parent 267ea14cb5
commit 90279d96ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 577 additions and 375 deletions

View File

@ -13,17 +13,23 @@ namespace Ryujinx.Core.OsHle.Handles
{ {
public KThread Thread { get; private set; } public KThread Thread { get; private set; }
public AutoResetEvent WaitEvent { get; private set; } public ManualResetEvent SyncWaitEvent { get; private set; }
public AutoResetEvent SchedWaitEvent { get; private set; }
public bool Active { get; set; } public bool Active { get; set; }
public int SyncTimeout { get; set; }
public SchedulerThread(KThread Thread) public SchedulerThread(KThread Thread)
{ {
this.Thread = Thread; this.Thread = Thread;
WaitEvent = new AutoResetEvent(false); SyncWaitEvent = new ManualResetEvent(true);
SchedWaitEvent = new AutoResetEvent(false);
Active = true; Active = true;
SyncTimeout = 0;
} }
public void Dispose() public void Dispose()
@ -35,7 +41,8 @@ namespace Ryujinx.Core.OsHle.Handles
{ {
if (Disposing) if (Disposing)
{ {
WaitEvent.Dispose(); SyncWaitEvent.Dispose();
SchedWaitEvent.Dispose();
} }
} }
} }
@ -71,9 +78,9 @@ namespace Ryujinx.Core.OsHle.Handles
{ {
SchedThread = Threads[Index]; SchedThread = Threads[Index];
if (HighestPriority > SchedThread.Thread.Priority) if (HighestPriority > SchedThread.Thread.ActualPriority)
{ {
HighestPriority = SchedThread.Thread.Priority; HighestPriority = SchedThread.Thread.ActualPriority;
HighestPrioIndex = Index; HighestPrioIndex = Index;
} }
@ -194,45 +201,66 @@ namespace Ryujinx.Core.OsHle.Handles
throw new InvalidOperationException(); throw new InvalidOperationException();
} }
lock (SchedLock)
{
bool OldState = SchedThread.Active;
SchedThread.Active = Active; SchedThread.Active = Active;
if (!OldState && Active) UpdateSyncWaitEvent(SchedThread);
WaitIfNeeded(SchedThread);
}
public bool EnterWait(KThread Thread, int Timeout = -1)
{ {
if (ActiveProcessors.Add(Thread.ProcessorId)) if (!AllThreads.TryGetValue(Thread, out SchedulerThread SchedThread))
{ {
RunThread(SchedThread); throw new InvalidOperationException();
}
SchedThread.SyncTimeout = Timeout;
UpdateSyncWaitEvent(SchedThread);
return WaitIfNeeded(SchedThread);
}
public void WakeUp(KThread Thread)
{
if (!AllThreads.TryGetValue(Thread, out SchedulerThread SchedThread))
{
throw new InvalidOperationException();
}
SchedThread.SyncTimeout = 0;
UpdateSyncWaitEvent(SchedThread);
WaitIfNeeded(SchedThread);
}
private void UpdateSyncWaitEvent(SchedulerThread SchedThread)
{
if (SchedThread.Active && SchedThread.SyncTimeout == 0)
{
SchedThread.SyncWaitEvent.Set();
} }
else else
{ {
WaitingToRun[Thread.ProcessorId].Push(SchedThread); SchedThread.SyncWaitEvent.Reset();
}
}
PrintDbgThreadInfo(Thread, "entering wait state..."); private bool WaitIfNeeded(SchedulerThread SchedThread)
}
}
else if (OldState && !Active)
{ {
if (Thread.Thread.IsCurrentThread()) KThread Thread = SchedThread.Thread;
if (!IsActive(SchedThread) && Thread.Thread.IsCurrentThread())
{ {
Suspend(Thread.ProcessorId); Suspend(Thread.ProcessorId);
PrintDbgThreadInfo(Thread, "entering inactive wait state..."); return Resume(Thread);
} }
else else
{ {
WaitingToRun[Thread.ProcessorId].Remove(SchedThread); return false;
}
}
}
if (!Active && Thread.Thread.IsCurrentThread())
{
SchedThread.WaitEvent.WaitOne();
PrintDbgThreadInfo(Thread, "resuming execution...");
} }
} }
@ -261,66 +289,78 @@ namespace Ryujinx.Core.OsHle.Handles
lock (SchedLock) lock (SchedLock)
{ {
SchedulerThread SchedThread = WaitingToRun[Thread.ProcessorId].Pop(Thread.Priority); SchedulerThread SchedThread = WaitingToRun[Thread.ProcessorId].Pop(Thread.ActualPriority);
if (SchedThread == null) if (IsActive(Thread) && SchedThread == null)
{ {
PrintDbgThreadInfo(Thread, "resumed because theres nothing better to run."); PrintDbgThreadInfo(Thread, "resumed because theres nothing better to run.");
return; return;
} }
if (SchedThread != null)
{
RunThread(SchedThread); RunThread(SchedThread);
} }
}
Resume(Thread); Resume(Thread);
} }
public void Resume(KThread Thread) public bool Resume(KThread Thread)
{ {
if (!AllThreads.TryGetValue(Thread, out SchedulerThread SchedThread)) if (!AllThreads.TryGetValue(Thread, out SchedulerThread SchedThread))
{ {
throw new InvalidOperationException(); throw new InvalidOperationException();
} }
TryResumingExecution(SchedThread); return TryResumingExecution(SchedThread);
} }
private void TryResumingExecution(SchedulerThread SchedThread) private bool TryResumingExecution(SchedulerThread SchedThread)
{ {
KThread Thread = SchedThread.Thread; KThread Thread = SchedThread.Thread;
if (SchedThread.Active) if (!SchedThread.Active || SchedThread.SyncTimeout != 0)
{ {
PrintDbgThreadInfo(Thread, "entering inactive wait state...");
}
bool Result = false;
if (SchedThread.SyncTimeout != 0)
{
Result = SchedThread.SyncWaitEvent.WaitOne(SchedThread.SyncTimeout);
SchedThread.SyncTimeout = 0;
}
lock (SchedLock) lock (SchedLock)
{ {
if (ActiveProcessors.Add(Thread.ProcessorId)) if (ActiveProcessors.Add(Thread.ProcessorId))
{ {
PrintDbgThreadInfo(Thread, "resuming execution..."); PrintDbgThreadInfo(Thread, "resuming execution...");
return; return Result;
} }
WaitingToRun[Thread.ProcessorId].Push(SchedThread); WaitingToRun[Thread.ProcessorId].Push(SchedThread);
PrintDbgThreadInfo(Thread, "entering wait state..."); PrintDbgThreadInfo(Thread, "entering wait state...");
} }
}
else
{
PrintDbgThreadInfo(Thread, "entering inactive wait state...");
}
SchedThread.WaitEvent.WaitOne(); SchedThread.SchedWaitEvent.WaitOne();
PrintDbgThreadInfo(Thread, "resuming execution..."); PrintDbgThreadInfo(Thread, "resuming execution...");
return Result;
} }
private void RunThread(SchedulerThread SchedThread) private void RunThread(SchedulerThread SchedThread)
{ {
if (!SchedThread.Thread.Thread.Execute()) if (!SchedThread.Thread.Thread.Execute())
{ {
SchedThread.WaitEvent.Set(); SchedThread.SchedWaitEvent.Set();
} }
else else
{ {
@ -328,12 +368,28 @@ namespace Ryujinx.Core.OsHle.Handles
} }
} }
private bool IsActive(KThread Thread)
{
if (!AllThreads.TryGetValue(Thread, out SchedulerThread SchedThread))
{
throw new InvalidOperationException();
}
return IsActive(SchedThread);
}
private bool IsActive(SchedulerThread SchedThread)
{
return SchedThread.Active && SchedThread.SyncTimeout == 0;
}
private void PrintDbgThreadInfo(KThread Thread, string Message) private void PrintDbgThreadInfo(KThread Thread, string Message)
{ {
Logging.Debug(LogClass.KernelScheduler, "(" + Logging.Debug(LogClass.KernelScheduler, "(" +
"ThreadId: " + Thread.ThreadId + ", " + "ThreadId: " + Thread.ThreadId + ", " +
"ProcessorId: " + Thread.ProcessorId + ", " + "ProcessorId: " + Thread.ProcessorId + ", " +
"Priority: " + Thread.Priority + ") " + Message); "ActualPriority: " + Thread.ActualPriority + ", " +
"WantedPriority: " + Thread.WantedPriority + ") " + Message);
} }
public void Dispose() public void Dispose()

View File

@ -1,4 +1,5 @@
using ChocolArm64; using ChocolArm64;
using System;
namespace Ryujinx.Core.OsHle.Handles namespace Ryujinx.Core.OsHle.Handles
{ {
@ -6,10 +7,20 @@ namespace Ryujinx.Core.OsHle.Handles
{ {
public AThread Thread { get; private set; } public AThread Thread { get; private set; }
public KThread MutexOwner { get; set; }
public KThread NextMutexThread { get; set; }
public KThread NextCondVarThread { get; set; }
public long MutexAddress { get; set; }
public long CondVarAddress { get; set; }
public int ActualPriority { get; private set; }
public int WantedPriority { get; private set; }
public int ProcessorId { get; private set; } public int ProcessorId { get; private set; }
public int Priority { get; set; } public int WaitHandle { get; set; }
public int Handle { get; set; }
public int ThreadId => Thread.ThreadId; public int ThreadId => Thread.ThreadId;
@ -17,7 +28,86 @@ namespace Ryujinx.Core.OsHle.Handles
{ {
this.Thread = Thread; this.Thread = Thread;
this.ProcessorId = ProcessorId; this.ProcessorId = ProcessorId;
this.Priority = Priority;
ActualPriority = WantedPriority = Priority;
}
public void SetPriority(int Priority)
{
WantedPriority = Priority;
UpdatePriority();
}
public void UpdatePriority()
{
int OldPriority = ActualPriority;
int CurrPriority = WantedPriority;
if (NextMutexThread != null && CurrPriority > NextMutexThread.WantedPriority)
{
CurrPriority = NextMutexThread.WantedPriority;
}
if (CurrPriority != OldPriority)
{
ActualPriority = CurrPriority;
UpdateWaitList();
MutexOwner?.UpdatePriority();
}
}
private void UpdateWaitList()
{
KThread OwnerThread = MutexOwner;
if (OwnerThread != null)
{
//The MutexOwner field should only be non null when the thread is
//waiting for the lock, and the lock belongs to another thread.
if (OwnerThread == this)
{
throw new InvalidOperationException();
}
lock (OwnerThread)
{
//Remove itself from the list.
KThread CurrThread = OwnerThread;
while (CurrThread.NextMutexThread != null)
{
if (CurrThread.NextMutexThread == this)
{
CurrThread.NextMutexThread = NextMutexThread;
break;
}
CurrThread = CurrThread.NextMutexThread;
}
//Re-add taking new priority into account.
CurrThread = OwnerThread;
while (CurrThread.NextMutexThread != null)
{
if (CurrThread.NextMutexThread.ActualPriority < ActualPriority)
{
break;
}
CurrThread = CurrThread.NextMutexThread;
}
NextMutexThread = CurrThread.NextMutexThread;
CurrThread.NextMutexThread = this;
}
}
} }
} }
} }

View File

@ -1,148 +0,0 @@
using Ryujinx.Core.OsHle.Handles;
using System.Collections.Generic;
using System.Threading;
namespace Ryujinx.Core.OsHle.Kernel
{
class ConditionVariable
{
private Process Process;
private long CondVarAddress;
private bool OwnsCondVarValue;
private List<(KThread Thread, AutoResetEvent WaitEvent)> WaitingThreads;
public ConditionVariable(Process Process, long CondVarAddress)
{
this.Process = Process;
this.CondVarAddress = CondVarAddress;
WaitingThreads = new List<(KThread, AutoResetEvent)>();
}
public bool WaitForSignal(KThread Thread, ulong Timeout)
{
bool Result = true;
int Count = Process.Memory.ReadInt32(CondVarAddress);
if (Count <= 0)
{
using (AutoResetEvent WaitEvent = new AutoResetEvent(false))
{
lock (WaitingThreads)
{
WaitingThreads.Add((Thread, WaitEvent));
}
if (Timeout == ulong.MaxValue)
{
Result = WaitEvent.WaitOne();
}
else
{
Result = WaitEvent.WaitOne(NsTimeConverter.GetTimeMs(Timeout));
lock (WaitingThreads)
{
WaitingThreads.Remove((Thread, WaitEvent));
}
}
}
}
AcquireCondVarValue();
Count = Process.Memory.ReadInt32(CondVarAddress);
if (Result && Count > 0)
{
Process.Memory.WriteInt32(CondVarAddress, Count - 1);
}
ReleaseCondVarValue();
return Result;
}
public void SetSignal(KThread Thread, int Count)
{
lock (WaitingThreads)
{
if (Count < 0)
{
foreach ((_, AutoResetEvent WaitEvent) in WaitingThreads)
{
IncrementCondVarValue();
WaitEvent.Set();
}
WaitingThreads.Clear();
}
else
{
while (WaitingThreads.Count > 0 && Count-- > 0)
{
int HighestPriority = WaitingThreads[0].Thread.Priority;
int HighestPrioIndex = 0;
for (int Index = 1; Index < WaitingThreads.Count; Index++)
{
if (HighestPriority > WaitingThreads[Index].Thread.Priority)
{
HighestPriority = WaitingThreads[Index].Thread.Priority;
HighestPrioIndex = Index;
}
}
IncrementCondVarValue();
WaitingThreads[HighestPrioIndex].WaitEvent.Set();
WaitingThreads.RemoveAt(HighestPrioIndex);
}
}
}
Process.Scheduler.Yield(Thread);
}
private void IncrementCondVarValue()
{
AcquireCondVarValue();
int Count = Process.Memory.ReadInt32(CondVarAddress);
Process.Memory.WriteInt32(CondVarAddress, Count + 1);
ReleaseCondVarValue();
}
private void AcquireCondVarValue()
{
if (!OwnsCondVarValue)
{
while (!Process.Memory.AcquireAddress(CondVarAddress))
{
Thread.Yield();
}
OwnsCondVarValue = true;
}
}
private void ReleaseCondVarValue()
{
if (OwnsCondVarValue)
{
OwnsCondVarValue = false;
Process.Memory.ReleaseAddress(CondVarAddress);
}
}
}
}

View File

@ -2,6 +2,8 @@ namespace Ryujinx.Core.OsHle.Kernel
{ {
static class KernelErr static class KernelErr
{ {
public const int InvalidAlignment = 102;
public const int InvalidAddress = 106;
public const int InvalidMemRange = 110; public const int InvalidMemRange = 110;
public const int InvalidHandle = 114; public const int InvalidHandle = 114;
public const int Timeout = 117; public const int Timeout = 117;

View File

@ -1,95 +0,0 @@
using Ryujinx.Core.OsHle.Handles;
using System.Collections.Generic;
using System.Threading;
namespace Ryujinx.Core.OsHle.Kernel
{
class MutualExclusion
{
private const int MutexHasListenersMask = 0x40000000;
private Process Process;
private long MutexAddress;
private int OwnerThreadHandle;
private List<(KThread Thread, AutoResetEvent WaitEvent)> WaitingThreads;
public MutualExclusion(Process Process, long MutexAddress)
{
this.Process = Process;
this.MutexAddress = MutexAddress;
WaitingThreads = new List<(KThread, AutoResetEvent)>();
}
public void WaitForLock(KThread RequestingThread)
{
WaitForLock(RequestingThread, OwnerThreadHandle);
}
public void WaitForLock(KThread RequestingThread, int OwnerThreadHandle)
{
if (OwnerThreadHandle == RequestingThread.Handle ||
OwnerThreadHandle == 0)
{
return;
}
using (AutoResetEvent WaitEvent = new AutoResetEvent(false))
{
lock (WaitingThreads)
{
WaitingThreads.Add((RequestingThread, WaitEvent));
}
Process.Scheduler.Suspend(RequestingThread.ProcessorId);
WaitEvent.WaitOne();
Process.Scheduler.Resume(RequestingThread);
}
}
public void Unlock()
{
lock (WaitingThreads)
{
int HasListeners = WaitingThreads.Count > 1 ? MutexHasListenersMask : 0;
if (WaitingThreads.Count > 0)
{
int HighestPriority = WaitingThreads[0].Thread.Priority;
int HighestPrioIndex = 0;
for (int Index = 1; Index < WaitingThreads.Count; Index++)
{
if (HighestPriority > WaitingThreads[Index].Thread.Priority)
{
HighestPriority = WaitingThreads[Index].Thread.Priority;
HighestPrioIndex = Index;
}
}
int Handle = WaitingThreads[HighestPrioIndex].Thread.Handle;
WaitingThreads[HighestPrioIndex].WaitEvent.Set();
WaitingThreads.RemoveAt(HighestPrioIndex);
Process.Memory.WriteInt32(MutexAddress, HasListeners | Handle);
OwnerThreadHandle = Handle;
}
else
{
Process.Memory.WriteInt32(MutexAddress, 0);
OwnerThreadHandle = 0;
}
}
}
}
}

View File

@ -3,7 +3,6 @@ using ChocolArm64.Memory;
using ChocolArm64.State; using ChocolArm64.State;
using Ryujinx.Core.OsHle.Handles; using Ryujinx.Core.OsHle.Handles;
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
namespace Ryujinx.Core.OsHle.Kernel namespace Ryujinx.Core.OsHle.Kernel
@ -18,8 +17,7 @@ namespace Ryujinx.Core.OsHle.Kernel
private Process Process; private Process Process;
private AMemory Memory; private AMemory Memory;
private ConcurrentDictionary<long, MutualExclusion> Mutexes; private object CondVarLock;
private ConcurrentDictionary<long, ConditionVariable> CondVars;
private HashSet<(HSharedMem, long)> MappedSharedMems; private HashSet<(HSharedMem, long)> MappedSharedMems;
@ -71,8 +69,7 @@ namespace Ryujinx.Core.OsHle.Kernel
this.Process = Process; this.Process = Process;
this.Memory = Process.Memory; this.Memory = Process.Memory;
Mutexes = new ConcurrentDictionary<long, MutualExclusion>(); CondVarLock = new object();
CondVars = new ConcurrentDictionary<long, ConditionVariable>();
MappedSharedMems = new HashSet<(HSharedMem, long)>(); MappedSharedMems = new HashSet<(HSharedMem, long)>();
} }

View File

@ -91,7 +91,7 @@ namespace Ryujinx.Core.OsHle.Kernel
if (CurrThread != null) if (CurrThread != null)
{ {
ThreadState.X0 = 0; ThreadState.X0 = 0;
ThreadState.X1 = (ulong)CurrThread.Priority; ThreadState.X1 = (ulong)CurrThread.ActualPriority;
} }
else else
{ {
@ -110,7 +110,7 @@ namespace Ryujinx.Core.OsHle.Kernel
if (CurrThread != null) if (CurrThread != null)
{ {
CurrThread.Priority = Priority; CurrThread.SetPriority(Priority);
ThreadState.X0 = 0; ThreadState.X0 = 0;
} }

View File

@ -1,5 +1,7 @@
using ChocolArm64.State; using ChocolArm64.State;
using Ryujinx.Core.OsHle.Handles; using Ryujinx.Core.OsHle.Handles;
using System;
using System.Threading;
using static Ryujinx.Core.OsHle.ErrorCode; using static Ryujinx.Core.OsHle.ErrorCode;
@ -7,11 +9,31 @@ namespace Ryujinx.Core.OsHle.Kernel
{ {
partial class SvcHandler partial class SvcHandler
{ {
private const int MutexHasListenersMask = 0x40000000;
private void SvcArbitrateLock(AThreadState ThreadState) private void SvcArbitrateLock(AThreadState ThreadState)
{ {
int OwnerThreadHandle = (int)ThreadState.X0; int OwnerThreadHandle = (int)ThreadState.X0;
long MutexAddress = (long)ThreadState.X1; long MutexAddress = (long)ThreadState.X1;
int RequestingThreadHandle = (int)ThreadState.X2; int WaitThreadHandle = (int)ThreadState.X2;
if (IsPointingInsideKernel(MutexAddress))
{
Logging.Warn(LogClass.KernelSvc, $"Invalid mutex address 0x{MutexAddress:x16}!");
ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAddress);
return;
}
if (IsWordAddressUnaligned(MutexAddress))
{
Logging.Warn(LogClass.KernelSvc, $"Unaligned mutex address 0x{MutexAddress:x16}!");
ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAlignment);
return;
}
KThread OwnerThread = Process.HandleTable.GetData<KThread>(OwnerThreadHandle); KThread OwnerThread = Process.HandleTable.GetData<KThread>(OwnerThreadHandle);
@ -24,20 +46,20 @@ namespace Ryujinx.Core.OsHle.Kernel
return; return;
} }
KThread RequestingThread = Process.HandleTable.GetData<KThread>(RequestingThreadHandle); KThread WaitThread = Process.HandleTable.GetData<KThread>(WaitThreadHandle);
if (RequestingThread == null) if (WaitThread == null)
{ {
Logging.Warn(LogClass.KernelSvc, $"Invalid requesting thread handle 0x{RequestingThreadHandle:x8}!"); Logging.Warn(LogClass.KernelSvc, $"Invalid requesting thread handle 0x{WaitThreadHandle:x8}!");
ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidHandle); ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidHandle);
return; return;
} }
MutualExclusion Mutex = GetMutex(MutexAddress); KThread CurrThread = Process.GetThread(ThreadState.Tpidr);
Mutex.WaitForLock(RequestingThread, OwnerThreadHandle); MutexLock(CurrThread, WaitThread, OwnerThreadHandle, WaitThreadHandle, MutexAddress);
ThreadState.X0 = 0; ThreadState.X0 = 0;
} }
@ -46,9 +68,28 @@ namespace Ryujinx.Core.OsHle.Kernel
{ {
long MutexAddress = (long)ThreadState.X0; long MutexAddress = (long)ThreadState.X0;
GetMutex(MutexAddress).Unlock(); if (IsPointingInsideKernel(MutexAddress))
{
Logging.Warn(LogClass.KernelSvc, $"Invalid mutex address 0x{MutexAddress:x16}!");
ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAddress);
return;
}
if (IsWordAddressUnaligned(MutexAddress))
{
Logging.Warn(LogClass.KernelSvc, $"Unaligned mutex address 0x{MutexAddress:x16}!");
ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAlignment);
return;
}
if (MutexUnlock(Process.GetThread(ThreadState.Tpidr), MutexAddress))
{
Process.Scheduler.Yield(Process.GetThread(ThreadState.Tpidr)); Process.Scheduler.Yield(Process.GetThread(ThreadState.Tpidr));
}
ThreadState.X0 = 0; ThreadState.X0 = 0;
} }
@ -60,6 +101,24 @@ namespace Ryujinx.Core.OsHle.Kernel
int ThreadHandle = (int)ThreadState.X2; int ThreadHandle = (int)ThreadState.X2;
ulong Timeout = ThreadState.X3; ulong Timeout = ThreadState.X3;
if (IsPointingInsideKernel(MutexAddress))
{
Logging.Warn(LogClass.KernelSvc, $"Invalid mutex address 0x{MutexAddress:x16}!");
ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAddress);
return;
}
if (IsWordAddressUnaligned(MutexAddress))
{
Logging.Warn(LogClass.KernelSvc, $"Unaligned mutex address 0x{MutexAddress:x16}!");
ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidAlignment);
return;
}
KThread Thread = Process.HandleTable.GetData<KThread>(ThreadHandle); KThread Thread = Process.HandleTable.GetData<KThread>(ThreadHandle);
if (Thread == null) if (Thread == null)
@ -67,24 +126,22 @@ namespace Ryujinx.Core.OsHle.Kernel
Logging.Warn(LogClass.KernelSvc, $"Invalid thread handle 0x{ThreadHandle:x8}!"); Logging.Warn(LogClass.KernelSvc, $"Invalid thread handle 0x{ThreadHandle:x8}!");
ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidHandle); ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.InvalidHandle);
return;
} }
Process.Scheduler.Suspend(Thread.ProcessorId); KThread CurrThread = Process.GetThread(ThreadState.Tpidr);
MutualExclusion Mutex = GetMutex(MutexAddress); MutexUnlock(CurrThread, MutexAddress);
Mutex.Unlock(); if (!CondVarWait(CurrThread, ThreadHandle, MutexAddress, CondVarAddress, Timeout))
if (!GetCondVar(CondVarAddress).WaitForSignal(Thread, Timeout))
{ {
ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.Timeout); ThreadState.X0 = MakeError(ErrorModule.Kernel, KernelErr.Timeout);
return; return;
} }
Mutex.WaitForLock(Thread); Process.Scheduler.Yield(Process.GetThread(ThreadState.Tpidr));
Process.Scheduler.Resume(Thread);
ThreadState.X0 = 0; ThreadState.X0 = 0;
} }
@ -94,31 +151,274 @@ namespace Ryujinx.Core.OsHle.Kernel
long CondVarAddress = (long)ThreadState.X0; long CondVarAddress = (long)ThreadState.X0;
int Count = (int)ThreadState.X1; int Count = (int)ThreadState.X1;
KThread CurrThread = Process.GetThread(ThreadState.Tpidr); CondVarSignal(CondVarAddress, Count);
GetCondVar(CondVarAddress).SetSignal(CurrThread, Count);
ThreadState.X0 = 0; ThreadState.X0 = 0;
} }
private MutualExclusion GetMutex(long MutexAddress) private void MutexLock(
KThread CurrThread,
KThread WaitThread,
int OwnerThreadHandle,
int WaitThreadHandle,
long MutexAddress)
{ {
MutualExclusion MutexFactory(long Key) int MutexValue = Process.Memory.ReadInt32(MutexAddress);
if (MutexValue != (OwnerThreadHandle | MutexHasListenersMask))
{ {
return new MutualExclusion(Process, MutexAddress); return;
} }
return Mutexes.GetOrAdd(MutexAddress, MutexFactory); CurrThread.WaitHandle = WaitThreadHandle;
CurrThread.MutexAddress = MutexAddress;
InsertWaitingMutexThread(OwnerThreadHandle, WaitThread);
Process.Scheduler.EnterWait(WaitThread);
} }
private ConditionVariable GetCondVar(long CondVarAddress) private bool MutexUnlock(KThread CurrThread, long MutexAddress)
{ {
ConditionVariable CondVarFactory(long Key) if (CurrThread == null)
{ {
return new ConditionVariable(Process, CondVarAddress); Logging.Warn(LogClass.KernelSvc, $"Invalid mutex 0x{MutexAddress:x16}!");
return false;
} }
return CondVars.GetOrAdd(CondVarAddress, CondVarFactory); lock (CurrThread)
{
//This is the new thread that will not own the mutex.
//If no threads are waiting for the lock, then it should be null.
KThread OwnerThread = CurrThread.NextMutexThread;
while (OwnerThread != null && OwnerThread.MutexAddress != MutexAddress)
{
OwnerThread = OwnerThread.NextMutexThread;
}
CurrThread.NextMutexThread = null;
if (OwnerThread != null)
{
int HasListeners = OwnerThread.NextMutexThread != null ? MutexHasListenersMask : 0;
Process.Memory.WriteInt32(MutexAddress, HasListeners | OwnerThread.WaitHandle);
OwnerThread.WaitHandle = 0;
OwnerThread.MutexAddress = 0;
OwnerThread.CondVarAddress = 0;
OwnerThread.MutexOwner = null;
OwnerThread.UpdatePriority();
Process.Scheduler.WakeUp(OwnerThread);
return true;
}
else
{
Process.Memory.WriteInt32(MutexAddress, 0);
return false;
}
}
}
private bool CondVarWait(
KThread WaitThread,
int WaitThreadHandle,
long MutexAddress,
long CondVarAddress,
ulong Timeout)
{
WaitThread.WaitHandle = WaitThreadHandle;
WaitThread.MutexAddress = MutexAddress;
WaitThread.CondVarAddress = CondVarAddress;
lock (CondVarLock)
{
KThread CurrThread = Process.ThreadArbiterList;
if (CurrThread != null)
{
bool DoInsert = CurrThread != WaitThread;
while (CurrThread.NextCondVarThread != null)
{
if (CurrThread.NextCondVarThread.ActualPriority < WaitThread.ActualPriority)
{
break;
}
CurrThread = CurrThread.NextCondVarThread;
DoInsert &= CurrThread != WaitThread;
}
//Only insert if the node doesn't already exist in the list.
//This prevents circular references.
if (DoInsert)
{
if (WaitThread.NextCondVarThread != null)
{
throw new InvalidOperationException();
}
WaitThread.NextCondVarThread = CurrThread.NextCondVarThread;
CurrThread.NextCondVarThread = WaitThread;
}
}
else
{
Process.ThreadArbiterList = WaitThread;
}
}
if (Timeout != ulong.MaxValue)
{
return Process.Scheduler.EnterWait(WaitThread, NsTimeConverter.GetTimeMs(Timeout));
}
else
{
return Process.Scheduler.EnterWait(WaitThread);
}
}
private void CondVarSignal(long CondVarAddress, int Count)
{
lock (CondVarLock)
{
KThread PrevThread = null;
KThread CurrThread = Process.ThreadArbiterList;
while (CurrThread != null && (Count == -1 || Count > 0))
{
if (CurrThread.CondVarAddress == CondVarAddress)
{
if (PrevThread != null)
{
PrevThread.NextCondVarThread = CurrThread.NextCondVarThread;
}
else
{
Process.ThreadArbiterList = CurrThread.NextCondVarThread;
}
CurrThread.NextCondVarThread = null;
AcquireMutexValue(CurrThread.MutexAddress);
int MutexValue = Process.Memory.ReadInt32(CurrThread.MutexAddress);
MutexValue &= ~MutexHasListenersMask;
if (MutexValue == 0)
{
//Give the lock to this thread.
Process.Memory.WriteInt32(CurrThread.MutexAddress, CurrThread.WaitHandle);
CurrThread.WaitHandle = 0;
CurrThread.MutexAddress = 0;
CurrThread.CondVarAddress = 0;
CurrThread.MutexOwner = null;
CurrThread.UpdatePriority();
Process.Scheduler.WakeUp(CurrThread);
}
else
{
//Wait until the lock is released.
InsertWaitingMutexThread(MutexValue, CurrThread);
MutexValue |= MutexHasListenersMask;
Process.Memory.WriteInt32(CurrThread.MutexAddress, MutexValue);
}
ReleaseMutexValue(CurrThread.MutexAddress);
Count--;
}
PrevThread = CurrThread;
CurrThread = CurrThread.NextCondVarThread;
}
}
}
private void InsertWaitingMutexThread(int OwnerThreadHandle, KThread WaitThread)
{
KThread OwnerThread = Process.HandleTable.GetData<KThread>(OwnerThreadHandle);
if (OwnerThread == null)
{
Logging.Warn(LogClass.KernelSvc, $"Invalid thread handle 0x{OwnerThreadHandle:x8}!");
return;
}
WaitThread.MutexOwner = OwnerThread;
lock (OwnerThread)
{
KThread CurrThread = OwnerThread;
while (CurrThread.NextMutexThread != null)
{
if (CurrThread == WaitThread)
{
return;
}
if (CurrThread.NextMutexThread.ActualPriority < WaitThread.ActualPriority)
{
break;
}
CurrThread = CurrThread.NextMutexThread;
}
if (CurrThread != WaitThread)
{
if (WaitThread.NextCondVarThread != null)
{
throw new InvalidOperationException();
}
WaitThread.NextMutexThread = CurrThread.NextMutexThread;
CurrThread.NextMutexThread = WaitThread;
}
}
OwnerThread.UpdatePriority();
}
private void AcquireMutexValue(long MutexAddress)
{
while (!Process.Memory.AcquireAddress(MutexAddress))
{
Thread.Yield();
}
}
private void ReleaseMutexValue(long MutexAddress)
{
Process.Memory.ReleaseAddress(MutexAddress);
}
private bool IsPointingInsideKernel(long Address)
{
return ((ulong)Address + 0x1000000000) < 0xffffff000;
}
private bool IsWordAddressUnaligned(long Address)
{
return (Address & 3) != 0;
} }
} }
} }

View File

@ -35,6 +35,8 @@ namespace Ryujinx.Core.OsHle
public KProcessScheduler Scheduler { get; private set; } public KProcessScheduler Scheduler { get; private set; }
public KThread ThreadArbiterList { get; set; }
public KProcessHandleTable HandleTable { get; private set; } public KProcessHandleTable HandleTable { get; private set; }
public AppletStateMgr AppletState { get; private set; } public AppletStateMgr AppletState { get; private set; }
@ -43,7 +45,7 @@ namespace Ryujinx.Core.OsHle
private ConcurrentDictionary<int, AThread> TlsSlots; private ConcurrentDictionary<int, AThread> TlsSlots;
private ConcurrentDictionary<long, KThread> ThreadsByTpidr; private ConcurrentDictionary<long, KThread> Threads;
private List<Executable> Executables; private List<Executable> Executables;
@ -71,7 +73,7 @@ namespace Ryujinx.Core.OsHle
TlsSlots = new ConcurrentDictionary<int, AThread>(); TlsSlots = new ConcurrentDictionary<int, AThread>();
ThreadsByTpidr = new ConcurrentDictionary<long, KThread>(); Threads = new ConcurrentDictionary<long, KThread>();
Executables = new List<Executable>(); Executables = new List<Executable>();
@ -185,34 +187,32 @@ namespace Ryujinx.Core.OsHle
throw new ObjectDisposedException(nameof(Process)); throw new ObjectDisposedException(nameof(Process));
} }
AThread Thread = new AThread(GetTranslator(), Memory, EntryPoint); AThread CpuThread = new AThread(GetTranslator(), Memory, EntryPoint);
KThread KernelThread = new KThread(Thread, ProcessorId, Priority); KThread Thread = new KThread(CpuThread, ProcessorId, Priority);
int Handle = HandleTable.OpenHandle(KernelThread); int Handle = HandleTable.OpenHandle(Thread);
KernelThread.Handle = Handle; int ThreadId = GetFreeTlsSlot(CpuThread);
int ThreadId = GetFreeTlsSlot(Thread);
long Tpidr = MemoryRegions.TlsPagesAddress + ThreadId * TlsSize; long Tpidr = MemoryRegions.TlsPagesAddress + ThreadId * TlsSize;
Thread.ThreadState.ProcessId = ProcessId; CpuThread.ThreadState.ProcessId = ProcessId;
Thread.ThreadState.ThreadId = ThreadId; CpuThread.ThreadState.ThreadId = ThreadId;
Thread.ThreadState.CntfrqEl0 = TickFreq; CpuThread.ThreadState.CntfrqEl0 = TickFreq;
Thread.ThreadState.Tpidr = Tpidr; CpuThread.ThreadState.Tpidr = Tpidr;
Thread.ThreadState.X0 = (ulong)ArgsPtr; CpuThread.ThreadState.X0 = (ulong)ArgsPtr;
Thread.ThreadState.X1 = (ulong)Handle; CpuThread.ThreadState.X1 = (ulong)Handle;
Thread.ThreadState.X31 = (ulong)StackTop; CpuThread.ThreadState.X31 = (ulong)StackTop;
Thread.ThreadState.Break += BreakHandler; CpuThread.ThreadState.Break += BreakHandler;
Thread.ThreadState.SvcCall += SvcHandler.SvcCall; CpuThread.ThreadState.SvcCall += SvcHandler.SvcCall;
Thread.ThreadState.Undefined += UndefinedHandler; CpuThread.ThreadState.Undefined += UndefinedHandler;
Thread.WorkFinished += ThreadFinished; CpuThread.WorkFinished += ThreadFinished;
ThreadsByTpidr.TryAdd(Thread.ThreadState.Tpidr, KernelThread); Threads.TryAdd(CpuThread.ThreadState.Tpidr, Thread);
return Handle; return Handle;
} }
@ -324,7 +324,7 @@ namespace Ryujinx.Core.OsHle
public KThread GetThread(long Tpidr) public KThread GetThread(long Tpidr)
{ {
if (!ThreadsByTpidr.TryGetValue(Tpidr, out KThread Thread)) if (!Threads.TryGetValue(Tpidr, out KThread Thread))
{ {
Logging.Error(LogClass.KernelScheduler, $"Thread with TPIDR 0x{Tpidr:x16} not found!"); Logging.Error(LogClass.KernelScheduler, $"Thread with TPIDR 0x{Tpidr:x16} not found!");
} }