diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -718,6 +718,98 @@ } }; +/// Converts `spv.loop` to LLVM dialect. All blocks within selection should be +/// reachable for conversion to succeed. +/// The structure of the loop in LLVM dialect will be the following: +/// +/// +------------------------------------+ +/// | | +/// | llvm.br ^header | +/// +------------------------------------+ +/// | +/// +----------------+ | +/// | | | +/// | V V +/// | +------------------------------------+ +/// | | ^header: | +/// | |
| +/// | | llvm.cond_br %cond, ^body, ^exit | +/// | +------------------------------------+ +/// | | +/// | |----------------------+ +/// | | | +/// | V | +/// | +------------------------------------+ | +/// | | ^body: | | +/// | | | | +/// | | llvm.br ^continue | | +/// | +------------------------------------+ | +/// | | | +/// | V | +/// | +------------------------------------+ | +/// | | ^continue: | | +/// | | | | +/// | | llvm.br ^header | | +/// | +------------------------------------+ | +/// | | | +/// +---------------+ +----------------------+ +/// | +/// V +/// +------------------------------------+ +/// | ^exit: | +/// | llvm.br ^remaining | +/// +------------------------------------+ +/// | +/// V +/// +------------------------------------+ +/// | ^remaining: | +/// | | +/// +------------------------------------+ +/// +class LoopPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::LoopOp loopOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // There is no support of loop control at the moment. + if (loopOp.loop_control() != spirv::LoopControl::None) + return failure(); + + Location loc = loopOp.getLoc(); + + // Split the current block after `spv.loop`. The remaing ops will be used in + // `endBlock`. + Block *currentBlock = rewriter.getBlock(); + auto position = Block::iterator(loopOp); + Block *endBlock = rewriter.splitBlock(currentBlock, position); + + // Remove entry block and create a branch in the current block going to the + // header block. + Block *entryBlock = loopOp.getEntryBlock(); + assert(entryBlock->getOperations().size() == 1); + auto brOp = dyn_cast(entryBlock->getOperations().front()); + if (!brOp) + return failure(); + Block *headerBlock = loopOp.getHeaderBlock(); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, brOp.getBlockArguments(), headerBlock); + rewriter.eraseBlock(entryBlock); + + // Branch from merge block to end block. + Block *mergeBlock = loopOp.getMergeBlock(); + Operation *terminator = mergeBlock->getTerminator(); + ValueRange terminatorOperands = terminator->getOperands(); + rewriter.setInsertionPointToEnd(mergeBlock); + rewriter.create(loc, terminatorOperands, endBlock); + + rewriter.inlineRegionBefore(loopOp.body(), endBlock); + rewriter.replaceOp(loopOp, endBlock->getArguments()); + return success(); + } +}; + class MergePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; @@ -1109,7 +1201,7 @@ ConstantScalarAndVectorPattern, // Control Flow ops - BranchConversionPattern, BranchConditionalConversionPattern, + BranchConversionPattern, BranchConditionalConversionPattern, LoopPattern, SelectionPattern, MergePattern, // Function Call op diff --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir @@ -81,6 +81,45 @@ // ----- +//===----------------------------------------------------------------------===// +// spv.loop +//===----------------------------------------------------------------------===// + +spv.module Logical GLSL450 { + // CHECK-LABEL: @infinite_loop + spv.func @infinite_loop(%count : i32) -> () "None" { + // CHECK: llvm.br ^[[BB1:.*]] + // CHECK: ^[[BB1]]: + // CHECK: %[[COND:.*]] = llvm.mlir.constant(true) : !llvm.i1 + // CHECK: llvm.cond_br %[[COND]], ^[[BB2:.*]], ^[[BB4:.*]] + // CHECK: ^[[BB2]]: + // CHECK: llvm.br ^[[BB3:.*]] + // CHECK: ^[[BB3]]: + // CHECK: llvm.br ^[[BB1:.*]] + // CHECK: ^[[BB4]]: + // CHECK: llvm.br ^[[BB5:.*]] + // CHECK: ^[[BB5]]: + // CHECK: llvm.return + spv.loop { + spv.Branch ^header + ^header: + %cond = spv.constant true + spv.BranchConditional %cond, ^body, ^merge + ^body: + // Do nothing + spv.Branch ^continue + ^continue: + // Do nothing + spv.Branch ^header + ^merge: + spv._merge + } + spv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spv.selection //===----------------------------------------------------------------------===//