diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -613,6 +613,84 @@ } }; +class MergePattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::MergeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +/// Converts `spv.selection` with `spv.BranchConditional` in its header block. +/// All blocks within selection should be reachable for conversion to succeed. +class SelectionPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::SelectionOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // There is no support for `Flatten` or `DontFlatten` selection control at + // the moment. This are just compiler hints and can be performed during the + // optimization passes. + if (op.selection_control() != spirv::SelectionControl::None) + return failure(); + + // `spv.selection` should have at least two blocks: one selection header + // block and one merge block. If no blocks are present, or control flow + // branches straight to merge block (two blocks are present), the op is + // redundant and it is erased. + if (op.body().getBlocks().size() <= 2) { + rewriter.eraseOp(op); + return success(); + } + + Location loc = op.getLoc(); + + // Split the current block after `spv.selection`. The remaing ops will be + // used in `continueBlock`. + auto *currentBlock = rewriter.getInsertionBlock(); + rewriter.setInsertionPointAfter(op); + auto position = rewriter.getInsertionPoint(); + auto *continueBlock = rewriter.splitBlock(currentBlock, position); + + // Extract conditional branch information from the header block. By SPIR-V + // dialect spec, it should contain `spv.BranchConditional` or `spv.Switch` + // op. Note that `spv.Switch op` is not supported at the moment in the + // SPIR-V dialect. Remove this block when finished. + auto *headerBlock = op.getHeaderBlock(); + assert(headerBlock->getOperations().size() == 1); + auto condBrOp = dyn_cast( + headerBlock->getOperations().front()); + if (!condBrOp) + return failure(); + rewriter.eraseBlock(headerBlock); + + // Branch from merge block to continue block. + auto *mergeBlock = op.getMergeBlock(); + Operation *terminator = mergeBlock->getTerminator(); + ValueRange terminatorOperands = terminator->getOperands(); + rewriter.setInsertionPointToEnd(mergeBlock); + rewriter.create(loc, terminatorOperands, continueBlock); + + // Link current block to `true` and `false` blocks within the selection. + Block *trueBlock = condBrOp.getTrueBlock(); + Block *falseBlock = condBrOp.getFalseBlock(); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, condBrOp.condition(), trueBlock, + condBrOp.trueTargetOperands(), falseBlock, + condBrOp.falseTargetOperands()); + + rewriter.inlineRegionBefore(op.body(), continueBlock); + rewriter.replaceOp(op, continueBlock->getArguments()); + return success(); + } +}; + /// Converts SPIR-V shift ops to LLVM shift ops. Since LLVM dialect /// puts a restriction on `Shift` and `Base` to have the same bit width, /// `Shift` is zero or sign extended to match this specification. Cases when @@ -843,6 +921,7 @@ // Control Flow ops BranchConversionPattern, BranchConditionalConversionPattern, + SelectionPattern, MergePattern, // Function Call op FunctionCallPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir @@ -80,3 +80,91 @@ spv.Return } } + +// ----- + +//===----------------------------------------------------------------------===// +// spv.selection +//===----------------------------------------------------------------------===// + +func @selection_empty() { + // CHECK: llvm.return + spv.selection { + } + return +} + +func @selection_with_merge_block_only() { + %cond = spv.constant true + // CHECK: llvm.return + spv.selection { + spv.BranchConditional %cond, ^merge, ^merge + ^merge: + spv._merge + } + return +} + +func @selection_with_true_block_only() { + // CHECK: %[[COND:.*]] = llvm.mlir.constant(true) : !llvm.i1 + %cond = spv.constant true + // CHECK: llvm.cond_br %[[COND]], ^bb1, ^bb2 + spv.selection { + spv.BranchConditional %cond, ^true, ^merge + // CHECK: ^bb1: + ^true: + // CHECK: llvm.br ^bb2 + spv.Branch ^merge + // CHECK: ^bb2: + ^merge: + // CHECK: llvm.br ^bb3 + spv._merge + } + // CHECK: ^bb3: + // CHECK-NEXT: llvm.return + return +} + +func @selection_with_both_true_and_false_block() { + // CHECK: %[[COND:.*]] = llvm.mlir.constant(true) : !llvm.i1 + %cond = spv.constant true + // CHECK: llvm.cond_br %[[COND]], ^bb1, ^bb2 + spv.selection { + spv.BranchConditional %cond, ^true, ^false + // CHECK: ^bb1: + ^true: + // CHECK: llvm.br ^bb3 + spv.Branch ^merge + // CHECK: ^bb2: + ^false: + // CHECK: llvm.br ^bb3 + spv.Branch ^merge + // CHECK: ^bb3: + ^merge: + // CHECK: llvm.br ^bb4 + spv._merge + } + // CHECK: ^bb4: + // CHECK-NEXT: llvm.return + return +} + +func @selection_with_early_return(%arg0: i1) -> i32 { + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + %0 = spv.constant 0 : i32 + // CHECK: llvm.cond_br %{{.*}}, ^bb1(%[[ZERO]] : !llvm.i32), ^bb2 + spv.selection { + spv.BranchConditional %arg0, ^true(%0 : i32), ^merge + // CHECK: ^bb1(%[[ARG:.*]]: !llvm.i32): + ^true(%arg1: i32): + // CHECK: llvm.return %[[ARG]] : !llvm.i32 + spv.ReturnValue %arg1 : i32 + // CHECK: ^bb2: + ^merge: + // CHECK: llvm.br ^bb3 + spv._merge + } + // CHECK: ^bb3: + %one = spv.constant 1 : i32 + spv.ReturnValue %one : i32 +}