diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -1763,15 +1763,19 @@ LLVM_DEBUG(llvm::dbgs() << "\n"); if (isLoop) { - // The selection/loop header block may have block arguments. Since now - // we place the selection/loop op inside the old merge block, we need to - // make sure the old merge block has the same block argument list. - assert(mergeBlock->args_empty() && "OpPhi in loop merge block unsupported"); + if (!mergeBlock->args_empty()) { + return mergeBlock->getParentOp()->emitError( + "OpPhi in loop merge block unsupported"); + } + + // The loop header block may have block arguments. Since now we place the + // loop op inside the old merge block, we need to make sure the old merge + // block has the same block argument list. for (BlockArgument blockArg : headerBlock->getArguments()) { mergeBlock->addArgument(blockArg.getType()); } - // If the loop header block has block arguments, make sure the spv.branch op + // If the loop header block has block arguments, make sure the spv.Branch op // matches. SmallVector blockArgs; if (!headerBlock->args_empty()) @@ -1792,6 +1796,19 @@ for (auto *block : constructBlocks) block->dropAllReferences(); + // Check that whether some op in the to-be-erased blocks still has uses. Those + // uses come from blocks that won't be sinked into the SelectionOp/LoopOp's + // region. We cannot handle such cases given that once a value is sinked into + // the SelectionOp/LoopOp's region, there is no escape for it: + // SelectionOp/LooOp does not support yield values right now. + for (auto *block : constructBlocks) { + for (Operation &op : *block) + if (!op.use_empty()) + return op.emitOpError( + "failed control flow structurization: it has uses outside of the " + "enclosing selection/loop construct"); + } + // Then erase all old blocks. for (auto *block : constructBlocks) { // We've cloned all blocks belonging to this construct into the structured @@ -1799,26 +1816,31 @@ // selection/loop. If so, they will be recorded within blockMergeInfo. // We need to update the pointers there to the newly remapped ones so we can // continue structurizing them later. - // TODO: The asserts in the following assumes input SPIR-V blob - // forms correctly nested selection/loop constructs. We should relax this - // and support error cases better. + // TODO: The asserts in the following assumes input SPIR-V blob forms + // correctly nested selection/loop constructs. We should relax this and + // support error cases better. auto it = blockMergeInfo.find(block); if (it != blockMergeInfo.end()) { + // Use the original location for nested selection/loop ops. + Location loc = it->second.loc; + Block *newHeader = mapper.lookupOrNull(block); - assert(newHeader && "nested loop header block should be remapped!"); + if (!newHeader) + return emitError(loc, "failed control flow structurization: nested " + "loop header block should be remapped!"); Block *newContinue = it->second.continueBlock; if (newContinue) { newContinue = mapper.lookupOrNull(newContinue); - assert(newContinue && "nested loop continue block should be remapped!"); + if (!newContinue) + return emitError(loc, "failed control flow structurization: nested " + "loop continue block should be remapped!"); } Block *newMerge = it->second.mergeBlock; if (Block *mappedTo = mapper.lookupOrNull(newMerge)) newMerge = mappedTo; - // Keep original location for nested selection/loop ops. - Location loc = it->second.loc; // The iterator should be erased before adding a new entry into // blockMergeInfo to avoid iterator invalidation. blockMergeInfo.erase(it); diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -369,9 +369,17 @@ auto *headerBlock = selectionOp.getHeaderBlock(); auto *mergeBlock = selectionOp.getMergeBlock(); + auto headerID = getBlockID(headerBlock); auto mergeID = getBlockID(mergeBlock); auto loc = selectionOp.getLoc(); + // This SelectionOp is in some MLIR block with preceding and following ops. In + // the binary format, it should reside in separate SPIR-V blocks from its + // preceding and following ops. So we need to emit unconditional branches to + // jump to this SelectionOp's SPIR-V blocks and jumping back to the normal + // flow afterwards. + encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID}); + // Emit the selection header block, which dominates all other blocks, first. // We need to emit an OpSelectionMerge instruction before the selection header // block's terminator. @@ -384,13 +392,8 @@ {mergeID, static_cast(selectionOp.selection_control())}); return success(); }; - // For structured selection, we cannot have blocks in the selection construct - // branching to the selection header block. Entering the selection (and - // reaching the selection header) must be from the block containing the - // spv.mlir.selection op. If there are ops ahead of the spv.mlir.selection op - // in the block, we can "merge" them into the selection header. So here we - // don't need to emit a separate block; just continue with the existing block. - if (failed(processBlock(headerBlock, /*omitLabel=*/true, emitSelectionMerge))) + if (failed( + processBlock(headerBlock, /*omitLabel=*/false, emitSelectionMerge))) return failure(); // Process all blocks with a depth-first visitor starting from the header diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir --- a/mlir/test/Target/SPIRV/loop.mlir +++ b/mlir/test/Target/SPIRV/loop.mlir @@ -222,8 +222,8 @@ // CHECK-SAME: (%[[INPUT0:.+]]: i64) spv.func @kernel(%input: i64) "None" { // CHECK-NEXT: %[[VAR:.+]] = spv.Variable : !spv.ptr -// CHECK-NEXT: spv.Branch ^[[BB:.+]](%[[INPUT0]] : i64) -// CHECK-NEXT: ^[[BB]](%[[INPUT1:.+]]: i64): +// CHECK-NEXT: spv.Branch ^[[BB0:.+]](%[[INPUT0]] : i64) +// CHECK-NEXT: ^[[BB0]](%[[INPUT1:.+]]: i64): %cst0_i64 = spv.Constant 0 : i64 %true = spv.Constant true %false = spv.Constant false @@ -235,12 +235,14 @@ ^loop_header(%arg1: i64): // CHECK-NEXT: spv.Branch ^[[LOOP_BODY:.+]] // CHECK-NEXT: ^[[LOOP_BODY]]: +// CHECK-NEXT: %[[C0:.+]] = spv.Constant 0 : i64 %gt = spv.SGreaterThan %arg1, %cst0_i64 : i64 +// CHECK-NEXT: %[[GT:.+]] = spv.SGreaterThan %[[ARG1]], %[[C0]] : i64 +// CHECK-NEXT: spv.Branch ^[[BB1:.+]] +// CHECK-NEXT: ^[[BB1]]: %var = spv.Variable : !spv.ptr // CHECK-NEXT: spv.mlir.selection { spv.mlir.selection { -// CHECK-NEXT: %[[C0:.+]] = spv.Constant 0 : i64 -// CHECK-NEXT: %[[GT:.+]] = spv.SGreaterThan %[[ARG1]], %[[C0]] : i64 // CHECK-NEXT: spv.BranchConditional %[[GT]], ^[[THEN:.+]], ^[[ELSE:.+]] spv.BranchConditional %gt, ^then, ^else // CHECK-NEXT: ^[[THEN]]: diff --git a/mlir/test/Target/SPIRV/selection.mlir b/mlir/test/Target/SPIRV/selection.mlir --- a/mlir/test/Target/SPIRV/selection.mlir +++ b/mlir/test/Target/SPIRV/selection.mlir @@ -3,38 +3,39 @@ // Selection with both then and else branches spv.module Logical GLSL450 requires #spv.vce { +// CHECK-LABEL: @selection spv.func @selection(%cond: i1) -> () "None" { -// CHECK: spv.Branch ^bb1 -// CHECK-NEXT: ^bb1: +// CHECK-NEXT: spv.Constant 0 +// CHECK-NEXT: spv.Variable +// CHECK: spv.Branch ^[[BB:.+]] +// CHECK-NEXT: ^[[BB]]: %zero = spv.Constant 0: i32 %one = spv.Constant 1: i32 %two = spv.Constant 2: i32 %var = spv.Variable init(%zero) : !spv.ptr // CHECK-NEXT: spv.mlir.selection control(Flatten) -// CHECK-NEXT: spv.Constant 0 -// CHECK-NEXT: spv.Variable spv.mlir.selection control(Flatten) { -// CHECK-NEXT: spv.BranchConditional %{{.*}} [5, 10], ^bb1, ^bb2 +// CHECK-NEXT: spv.BranchConditional %{{.*}} [5, 10], ^[[THEN:.+]], ^[[ELSE:.+]] spv.BranchConditional %cond [5, 10], ^then, ^else -// CHECK-NEXT: ^bb1: +// CHECK-NEXT: ^[[THEN]]: ^then: // CHECK-NEXT: spv.Constant 1 // CHECK-NEXT: spv.Store spv.Store "Function" %var, %one : i32 -// CHECK-NEXT: spv.Branch ^bb3 +// CHECK-NEXT: spv.Branch ^[[MERGE:.+]] spv.Branch ^merge -// CHECK-NEXT: ^bb2: +// CHECK-NEXT: ^[[ELSE]]: ^else: // CHECK-NEXT: spv.Constant 2 // CHECK-NEXT: spv.Store spv.Store "Function" %var, %two : i32 -// CHECK-NEXT: spv.Branch ^bb3 +// CHECK-NEXT: spv.Branch ^[[MERGE]] spv.Branch ^merge -// CHECK-NEXT: ^bb3: +// CHECK-NEXT: ^[[MERGE]]: ^merge: // CHECK-NEXT: spv.mlir.merge spv.mlir.merge @@ -56,21 +57,22 @@ // Selection in function entry block spv.module Logical GLSL450 requires #spv.vce { -// CHECK: spv.func @selection(%[[ARG:.*]]: i1 +// CHECK-LABEL: spv.func @selection +// CHECK-SAME: (%[[ARG:.*]]: i1) spv.func @selection(%cond: i1) -> (i32) "None" { -// CHECK: spv.Branch ^bb1 -// CHECK-NEXT: ^bb1: +// CHECK: spv.Branch ^[[BB:.+]] +// CHECK-NEXT: ^[[BB]]: // CHECK-NEXT: spv.mlir.selection spv.mlir.selection { -// CHECK-NEXT: spv.BranchConditional %[[ARG]], ^bb1, ^bb2 +// CHECK-NEXT: spv.BranchConditional %[[ARG]], ^[[THEN:.+]], ^[[ELSE:.+]] spv.BranchConditional %cond, ^then, ^merge -// CHECK: ^bb1: +// CHECK: ^[[THEN]]: ^then: %zero = spv.Constant 0 : i32 spv.ReturnValue %zero : i32 -// CHECK: ^bb2: +// CHECK: ^[[ELSE]]: ^merge: // CHECK-NEXT: spv.mlir.merge spv.mlir.merge @@ -87,3 +89,62 @@ spv.ExecutionMode @main "LocalSize", 1, 1, 1 } +// ----- + +// Selection with control flow afterwards +// SSA value def before selection and use after selection + +spv.module Logical GLSL450 requires #spv.vce { +// CHECK-LABEL: @selection_cf() + spv.func @selection_cf() -> () "None" { + %true = spv.Constant true + %false = spv.Constant false + %zero = spv.Constant 0 : i32 + %one = spv.Constant 1 : i32 +// CHECK-NEXT: %[[VAR:.+]] = spv.Variable + %var = spv.Variable : !spv.ptr +// CHECK-NEXT: spv.Branch ^[[BB:.+]] +// CHECK-NEXT: ^[[BB]]: + +// CHECK-NEXT: spv.mlir.selection { + spv.mlir.selection { +// CHECK: spv.BranchConditional %{{.+}}, ^[[THEN0:.+]], ^[[ELSE0:.+]] + spv.BranchConditional %true, ^then0, ^else0 + +// CHECK-NEXT: ^[[THEN0]]: +// CHECK: spv.Store "Function" %[[VAR]] +// CHECK-NEXT: spv.Branch ^[[MERGE:.+]] + ^then0: + spv.Store "Function" %var, %true : i1 + spv.Branch ^merge + +// CHECK-NEXT: ^[[ELSE0]]: +// CHECK: spv.Store "Function" %[[VAR]] +// CHECK-NEXT: spv.Branch ^[[MERGE]] + ^else0: + spv.Store "Function" %var, %false : i1 + spv.Branch ^merge + +// CHECK-NEXT: ^[[MERGE]]: +// CHECK-NEXT: spv.mlir.merge + ^merge: + spv.mlir.merge +// CHECK-NEXT: } + } + +// CHECK-NEXT: spv.Load "Function" %[[VAR]] + %cond = spv.Load "Function" %var : i1 +// CHECK: spv.BranchConditional %1, ^[[THEN1:.+]](%{{.+}} : i32), ^[[ELSE1:.+]](%{{.+}}, %{{.+}} : i32, i32) + spv.BranchConditional %cond, ^then1(%one: i32), ^else1(%zero, %zero: i32, i32) + +// CHECK-NEXT: ^[[THEN1]](%{{.+}}: i32): +// CHECK-NEXT: spv.Return + ^then1(%arg0: i32): + spv.Return + +// CHECK-NEXT: ^[[ELSE1]](%{{.+}}: i32, %{{.+}}: i32): +// CHECK-NEXT: spv.Return + ^else1(%arg1: i32, %arg2: i32): + spv.Return + } +}