diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -718,6 +718,52 @@ } }; +/// Converts `spv.loop` to LLVM dialect. All blocks within selection should be +/// reachable for conversion to succeed. +class LoopPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::LoopOp loopOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // There is no support of loop control at the moment. + if (loopOp.loop_control() != spirv::LoopControl::None) + return failure(); + + Location loc = loopOp.getLoc(); + + // Split the current block after `spv.loop`. The remaing ops will be used in + // `endBlock`. + Block *currentBlock = rewriter.getBlock(); + auto position = Block::iterator(loopOp); + Block *endBlock = rewriter.splitBlock(currentBlock, position); + + // Remove entry block and create a branch in the current block going to the + // header block. + Block *entryBlock = loopOp.getEntryBlock(); + assert(entryBlock->getOperations().size() == 1); + auto brOp = dyn_cast(entryBlock->getOperations().front()); + if (!brOp) + return failure(); + Block *headerBlock = loopOp.getHeaderBlock(); + rewriter.setInsertionPointToEnd(currentBlock); + rewriter.create(loc, brOp.getBlockArguments(), headerBlock); + rewriter.eraseBlock(entryBlock); + + // Branch from merge block to end block. + Block *mergeBlock = loopOp.getMergeBlock(); + Operation *terminator = mergeBlock->getTerminator(); + ValueRange terminatorOperands = terminator->getOperands(); + rewriter.setInsertionPointToEnd(mergeBlock); + rewriter.create(loc, terminatorOperands, endBlock); + + rewriter.inlineRegionBefore(loopOp.body(), endBlock); + rewriter.replaceOp(loopOp, endBlock->getArguments()); + return success(); + } +}; + class MergePattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; @@ -1109,7 +1155,7 @@ ConstantScalarAndVectorPattern, // Control Flow ops - BranchConversionPattern, BranchConditionalConversionPattern, + BranchConversionPattern, BranchConditionalConversionPattern, LoopPattern, SelectionPattern, MergePattern, // Function Call op diff --git a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/control-flow-ops-to-llvm.mlir @@ -81,6 +81,44 @@ // ----- +//===----------------------------------------------------------------------===// +// spv.loop +//===----------------------------------------------------------------------===// + +spv.module Logical GLSL450 { + spv.func @infinite_loop(%count : i32) -> () "None" { + // CHECK: llvm.br ^bb1 + spv.loop { + spv.Branch ^header + // CHECK: ^bb1: + ^header: + // CHECK: %[[COND:.*]] = llvm.mlir.constant(true) : !llvm.i1 + // CHECK-NEXT: llvm.cond_br %[[COND]], ^bb2, ^bb4 + %cond = spv.constant true + spv.BranchConditional %cond, ^body, ^merge + // CHECK: ^bb2: + ^body: + // Do nothing + // CHECK: llvm.br ^bb3 + spv.Branch ^continue + // CHECK: ^bb3: + ^continue: + // Do nothing + // CHECK: llvm.br ^bb1 + spv.Branch ^header + // CHECK: ^bb4: + ^merge: + // CHECK: llvm.br ^bb5 + spv._merge + } + // CHECK: ^bb5: + // CHECK-NEXT: llvm.return + spv.Return + } +} + +// ----- + //===----------------------------------------------------------------------===// // spv.selection //===----------------------------------------------------------------------===//