diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -75,16 +75,43 @@ return success(); } -/// Returns the last structured control flow op's merge block if the given -/// `block` contains any structured control flow op. Otherwise returns nullptr. -static Block *getLastStructuredControlFlowOpMergeBlock(Block *block) { +/// Returns the merge block if the given `op` is a structured control flow op. +/// Otherwise returns nullptr. +static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { + if (auto selectionOp = dyn_cast(op)) + return selectionOp.getMergeBlock(); + if (auto loopOp = dyn_cast(op)) + return loopOp.getMergeBlock(); + return nullptr; +} + +/// Given a predecessor `block` for a block with arguments, returns the block +/// that should be used as the parent block for SPIR-V OpPhi instructions +/// corresponding to the block arguments. +static Block *getPhiIncomingBlock(Block *block) { + // If the predecessor block in question is the entry block for a spv.loop, + // we jump to this spv.loop from its enclosing block. + if (block->isEntryBlock()) { + if (auto loopOp = dyn_cast(block->getParentOp())) { + // Then the incoming parent block for OpPhi should be the merge block of + // the structured control flow op before this loop. + Operation *op = loopOp.getOperation(); + while ((op = op->getPrevNode()) != nullptr) + if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) + return incomingBlock; + // Or the enclosing block itself if no structured control flow ops + // exists before this loop. + return loopOp.getOperation()->getBlock(); + } + } + + // Otherwise, we jump from the given predecessor block. Try to see if there is + // a structured control flow op inside it. for (Operation &op : llvm::reverse(block->getOperations())) { - if (auto selectionOp = dyn_cast(op)) - return selectionOp.getMergeBlock(); - if (auto loopOp = dyn_cast(op)) - return loopOp.getMergeBlock(); + if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) + return incomingBlock; } - return nullptr; + return block; } namespace { @@ -1374,12 +1401,14 @@ SmallVector, 4> predecessors; for (Block *predecessor : block->getPredecessors()) { auto *terminator = predecessor->getTerminator(); - // Check whether this predecessor block contains a structured control flow - // op. If so, the structured control flow op will be serialized to multiple - // SPIR-V blocks. The branch op jumping to the OpPhi's block then resides in - // the last structured control flow op's merge block. - if (auto *merge = getLastStructuredControlFlowOpMergeBlock(predecessor)) - predecessor = merge; + // The predecessor here is the immediate one according to MLIR's IR + // structure. It does not directly map to the incoming parent block for the + // OpPhi instructions at SPIR-V binary level. This is because structured + // control flow ops are serialized to multiple SPIR-V blocks. If there is a + // spv.selection/spv.loop op in the MLIR predecessor block, the branch op + // jumping to the OpPhi's block then resides in the previous structured + // control flow op's merge block. + predecessor = getPhiIncomingBlock(predecessor); if (auto branchOp = dyn_cast(terminator)) { predecessors.emplace_back(predecessor, branchOp.operand_begin()); } else { @@ -1400,6 +1429,7 @@ LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' << arg << " (id = " << phiID << ")\n"); + // Prepare the (value , parent block ) pairs. SmallVector phiArgs; phiArgs.push_back(phiTypeID); phiArgs.push_back(phiID); @@ -1499,16 +1529,9 @@ // afterwards. encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); - // We omit the LoopOp's entry block and start serialization from the loop - // header block. The entry block should not contain any additional ops other - // than a single spv.Branch that jumps to the loop header block. However, - // the spv.Branch can contain additional block arguments. Those block - // arguments must come from out of the loop using implicit capture. We will - // need to query the for the value sent and the for the incoming - // parent block. For the latter, we need to make sure this block is - // registered. The value sent should come from the block this loop resides in. - blockIDMap[loopOp.getEntryBlock()] = - getBlockID(loopOp.getOperation()->getBlock()); + // LoopOp's entry block is just there for satisfying MLIR's structural + // requirements so we omit it and start serialization from the loop header + // block. // Emit the loop header block, which dominates all other blocks, first. We // need to emit an OpLoopMerge instruction before the loop header block's diff --git a/mlir/test/Dialect/SPIRV/Serialization/phi.mlir b/mlir/test/Dialect/SPIRV/Serialization/phi.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/phi.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/phi.mlir @@ -236,3 +236,53 @@ spv.EntryPoint "GLCompute" @fmul_kernel, @__builtin_var_WorkgroupId__, @__builtin_var_NumWorkgroups__ spv.ExecutionMode @fmul_kernel "LocalSize", 32, 1, 1 } + +// ----- + +// Test back-to-back loops with block arguments + +spv.module Logical GLSL450 requires #spv.vce { + spv.func @fmul_kernel() "None" { + %cst4 = spv.constant 4 : i32 + + %val1 = spv.constant 43 : i32 + %val2 = spv.constant 44 : i32 + +// CHECK: spv.constant 43 +// CHECK-NEXT: spv.Branch ^[[BB1:.+]](%{{.+}} : i32) +// CHECK-NEXT: ^[[BB1]](%{{.+}}: i32): +// CHECK-NEXT: spv.loop + spv.loop { // loop 1 + spv.Branch ^bb1(%val1 : i32) + ^bb1(%loop1_bb_arg: i32): + %loop1_lt = spv.SLessThan %loop1_bb_arg, %cst4 : i32 + spv.BranchConditional %loop1_lt, ^bb2, ^bb3 + ^bb2: + %loop1_add = spv.IAdd %loop1_bb_arg, %cst4 : i32 + spv.Branch ^bb1(%loop1_add : i32) + ^bb3: + spv._merge + } + +// CHECK: spv.constant 44 +// CHECK-NEXT: spv.Branch ^[[BB2:.+]](%{{.+}} : i32) +// CHECK-NEXT: ^[[BB2]](%{{.+}}: i32): +// CHECK-NEXT: spv.loop + spv.loop { // loop 2 + spv.Branch ^bb1(%val2 : i32) + ^bb1(%loop2_bb_arg: i32): + %loop2_lt = spv.SLessThan %loop2_bb_arg, %cst4 : i32 + spv.BranchConditional %loop2_lt, ^bb2, ^bb3 + ^bb2: + %loop2_add = spv.IAdd %loop2_bb_arg, %cst4 : i32 + spv.Branch ^bb1(%loop2_add : i32) + ^bb3: + spv._merge + } + + spv.Return + } + + spv.EntryPoint "GLCompute" @fmul_kernel + spv.ExecutionMode @fmul_kernel "LocalSize", 32, 1, 1 +}