diff --git a/src/Ryujinx.Graphics.GAL/Capabilities.cs b/src/Ryujinx.Graphics.GAL/Capabilities.cs index f4b1d4d101..d41f8e59fd 100644 --- a/src/Ryujinx.Graphics.GAL/Capabilities.cs +++ b/src/Ryujinx.Graphics.GAL/Capabilities.cs @@ -52,6 +52,7 @@ namespace Ryujinx.Graphics.GAL public readonly int MaximumComputeSharedMemorySize; public readonly float MaximumSupportedAnisotropy; + public readonly int ShaderSubgroupSize; public readonly int StorageBufferOffsetAlignment; public readonly int GatherBiasPrecision; @@ -101,6 +102,7 @@ namespace Ryujinx.Graphics.GAL uint maximumImagesPerStage, int maximumComputeSharedMemorySize, float maximumSupportedAnisotropy, + int shaderSubgroupSize, int storageBufferOffsetAlignment, int gatherBiasPrecision) { @@ -148,6 +150,7 @@ namespace Ryujinx.Graphics.GAL MaximumImagesPerStage = maximumImagesPerStage; MaximumComputeSharedMemorySize = maximumComputeSharedMemorySize; MaximumSupportedAnisotropy = maximumSupportedAnisotropy; + ShaderSubgroupSize = shaderSubgroupSize; StorageBufferOffsetAlignment = storageBufferOffsetAlignment; GatherBiasPrecision = gatherBiasPrecision; } diff --git a/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs b/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs index 9afc5b6180..71a738255b 100644 --- a/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs +++ b/src/Ryujinx.Graphics.Gpu/Shader/DiskCache/DiskCacheHostStorage.cs @@ -22,7 +22,7 @@ namespace Ryujinx.Graphics.Gpu.Shader.DiskCache private const ushort FileFormatVersionMajor = 1; private const ushort FileFormatVersionMinor = 2; private const uint FileFormatVersionPacked = ((uint)FileFormatVersionMajor << 16) | FileFormatVersionMinor; - private const uint CodeGenVersion = 5576; + private const uint CodeGenVersion = 5540; private const string SharedTocFileName = "shared.toc"; private const string SharedDataFileName = "shared.data"; diff --git a/src/Ryujinx.Graphics.Gpu/Shader/GpuAccessorBase.cs b/src/Ryujinx.Graphics.Gpu/Shader/GpuAccessorBase.cs index e7a2d345ff..52193940b5 100644 --- a/src/Ryujinx.Graphics.Gpu/Shader/GpuAccessorBase.cs +++ b/src/Ryujinx.Graphics.Gpu/Shader/GpuAccessorBase.cs @@ -137,6 +137,8 @@ namespace Ryujinx.Graphics.Gpu.Shader public int QueryHostStorageBufferOffsetAlignment() => _context.Capabilities.StorageBufferOffsetAlignment; + public int QueryHostSubgroupSize() => _context.Capabilities.ShaderSubgroupSize; + public bool QueryHostSupportsBgraFormat() => _context.Capabilities.SupportsBgraFormat; public bool QueryHostSupportsFragmentShaderInterlock() => _context.Capabilities.SupportsFragmentShaderInterlock; diff --git a/src/Ryujinx.Graphics.OpenGL/Constants.cs b/src/Ryujinx.Graphics.OpenGL/Constants.cs index 8817011a97..38fedea0d5 100644 --- a/src/Ryujinx.Graphics.OpenGL/Constants.cs +++ b/src/Ryujinx.Graphics.OpenGL/Constants.cs @@ -7,5 +7,6 @@ public const int MaxVertexAttribs = 16; public const int MaxVertexBuffers = 16; public const int MaxTransformFeedbackBuffers = 4; + public const int MaxSubgroupSize = 64; } } diff --git a/src/Ryujinx.Graphics.OpenGL/OpenGLRenderer.cs b/src/Ryujinx.Graphics.OpenGL/OpenGLRenderer.cs index 8a7ac85595..35d1569fe7 100644 --- a/src/Ryujinx.Graphics.OpenGL/OpenGLRenderer.cs +++ b/src/Ryujinx.Graphics.OpenGL/OpenGLRenderer.cs @@ -175,6 +175,7 @@ namespace Ryujinx.Graphics.OpenGL maximumImagesPerStage: 8, maximumComputeSharedMemorySize: HwCapabilities.MaximumComputeSharedMemorySize, maximumSupportedAnisotropy: HwCapabilities.MaximumSupportedAnisotropy, + shaderSubgroupSize: Constants.MaxSubgroupSize, storageBufferOffsetAlignment: HwCapabilities.StorageBufferOffsetAlignment, gatherBiasPrecision: intelWindows || amdWindows ? 8 : 0); // Precision is 8 for these vendors on Vulkan. } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs index e181ae98d5..607ff431e4 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Declarations.cs @@ -25,6 +25,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl { context.AppendLine("#extension GL_KHR_shader_subgroup_basic : enable"); context.AppendLine("#extension GL_KHR_shader_subgroup_ballot : enable"); + context.AppendLine("#extension GL_KHR_shader_subgroup_shuffle : enable"); } context.AppendLine("#extension GL_ARB_shader_group_vote : enable"); @@ -201,26 +202,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/MultiplyHighU32.glsl"); } - if ((info.HelperFunctionsMask & HelperFunctionsMask.Shuffle) != 0) - { - AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/Shuffle.glsl"); - } - - if ((info.HelperFunctionsMask & HelperFunctionsMask.ShuffleDown) != 0) - { - AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleDown.glsl"); - } - - if ((info.HelperFunctionsMask & HelperFunctionsMask.ShuffleUp) != 0) - { - AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleUp.glsl"); - } - - if ((info.HelperFunctionsMask & HelperFunctionsMask.ShuffleXor) != 0) - { - AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleXor.glsl"); - } - if ((info.HelperFunctionsMask & HelperFunctionsMask.SwizzleAdd) != 0) { AppendHelperFunction(context, "Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/SwizzleAdd.glsl"); diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/HelperFunctionNames.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/HelperFunctionNames.cs index 2218027271..0b80ac2b6b 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/HelperFunctionNames.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/HelperFunctionNames.cs @@ -5,10 +5,6 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl public static string MultiplyHighS32 = "Helper_MultiplyHighS32"; public static string MultiplyHighU32 = "Helper_MultiplyHighU32"; - public static string Shuffle = "Helper_Shuffle"; - public static string ShuffleDown = "Helper_ShuffleDown"; - public static string ShuffleUp = "Helper_ShuffleUp"; - public static string ShuffleXor = "Helper_ShuffleXor"; public static string SwizzleAdd = "Helper_SwizzleAdd"; } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/Shuffle.glsl b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/Shuffle.glsl deleted file mode 100644 index 7cb4764dd4..0000000000 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/Shuffle.glsl +++ /dev/null @@ -1,11 +0,0 @@ -float Helper_Shuffle(float x, uint index, uint mask, out bool valid) -{ - uint clamp = mask & 0x1fu; - uint segMask = (mask >> 8) & 0x1fu; - uint minThreadId = $SUBGROUP_INVOCATION$ & segMask; - uint maxThreadId = minThreadId | (clamp & ~segMask); - uint srcThreadId = (index & ~segMask) | minThreadId; - valid = srcThreadId <= maxThreadId; - float v = $SUBGROUP_BROADCAST$(x, srcThreadId); - return valid ? v : x; -} \ No newline at end of file diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleDown.glsl b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleDown.glsl deleted file mode 100644 index 71d901d5d2..0000000000 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleDown.glsl +++ /dev/null @@ -1,11 +0,0 @@ -float Helper_ShuffleDown(float x, uint index, uint mask, out bool valid) -{ - uint clamp = mask & 0x1fu; - uint segMask = (mask >> 8) & 0x1fu; - uint minThreadId = $SUBGROUP_INVOCATION$ & segMask; - uint maxThreadId = minThreadId | (clamp & ~segMask); - uint srcThreadId = $SUBGROUP_INVOCATION$ + index; - valid = srcThreadId <= maxThreadId; - float v = $SUBGROUP_BROADCAST$(x, srcThreadId); - return valid ? v : x; -} \ No newline at end of file diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleUp.glsl b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleUp.glsl deleted file mode 100644 index ae264d8704..0000000000 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleUp.glsl +++ /dev/null @@ -1,9 +0,0 @@ -float Helper_ShuffleUp(float x, uint index, uint mask, out bool valid) -{ - uint segMask = (mask >> 8) & 0x1fu; - uint minThreadId = $SUBGROUP_INVOCATION$ & segMask; - uint srcThreadId = $SUBGROUP_INVOCATION$ - index; - valid = int(srcThreadId) >= int(minThreadId); - float v = $SUBGROUP_BROADCAST$(x, srcThreadId); - return valid ? v : x; -} \ No newline at end of file diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleXor.glsl b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleXor.glsl deleted file mode 100644 index 789089d69c..0000000000 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/HelperFunctions/ShuffleXor.glsl +++ /dev/null @@ -1,11 +0,0 @@ -float Helper_ShuffleXor(float x, uint index, uint mask, out bool valid) -{ - uint clamp = mask & 0x1fu; - uint segMask = (mask >> 8) & 0x1fu; - uint minThreadId = $SUBGROUP_INVOCATION$ & segMask; - uint maxThreadId = minThreadId | (clamp & ~segMask); - uint srcThreadId = $SUBGROUP_INVOCATION$ ^ index; - valid = srcThreadId <= maxThreadId; - float v = $SUBGROUP_BROADCAST$(x, srcThreadId); - return valid ? v : x; -} \ No newline at end of file diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs index 9208ceeadd..796eb4417a 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGen.cs @@ -9,6 +9,7 @@ using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenFSI; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenMemory; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenPacking; +using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenShuffle; using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenVector; using static Ryujinx.Graphics.Shader.StructuredIr.InstructionInfo; @@ -174,6 +175,9 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions case Instruction.PackHalf2x16: return PackHalf2x16(context, operation); + case Instruction.Shuffle: + return Shuffle(context, operation); + case Instruction.Store: return Store(context, operation); diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenBallot.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenBallot.cs index b44759c0dc..6cc7048bd7 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenBallot.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenBallot.cs @@ -13,14 +13,15 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions AggregateType dstType = GetSrcVarType(operation.Inst, 0); string arg = GetSoureExpr(context, operation.GetSource(0), dstType); + char component = "xyzw"[operation.Index]; if (context.HostCapabilities.SupportsShaderBallot) { - return $"unpackUint2x32(ballotARB({arg})).x"; + return $"unpackUint2x32(ballotARB({arg})).{component}"; } else { - return $"subgroupBallot({arg}).x"; + return $"subgroupBallot({arg}).{component}"; } } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs index c3d52b2c53..eb194c2097 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenHelper.cs @@ -108,10 +108,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3); Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3); Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3); - Add(Instruction.Shuffle, InstType.CallQuaternary, HelperFunctionNames.Shuffle); - Add(Instruction.ShuffleDown, InstType.CallQuaternary, HelperFunctionNames.ShuffleDown); - Add(Instruction.ShuffleUp, InstType.CallQuaternary, HelperFunctionNames.ShuffleUp); - Add(Instruction.ShuffleXor, InstType.CallQuaternary, HelperFunctionNames.ShuffleXor); + Add(Instruction.Shuffle, InstType.Special); + Add(Instruction.ShuffleDown, InstType.CallBinary, "subgroupShuffleDown"); + Add(Instruction.ShuffleUp, InstType.CallBinary, "subgroupShuffleUp"); + Add(Instruction.ShuffleXor, InstType.CallBinary, "subgroupShuffleXor"); Add(Instruction.Sine, InstType.CallUnary, "sin"); Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt"); Add(Instruction.Store, InstType.Special); diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenShuffle.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenShuffle.cs new file mode 100644 index 0000000000..6d3859efdc --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Glsl/Instructions/InstGenShuffle.cs @@ -0,0 +1,25 @@ +using Ryujinx.Graphics.Shader.StructuredIr; +using Ryujinx.Graphics.Shader.Translation; + +using static Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions.InstGenHelper; + +namespace Ryujinx.Graphics.Shader.CodeGen.Glsl.Instructions +{ + static class InstGenShuffle + { + public static string Shuffle(CodeGenContext context, AstOperation operation) + { + string value = GetSoureExpr(context, operation.GetSource(0), AggregateType.FP32); + string index = GetSoureExpr(context, operation.GetSource(1), AggregateType.U32); + + if (context.HostCapabilities.SupportsShaderBallot) + { + return $"readInvocationARB({value}, {index})"; + } + else + { + return $"subgroupShuffle({value}, {index})"; + } + } + } +} diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs index 98c1b9d284..719ccf0cf0 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs @@ -231,7 +231,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var execution = context.Constant(context.TypeU32(), Scope.Subgroup); var maskVector = context.GroupNonUniformBallot(uvec4Type, execution, context.Get(AggregateType.Bool, source)); - var mask = context.CompositeExtract(context.TypeU32(), maskVector, (SpvLiteralInteger)0); + var mask = context.CompositeExtract(context.TypeU32(), maskVector, (SpvLiteralInteger)operation.Index); return new OperationResult(AggregateType.U32, mask); } @@ -1100,117 +1100,40 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv private static OperationResult GenerateShuffle(CodeGenContext context, AstOperation operation) { - var x = context.GetFP32(operation.GetSource(0)); + var value = context.GetFP32(operation.GetSource(0)); var index = context.GetU32(operation.GetSource(1)); - var mask = context.GetU32(operation.GetSource(2)); - var const31 = context.Constant(context.TypeU32(), 31); - var const8 = context.Constant(context.TypeU32(), 8); - - var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31); - var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31); - var notSegMask = context.Not(context.TypeU32(), segMask); - var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask); - var indexNotSegMask = context.BitwiseAnd(context.TypeU32(), index, notSegMask); - - var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId); - - var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask); - var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask); - var srcThreadId = context.BitwiseOr(context.TypeU32(), indexNotSegMask, minThreadId); - var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId); - var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId); - var result = context.Select(context.TypeFP32(), valid, value, x); - - var validLocal = (AstOperand)operation.GetSource(3); - - context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid)); + var result = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index); return new OperationResult(AggregateType.FP32, result); } private static OperationResult GenerateShuffleDown(CodeGenContext context, AstOperation operation) { - var x = context.GetFP32(operation.GetSource(0)); + var value = context.GetFP32(operation.GetSource(0)); var index = context.GetU32(operation.GetSource(1)); - var mask = context.GetU32(operation.GetSource(2)); - var const31 = context.Constant(context.TypeU32(), 31); - var const8 = context.Constant(context.TypeU32(), 8); - - var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31); - var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31); - var notSegMask = context.Not(context.TypeU32(), segMask); - var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask); - - var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId); - - var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask); - var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask); - var srcThreadId = context.IAdd(context.TypeU32(), threadId, index); - var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId); - var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId); - var result = context.Select(context.TypeFP32(), valid, value, x); - - var validLocal = (AstOperand)operation.GetSource(3); - - context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid)); + var result = context.GroupNonUniformShuffleDown(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index); return new OperationResult(AggregateType.FP32, result); } private static OperationResult GenerateShuffleUp(CodeGenContext context, AstOperation operation) { - var x = context.GetFP32(operation.GetSource(0)); + var value = context.GetFP32(operation.GetSource(0)); var index = context.GetU32(operation.GetSource(1)); - var mask = context.GetU32(operation.GetSource(2)); - var const31 = context.Constant(context.TypeU32(), 31); - var const8 = context.Constant(context.TypeU32(), 8); - - var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31); - - var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId); - - var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask); - var srcThreadId = context.ISub(context.TypeU32(), threadId, index); - var valid = context.SGreaterThanEqual(context.TypeBool(), srcThreadId, minThreadId); - var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId); - var result = context.Select(context.TypeFP32(), valid, value, x); - - var validLocal = (AstOperand)operation.GetSource(3); - - context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid)); + var result = context.GroupNonUniformShuffleUp(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index); return new OperationResult(AggregateType.FP32, result); } private static OperationResult GenerateShuffleXor(CodeGenContext context, AstOperation operation) { - var x = context.GetFP32(operation.GetSource(0)); + var value = context.GetFP32(operation.GetSource(0)); var index = context.GetU32(operation.GetSource(1)); - var mask = context.GetU32(operation.GetSource(2)); - var const31 = context.Constant(context.TypeU32(), 31); - var const8 = context.Constant(context.TypeU32(), 8); - - var clamp = context.BitwiseAnd(context.TypeU32(), mask, const31); - var segMask = context.BitwiseAnd(context.TypeU32(), context.ShiftRightLogical(context.TypeU32(), mask, const8), const31); - var notSegMask = context.Not(context.TypeU32(), segMask); - var clampNotSegMask = context.BitwiseAnd(context.TypeU32(), clamp, notSegMask); - - var threadId = GetScalarInput(context, IoVariable.SubgroupLaneId); - - var minThreadId = context.BitwiseAnd(context.TypeU32(), threadId, segMask); - var maxThreadId = context.BitwiseOr(context.TypeU32(), minThreadId, clampNotSegMask); - var srcThreadId = context.BitwiseXor(context.TypeU32(), threadId, index); - var valid = context.ULessThanEqual(context.TypeBool(), srcThreadId, maxThreadId); - var value = context.GroupNonUniformShuffle(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), x, srcThreadId); - var result = context.Select(context.TypeFP32(), valid, value, x); - - var validLocal = (AstOperand)operation.GetSource(3); - - context.Store(context.GetLocalPointer(validLocal), context.BitcastIfNeeded(validLocal.VarType, AggregateType.Bool, valid)); + var result = context.GroupNonUniformShuffleXor(context.TypeFP32(), context.Constant(context.TypeU32(), (int)Scope.Subgroup), value, index); return new OperationResult(AggregateType.FP32, result); } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs index 5eee888e47..70f1dd3c42 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Spirv/SpirvGenerator.cs @@ -28,12 +28,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv _poolLock = new object(); } - private const HelperFunctionsMask NeedsInvocationIdMask = - HelperFunctionsMask.Shuffle | - HelperFunctionsMask.ShuffleDown | - HelperFunctionsMask.ShuffleUp | - HelperFunctionsMask.ShuffleXor | - HelperFunctionsMask.SwizzleAdd; + private const HelperFunctionsMask NeedsInvocationIdMask = HelperFunctionsMask.SwizzleAdd; public static byte[] Generate(StructuredProgramInfo info, CodeGenParameters parameters) { diff --git a/src/Ryujinx.Graphics.Shader/Decoders/Decoder.cs b/src/Ryujinx.Graphics.Shader/Decoders/Decoder.cs index d18a9baf8c..4266dedcae 100644 --- a/src/Ryujinx.Graphics.Shader/Decoders/Decoder.cs +++ b/src/Ryujinx.Graphics.Shader/Decoders/Decoder.cs @@ -307,6 +307,9 @@ namespace Ryujinx.Graphics.Shader.Decoders case InstName.Sts: context.SetUsedFeature(FeatureFlags.SharedMemory); break; + case InstName.Shfl: + context.SetUsedFeature(FeatureFlags.Shuffle); + break; } block.OpCodes.Add(op); diff --git a/src/Ryujinx.Graphics.Shader/IGpuAccessor.cs b/src/Ryujinx.Graphics.Shader/IGpuAccessor.cs index ee31f02d17..ba10f2720d 100644 --- a/src/Ryujinx.Graphics.Shader/IGpuAccessor.cs +++ b/src/Ryujinx.Graphics.Shader/IGpuAccessor.cs @@ -194,6 +194,15 @@ namespace Ryujinx.Graphics.Shader return 16; } + /// + /// Queries host shader subgroup size. + /// + /// Host shader subgroup size in invocations + int QueryHostSubgroupSize() + { + return 32; + } + /// /// Queries host support for texture formats with BGRA component order (such as BGRA8). /// diff --git a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitMove.cs b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitMove.cs index 9d1c7d087c..944039d652 100644 --- a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitMove.cs +++ b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitMove.cs @@ -76,7 +76,7 @@ namespace Ryujinx.Graphics.Shader.Instructions switch (op.SReg) { case SReg.LaneId: - src = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + src = EmitLoadSubgroupLaneId(context); break; case SReg.InvocationId: @@ -146,19 +146,19 @@ namespace Ryujinx.Graphics.Shader.Instructions break; case SReg.EqMask: - src = context.Load(StorageKind.Input, IoVariable.SubgroupEqMask, null, Const(0)); + src = EmitLoadSubgroupMask(context, IoVariable.SubgroupEqMask); break; case SReg.LtMask: - src = context.Load(StorageKind.Input, IoVariable.SubgroupLtMask, null, Const(0)); + src = EmitLoadSubgroupMask(context, IoVariable.SubgroupLtMask); break; case SReg.LeMask: - src = context.Load(StorageKind.Input, IoVariable.SubgroupLeMask, null, Const(0)); + src = EmitLoadSubgroupMask(context, IoVariable.SubgroupLeMask); break; case SReg.GtMask: - src = context.Load(StorageKind.Input, IoVariable.SubgroupGtMask, null, Const(0)); + src = EmitLoadSubgroupMask(context, IoVariable.SubgroupGtMask); break; case SReg.GeMask: - src = context.Load(StorageKind.Input, IoVariable.SubgroupGeMask, null, Const(0)); + src = EmitLoadSubgroupMask(context, IoVariable.SubgroupGeMask); break; default: @@ -169,6 +169,52 @@ namespace Ryujinx.Graphics.Shader.Instructions context.Copy(GetDest(op.Dest), src); } + private static Operand EmitLoadSubgroupLaneId(EmitterContext context) + { + if (context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize() <= 32) + { + return context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + } + + return context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x1f)); + } + + private static Operand EmitLoadSubgroupMask(EmitterContext context, IoVariable ioVariable) + { + int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize(); + + if (subgroupSize <= 32) + { + return context.Load(StorageKind.Input, ioVariable, null, Const(0)); + } + else if (subgroupSize == 64) + { + Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + Operand low = context.Load(StorageKind.Input, ioVariable, null, Const(0)); + Operand high = context.Load(StorageKind.Input, ioVariable, null, Const(1)); + + return context.ConditionalSelect(context.BitwiseAnd(laneId, Const(32)), high, low); + } + else + { + Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + Operand element = context.ShiftRightU32(laneId, Const(5)); + + Operand res = context.Load(StorageKind.Input, ioVariable, null, Const(0)); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(1)), + context.Load(StorageKind.Input, ioVariable, null, Const(1)), res); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(2)), + context.Load(StorageKind.Input, ioVariable, null, Const(2)), res); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(3)), + context.Load(StorageKind.Input, ioVariable, null, Const(3)), res); + + return res; + } + } + public static void SelR(EmitterContext context) { InstSelR op = context.GetOp(); diff --git a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs index a84944e43e..73eea5c34d 100644 --- a/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs +++ b/src/Ryujinx.Graphics.Shader/Instructions/InstEmitWarp.cs @@ -50,20 +50,7 @@ namespace Ryujinx.Graphics.Shader.Instructions InstVote op = context.GetOp(); Operand pred = GetPredicate(context, op.SrcPred, op.SrcPredInv); - Operand res = null; - - switch (op.VoteMode) - { - case VoteMode.All: - res = context.VoteAll(pred); - break; - case VoteMode.Any: - res = context.VoteAny(pred); - break; - case VoteMode.Eq: - res = context.VoteAllEqual(pred); - break; - } + Operand res = EmitVote(context, op.VoteMode, pred); if (res != null) { @@ -76,7 +63,81 @@ namespace Ryujinx.Graphics.Shader.Instructions if (op.Dest != RegisterConsts.RegisterZeroIndex) { - context.Copy(GetDest(op.Dest), context.Ballot(pred)); + context.Copy(GetDest(op.Dest), EmitBallot(context, pred)); + } + } + + private static Operand EmitVote(EmitterContext context, VoteMode voteMode, Operand pred) + { + int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize(); + + if (subgroupSize <= 32) + { + return voteMode switch + { + VoteMode.All => context.VoteAll(pred), + VoteMode.Any => context.VoteAny(pred), + VoteMode.Eq => context.VoteAllEqual(pred), + _ => null, + }; + } + + // Emulate vote with ballot masks. + // We do that when the GPU thread count is not 32, + // since the shader code assumes it is 32. + // allInvocations => ballot(pred) == ballot(true), + // anyInvocation => ballot(pred) != 0, + // allInvocationsEqual => ballot(pred) == balot(true) || ballot(pred) == 0 + Operand ballotMask = EmitBallot(context, pred); + + Operand AllTrue() => context.ICompareEqual(ballotMask, EmitBallot(context, Const(IrConsts.True))); + + return voteMode switch + { + VoteMode.All => AllTrue(), + VoteMode.Any => context.ICompareNotEqual(ballotMask, Const(0)), + VoteMode.Eq => context.BitwiseOr(AllTrue(), context.ICompareEqual(ballotMask, Const(0))), + _ => null, + }; + } + + private static Operand EmitBallot(EmitterContext context, Operand pred) + { + int subgroupSize = context.TranslatorContext.GpuAccessor.QueryHostSubgroupSize(); + + if (subgroupSize <= 32) + { + return context.Ballot(pred, 0); + } + else if (subgroupSize == 64) + { + // TODO: Add support for vector destination and do that with a single operation. + + Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + Operand low = context.Ballot(pred, 0); + Operand high = context.Ballot(pred, 1); + + return context.ConditionalSelect(context.BitwiseAnd(laneId, Const(32)), high, low); + } + else + { + // TODO: Add support for vector destination and do that with a single operation. + + Operand laneId = context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + Operand element = context.ShiftRightU32(laneId, Const(5)); + + Operand res = context.Ballot(pred, 0); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(1)), + context.Ballot(pred, 1), res); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(2)), + context.Ballot(pred, 2), res); + res = context.ConditionalSelect( + context.ICompareEqual(element, Const(3)), + context.Ballot(pred, 3), res); + + return res; } } } diff --git a/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj b/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj index b1f1fb9633..ea9a7821b1 100644 --- a/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj +++ b/src/Ryujinx.Graphics.Shader/Ryujinx.Graphics.Shader.csproj @@ -12,10 +12,6 @@ - - - - diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs index 73ce908278..2a3d65e75e 100644 --- a/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs +++ b/src/Ryujinx.Graphics.Shader/StructuredIr/HelperFunctionsMask.cs @@ -7,10 +7,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr { MultiplyHighS32 = 1 << 2, MultiplyHighU32 = 1 << 3, - Shuffle = 1 << 4, - ShuffleDown = 1 << 5, - ShuffleUp = 1 << 6, - ShuffleXor = 1 << 7, SwizzleAdd = 1 << 10, FSI = 1 << 11, } diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs index 6cd0fd0863..1169512e98 100644 --- a/src/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs +++ b/src/Ryujinx.Graphics.Shader/StructuredIr/InstructionInfo.cs @@ -109,14 +109,15 @@ namespace Ryujinx.Graphics.Shader.StructuredIr Add(Instruction.PackDouble2x32, AggregateType.FP64, AggregateType.U32, AggregateType.U32); Add(Instruction.PackHalf2x16, AggregateType.U32, AggregateType.FP32, AggregateType.FP32); Add(Instruction.ReciprocalSquareRoot, AggregateType.Scalar, AggregateType.Scalar); + Add(Instruction.Return, AggregateType.Void, AggregateType.U32); Add(Instruction.Round, AggregateType.Scalar, AggregateType.Scalar); Add(Instruction.ShiftLeft, AggregateType.S32, AggregateType.S32, AggregateType.S32); Add(Instruction.ShiftRightS32, AggregateType.S32, AggregateType.S32, AggregateType.S32); Add(Instruction.ShiftRightU32, AggregateType.U32, AggregateType.U32, AggregateType.S32); - Add(Instruction.Shuffle, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); - Add(Instruction.ShuffleDown, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); - Add(Instruction.ShuffleUp, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); - Add(Instruction.ShuffleXor, AggregateType.FP32, AggregateType.FP32, AggregateType.U32, AggregateType.U32, AggregateType.Bool); + Add(Instruction.Shuffle, AggregateType.FP32, AggregateType.FP32, AggregateType.U32); + Add(Instruction.ShuffleDown, AggregateType.FP32, AggregateType.FP32, AggregateType.U32); + Add(Instruction.ShuffleUp, AggregateType.FP32, AggregateType.FP32, AggregateType.U32); + Add(Instruction.ShuffleXor, AggregateType.FP32, AggregateType.FP32, AggregateType.U32); Add(Instruction.Sine, AggregateType.Scalar, AggregateType.Scalar); Add(Instruction.SquareRoot, AggregateType.Scalar, AggregateType.Scalar); Add(Instruction.Store, AggregateType.Void); @@ -131,7 +132,7 @@ namespace Ryujinx.Graphics.Shader.StructuredIr Add(Instruction.VoteAll, AggregateType.Bool, AggregateType.Bool); Add(Instruction.VoteAllEqual, AggregateType.Bool, AggregateType.Bool); Add(Instruction.VoteAny, AggregateType.Bool, AggregateType.Bool); -#pragma warning restore IDE0055v +#pragma warning restore IDE0055 } private static void Add(Instruction inst, AggregateType destType, params AggregateType[] srcTypes) diff --git a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs index 862fef1267..b0db0ffb0d 100644 --- a/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs +++ b/src/Ryujinx.Graphics.Shader/StructuredIr/StructuredProgram.cs @@ -282,18 +282,6 @@ namespace Ryujinx.Graphics.Shader.StructuredIr case Instruction.MultiplyHighU32: context.Info.HelperFunctionsMask |= HelperFunctionsMask.MultiplyHighU32; break; - case Instruction.Shuffle: - context.Info.HelperFunctionsMask |= HelperFunctionsMask.Shuffle; - break; - case Instruction.ShuffleDown: - context.Info.HelperFunctionsMask |= HelperFunctionsMask.ShuffleDown; - break; - case Instruction.ShuffleUp: - context.Info.HelperFunctionsMask |= HelperFunctionsMask.ShuffleUp; - break; - case Instruction.ShuffleXor: - context.Info.HelperFunctionsMask |= HelperFunctionsMask.ShuffleXor; - break; case Instruction.SwizzleAdd: context.Info.HelperFunctionsMask |= HelperFunctionsMask.SwizzleAdd; break; diff --git a/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs b/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs index 6cb572381f..a08c8ea9d4 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/EmitterContextInsts.cs @@ -112,9 +112,13 @@ namespace Ryujinx.Graphics.Shader.Translation return context.Add(Instruction.AtomicXor, storageKind, Local(), Const(binding), e0, e1, value); } - public static Operand Ballot(this EmitterContext context, Operand a) + public static Operand Ballot(this EmitterContext context, Operand a, int index) { - return context.Add(Instruction.Ballot, Local(), a); + Operand dest = Local(); + + context.Add(new Operation(Instruction.Ballot, index, dest, a)); + + return dest; } public static Operand Barrier(this EmitterContext context) @@ -782,21 +786,41 @@ namespace Ryujinx.Graphics.Shader.Translation return context.Add(Instruction.ShiftRightU32, Local(), a, b); } + public static Operand Shuffle(this EmitterContext context, Operand a, Operand b) + { + return context.Add(Instruction.Shuffle, Local(), a, b); + } + public static (Operand, Operand) Shuffle(this EmitterContext context, Operand a, Operand b, Operand c) { return context.Add(Instruction.Shuffle, (Local(), Local()), a, b, c); } + public static Operand ShuffleDown(this EmitterContext context, Operand a, Operand b) + { + return context.Add(Instruction.ShuffleDown, Local(), a, b); + } + public static (Operand, Operand) ShuffleDown(this EmitterContext context, Operand a, Operand b, Operand c) { return context.Add(Instruction.ShuffleDown, (Local(), Local()), a, b, c); } + public static Operand ShuffleUp(this EmitterContext context, Operand a, Operand b) + { + return context.Add(Instruction.ShuffleUp, Local(), a, b); + } + public static (Operand, Operand) ShuffleUp(this EmitterContext context, Operand a, Operand b, Operand c) { return context.Add(Instruction.ShuffleUp, (Local(), Local()), a, b, c); } + public static Operand ShuffleXor(this EmitterContext context, Operand a, Operand b) + { + return context.Add(Instruction.ShuffleXor, Local(), a, b); + } + public static (Operand, Operand) ShuffleXor(this EmitterContext context, Operand a, Operand b, Operand c) { return context.Add(Instruction.ShuffleXor, (Local(), Local()), a, b, c); diff --git a/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs b/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs index 5b7226acdc..552a3f3100 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/FeatureFlags.cs @@ -18,6 +18,7 @@ namespace Ryujinx.Graphics.Shader.Translation InstanceId = 1 << 3, DrawParameters = 1 << 4, RtLayer = 1 << 5, + Shuffle = 1 << 6, FixedFuncAttr = 1 << 9, LocalMemory = 1 << 10, SharedMemory = 1 << 11, diff --git a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs index 2addff5c0a..ef2f8759da 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionManager.cs @@ -56,6 +56,20 @@ namespace Ryujinx.Graphics.Shader.Translation return functionId; } + public int GetOrCreateShuffleFunctionId(HelperFunctionName functionName, int subgroupSize) + { + if (_functionIds.TryGetValue((int)functionName, out int functionId)) + { + return functionId; + } + + Function function = GenerateShuffleFunction(functionName, subgroupSize); + functionId = AddFunction(function); + _functionIds.Add((int)functionName, functionId); + + return functionId; + } + private Function GenerateFunction(HelperFunctionName functionName) { return functionName switch @@ -216,6 +230,137 @@ namespace Ryujinx.Graphics.Shader.Translation return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, $"SharedStore{bitSize}_{id}", false, 2, 0); } + private static Function GenerateShuffleFunction(HelperFunctionName functionName, int subgroupSize) + { + return functionName switch + { + HelperFunctionName.Shuffle => GenerateShuffle(subgroupSize), + HelperFunctionName.ShuffleDown => GenerateShuffleDown(subgroupSize), + HelperFunctionName.ShuffleUp => GenerateShuffleUp(subgroupSize), + HelperFunctionName.ShuffleXor => GenerateShuffleXor(subgroupSize), + _ => throw new ArgumentException($"Invalid function name {functionName}"), + }; + } + + private static Function GenerateShuffle(int subgroupSize) + { + EmitterContext context = new(); + + Operand value = Argument(0); + Operand index = Argument(1); + Operand mask = Argument(2); + + Operand clamp = context.BitwiseAnd(mask, Const(0x1f)); + Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f)); + Operand minThreadId = context.BitwiseAnd(GenerateLoadSubgroupLaneId(context, subgroupSize), segMask); + Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId); + Operand srcThreadId = context.BitwiseOr(context.BitwiseAnd(index, context.BitwiseNot(segMask)), minThreadId); + Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId); + + context.Copy(Argument(3), valid); + + Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize)); + + context.Return(context.ConditionalSelect(valid, result, value)); + + return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "Shuffle", true, 3, 1); + } + + private static Function GenerateShuffleDown(int subgroupSize) + { + EmitterContext context = new(); + + Operand value = Argument(0); + Operand index = Argument(1); + Operand mask = Argument(2); + + Operand clamp = context.BitwiseAnd(mask, Const(0x1f)); + Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f)); + Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize); + Operand minThreadId = context.BitwiseAnd(laneId, segMask); + Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId); + Operand srcThreadId = context.IAdd(laneId, index); + Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId); + + context.Copy(Argument(3), valid); + + Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize)); + + context.Return(context.ConditionalSelect(valid, result, value)); + + return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleDown", true, 3, 1); + } + + private static Function GenerateShuffleUp(int subgroupSize) + { + EmitterContext context = new(); + + Operand value = Argument(0); + Operand index = Argument(1); + Operand mask = Argument(2); + + Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f)); + Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize); + Operand minThreadId = context.BitwiseAnd(laneId, segMask); + Operand srcThreadId = context.ISubtract(laneId, index); + Operand valid = context.ICompareGreaterOrEqual(srcThreadId, minThreadId); + + context.Copy(Argument(3), valid); + + Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize)); + + context.Return(context.ConditionalSelect(valid, result, value)); + + return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleUp", true, 3, 1); + } + + private static Function GenerateShuffleXor(int subgroupSize) + { + EmitterContext context = new(); + + Operand value = Argument(0); + Operand index = Argument(1); + Operand mask = Argument(2); + + Operand clamp = context.BitwiseAnd(mask, Const(0x1f)); + Operand segMask = context.BitwiseAnd(context.ShiftRightU32(mask, Const(8)), Const(0x1f)); + Operand laneId = GenerateLoadSubgroupLaneId(context, subgroupSize); + Operand minThreadId = context.BitwiseAnd(laneId, segMask); + Operand maxThreadId = context.BitwiseOr(context.BitwiseAnd(clamp, context.BitwiseNot(segMask)), minThreadId); + Operand srcThreadId = context.BitwiseExclusiveOr(laneId, index); + Operand valid = context.ICompareLessOrEqualUnsigned(srcThreadId, maxThreadId); + + context.Copy(Argument(3), valid); + + Operand result = context.Shuffle(value, GenerateSubgroupShuffleIndex(context, srcThreadId, subgroupSize)); + + context.Return(context.ConditionalSelect(valid, result, value)); + + return new Function(ControlFlowGraph.Create(context.GetOperations()).Blocks, "ShuffleXor", true, 3, 1); + } + + private static Operand GenerateLoadSubgroupLaneId(EmitterContext context, int subgroupSize) + { + if (subgroupSize <= 32) + { + return context.Load(StorageKind.Input, IoVariable.SubgroupLaneId); + } + + return context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x1f)); + } + + private static Operand GenerateSubgroupShuffleIndex(EmitterContext context, Operand srcThreadId, int subgroupSize) + { + if (subgroupSize <= 32) + { + return srcThreadId; + } + + return context.BitwiseOr( + context.BitwiseAnd(context.Load(StorageKind.Input, IoVariable.SubgroupLaneId), Const(0x60)), + srcThreadId); + } + private Function GenerateTexelFetchScaleFunction() { EmitterContext context = new(); diff --git a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs index e5af173556..09b17729d4 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/HelperFunctionName.cs @@ -2,12 +2,18 @@ namespace Ryujinx.Graphics.Shader.Translation { enum HelperFunctionName { + Invalid, + ConvertDoubleToFloat, ConvertFloatToDouble, SharedAtomicMaxS32, SharedAtomicMinS32, SharedStore8, SharedStore16, + Shuffle, + ShuffleDown, + ShuffleUp, + ShuffleXor, TexelFetchScale, TextureSizeUnscale, } diff --git a/src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs b/src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs new file mode 100644 index 0000000000..839d4f8185 --- /dev/null +++ b/src/Ryujinx.Graphics.Shader/Translation/Transforms/ShufflePass.cs @@ -0,0 +1,52 @@ +using Ryujinx.Graphics.Shader.IntermediateRepresentation; +using Ryujinx.Graphics.Shader.Translation.Optimizations; +using System.Collections.Generic; +using static Ryujinx.Graphics.Shader.IntermediateRepresentation.OperandHelper; + +namespace Ryujinx.Graphics.Shader.Translation.Transforms +{ + class ShufflePass : ITransformPass + { + public static bool IsEnabled(IGpuAccessor gpuAccessor, ShaderStage stage, TargetLanguage targetLanguage, FeatureFlags usedFeatures) + { + return usedFeatures.HasFlag(FeatureFlags.Shuffle); + } + + public static LinkedListNode RunPass(TransformContext context, LinkedListNode node) + { + Operation operation = (Operation)node.Value; + + HelperFunctionName functionName = operation.Inst switch + { + Instruction.Shuffle => HelperFunctionName.Shuffle, + Instruction.ShuffleDown => HelperFunctionName.ShuffleDown, + Instruction.ShuffleUp => HelperFunctionName.ShuffleUp, + Instruction.ShuffleXor => HelperFunctionName.ShuffleXor, + _ => HelperFunctionName.Invalid, + }; + + if (functionName == HelperFunctionName.Invalid || operation.SourcesCount != 3 || operation.DestsCount != 2) + { + return node; + } + + int functionId = context.Hfm.GetOrCreateShuffleFunctionId(functionName, context.GpuAccessor.QueryHostSubgroupSize()); + + Operand result = operation.GetDest(0); + Operand valid = operation.GetDest(1); + Operand value = operation.GetSource(0); + Operand index = operation.GetSource(1); + Operand mask = operation.GetSource(2); + + operation.Dest = null; + + Operand[] callArgs = new Operand[] { Const(functionId), value, index, mask, valid }; + + LinkedListNode newNode = node.List.AddBefore(node, new Operation(Instruction.Call, 0, result, callArgs)); + + Utils.DeleteNode(node, operation); + + return newNode; + } + } +} diff --git a/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs b/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs index c3bbe7ddf8..2939388079 100644 --- a/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs +++ b/src/Ryujinx.Graphics.Shader/Translation/Transforms/TransformPasses.cs @@ -13,6 +13,7 @@ namespace Ryujinx.Graphics.Shader.Translation.Transforms RunPass(context); RunPass(context); RunPass(context); + RunPass(context); } private static void RunPass(TransformContext context) where T : ITransformPass diff --git a/src/Ryujinx.Graphics.Vulkan/HardwareCapabilities.cs b/src/Ryujinx.Graphics.Vulkan/HardwareCapabilities.cs index e76a332f42..798de5c904 100644 --- a/src/Ryujinx.Graphics.Vulkan/HardwareCapabilities.cs +++ b/src/Ryujinx.Graphics.Vulkan/HardwareCapabilities.cs @@ -25,7 +25,6 @@ namespace Ryujinx.Graphics.Vulkan public readonly bool SupportsIndirectParameters; public readonly bool SupportsFragmentShaderInterlock; public readonly bool SupportsGeometryShaderPassthrough; - public readonly bool SupportsSubgroupSizeControl; public readonly bool SupportsShaderFloat64; public readonly bool SupportsShaderInt8; public readonly bool SupportsShaderStencilExport; @@ -45,9 +44,7 @@ namespace Ryujinx.Graphics.Vulkan public readonly bool SupportsViewportArray2; public readonly bool SupportsHostImportedMemory; public readonly bool SupportsDepthClipControl; - public readonly uint MinSubgroupSize; - public readonly uint MaxSubgroupSize; - public readonly ShaderStageFlags RequiredSubgroupSizeStages; + public readonly uint SubgroupSize; public readonly SampleCountFlags SupportedSampleCounts; public readonly PortabilitySubsetFlags PortabilitySubset; public readonly uint VertexBufferAlignment; @@ -64,7 +61,6 @@ namespace Ryujinx.Graphics.Vulkan bool supportsIndirectParameters, bool supportsFragmentShaderInterlock, bool supportsGeometryShaderPassthrough, - bool supportsSubgroupSizeControl, bool supportsShaderFloat64, bool supportsShaderInt8, bool supportsShaderStencilExport, @@ -84,9 +80,7 @@ namespace Ryujinx.Graphics.Vulkan bool supportsViewportArray2, bool supportsHostImportedMemory, bool supportsDepthClipControl, - uint minSubgroupSize, - uint maxSubgroupSize, - ShaderStageFlags requiredSubgroupSizeStages, + uint subgroupSize, SampleCountFlags supportedSampleCounts, PortabilitySubsetFlags portabilitySubset, uint vertexBufferAlignment, @@ -102,7 +96,6 @@ namespace Ryujinx.Graphics.Vulkan SupportsIndirectParameters = supportsIndirectParameters; SupportsFragmentShaderInterlock = supportsFragmentShaderInterlock; SupportsGeometryShaderPassthrough = supportsGeometryShaderPassthrough; - SupportsSubgroupSizeControl = supportsSubgroupSizeControl; SupportsShaderFloat64 = supportsShaderFloat64; SupportsShaderInt8 = supportsShaderInt8; SupportsShaderStencilExport = supportsShaderStencilExport; @@ -122,9 +115,7 @@ namespace Ryujinx.Graphics.Vulkan SupportsViewportArray2 = supportsViewportArray2; SupportsHostImportedMemory = supportsHostImportedMemory; SupportsDepthClipControl = supportsDepthClipControl; - MinSubgroupSize = minSubgroupSize; - MaxSubgroupSize = maxSubgroupSize; - RequiredSubgroupSizeStages = requiredSubgroupSizeStages; + SubgroupSize = subgroupSize; SupportedSampleCounts = supportedSampleCounts; PortabilitySubset = portabilitySubset; VertexBufferAlignment = vertexBufferAlignment; diff --git a/src/Ryujinx.Graphics.Vulkan/PipelineState.cs b/src/Ryujinx.Graphics.Vulkan/PipelineState.cs index cc9af5b6d9..5a30cff8ec 100644 --- a/src/Ryujinx.Graphics.Vulkan/PipelineState.cs +++ b/src/Ryujinx.Graphics.Vulkan/PipelineState.cs @@ -352,11 +352,6 @@ namespace Ryujinx.Graphics.Vulkan return pipeline; } - if (gd.Capabilities.SupportsSubgroupSizeControl) - { - UpdateStageRequiredSubgroupSizes(gd, 1); - } - var pipelineCreateInfo = new ComputePipelineCreateInfo { SType = StructureType.ComputePipelineCreateInfo, @@ -616,11 +611,6 @@ namespace Ryujinx.Graphics.Vulkan PDynamicStates = dynamicStates, }; - if (gd.Capabilities.SupportsSubgroupSizeControl) - { - UpdateStageRequiredSubgroupSizes(gd, (int)StagesCount); - } - var pipelineCreateInfo = new GraphicsPipelineCreateInfo { SType = StructureType.GraphicsPipelineCreateInfo, @@ -659,19 +649,6 @@ namespace Ryujinx.Graphics.Vulkan return pipeline; } - private readonly unsafe void UpdateStageRequiredSubgroupSizes(VulkanRenderer gd, int count) - { - for (int index = 0; index < count; index++) - { - bool canUseExplicitSubgroupSize = - (gd.Capabilities.RequiredSubgroupSizeStages & Stages[index].Stage) != 0 && - gd.Capabilities.MinSubgroupSize <= RequiredSubgroupSize && - gd.Capabilities.MaxSubgroupSize >= RequiredSubgroupSize; - - Stages[index].PNext = canUseExplicitSubgroupSize ? StageRequiredSubgroupSizes.Pointer + index : null; - } - } - private void UpdateVertexAttributeDescriptions(VulkanRenderer gd) { // Vertex attributes exceeding the stride are invalid. diff --git a/src/Ryujinx.Graphics.Vulkan/VulkanInitialization.cs b/src/Ryujinx.Graphics.Vulkan/VulkanInitialization.cs index 6f73397b80..973c6d396f 100644 --- a/src/Ryujinx.Graphics.Vulkan/VulkanInitialization.cs +++ b/src/Ryujinx.Graphics.Vulkan/VulkanInitialization.cs @@ -37,7 +37,6 @@ namespace Ryujinx.Graphics.Vulkan "VK_EXT_shader_stencil_export", "VK_KHR_shader_float16_int8", "VK_EXT_shader_subgroup_ballot", - "VK_EXT_subgroup_size_control", "VK_NV_geometry_shader_passthrough", "VK_NV_viewport_array2", "VK_EXT_depth_clip_control", diff --git a/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs b/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs index 7848bc8779..6755122933 100644 --- a/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs +++ b/src/Ryujinx.Graphics.Vulkan/VulkanRenderer.cs @@ -151,6 +151,14 @@ namespace Ryujinx.Graphics.Vulkan SType = StructureType.PhysicalDeviceProperties2, }; + PhysicalDeviceSubgroupProperties propertiesSubgroup = new() + { + SType = StructureType.PhysicalDeviceSubgroupProperties, + PNext = properties2.PNext, + }; + + properties2.PNext = &propertiesSubgroup; + PhysicalDeviceBlendOperationAdvancedPropertiesEXT propertiesBlendOperationAdvanced = new() { SType = StructureType.PhysicalDeviceBlendOperationAdvancedPropertiesExt, @@ -164,18 +172,6 @@ namespace Ryujinx.Graphics.Vulkan properties2.PNext = &propertiesBlendOperationAdvanced; } - PhysicalDeviceSubgroupSizeControlPropertiesEXT propertiesSubgroupSizeControl = new() - { - SType = StructureType.PhysicalDeviceSubgroupSizeControlPropertiesExt, - }; - - bool supportsSubgroupSizeControl = _physicalDevice.IsDeviceExtensionPresent("VK_EXT_subgroup_size_control"); - - if (supportsSubgroupSizeControl) - { - properties2.PNext = &propertiesSubgroupSizeControl; - } - bool supportsTransformFeedback = _physicalDevice.IsDeviceExtensionPresent(ExtTransformFeedback.ExtensionName); PhysicalDeviceTransformFeedbackPropertiesEXT propertiesTransformFeedback = new() @@ -315,7 +311,6 @@ namespace Ryujinx.Graphics.Vulkan _physicalDevice.IsDeviceExtensionPresent(KhrDrawIndirectCount.ExtensionName), _physicalDevice.IsDeviceExtensionPresent("VK_EXT_fragment_shader_interlock"), _physicalDevice.IsDeviceExtensionPresent("VK_NV_geometry_shader_passthrough"), - supportsSubgroupSizeControl, features2.Features.ShaderFloat64, featuresShaderInt8.ShaderInt8, _physicalDevice.IsDeviceExtensionPresent("VK_EXT_shader_stencil_export"), @@ -335,9 +330,7 @@ namespace Ryujinx.Graphics.Vulkan _physicalDevice.IsDeviceExtensionPresent("VK_NV_viewport_array2"), _physicalDevice.IsDeviceExtensionPresent(ExtExternalMemoryHost.ExtensionName), supportsDepthClipControl && featuresDepthClipControl.DepthClipControl, - propertiesSubgroupSizeControl.MinSubgroupSize, - propertiesSubgroupSizeControl.MaxSubgroupSize, - propertiesSubgroupSizeControl.RequiredSubgroupSizeStages, + propertiesSubgroup.SubgroupSize, supportedSampleCounts, portabilityFlags, vertexBufferAlignment, @@ -623,6 +616,7 @@ namespace Ryujinx.Graphics.Vulkan maximumImagesPerStage: Constants.MaxImagesPerStage, maximumComputeSharedMemorySize: (int)limits.MaxComputeSharedMemorySize, maximumSupportedAnisotropy: (int)limits.MaxSamplerAnisotropy, + shaderSubgroupSize: (int)Capabilities.SubgroupSize, storageBufferOffsetAlignment: (int)limits.MinStorageBufferOffsetAlignment, gatherBiasPrecision: IsIntelWindows || IsAmdWindows ? (int)Capabilities.SubTexelPrecisionBits : 0); }