diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -613,6 +613,60 @@ } }; +/// Converts `spv.loop` to LLVM dialect. All blocks within selection should be +/// reachable for conversion to succeed. +class LoopPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::LoopOp loopOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // There is no support of loop control at the moment. + if (loopOp.loop_control() != spirv::LoopControl::None) + return failure(); + + // Guard against `spv.loop` containing an empty region. In this case, the op + // is redundant and it is erased. + if (loopOp.body().getBlocks().size() == 0) { + rewriter.eraseOp(loopOp); + return success(); + } + + Location loc = loopOp.getLoc(); + + // Split the current block after `spv.loop`. The remaing ops will be used in + // `endBlock`. + auto *currentBlock = rewriter.getInsertionBlock(); + rewriter.setInsertionPointAfter(loopOp); + auto position = rewriter.getInsertionPoint(); + auto *endBlock = rewriter.splitBlock(currentBlock, position); + + // Remove entry block and create a branch in the current block going to the + // header block. + auto *entryBlock = loopOp.getEntryBlock(); + assert(entryBlock->getOperations().size() == 1); + auto brOp = dyn_cast(entryBlock->getOperations().front()); + if (!brOp) + return failure(); + auto *headerBlock = loopOp.getHeaderBlock(); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, brOp.getBlockArguments(), headerBlock); + rewriter.eraseBlock(entryBlock); + + // Branch from merge block to end block. + auto *mergeBlock = loopOp.getMergeBlock(); + Operation *terminator = mergeBlock->getTerminator(); + ValueRange terminatorOperands = terminator->getOperands(); + rewriter.setInsertionPointToEnd(mergeBlock); + rewriter.create(loc, terminatorOperands, endBlock); + + rewriter.inlineRegionBefore(loopOp.body(), endBlock); + rewriter.replaceOp(loopOp, endBlock->getArguments()); + return success(); + } +}; + class MergePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; @@ -920,7 +974,7 @@ ConstantScalarAndVectorPattern, // Control Flow ops - BranchConversionPattern, BranchConditionalConversionPattern, + BranchConversionPattern, BranchConditionalConversionPattern, LoopPattern, SelectionPattern, MergePattern, // Function Call op diff --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir @@ -83,6 +83,51 @@ // ----- +//===----------------------------------------------------------------------===// +// spv.loop +//===----------------------------------------------------------------------===// + +spv.module Logical GLSL450 { + spv.func @loop_empty() -> () "None" { + // CHECK: llvm.return + spv.loop { + } + spv.Return + } + + spv.func @infinite_loop(%count : i32) -> () "None" { + // CHECK: llvm.br ^bb1 + spv.loop { + spv.Branch ^header + // CHECK: ^bb1: + ^header: + // CHECK: %[[COND:.*]] = llvm.mlir.constant(true) : !llvm.i1 + // CHECK-NEXT: llvm.cond_br %[[COND]], ^bb2, ^bb4 + %cond = spv.constant true + spv.BranchConditional %cond, ^body, ^merge + // CHECK: ^bb2: + ^body: + // Do nothing + // CHECK: llvm.br ^bb3 + spv.Branch ^continue + // CHECK: ^bb3: + ^continue: + // Do nothing + // CHECK: llvm.br ^bb1 + spv.Branch ^header + // CHECK: ^bb4: + ^merge: + // CHECK: llvm.br ^bb5 + spv._merge + } + // CHECK: ^bb5: + // CHECK-NEXT: llvm.return + spv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spv.selection //===----------------------------------------------------------------------===//