From c91f37662898cf2ee03477a0905f8bb803681a8c Mon Sep 17 00:00:00 2001 From: Manon Oomen Date: Tue, 12 May 2026 18:22:10 +0200 Subject: [PATCH] Add Mesh Shader Support to Vulkan. --- lib/API/VK/Device.cpp | 420 +++++++++++++++++++++++++++++++++++------- test/lit.cfg.py | 2 + 2 files changed, 359 insertions(+), 63 deletions(-) diff --git a/lib/API/VK/Device.cpp b/lib/API/VK/Device.cpp index da683603a..39d0b219d 100644 --- a/lib/API/VK/Device.cpp +++ b/lib/API/VK/Device.cpp @@ -400,6 +400,18 @@ struct VulkanInstance { namespace { +struct MeshShaderFunctions { + PFN_vkCmdDrawMeshTasksEXT VkCmdDrawMeshTasksEXT = nullptr; + + static MeshShaderFunctions create(VkDevice Device) { + MeshShaderFunctions Result; + Result.VkCmdDrawMeshTasksEXT = + (PFN_vkCmdDrawMeshTasksEXT)vkGetDeviceProcAddr(Device, + "vkCmdDrawMeshTasksEXT"); + return Result; + } +}; + class VulkanBuffer : public offloadtest::Buffer { public: VkDevice Dev; // Needed for clean-up @@ -569,6 +581,7 @@ class VulkanQueue : public offloadtest::Queue { class VulkanCommandBuffer : public offloadtest::CommandBuffer { public: VkDevice Device = VK_NULL_HANDLE; + MeshShaderFunctions MeshShaderFns; // Owned per command buffer so that recording, submission, and lifetime // management of each command buffer are independently safe without external // synchronization. @@ -586,7 +599,8 @@ class VulkanCommandBuffer : public offloadtest::CommandBuffer { create(VkDevice Device, uint32_t QueueFamilyIdx, PFN_vkCmdBeginDebugUtilsLabelEXT CmdBeginDebugUtilsLabel, PFN_vkCmdEndDebugUtilsLabelEXT CmdEndDebugUtilsLabel, - PFN_vkCmdInsertDebugUtilsLabelEXT CmdInsertDebugUtilsLabel) { + PFN_vkCmdInsertDebugUtilsLabelEXT CmdInsertDebugUtilsLabel, + MeshShaderFunctions MeshShaderFns) { auto CB = std::unique_ptr(new VulkanCommandBuffer()); CB->Device = Device; @@ -618,6 +632,7 @@ class VulkanCommandBuffer : public offloadtest::CommandBuffer { CB->CmdBeginDebugUtilsLabel = CmdBeginDebugUtilsLabel; CB->CmdEndDebugUtilsLabel = CmdEndDebugUtilsLabel; CB->CmdInsertDebugUtilsLabel = CmdInsertDebugUtilsLabel; + CB->MeshShaderFns = MeshShaderFns; return CB; } @@ -931,13 +946,19 @@ class VulkanRenderEncoder : public offloadtest::RenderEncoder { llvm::Error dispatchMesh(const offloadtest::PipelineState &PSO, uint32_t GroupCountX, uint32_t GroupCountY, uint32_t GroupCountZ) override { - (void)PSO; - (void)GroupCountX; - (void)GroupCountY; - (void)GroupCountZ; + if (!ViewportSet) + return llvm::createStringError(std::errc::invalid_argument, + "Viewport must be set before drawing."); + if (!ScissorSet) + return llvm::createStringError(std::errc::invalid_argument, + "Scissor must be set before drawing."); - return llvm::createStringError( - "dispatchMesh is unimplemented in the Vulkan backend."); + const auto &VKPSO = llvm::cast(PSO); + vkCmdBindPipeline(CB.CmdBuffer, VK_PIPELINE_BIND_POINT_GRAPHICS, + VKPSO.Pipeline); + CB.MeshShaderFns.VkCmdDrawMeshTasksEXT(CB.CmdBuffer, GroupCountX, + GroupCountY, GroupCountZ); + return llvm::Error::success(); } void endEncodingImpl() override { @@ -1082,6 +1103,7 @@ class VulkanDevice : public offloadtest::Device { PFN_vkCmdBeginDebugUtilsLabelEXT CmdBeginDebugUtilsLabel = nullptr; PFN_vkCmdEndDebugUtilsLabelEXT CmdEndDebugUtilsLabel = nullptr; PFN_vkCmdInsertDebugUtilsLabelEXT CmdInsertDebugUtilsLabel = nullptr; + MeshShaderFunctions MeshShaderFns; struct BufferRef { VkBuffer Buffer; @@ -1236,6 +1258,31 @@ class VulkanDevice : public offloadtest::Device { #endif vkGetPhysicalDeviceFeatures2(PhysicalDevice, &Features); + const VulkanDevice::ExtensionVector AvailableDeviceExtensions = + queryDeviceExtensions(PhysicalDevice); + + llvm::SmallVector EnabledDeviceExtensions; + const llvm::StringRef ExtensionName = "VK_EXT_mesh_shader"; + VkPhysicalDeviceMeshShaderFeaturesEXT MeshFeatures{}; + if (isExtensionSupported(AvailableDeviceExtensions, ExtensionName)) { + EnabledDeviceExtensions.push_back(ExtensionName.data()); + MeshFeatures.sType = + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MESH_SHADER_FEATURES_EXT; + MeshFeatures.taskShader = 1; + MeshFeatures.meshShader = 1; + MeshFeatures.multiviewMeshShader = 0; + MeshFeatures.primitiveFragmentShadingRateMeshShader = 0; + MeshFeatures.meshShaderQueries = 0; +#ifdef VK_VERSION_1_4 + Features14.pNext = &MeshFeatures; +#else + Features13.pNext = &MeshFeatures; +#endif + } + + DeviceInfo.enabledExtensionCount = + static_cast(EnabledDeviceExtensions.size()); + DeviceInfo.ppEnabledExtensionNames = EnabledDeviceExtensions.data(); DeviceInfo.pEnabledFeatures = &Features.features; DeviceInfo.pNext = Features.pNext; @@ -1253,16 +1300,18 @@ class VulkanDevice : public offloadtest::Device { VulkanQueue GraphicsQueue(DeviceQueue, QueueFamilyIdx, Device, std::move(*SubmitFenceOrErr)); - return std::make_unique(Instance, PhysicalDevice, Props, - Device, std::move(GraphicsQueue), - std::move(InstanceLayers)); + return std::make_unique( + Instance, PhysicalDevice, Props, Device, std::move(GraphicsQueue), + std::move(InstanceLayers), std::move(AvailableDeviceExtensions)); } VulkanDevice(std::shared_ptr I, VkPhysicalDevice P, VkPhysicalDeviceProperties Props, VkDevice D, VulkanQueue Q, - llvm::SmallVector InstanceLayers) + llvm::SmallVector InstanceLayers, + ExtensionVector DeviceExtensions) : Instance(I), PhysicalDevice(P), Props(Props), Device(D), - GraphicsQueue(std::move(Q)), InstanceLayers(std::move(InstanceLayers)) { + GraphicsQueue(std::move(Q)), InstanceLayers(std::move(InstanceLayers)), + DeviceExtensions(std::move(DeviceExtensions)) { const uint64_t DeviceNameSz = strnlen(Props.deviceName, VK_MAX_PHYSICAL_DEVICE_NAME_SIZE); Description = std::string(Props.deviceName, DeviceNameSz); @@ -1292,8 +1341,6 @@ class VulkanDevice : public offloadtest::Device { Description += " (" + DriverName + ")"; #endif - DeviceExtensions = queryDeviceExtensions(PhysicalDevice); - CmdBeginDebugUtilsLabel = (PFN_vkCmdBeginDebugUtilsLabelEXT)vkGetDeviceProcAddr( Device, "vkCmdBeginDebugUtilsLabelEXT"); @@ -1302,6 +1349,8 @@ class VulkanDevice : public offloadtest::Device { CmdInsertDebugUtilsLabel = (PFN_vkCmdInsertDebugUtilsLabelEXT)vkGetDeviceProcAddr( Device, "vkCmdInsertDebugUtilsLabelEXT"); + + MeshShaderFns = MeshShaderFunctions::create(Device); } VulkanDevice(const VulkanDevice &) = delete; @@ -1707,6 +1756,205 @@ class VulkanDevice : public offloadtest::Device { Name, Device, Pipeline, PipelineLayout, std::move(SetLayouts)); } + llvm::Expected> + createPipelineAsMsPs(llvm::StringRef Name, const BindingsDesc &BindingsDesc, + llvm::ArrayRef RTFormats, + std::optional DSFormat, + std::optional AS, ShaderContainer MS, + std::optional PS) /*override*/ { + assert(RTFormats.size() <= 8); + + VkShaderStageFlags GraphicsFlags = VK_SHADER_STAGE_MESH_BIT_EXT; + llvm::SmallVector ShaderStages; + // No longer need shader modules after pipeline compilation. + auto ShaderModuleCleanUp = llvm::scope_exit([&] { + for (auto &Stage : ShaderStages) + vkDestroyShaderModule(Device, Stage.module, nullptr); + }); + + llvm::SmallVector MSSpecEntries; + llvm::SmallVector MSSpecData; + VkSpecializationInfo MSSpecInfo = {}; + { + if (auto Err = parseSpecializationConstants(MS.SpecializationConstants, + MSSpecEntries, MSSpecData, + MSSpecInfo)) + return Err; + + auto MSModOrErr = createShaderModule(MS.Shader, "mesh"); + if (!MSModOrErr) + return MSModOrErr.takeError(); + + VkPipelineShaderStageCreateInfo ShaderStage = {}; + ShaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + ShaderStage.stage = VK_SHADER_STAGE_MESH_BIT_EXT; + ShaderStage.module = *MSModOrErr; + ShaderStage.pName = MS.EntryPoint.c_str(); + ShaderStage.pSpecializationInfo = + MS.SpecializationConstants.empty() ? nullptr : &MSSpecInfo; + ShaderStages.push_back(ShaderStage); + } + + llvm::SmallVector ASSpecEntries; + llvm::SmallVector ASSpecData; + VkSpecializationInfo ASSpecInfo = {}; + if (AS) { + if (auto Err = parseSpecializationConstants((*AS).SpecializationConstants, + ASSpecEntries, ASSpecData, + ASSpecInfo)) + return Err; + + auto ASModOrErr = createShaderModule((*AS).Shader, "task"); + if (!ASModOrErr) + return ASModOrErr.takeError(); + + GraphicsFlags |= VK_SHADER_STAGE_TASK_BIT_EXT; + + VkPipelineShaderStageCreateInfo ShaderStage = {}; + ShaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + ShaderStage.stage = VK_SHADER_STAGE_TASK_BIT_EXT; + ShaderStage.module = *ASModOrErr; + ShaderStage.pName = (*AS).EntryPoint.c_str(); + ShaderStage.pSpecializationInfo = + (*AS).SpecializationConstants.empty() ? nullptr : &ASSpecInfo; + ShaderStages.push_back(ShaderStage); + } + + llvm::SmallVector PSSpecEntries; + llvm::SmallVector PSSpecData; + VkSpecializationInfo PSSpecInfo = {}; + if (PS) { + if (auto Err = parseSpecializationConstants((*PS).SpecializationConstants, + PSSpecEntries, PSSpecData, + PSSpecInfo)) + return Err; + + auto PSModOrErr = createShaderModule((*PS).Shader, "pixel"); + if (!PSModOrErr) + return PSModOrErr.takeError(); + + GraphicsFlags |= VK_SHADER_STAGE_FRAGMENT_BIT; + + VkPipelineShaderStageCreateInfo ShaderStage = {}; + ShaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + ShaderStage.stage = VK_SHADER_STAGE_FRAGMENT_BIT; + ShaderStage.module = *PSModOrErr; + ShaderStage.pName = (*PS).EntryPoint.c_str(); + ShaderStage.pSpecializationInfo = + (*PS).SpecializationConstants.empty() ? nullptr : &PSSpecInfo; + ShaderStages.push_back(ShaderStage); + } + + // Build a RenderPassDesc from the PSO's RT/DS formats. + RenderPassDesc PassDesc; + PassDesc.ColorAttachments.reserve(RTFormats.size()); + for (const Format F : RTFormats) { + ColorAttachmentFormatDesc CA = {}; + CA.Fmt = F; + CA.Load = LoadAction::DontCare; + CA.Store = StoreAction::DontCare; + PassDesc.ColorAttachments.push_back(CA); + } + if (DSFormat) { + DepthStencilAttachmentFormatDesc DS = {}; + DS.Fmt = *DSFormat; + DS.DepthLoad = LoadAction::DontCare; + DS.DepthStore = StoreAction::DontCare; + DS.StencilLoad = LoadAction::DontCare; + DS.StencilStore = StoreAction::DontCare; + PassDesc.DepthStencil = DS; + } + + // NOTE: After pipeline creation this render pass can be dropped. Later + // render passes just need to be compatible with this render pass, or in + // other words: the format, sample count and number of targets (rt and ds), + // need to match. + auto RenderPassOrErr = createRenderPass(PassDesc); + if (!RenderPassOrErr) + return RenderPassOrErr.takeError(); + const std::unique_ptr RenderPass = + std::move(*RenderPassOrErr); + VkRenderPass RenderPassHandle = + llvm::cast(*RenderPass).Handle; + + llvm::SmallVector SetLayouts; + VkPipelineLayout PipelineLayout = VK_NULL_HANDLE; + if (auto Err = createPipelineLayout(BindingsDesc, GraphicsFlags, SetLayouts, + PipelineLayout)) + return Err; + + VkPipelineViewportStateCreateInfo ViewportCI = {}; + ViewportCI.sType = VK_STRUCTURE_TYPE_PIPELINE_VIEWPORT_STATE_CREATE_INFO; + ViewportCI.viewportCount = 1; + ViewportCI.scissorCount = 1; + + const VkDynamicState DynStates[] = {VK_DYNAMIC_STATE_VIEWPORT, + VK_DYNAMIC_STATE_SCISSOR}; + VkPipelineDynamicStateCreateInfo DynamicCI = {}; + DynamicCI.sType = VK_STRUCTURE_TYPE_PIPELINE_DYNAMIC_STATE_CREATE_INFO; + DynamicCI.dynamicStateCount = 2; + DynamicCI.pDynamicStates = DynStates; + + VkPipelineRasterizationStateCreateInfo RastCI = {}; + RastCI.sType = VK_STRUCTURE_TYPE_PIPELINE_RASTERIZATION_STATE_CREATE_INFO; + RastCI.polygonMode = VK_POLYGON_MODE_FILL; + RastCI.cullMode = VK_CULL_MODE_NONE; + RastCI.frontFace = VK_FRONT_FACE_COUNTER_CLOCKWISE; + RastCI.lineWidth = 1.0f; + + VkPipelineMultisampleStateCreateInfo MultisampleCI = {}; + MultisampleCI.sType = + VK_STRUCTURE_TYPE_PIPELINE_MULTISAMPLE_STATE_CREATE_INFO; + MultisampleCI.rasterizationSamples = VK_SAMPLE_COUNT_1_BIT; + + VkPipelineDepthStencilStateCreateInfo DepthStencilCI = {}; + DepthStencilCI.sType = + VK_STRUCTURE_TYPE_PIPELINE_DEPTH_STENCIL_STATE_CREATE_INFO; + DepthStencilCI.depthTestEnable = VK_TRUE; + DepthStencilCI.depthWriteEnable = VK_TRUE; + DepthStencilCI.depthCompareOp = VK_COMPARE_OP_LESS_OR_EQUAL; + DepthStencilCI.back.failOp = VK_STENCIL_OP_KEEP; + DepthStencilCI.back.passOp = VK_STENCIL_OP_KEEP; + DepthStencilCI.back.compareOp = VK_COMPARE_OP_ALWAYS; + DepthStencilCI.front = DepthStencilCI.back; + + llvm::SmallVector BlendAttachments( + RTFormats.size()); + for (auto &BA : BlendAttachments) + BA.colorWriteMask = 0xf; + VkPipelineColorBlendStateCreateInfo BlendCI = {}; + BlendCI.sType = VK_STRUCTURE_TYPE_PIPELINE_COLOR_BLEND_STATE_CREATE_INFO; + BlendCI.attachmentCount = static_cast(BlendAttachments.size()); + BlendCI.pAttachments = BlendAttachments.data(); + + VkGraphicsPipelineCreateInfo PipelineCI = {}; + PipelineCI.sType = VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_CREATE_INFO; + PipelineCI.stageCount = static_cast(ShaderStages.size()); + PipelineCI.pStages = ShaderStages.data(); + PipelineCI.pViewportState = &ViewportCI; + PipelineCI.pRasterizationState = &RastCI; + PipelineCI.pMultisampleState = &MultisampleCI; + PipelineCI.pDepthStencilState = &DepthStencilCI; + PipelineCI.pColorBlendState = &BlendCI; + PipelineCI.pDynamicState = &DynamicCI; + PipelineCI.layout = PipelineLayout; + PipelineCI.renderPass = RenderPassHandle; + + VkPipeline Pipeline = VK_NULL_HANDLE; + if (auto Err = VK::toError(vkCreateGraphicsPipelines(Device, VK_NULL_HANDLE, + 1, &PipelineCI, + nullptr, &Pipeline), + "Failed to create mesh shader pipeline.")) { + vkDestroyPipelineLayout(Device, PipelineLayout, nullptr); + for (auto *L : SetLayouts) + vkDestroyDescriptorSetLayout(Device, L, nullptr); + return Err; + } + + return std::make_unique( + Name, Device, Pipeline, PipelineLayout, std::move(SetLayouts)); + } + llvm::Expected> createFence(llvm::StringRef Name) override { return VulkanFence::create(Device, Name); @@ -1940,7 +2188,7 @@ class VulkanDevice : public offloadtest::Device { createCommandBuffer() override { return VulkanCommandBuffer::create( Device, GraphicsQueue.QueueFamilyIdx, CmdBeginDebugUtilsLabel, - CmdEndDebugUtilsLabel, CmdInsertDebugUtilsLabel); + CmdEndDebugUtilsLabel, CmdInsertDebugUtilsLabel, MeshShaderFns); } llvm::Expected> @@ -2998,7 +3246,7 @@ class VulkanDevice : public offloadtest::Device { << P.DispatchParameters.DispatchGroupCount[0] << ", " << P.DispatchParameters.DispatchGroupCount[1] << ", " << P.DispatchParameters.DispatchGroupCount[2] << " }\n"; - } else if (P.isTraditionalRaster()) { + } else if (P.isRaster()) { RenderPassBeginDesc BeginDesc = {}; BeginDesc.Pass = IS.RenderPass.get(); BeginDesc.ColorAttachments.push_back(IS.RenderTarget.get()); @@ -3021,14 +3269,22 @@ class VulkanDevice : public offloadtest::Device { Scissor.Height = static_cast(VP.Height); Encoder.setScissor(Scissor); - if (IS.VB) - Encoder.setVertexBuffer(0, IS.VB.get(), 0, - P.Bindings.getVertexStride()); + if (P.isTraditionalRaster()) { + if (IS.VB) + Encoder.setVertexBuffer(0, IS.VB.get(), 0, + P.Bindings.getVertexStride()); - if (auto Err = - Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(), - /*InstanceCount=*/1)) - return Err; + if (auto Err = + Encoder.drawInstanced(*IS.Pipeline.get(), P.getVertexCount(), + /*InstanceCount=*/1)) + return Err; + } else if (P.isMeshShaderRaster()) { + if (auto Err = Encoder.dispatchMesh( + *IS.Pipeline.get(), P.DispatchParameters.DispatchGroupCount[0], + P.DispatchParameters.DispatchGroupCount[1], + P.DispatchParameters.DispatchGroupCount[2])) + return Err; + } Encoder.endEncoding(); copyTextureToReadback(IS.CB->CmdBuffer, @@ -3162,7 +3418,7 @@ class VulkanDevice : public offloadtest::Device { auto CBOrErr = VulkanCommandBuffer::create( Device, GraphicsQueue.QueueFamilyIdx, CmdBeginDebugUtilsLabel, - CmdEndDebugUtilsLabel, CmdInsertDebugUtilsLabel); + CmdEndDebugUtilsLabel, CmdInsertDebugUtilsLabel, MeshShaderFns); if (!CBOrErr) return CBOrErr.takeError(); State.CB = std::move(*CBOrErr); @@ -3226,44 +3482,7 @@ class VulkanDevice : public offloadtest::Device { return PipelineStateOrErr.takeError(); State.Pipeline = std::move(*PipelineStateOrErr); llvm::outs() << "Compute Pipeline created.\n"; - } else if (P.isTraditionalRaster()) { - TraditionalRasterPipelineCreateDesc PipelineDesc = {}; - PipelineDesc.Topology = P.Bindings.Topology; - PipelineDesc.DSFormat = Format::D32FloatS8Uint; - for (auto &Shader : P.Shaders) { - ShaderContainer SC = {}; - SC.EntryPoint = Shader.Entry; - SC.Shader = Shader.Shader.get(); - SC.SpecializationConstants = Shader.SpecializationConstants; - PipelineDesc.setShader(Shader.Stage, std::move(SC)); - } - - // Create the input layout based on the vertex attributes. - for (auto &Attr : P.Bindings.VertexAttributes) { - auto FormatOrErr = toFormat(Attr.Format, Attr.Channels); - if (!FormatOrErr) - return FormatOrErr.takeError(); - - InputLayoutDesc Layout = {}; - Layout.Name = Attr.Name; - Layout.Fmt = *FormatOrErr; - Layout.OffsetInBytes = Attr.Offset; - PipelineDesc.InputLayout.push_back(Layout); - } - - auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format, - P.Bindings.RTargetBufferPtr->Channels); - if (!FormatOrErr) - return FormatOrErr.takeError(); - PipelineDesc.RTFormats.push_back(*FormatOrErr); - - auto PipelineStateOrErr = createTraditionalRasterPipeline( - "Graphics Pipeline State", BindingsDesc, PipelineDesc); - if (!PipelineStateOrErr) - return PipelineStateOrErr.takeError(); - State.Pipeline = std::move(*PipelineStateOrErr); - llvm::outs() << "Graphics Pipeline created.\n"; - + } else if (P.isRaster()) { ColorAttachmentFormatDesc ColorAttachment = {}; ColorAttachment.Fmt = State.RenderTarget->getDesc().Fmt; ColorAttachment.Load = LoadAction::Clear; @@ -3286,6 +3505,81 @@ class VulkanDevice : public offloadtest::Device { State.RenderPass = std::move(*RenderPassOrErr); llvm::outs() << "Render pass created.\n"; + if (P.isTraditionalRaster()) { + TraditionalRasterPipelineCreateDesc PipelineDesc = {}; + PipelineDesc.Topology = P.Bindings.Topology; + PipelineDesc.DSFormat = Format::D32FloatS8Uint; + for (auto &Shader : P.Shaders) { + ShaderContainer SC = {}; + SC.EntryPoint = Shader.Entry; + SC.Shader = Shader.Shader.get(); + SC.SpecializationConstants = Shader.SpecializationConstants; + PipelineDesc.setShader(Shader.Stage, std::move(SC)); + } + + // Create the input layout based on the vertex attributes. + for (auto &Attr : P.Bindings.VertexAttributes) { + auto FormatOrErr = toFormat(Attr.Format, Attr.Channels); + if (!FormatOrErr) + return FormatOrErr.takeError(); + + InputLayoutDesc Layout = {}; + Layout.Name = Attr.Name; + Layout.Fmt = *FormatOrErr; + Layout.OffsetInBytes = Attr.Offset; + PipelineDesc.InputLayout.push_back(Layout); + } + + auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format, + P.Bindings.RTargetBufferPtr->Channels); + if (!FormatOrErr) + return FormatOrErr.takeError(); + PipelineDesc.RTFormats.push_back(*FormatOrErr); + + auto PipelineStateOrErr = createTraditionalRasterPipeline( + "Graphics Pipeline State", BindingsDesc, PipelineDesc); + if (!PipelineStateOrErr) + return PipelineStateOrErr.takeError(); + State.Pipeline = std::move(*PipelineStateOrErr); + llvm::outs() << "Graphics Pipeline created.\n"; + } else if (P.isMeshShaderRaster()) { + std::optional AS = {}; + ShaderContainer MS = {}; + std::optional PS = {}; + for (auto &Shader : P.Shaders) { + if (Shader.Stage == Stages::Amplification) { + ShaderContainer Container; + Container.EntryPoint = Shader.Entry; + Container.Shader = Shader.Shader.get(); + AS = Container; + } else if (Shader.Stage == Stages::Mesh) { + MS.EntryPoint = Shader.Entry; + MS.Shader = Shader.Shader.get(); + } else if (Shader.Stage == Stages::Pixel) { + ShaderContainer Container; + Container.EntryPoint = Shader.Entry; + Container.Shader = Shader.Shader.get(); + PS = Container; + } + } + + auto FormatOrErr = toFormat(P.Bindings.RTargetBufferPtr->Format, + P.Bindings.RTargetBufferPtr->Channels); + if (!FormatOrErr) + return FormatOrErr.takeError(); + + llvm::SmallVector RTFormats; + RTFormats.push_back(*FormatOrErr); + + auto PipelineStateOrErr = + createPipelineAsMsPs("Mesh Shader Pipeline State", BindingsDesc, + RTFormats, Format::D32FloatS8Uint, AS, MS, PS); + if (!PipelineStateOrErr) + return PipelineStateOrErr.takeError(); + State.Pipeline = std::move(*PipelineStateOrErr); + llvm::outs() << "Mesh Shader Pipeline created.\n"; + } + if (auto Err = createFrameBuffer(State)) return Err; llvm::outs() << "Frame buffer created.\n"; @@ -3303,7 +3597,7 @@ class VulkanDevice : public offloadtest::Device { llvm::outs() << "Executed copy command buffer.\n"; auto DispatchCBOrErr = VulkanCommandBuffer::create( Device, GraphicsQueue.QueueFamilyIdx, CmdBeginDebugUtilsLabel, - CmdEndDebugUtilsLabel, CmdInsertDebugUtilsLabel); + CmdEndDebugUtilsLabel, CmdInsertDebugUtilsLabel, MeshShaderFns); if (!DispatchCBOrErr) return DispatchCBOrErr.takeError(); State.CB = std::move(*DispatchCBOrErr); diff --git a/test/lit.cfg.py b/test/lit.cfg.py index cc59667d7..5f92b6a2e 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -174,6 +174,8 @@ def setDeviceFeatures(config, device, compiler): # Add supported extensions. for Extension in device["Extensions"]: config.available_features.add(Extension["ExtensionName"]) + if Extension["ExtensionName"] == "VK_EXT_mesh_shader": + config.available_features.add("MeshShader") offloader_args = []