diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -350,9 +350,8 @@ // guarantees that we enter and exit in structured ways and the construct // is nestable. // 3. Put the new spv.mlir.selection/spv.mlir.loop op at the beginning of the - // old merge - // block and redirect all branches to the old header block to the old - // merge block (which contains the spv.mlir.selection/spv.mlir.loop op + // old merge block and redirect all branches to the old header block to the + // old merge block (which contains the spv.mlir.selection/spv.mlir.loop op // now). /// For OpPhi instructions, we use block arguments to represent them. OpPhi diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1733,6 +1733,7 @@ LLVM_DEBUG(llvm::dbgs() << "[cf] block " << block << " is a function entry block\n"); } + for (auto &op : *block) newBlock->push_back(op.clone(mapper)); } @@ -1746,9 +1747,8 @@ if (Block *mappedOp = mapper.lookupOrNull(succOp.get())) succOp.set(mappedOp); }; - for (auto &block : body) { + for (auto &block : body) block.walk(remapOperands); - } // We have created the SelectionOp/LoopOp and "moved" all blocks belonging to // the selection/loop construct into its region. Next we need to fix the @@ -1758,8 +1758,12 @@ // SelectionOp/LoopOp resides right now. headerBlock->replaceAllUsesWith(mergeBlock); + LLVM_DEBUG(llvm::dbgs() << "[cf] after cloning and fixing references:\n"); + LLVM_DEBUG(llvm::dbgs() << *headerBlock->getParentOp()); + LLVM_DEBUG(llvm::dbgs() << "\n"); + if (isLoop) { - // The loop selection/loop header block may have block arguments. Since now + // The selection/loop header block may have block arguments. Since now // we place the selection/loop op inside the old merge block, we need to // make sure the old merge block has the same block argument list. assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -406,6 +406,9 @@ // instruction to start a new SPIR-V block for ops following this SelectionOp. // The block should use the for the merge block. encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); + LLVM_DEBUG(llvm::dbgs() << "done merge "); + LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); return success(); } @@ -414,10 +417,9 @@ // properly. We don't need to assign for the entry block, which is just for // satisfying MLIR region's structural requirement. auto &body = loopOp.body(); - for (Block &block : - llvm::make_range(std::next(body.begin(), 1), body.end())) { + for (Block &block : llvm::make_range(std::next(body.begin(), 1), body.end())) getOrCreateBlockID(&block); - } + auto *headerBlock = loopOp.getHeaderBlock(); auto *continueBlock = loopOp.getContinueBlock(); auto *mergeBlock = loopOp.getMergeBlock(); @@ -469,6 +471,9 @@ // start a new SPIR-V block for ops following this LoopOp. The block should // use the for the merge block. encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID}); + LLVM_DEBUG(llvm::dbgs() << "done merge "); + LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << "\n"); return success(); } diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -238,14 +238,18 @@ /// assigns the next available uint32_t getOrCreateBlockID(Block *block); +#ifndef NDEBUG + /// (For debugging) prints the block with its result . + void printBlock(Block *block, raw_ostream &os); +#endif + /// Processes the given `block` and emits SPIR-V instructions for all ops /// inside. Does not emit OpLabel for this block if `omitLabel` is true. - /// `actionBeforeTerminator` is a callback that will be invoked before - /// handling the terminator op. It can be used to inject the Op*Merge - /// instruction if this is a SPIR-V selection/loop header block. - LogicalResult - processBlock(Block *block, bool omitLabel = false, - function_ref actionBeforeTerminator = nullptr); + /// `emitMerge` is a callback that will be invoked before handling the + /// terminator op to inject the Op*Merge instruction if this is a SPIR-V + /// selection/loop header block. + LogicalResult processBlock(Block *block, bool omitLabel = false, + function_ref emitMerge = nullptr); /// Emits OpPhi instructions for the given block if it has block arguments. LogicalResult emitPhiForBlockArguments(Block *block); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -921,16 +921,26 @@ return blockIDMap[block] = getNextID(); } +#ifndef NDEBUG +void Serializer::printBlock(Block *block, raw_ostream &os) { + os << "block " << block << " (id = "; + if (uint32_t id = getBlockID(block)) + os << id; + else + os << "unknown"; + os << ")\n"; +} +#endif + LogicalResult Serializer::processBlock(Block *block, bool omitLabel, - function_ref actionBeforeTerminator) { + function_ref emitMerge) { LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); LLVM_DEBUG(block->print(llvm::dbgs())); LLVM_DEBUG(llvm::dbgs() << '\n'); if (!omitLabel) { uint32_t blockID = getOrCreateBlockID(block); - LLVM_DEBUG(llvm::dbgs() - << "[block] " << block << " (id = " << blockID << ")\n"); + LLVM_DEBUG(printBlock(block, llvm::dbgs())); // Emit OpLabel for this block. encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); @@ -940,6 +950,24 @@ if (failed(emitPhiForBlockArguments(block))) return failure(); + // If we need to emit merge instructions, it must happen in this block. Check + // whether we have other structured control flow ops, which will be expanded + // into multiple basic blocks. If that's the case, we need to emit the merge + // right now and then create new blocks for further serialization of the ops + // in this block. + if (emitMerge && llvm::any_of(block->getOperations(), [](Operation &op) { + return isa(op); + })) { + if (failed(emitMerge())) + return failure(); + emitMerge = nullptr; + + // Start a new block for further serialization. + uint32_t blockID = getNextID(); + encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID}); + encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID}); + } + // Process each op in this block except the terminator. for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { if (failed(processOperation(&op))) @@ -947,8 +975,8 @@ } // Process the terminator. - if (actionBeforeTerminator) - if (failed(actionBeforeTerminator())) + if (emitMerge) + if (failed(emitMerge())) return failure(); if (failed(processOperation(&block->back()))) return failure(); @@ -962,14 +990,19 @@ if (block->args_empty() || block->isEntryBlock()) return success(); + LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n"); + // If the block has arguments, we need to create SPIR-V OpPhi instructions. // A SPIR-V OpPhi instruction is of the syntax: // OpPhi | result type | result | (value , parent block ) pair // So we need to collect all predecessor blocks and the arguments they send // to this block. SmallVector, 4> predecessors; - for (Block *predecessor : block->getPredecessors()) { - auto *terminator = predecessor->getTerminator(); + for (Block *mlirPredecessor : block->getPredecessors()) { + auto *terminator = mlirPredecessor->getTerminator(); + LLVM_DEBUG(llvm::dbgs() << " mlir predecessor "); + LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs())); + LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n"); // The predecessor here is the immediate one according to MLIR's IR // structure. It does not directly map to the incoming parent block for the // OpPhi instructions at SPIR-V binary level. This is because structured @@ -977,26 +1010,32 @@ // spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the // branch op jumping to the OpPhi's block then resides in the previous // structured control flow op's merge block. - predecessor = getPhiIncomingBlock(predecessor); + Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor); + LLVM_DEBUG(llvm::dbgs() << " spirv predecessor "); + LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs())); if (auto branchOp = dyn_cast(terminator)) { - predecessors.emplace_back(predecessor, branchOp.getOperands()); + predecessors.emplace_back(spirvPredecessor, branchOp.getOperands()); } else if (auto branchCondOp = dyn_cast(terminator)) { Optional blockOperands; + if (branchCondOp.trueTarget() == block) { + blockOperands = branchCondOp.trueTargetOperands(); + } else { + assert(branchCondOp.falseTarget() == block); + blockOperands = branchCondOp.falseTargetOperands(); + } - for (auto successorIdx : - llvm::seq(0, predecessor->getNumSuccessors())) - if (predecessor->getSuccessors()[successorIdx] == block) { - blockOperands = branchCondOp.getSuccessorOperands(successorIdx); - break; - } - - assert(blockOperands && !blockOperands->empty() && + assert(!blockOperands->empty() && "expected non-empty block operand range"); - predecessors.emplace_back(predecessor, *blockOperands); + predecessors.emplace_back(spirvPredecessor, *blockOperands); } else { return terminator->emitError("unimplemented terminator for Phi creation"); } + LLVM_DEBUG({ + llvm::dbgs() << " block arguments:\n"; + for (Value v : predecessors.back().second) + llvm::dbgs() << " " << v << "\n"; + }); } // Then create OpPhi instruction for each of the block argument. diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir --- a/mlir/test/Target/SPIRV/loop.mlir +++ b/mlir/test/Target/SPIRV/loop.mlir @@ -4,6 +4,7 @@ spv.module Logical GLSL450 requires #spv.vce { // for (int i = 0; i < count; ++i) {} +// CHECK-LABEL: @loop spv.func @loop(%count : i32) -> () "None" { %zero = spv.Constant 0: i32 %one = spv.Constant 1: i32 @@ -59,9 +60,12 @@ // ----- +// Single loop with block arguments + spv.module Logical GLSL450 requires #spv.vce { spv.GlobalVariable @GV1 bind(0, 0) : !spv.ptr [0])>, StorageBuffer> spv.GlobalVariable @GV2 bind(0, 1) : !spv.ptr [0])>, StorageBuffer> +// CHECK-LABEL: @loop_kernel spv.func @loop_kernel() "None" { %0 = spv.mlir.addressof @GV1 : !spv.ptr [0])>, StorageBuffer> %1 = spv.Constant 0 : i32 @@ -111,6 +115,7 @@ // for (int i = 0; i < count; ++i) { // for (int j = 0; j < count; ++j) { } // } +// CHECK-LABEL: @loop spv.func @loop(%count : i32) -> () "None" { %zero = spv.Constant 0: i32 %one = spv.Constant 1: i32 @@ -207,3 +212,77 @@ spv.EntryPoint "GLCompute" @main } + +// ----- + +// Loop with selection in its header + +spv.module Physical64 OpenCL requires #spv.vce { +// CHECK-LABEL: @kernel +// CHECK-SAME: (%[[INPUT0:.+]]: i64) + spv.func @kernel(%input: i64) "None" { +// CHECK-NEXT: %[[VAR:.+]] = spv.Variable : !spv.ptr +// CHECK-NEXT: spv.Branch ^[[BB:.+]](%[[INPUT0]] : i64) +// CHECK-NEXT: ^[[BB]](%[[INPUT1:.+]]: i64): + %cst0_i64 = spv.Constant 0 : i64 + %true = spv.Constant true + %false = spv.Constant false +// CHECK-NEXT: spv.mlir.loop { + spv.mlir.loop { +// CHECK-NEXT: spv.Branch ^[[LOOP_HEADER:.+]](%[[INPUT1]] : i64) + spv.Branch ^loop_header(%input : i64) +// CHECK-NEXT: ^[[LOOP_HEADER]](%[[ARG1:.+]]: i64): + ^loop_header(%arg1: i64): +// CHECK-NEXT: spv.Branch ^[[LOOP_BODY:.+]] +// CHECK-NEXT: ^[[LOOP_BODY]]: + %gt = spv.SGreaterThan %arg1, %cst0_i64 : i64 + %var = spv.Variable : !spv.ptr +// CHECK-NEXT: spv.mlir.selection { + spv.mlir.selection { +// CHECK-NEXT: %[[C0:.+]] = spv.Constant 0 : i64 +// CHECK-NEXT: %[[GT:.+]] = spv.SGreaterThan %[[ARG1]], %[[C0]] : i64 +// CHECK-NEXT: spv.BranchConditional %[[GT]], ^[[THEN:.+]], ^[[ELSE:.+]] + spv.BranchConditional %gt, ^then, ^else +// CHECK-NEXT: ^[[THEN]]: + ^then: +// CHECK-NEXT: %true = spv.Constant true +// CHECK-NEXT: spv.Store "Function" %[[VAR]], %true : i1 + spv.Store "Function" %var, %true : i1 +// CHECK-NEXT: spv.Branch ^[[SELECTION_MERGE:.+]] + spv.Branch ^selection_merge +// CHECK-NEXT: ^[[ELSE]]: + ^else: +// CHECK-NEXT: %false = spv.Constant false +// CHECK-NEXT: spv.Store "Function" %[[VAR]], %false : i1 + spv.Store "Function" %var, %false : i1 +// CHECK-NEXT: spv.Branch ^[[SELECTION_MERGE]] + spv.Branch ^selection_merge +// CHECK-NEXT: ^[[SELECTION_MERGE]]: + ^selection_merge: +// CHECK-NEXT: spv.mlir.merge + spv.mlir.merge +// CHECK-NEXT: } + } +// CHECK-NEXT: %[[LOAD:.+]] = spv.Load "Function" %[[VAR]] : i1 + %load = spv.Load "Function" %var : i1 +// CHECK-NEXT: spv.BranchConditional %[[LOAD]], ^[[CONTINUE:.+]](%[[ARG1]] : i64), ^[[LOOP_MERGE:.+]] + spv.BranchConditional %load, ^continue(%arg1 : i64), ^loop_merge +// CHECK-NEXT: ^[[CONTINUE]](%[[ARG2:.+]]: i64): + ^continue(%arg2: i64): +// CHECK-NEXT: %[[C0:.+]] = spv.Constant 0 : i64 +// CHECK-NEXT: %[[LT:.+]] = spv.SLessThan %[[ARG2]], %[[C0]] : i64 + %lt = spv.SLessThan %arg2, %cst0_i64 : i64 +// CHECK-NEXT: spv.Store "Function" %[[VAR]], %[[LT]] : i1 + spv.Store "Function" %var, %lt : i1 +// CHECK-NEXT: spv.Branch ^[[LOOP_HEADER]](%[[ARG2]] : i64) + spv.Branch ^loop_header(%arg2 : i64) +// CHECK-NEXT: ^[[LOOP_MERGE]]: + ^loop_merge: +// CHECK-NEXT: spv.mlir.merge + spv.mlir.merge +// CHECK-NEXT: } + } +// CHECK-NEXT: spv.Return + spv.Return + } +}