diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -107,6 +107,15 @@ matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; + +class WhileOpConversion final : public SCFToSPIRVPattern { +public: + using SCFToSPIRVPattern::SCFToSPIRVPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp forOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; } // namespace /// Helper function to replaces SCF op outputs with SPIR-V variable loads. @@ -141,6 +150,10 @@ rewriter.replaceOp(scfOp, resultValue); } +static Region::iterator getBlockIt(Region ®ion, unsigned index) { + return std::next(region.begin(), index); +} + //===----------------------------------------------------------------------===// // scf::ForOp //===----------------------------------------------------------------------===// @@ -161,7 +174,7 @@ // Create the block for the header. auto *header = new Block(); // Insert the header. - loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); + loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header); // Create the new induction variable to use. BlockArgument newIndVar = header->addArgument(adaptor.lowerBound().getType()); @@ -183,7 +196,7 @@ // Move the blocks from the forOp into the loopOp. This is the body of the // loopOp. rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(), - std::next(loopOp.body().begin(), 2)); + getBlockIt(loopOp.body(), 2)); SmallVector args(1, adaptor.lowerBound()); args.append(adaptor.initArgs().begin(), adaptor.initArgs().end()); @@ -293,9 +306,11 @@ // If the region is return values, store each value into the associated // VariableOp created during lowering of the parent region. if (!operands.empty()) { - auto loc = terminatorOp.getLoc(); auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()]; - assert(allocas.size() == operands.size()); + if (allocas.size() != operands.size()) + return failure(); + + auto loc = terminatorOp.getLoc(); for (unsigned i = 0, e = operands.size(); i < e; i++) rewriter.create(loc, allocas[i], operands[i]); if (isa(terminatorOp->getParentOp())) { @@ -315,12 +330,79 @@ } //===----------------------------------------------------------------------===// +// scf::WhileOp +//===----------------------------------------------------------------------===// + +LogicalResult +WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = whileOp.getLoc(); + auto loopOp = rewriter.create(loc, spirv::LoopControl::None); + loopOp.addEntryAndMergeBlock(); + + OpBuilder::InsertionGuard guard(rewriter); + + Region &beforeReg = whileOp.before(); + Region &afterReg = whileOp.after(); + + Block &beforeBlock = beforeReg.front(); + Block &afterBlock = afterReg.front(); + Block &mergeBlock = *loopOp.getMergeBlock(); + + rewriter.inlineRegionBefore(beforeReg, loopOp.body(), + getBlockIt(loopOp.body(), 1)); + + rewriter.inlineRegionBefore(afterReg, loopOp.body(), + getBlockIt(loopOp.body(), 2)); + + rewriter.setInsertionPointToEnd(&(loopOp.body().front())); + rewriter.create(loc, &beforeBlock, adaptor.inits()); + + auto cond = cast(beforeBlock.getTerminator()); + + auto condLoc = cond.getLoc(); + SmallVector resultValues; + resultValues.reserve(cond.args().size()); + for (auto it : llvm::enumerate(cond.args())) { + auto res = it.value(); + auto i = it.index(); + auto pointerType = + spirv::PointerType::get(res.getType(), spirv::StorageClass::Function); + rewriter.setInsertionPoint(loopOp); + auto alloc = rewriter.create( + condLoc, pointerType, spirv::StorageClass::Function, + /*initializer=*/nullptr); + + rewriter.setInsertionPointAfter(loopOp); + auto loadResult = rewriter.create(condLoc, alloc); + resultValues.emplace_back(loadResult); + + rewriter.setInsertionPointToEnd(&beforeBlock); + rewriter.create(condLoc, alloc, res); + } + + rewriter.setInsertionPointToEnd(&beforeBlock); + rewriter.replaceOpWithNewOp( + cond, cond.condition(), &afterBlock, cond.args(), &mergeBlock, + llvm::None); + + auto yield = cast(afterBlock.getTerminator()); + rewriter.setInsertionPointToEnd(&afterBlock); + rewriter.replaceOpWithNewOp(yield, &beforeBlock, + yield.results()); + + rewriter.replaceOp(whileOp, resultValues); + return success(); +} + +//===----------------------------------------------------------------------===// // Hooks //===----------------------------------------------------------------------===// void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns) { - patterns.add( - patterns.getContext(), typeConverter, scfToSPIRVContext.getImpl()); + patterns.add(patterns.getContext(), typeConverter, + scfToSPIRVContext.getImpl()); } diff --git a/mlir/test/Conversion/SCFToSPIRV/while.mlir b/mlir/test/Conversion/SCFToSPIRV/while.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SCFToSPIRV/while.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt -convert-scf-to-spirv %s -o - | FileCheck %s + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, {}> +} { + +// CHECK-LABEL: @while_loop +func @while_loop(%arg0: i32, %arg1: i32) -> i32 { + // CHECK: (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) + // CHECK: %[[INITVAR:.*]] = spv.Constant 2 : i32 + // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr + // CHECK: spv.mlir.loop { + // CHECK: spv.Branch ^[[HEADER:.*]](%[[ARG1]] : i32) + // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: i32): + // CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR1]], %[[ARG2]] : i32 + // CHECK: spv.Store "Function" %[[VAR1]], %[[INDVAR1]] : i32 + // CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]](%[[INDVAR1]] : i32), ^[[MERGE:.*]] + // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i32): + // CHECK: %[[UPDATED:.*]] = spv.IMul %[[INDVAR2]], %[[INITVAR]] : i32 + // CHECK: spv.Branch ^[[HEADER]](%[[UPDATED]] : i32) + // CHECK: ^[[MERGE]]: + // CHECK: spv.mlir.merge + // CHECK: } + %c2_i32 = arith.constant 2 : i32 + %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) { + %1 = arith.cmpi slt, %arg3, %arg1 : i32 + scf.condition(%1) %arg3 : i32 + } do { + ^bb0(%arg5: i32): + %1 = arith.muli %arg5, %c2_i32 : i32 + scf.yield %1 : i32 + } + // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR1]] : i32 + // CHECK: spv.ReturnValue %[[OUT]] : i32 + return %0 : i32 +} + +} // end module