//-------------------------------------------------------------------------------------------------------------------------------------------------------------
//
// Copyright 2023 Apple Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//-------------------------------------------------------------------------------------------------------------------------------------------------------------


#include "ShaderPipelineBuilder.hpp"

#include "ShaderCompiler.hpp"
#include <simd/simd.h>

namespace shader_pipeline
{
    struct DXILFunctionDescriptor
    {
        enum class CompilationOption
        {
            Skip,
            SingleFunction,
            FuseFunction
        };
        
        const char*       entryPointName;
        IRShaderStage     shaderStage;
        CompilationOption compilationOption;
        const char*       fuseEntryPointName;
        const char*       renameEntryPoint;
    };

    struct DXILHitGroupDescriptor
    {
        std::unordered_map<uint32_t, DXILFunctionDescriptor> functions;
        IRHitGroupType                      hitGroupType;
        uint32_t                            maxAttributeSizeInBytes;
        IRObject*                           pIR;
    };

    static IRRootSignature* newRTComputeGlobalRootSignature();
    static IRRootSignature* newRTComputeLocalRootSignature();
    static std::vector<DXILHitGroupDescriptor> getRaytracingShaders(const std::string& shaderSearchPath);
    static std::vector<MTL::Function*>         compileRayTracingFunctions(const std::vector<DXILHitGroupDescriptor>& sourceFunctions, const IRRootSignature* pGlobalRootSignature, const IRRootSignature* pLocalRootSignature, MTL::Device* pDevice);
    static std::vector<MTL::Function*>         synthesizeIndirectIntersectionFunctions(MTL::Device* pDevice);
    static MTL::ComputePipelineState*          newRaytracingPipeline(const std::vector<MTL::Function*>& rayTracingFunctions, const std::vector<MTL::Function*>& intersectionFunctions, MTL::Device* pDevice);
    static MTL::VisibleFunctionTable*          newVisibleFunctionTable(const std::vector<MTL::Function*>& raytracingFunctions, MTL::ComputePipelineState* pRTPSO);
    static MTL::IntersectionFunctionTable*     newIntersectionFunctionTable(const std::vector<MTL::Function*>& indirectIntersectionFunctions, MTL::ComputePipelineState* pRTPSO);
}

static IRRootSignature* shader_pipeline::newRTComputeGlobalRootSignature()
{
    IRDescriptorRange1 uavRange = {
        .RangeType                         = IRDescriptorRangeTypeUAV,
        .NumDescriptors                    = 1,
        .BaseShaderRegister                = 0,
        .RegisterSpace                     = 1,
        .OffsetInDescriptorsFromTableStart = 0
    };

    IRRootParameter1 params[] = {
        {   /* Acceleration Structure */
            .ParameterType    = IRRootParameterTypeSRV,
            .ShaderVisibility = IRShaderVisibilityAll,
            .Descriptor       = {
                .ShaderRegister = 0,
                .RegisterSpace  = 0,
                .Flags          = IRRootDescriptorFlagDataStatic }
        },
        {
            /* RW Texture */
            .ParameterType = IRRootParameterTypeDescriptorTable,
            .ShaderVisibility = IRShaderVisibilityAll,
            .DescriptorTable = {
                .NumDescriptorRanges = 1,
                .pDescriptorRanges = &uavRange
            }
        }
    };

    IRVersionedRootSignatureDescriptor rootSigDesc;
    memset(&rootSigDesc, 0x0, sizeof(IRVersionedRootSignatureDescriptor));
    rootSigDesc.version                = IRRootSignatureVersion_1_1;
    rootSigDesc.desc_1_1.NumParameters = sizeof(params) / sizeof(IRRootParameter1);
    rootSigDesc.desc_1_1.pParameters   = params;

    IRError* pRootSigError    = nullptr;
    IRRootSignature* pRootSig = IRRootSignatureCreateFromDescriptor(&rootSigDesc, &pRootSigError);
    
    if (!pRootSig)
    {
        printf("Error creating GRS: %s\n", (const char *)IRErrorGetPayload(pRootSigError));
        IRErrorDestroy(pRootSigError);
    }
    
    return pRootSig;
}

static IRRootSignature* shader_pipeline::newRTComputeLocalRootSignature()
{
    IRError* pRootSigError = nullptr;
    
    IRRootParameter1 localParams[] = {
        {
            .ParameterType = IRRootParameterType32BitConstants,
            .ShaderVisibility = IRShaderVisibilityAll,
            .Constants = {
                .ShaderRegister = 0,
                .RegisterSpace  = 16,
                .Num32BitValues = 1
            }
        }
    };
    
    IRVersionedRootSignatureDescriptor localRootSigDesc;
    memset(&localRootSigDesc, 0x0, sizeof(IRVersionedRootSignatureDescriptor));
    localRootSigDesc.version = IRRootSignatureVersion_1_1;
    localRootSigDesc.desc_1_1.NumParameters = sizeof(localParams) / sizeof(IRRootParameter1);
    localRootSigDesc.desc_1_1.pParameters = localParams;
    
    IRRootSignature* pLocalRootSig = IRRootSignatureCreateFromDescriptor(&localRootSigDesc, &pRootSigError);
    if (!pLocalRootSig)
    {
        printf("Error creating LRS: %s\n", (const char *)IRErrorGetPayload(pRootSigError));
        IRErrorDestroy(pRootSigError);
    }
    
    return pLocalRootSig;
}

static std::vector<shader_pipeline::DXILHitGroupDescriptor> shader_pipeline::getRaytracingShaders(const std::string& shaderSearchPath)
{
    std::string triangleRTLibraryPath = shaderSearchPath + "/rt_triangle_pipeline.dxil";
    IRObject* pTriangleDXIL = newDXILObject(triangleRTLibraryPath);
    assert(pTriangleDXIL);
    
    std::string sphereRTLibraryPath = shaderSearchPath + "/rt_procedural_pipeline.dxil";
    IRObject* pSphereDXIL = newDXILObject(sphereRTLibraryPath);
    assert(pSphereDXIL);

    
    using shader_pipeline::DXILHitGroupDescriptor;
    using shader_pipeline::DXILFunctionDescriptor;
    using Opt=shader_pipeline::DXILFunctionDescriptor::CompilationOption;
    
    std::vector<DXILHitGroupDescriptor> raytracingShaders {
        DXILHitGroupDescriptor {
            .functions = {
                {kSphereRayGenIndex,             { "MainRayGen",         IRShaderStageRayGeneration, Opt::SingleFunction, nullptr,        "SphereRayGen"       }},
                {kSphereIntersectionIndex,       { "SphereIntersection", IRShaderStageIntersection,  Opt::FuseFunction,   nullptr,        "SphereIntersection" }},
                {kSphereClosestHitIndex,         { "SphereClosestHit",   IRShaderStageClosestHit,    Opt::SingleFunction, nullptr,        "SphereClosestHit"   }},
                {kSphereMissIndex,               { "MainMiss",           IRShaderStageMiss,          Opt::SingleFunction, nullptr,        "SphereMiss"         }},
                {kSphereIntersectionAnyHitIndex, { "SphereIntersection", IRShaderStageAnyHit,        Opt::SingleFunction, "SphereAnyHit", "SphereAnyHit"       }}
            },
            .hitGroupType = IRHitGroupTypeProceduralPrimitive,
            .maxAttributeSizeInBytes = 16,
            .pIR = pSphereDXIL,
        },
        DXILHitGroupDescriptor {
            .functions = {
                {kTriangleRayGenIndex,     { "MainRayGen",         IRShaderStageRayGeneration, Opt::SingleFunction, nullptr,        "TriangleRayGen"    }},
                {kTriangleClosestHitIndex, { "TriangleClosestHit", IRShaderStageClosestHit,    Opt::SingleFunction, nullptr,        "TriangleClosestHit"}},
                {kTriangleMissIndex,       { "MainMiss",           IRShaderStageMiss,          Opt::SingleFunction, nullptr,        "TriangleMiss"      }},
                {kTriangleAnyHitIndex,     { "TriangleAnyHit",     IRShaderStageAnyHit,        Opt::SingleFunction, nullptr,        "TriangleAnyHit"    }}
            },
            .hitGroupType = IRHitGroupTypeTriangles,
            .maxAttributeSizeInBytes = 16,
            .pIR = pTriangleDXIL
        }
    };

    return raytracingShaders;
}

static std::vector<MTL::Function*> shader_pipeline::compileRayTracingFunctions(const std::vector<shader_pipeline::DXILHitGroupDescriptor>& sourceFunctions, const IRRootSignature* pGlobalRootSignature, const IRRootSignature* pLocalRootSignature, MTL::Device* pDevice)
{
    using shader_pipeline::DXILHitGroupDescriptor;
    using shader_pipeline::DXILFunctionDescriptor;
    
    std::vector<MTL::Function*> compiledFunctions;
    
    IRCompiler* pCompiler = IRCompilerCreate();
    IRCompilerSetMinimumDeploymentTarget(pCompiler, IROperatingSystem_macOS, "14.0.0");
    IRCompilerSetGlobalRootSignature(pCompiler, pGlobalRootSignature);
    IRCompilerSetLocalRootSignature(pCompiler, pLocalRootSignature);

    
    // Gather intrinsic masks to tailor code generation to what's needed
    // by the shaders to operate together, improving runtime performance.
    
    uint64_t chsMask    = 0;
    uint64_t missMask   = 0;
    uint64_t anyHitMask = 0;
    
    for (const auto& hg : sourceFunctions)
    {
        for (const auto& [index, function] : hg.functions)
        {
            switch (function.shaderStage)
            {
                case IRShaderStageClosestHit:
                    chsMask |= IRObjectGatherRaytracingIntrinsics(hg.pIR, function.entryPointName);
                    break;
                case IRShaderStageMiss:
                    missMask |= IRObjectGatherRaytracingIntrinsics(hg.pIR, function.entryPointName);
                    break;
                case IRShaderStageAnyHit:
                    if (function.compilationOption == shader_pipeline::DXILFunctionDescriptor::CompilationOption::FuseFunction)
                    {
                        assert(function.fuseEntryPointName);
                        anyHitMask |= IRObjectGatherRaytracingIntrinsics(hg.pIR, function.fuseEntryPointName);
                    }
                    break;
                case IRShaderStageCallable:
                    // If this sample contained callable shaders, here it would also OR its intrinsic mask.
                    break;
                default:
                    break;
            }
        }
    }
        
    for (const auto& hg : sourceFunctions)
    {
        IRCompilerSetRayTracingPipelineArguments(pCompiler,
                                                 hg.maxAttributeSizeInBytes,
                                                 IRRaytracingPipelineFlagNone,
                                                 chsMask,
                                                 missMask,
                                                 anyHitMask,
                                                 ~0,
                                                 -1,
                                                 IRRayGenerationCompilationVisibleFunction,
                                                 IRIntersectionFunctionCompilationVisibleFunction);
        
        IRCompilerSetHitgroupType(pCompiler, hg.hitGroupType);
        
        for ( const auto& [index, function] : hg.functions )
        {
            
            IRCompilerSetEntryPointName(pCompiler, function.renameEntryPoint);
            
            MTL::Library* pLib = newLibraryFromDXILUsingCompiler(hg.pIR, function.shaderStage, function.entryPointName, pCompiler, pDevice, function.fuseEntryPointName);
            assert(pLib);
            assert(pLib->functionNames()->count() > 0);
            
            // If the metallib only has one function, use that one, otherwise, rely on the
            // entry point name to disambiguate.
            const char* entryPoint = function.entryPointName;
            
            if (pLib->functionNames()->count() == 1)
            {
                entryPoint = reinterpret_cast<NS::String *>(pLib->functionNames()->object(0))->utf8String();
            }

            MTL::Function* pFn = newFunction(pLib, entryPoint);
            assert(pFn);
            
            if (index >= compiledFunctions.size())
            {
                compiledFunctions.resize(index + 1);
            }
            compiledFunctions[index] = pFn; // ensure the index in compiledFunctions vector matches the constexpr index from the header, as the ShaderIndentifiers reference visible functions by index.
        }
        
    }

    IRCompilerDestroy(pCompiler);
    return compiledFunctions;
    
}

static std::vector<MTL::Function*> shader_pipeline::synthesizeIndirectIntersectionFunctions(MTL::Device* pDevice)
{
    std::vector<MTL::Function*> intersectionFunctions(kMaxIntersectionFunctions);
    
    MTL::Library* pIndirectBoxIntersectionLib = newSBTIntersectionFunction(IRHitGroupTypeProceduralPrimitive, pDevice);
    MTL::Function* pIndirectBoxIntersectionFn = pIndirectBoxIntersectionLib->newFunction(NS::String::string(kIRIndirectProceduralIntersectionFunctionName, NS::UTF8StringEncoding));
    assert(pIndirectBoxIntersectionLib);
    assert(pIndirectBoxIntersectionFn);
    intersectionFunctions[kSBTBoxIntersectionFunctionIndex] = pIndirectBoxIntersectionFn;
    
    MTL::Library* pIndirectTriangleIntersectionLib = newSBTIntersectionFunction(IRHitGroupTypeTriangles, pDevice);
    MTL::Function* pIndirectTriangleIntersectionFn = pIndirectTriangleIntersectionLib->newFunction(NS::String::string(kIRIndirectTriangleIntersectionFunctionName, NS::UTF8StringEncoding));
    assert(pIndirectTriangleIntersectionLib);
    assert(pIndirectTriangleIntersectionFn);
    intersectionFunctions[kSBTTriangleIntersectionFunctionIndex] = pIndirectTriangleIntersectionFn;
    
    pIndirectTriangleIntersectionLib->release();
    pIndirectBoxIntersectionLib->release();
    
    return intersectionFunctions;
}

static MTL::ComputePipelineState* shader_pipeline::newRaytracingPipeline(const std::vector<MTL::Function *> &rayTracingFunctions, const std::vector<MTL::Function *> &intersectionFunctions, MTL::Device *pDevice)
{
    MTL::ComputePipelineDescriptor* pDesc = MTL::ComputePipelineDescriptor::alloc()->init()->autorelease();
    
    // Collect all functions to link into the PSO (including SBT intersection functions):
    
    std::vector<MTL::Function*> allFunctions;
    allFunctions.insert(allFunctions.end(), rayTracingFunctions.begin(), rayTracingFunctions.end());
    allFunctions.insert(allFunctions.end(), intersectionFunctions.begin(), intersectionFunctions.end());
    std::erase(allFunctions, nullptr);
    
    NS::Array* pFunctionsToLink = ((NS::Array*)CFArrayCreate(CFAllocatorGetDefault(),
                                                             (const void **)(allFunctions.data()),
                                                             allFunctions.size(),
                                                             &kCFTypeArrayCallBacks))->autorelease();
    
    // Reference via pipeline descriptor:
    
    MTL::LinkedFunctions* pLinkedFunctions = MTL::LinkedFunctions::alloc()->init()->autorelease();
    pLinkedFunctions->setFunctions(pFunctionsToLink);
    pDesc->setLinkedFunctions(pLinkedFunctions);
    
    // Create indirect dispatch rays function:
    
    MTL::Library* pDispatchLib = newIndirectRTDispatchLibrary(pDevice)->autorelease();
    MTL::Function* pDispatchFn = pDispatchLib->newFunction(NS::String::string(kIRRayDispatchIndirectionKernelName, NS::UTF8StringEncoding))->autorelease();
    assert(pDispatchLib);
    assert(pDispatchFn);
    
    pDesc->setComputeFunction(pDispatchFn);
    
    // Build pipeline:
    NS::Error* pMtlError = nullptr;
    MTL::ComputePipelineState* pRTPSO = pDevice->newComputePipelineState(pDesc, MTL::PipelineOptionNone, nullptr, &pMtlError);
    if (!pRTPSO)
    {
        printf("Error building RTPSO: %s", pMtlError->localizedDescription()->utf8String());
    }
    
    return pRTPSO;
}

static MTL::VisibleFunctionTable* shader_pipeline::newVisibleFunctionTable(const std::vector<MTL::Function*>& raytracingFunctions, MTL::ComputePipelineState* pRTPSO)
{
    // The visible function table indices can be any value > 0. Index '0' is typically reserved for the "null function".
    // The compileRaytracingFunctions() call has laid out the raytracingFunction vector indices to match the constexpr indices in the header file,
    // which is necessary because RenderCore.cpp's buildSBT function creates ShaderIdentifiers that reference the visible functions by the indices in the header.
    
    MTL::VisibleFunctionTableDescriptor* pVFTDesc = MTL::VisibleFunctionTableDescriptor::alloc()->init()->autorelease();
    pVFTDesc->setFunctionCount(raytracingFunctions.size()); // raytracingFunctions reserves index '0' for the "null function"
    
    MTL::VisibleFunctionTable* pVFT = pRTPSO->newVisibleFunctionTable(pVFTDesc);
    for (size_t i = 0; i < raytracingFunctions.size(); ++i)
    {
        if (raytracingFunctions[i])
        {
            MTL::FunctionHandle* pHandle = pRTPSO->functionHandle(raytracingFunctions[i]);
            pVFT->setFunction(pHandle, i);
        }
    }

    return pVFT;
}

static MTL::IntersectionFunctionTable* shader_pipeline::newIntersectionFunctionTable(const std::vector<MTL::Function*>& indirectIntersectionFunctions, MTL::ComputePipelineState* pRTPSO)
{
    // The intersection function indices can be any value >= 0, but should always reference synthesized indirect intersection functions.
    // These functions are referenced from RenderCore.cpp's buildSBT function, as the ShaderIdentifiers reference these indices for SBT hitgroups.
    
    assert(indirectIntersectionFunctions.size() == kMaxIntersectionFunctions);
    
    MTL::IntersectionFunctionTableDescriptor* pIFTDesc = MTL::IntersectionFunctionTableDescriptor::alloc()->init()->autorelease();
    pIFTDesc->setFunctionCount(kMaxIntersectionFunctions);
    
    MTL::IntersectionFunctionTable* pIFT = pRTPSO->newIntersectionFunctionTable(pIFTDesc);
    
    // Indirect triangle intersection:
    const MTL::FunctionHandle* pSBTTriangleIntersectionHandle = pRTPSO->functionHandle(indirectIntersectionFunctions[kSBTTriangleIntersectionFunctionIndex]);
    pIFT->setFunction(pSBTTriangleIntersectionHandle, kSBTTriangleIntersectionFunctionIndex);
    
    // Indirect box intersection:
    const MTL::FunctionHandle* pSBTBoxIntersectionHandle = pRTPSO->functionHandle(indirectIntersectionFunctions[kSBTBoxIntersectionFunctionIndex]);
    pIFT->setFunction(pSBTBoxIntersectionHandle, kSBTBoxIntersectionFunctionIndex);

    return pIFT;
}


MTL::RenderPipelineState* shader_pipeline::newPresentPipeline( const std::string& shaderSearchPath, MTL::Device* pDevice )
{
    std::vector<IRDescriptorRange1> srvRanges {
        IRDescriptorRange1 {
            .RangeType = IRDescriptorRangeTypeSRV,
            .NumDescriptors = 1,
            .BaseShaderRegister = 0,
            .RegisterSpace = 0,
            .Flags = IRDescriptorRangeFlagNone,
            .OffsetInDescriptorsFromTableStart = 0
        }
    };
    
    std::vector<IRDescriptorRange1> smpRanges = {
        IRDescriptorRange1 {
            .RangeType = IRDescriptorRangeTypeSampler,
            .NumDescriptors = 1,
            .BaseShaderRegister = 0,
            .RegisterSpace = 1,
            .Flags = IRDescriptorRangeFlagNone,
            .OffsetInDescriptorsFromTableStart = 0
        }
    };
    
    std::vector<IRRootParameter1> rootParams = {
        {
            .ParameterType = IRRootParameterTypeDescriptorTable,
            .ShaderVisibility = IRShaderVisibilityPixel,
            .DescriptorTable = {
                .pDescriptorRanges = srvRanges.data(),
                .NumDescriptorRanges = (uint32_t)srvRanges.size()
            }
        },
        {
            .ParameterType = IRRootParameterTypeDescriptorTable,
            .ShaderVisibility = IRShaderVisibilityPixel,
            .DescriptorTable = {
                .pDescriptorRanges = smpRanges.data(),
                .NumDescriptorRanges = (uint32_t)smpRanges.size()
            }
        }
    };
    
    IRVersionedRootSignatureDescriptor rootSigDesc;
    rootSigDesc.version = IRRootSignatureVersion_1_1;
    rootSigDesc.desc_1_1.pParameters = rootParams.data();
    rootSigDesc.desc_1_1.NumParameters = (uint32_t)rootParams.size();
    rootSigDesc.desc_1_1.NumStaticSamplers = 0;
    
    IRError* pIrError = nullptr;
    IRRootSignature* pRootSig = IRRootSignatureCreateFromDescriptor(&rootSigDesc, &pIrError);
    assert(pRootSig);
    
    NS::SharedPtr<MTL::Library> pVtxLib = NS::TransferPtr(newLibraryFromDXIL(shaderSearchPath + "/present_vs.dxil",
                                                                             IRShaderStageVertex,
                                                                             "MainVS",
                                                                             pRootSig,
                                                                             pDevice));
    assert(pVtxLib);
    
    NS::SharedPtr<MTL::Library> pFragLib = NS::TransferPtr(newLibraryFromDXIL(shaderSearchPath + "/present_fs.dxil",
                                                                              IRShaderStageFragment,
                                                                              "MainFS",
                                                                              pRootSig,
                                                                              pDevice));
    assert(pFragLib);
    
    NS::SharedPtr<MTL::Function> pVFn = NS::TransferPtr(pVtxLib->newFunction(MTLSTR("MainVS")));
    NS::SharedPtr<MTL::Function> pFFn = NS::TransferPtr(pFragLib->newFunction(MTLSTR("MainFS")));
    assert(pVFn);
    assert(pFFn);
    
    NS::SharedPtr<MTL::VertexDescriptor> pVtxDesc = NS::TransferPtr(MTL::VertexDescriptor::alloc()->init());
    auto pAttrib0 = pVtxDesc->attributes()->object(kIRStageInAttributeStartIndex + 0);
    auto pAttrib1 = pVtxDesc->attributes()->object(kIRStageInAttributeStartIndex + 1);
    auto pLayout  = pVtxDesc->layouts()->object(kIRVertexBufferBindPoint);
    
    pAttrib0->setFormat(MTL::VertexFormatFloat4);
    pAttrib0->setOffset(0);
    pAttrib0->setBufferIndex(kIRVertexBufferBindPoint);
    
    pAttrib1->setFormat(MTL::VertexFormatFloat2);
    pAttrib1->setOffset(sizeof(simd::float4));
    pAttrib1->setBufferIndex(kIRVertexBufferBindPoint);
    
    pLayout->setStride(sizeof(simd::float4) + sizeof(simd::float4)); // clang pads the VertexData struct to 32 bytes, so that's the stride (16+16)
    pLayout->setStepRate(1);
    pLayout->setStepFunction(MTL::VertexStepFunctionPerVertex);
    
    NS::SharedPtr<MTL::RenderPipelineDescriptor> pPsoDesc = NS::TransferPtr(MTL::RenderPipelineDescriptor::alloc()->init());
    pPsoDesc->setVertexDescriptor(pVtxDesc.get());
    pPsoDesc->setVertexFunction(pVFn.get());
    pPsoDesc->setFragmentFunction(pFFn.get());
    
    auto pColorDesc = pPsoDesc->colorAttachments()->object(0);
    pColorDesc->setPixelFormat(MTL::PixelFormatBGRA8Unorm_sRGB);
    pColorDesc->setBlendingEnabled(true);
    pColorDesc->setSourceRGBBlendFactor(MTL::BlendFactorSourceAlpha);
    pColorDesc->setSourceAlphaBlendFactor(MTL::BlendFactorSourceAlpha);
    pColorDesc->setDestinationRGBBlendFactor(MTL::BlendFactorOneMinusSourceAlpha);
    pColorDesc->setDestinationAlphaBlendFactor(MTL::BlendFactorOneMinusSourceAlpha);
    
    NS::Error* pMtlError = nullptr;
    MTL::RenderPipelineState* pPso = pDevice->newRenderPipelineState(pPsoDesc.get(), &pMtlError);
    assert(pPso);
    
    IRRootSignatureDestroy(pRootSig);
    
    return pPso;
}

using shader_pipeline::RTPSOContext;

RTPSOContext shader_pipeline::newTriangleSphereRTPipeline(const std::string& shaderSearchPath, MTL::Device* pDevice)
{
    NS::AutoreleasePool* pPool = NS::AutoreleasePool::alloc()->init();
    
    // Root signatures:
    
    IRRootSignature* pGlobalRootSig = shader_pipeline::newRTComputeGlobalRootSignature();
    IRRootSignature* pLocalRootSig  = shader_pipeline::newRTComputeLocalRootSignature();
    
    assert(pGlobalRootSig);
    assert(pLocalRootSig);
    
    // Compile input shaders into Metal visible functions:
    
    using shader_pipeline::DXILHitGroupDescriptor;
    std::vector<DXILHitGroupDescriptor> inputFunctions        = getRaytracingShaders(shaderSearchPath);
    std::vector<MTL::Function*> compiledFunctions             = compileRayTracingFunctions(inputFunctions, pGlobalRootSig, pLocalRootSig, pDevice);
    std::vector<MTL::Function*> indirectIntersectionFunctions = synthesizeIndirectIntersectionFunctions(pDevice);

    // Build pipeline:
    
    MTL::ComputePipelineState* pRTPSO = newRaytracingPipeline(compiledFunctions, indirectIntersectionFunctions, pDevice);
    assert(pRTPSO);
    
    // Build visible function table:
    
    MTL::VisibleFunctionTable* pVFT = newVisibleFunctionTable(compiledFunctions, pRTPSO);
    assert(pVFT);
    
    // Build intersection function table:
    
    MTL::IntersectionFunctionTable* pIFT = newIntersectionFunctionTable(indirectIntersectionFunctions, pRTPSO);
    assert(pIFT);
    
    // Release resources:
    
    for (auto&& indirectIntersectionFn : indirectIntersectionFunctions)
    {
        indirectIntersectionFn->release();
    }
    
    for (auto&& compiledFn : compiledFunctions)
    {
        compiledFn->release();
    }
    
    for (auto&& hgd : inputFunctions)
    {
        IRObjectDestroy(hgd.pIR);
    }
    
    IRRootSignatureDestroy(pLocalRootSig);
    IRRootSignatureDestroy(pGlobalRootSig);
    
    pPool->release();
    
    return RTPSOContext {
        .pRTPSO = pRTPSO,
        .pVFT   = pVFT,
        .pIFT   = pIFT
    };
}
