diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -107,6 +107,15 @@ matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; + +class WhileOpConversion final : public SCFToSPIRVPattern { +public: + using SCFToSPIRVPattern::SCFToSPIRVPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp forOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; } // namespace /// Helper function to replaces SCF op outputs with SPIR-V variable loads. @@ -141,6 +150,10 @@ rewriter.replaceOp(scfOp, resultValue); } +static Region::iterator getBlockIt(Region ®ion, unsigned index) { + return std::next(region.begin(), index); +} + //===----------------------------------------------------------------------===// // scf::ForOp //===----------------------------------------------------------------------===// @@ -161,7 +174,7 @@ // Create the block for the header. auto *header = new Block(); // Insert the header. - loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); + loopOp.body().getBlocks().insert(getBlockIt(loopOp.body(), 1), header); // Create the new induction variable to use. BlockArgument newIndVar = header->addArgument(adaptor.lowerBound().getType()); @@ -183,7 +196,7 @@ // Move the blocks from the forOp into the loopOp. This is the body of the // loopOp. rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(), - std::next(loopOp.body().begin(), 2)); + getBlockIt(loopOp.body(), 2)); SmallVector args(1, adaptor.lowerBound()); args.append(adaptor.initArgs().begin(), adaptor.initArgs().end()); @@ -293,9 +306,11 @@ // If the region is return values, store each value into the associated // VariableOp created during lowering of the parent region. if (!operands.empty()) { - auto loc = terminatorOp.getLoc(); auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()]; - assert(allocas.size() == operands.size()); + if (allocas.size() != operands.size()) + return failure(); + + auto loc = terminatorOp.getLoc(); for (unsigned i = 0, e = operands.size(); i < e; i++) rewriter.create(loc, allocas[i], operands[i]); if (isa(terminatorOp->getParentOp())) { @@ -314,6 +329,97 @@ return success(); } +//===----------------------------------------------------------------------===// +// scf::WhileOp +//===----------------------------------------------------------------------===// + +LogicalResult +WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = whileOp.getLoc(); + auto loopOp = rewriter.create(loc, spirv::LoopControl::None); + loopOp.addEntryAndMergeBlock(); + + OpBuilder::InsertionGuard guard(rewriter); + + Region &beforeRegion = whileOp.before(); + Region &afterRegion = whileOp.after(); + + Block &entryBlock = *loopOp.getEntryBlock(); + Block &beforeBlock = beforeRegion.front(); + Block &afterBlock = afterRegion.front(); + Block &mergeBlock = *loopOp.getMergeBlock(); + + auto cond = cast(beforeBlock.getTerminator()); + SmallVector condArgs; + if (failed(rewriter.getRemappedValues(cond.args(), condArgs))) + return failure(); + + Value conditionVal = rewriter.getRemappedValue(cond.condition()); + if (!conditionVal) + return failure(); + + auto yield = cast(afterBlock.getTerminator()); + SmallVector yieldArgs; + if (failed(rewriter.getRemappedValues(yield.results(), yieldArgs))) + return failure(); + + // Move the while before block as the initial loop header block. + rewriter.inlineRegionBefore(beforeRegion, loopOp.body(), + getBlockIt(loopOp.body(), 1)); + + // Move the while after block as the initial loop body block. + rewriter.inlineRegionBefore(afterRegion, loopOp.body(), + getBlockIt(loopOp.body(), 2)); + + // Jump from the loop entry block to the loop header block. + rewriter.setInsertionPointToEnd(&entryBlock); + rewriter.create(loc, &beforeBlock, adaptor.inits()); + + auto condLoc = cond.getLoc(); + + SmallVector resultValues(condArgs.size()); + + // For other SCF ops, the scf.yield op yields the value for the whole SCF op. + // So we use the scf.yield op as the anchor to create/load/store SPIR-V local + // variables. But for the scf.while op, the scf.yield op yields a value for + // the before region, which may not matching the whole op's result. Instead, + // the scf.condition op returns values matching the whole op's results. So we + // need to create/load/store variables according to that. + for (auto it : llvm::enumerate(condArgs)) { + auto res = it.value(); + auto i = it.index(); + auto pointerType = + spirv::PointerType::get(res.getType(), spirv::StorageClass::Function); + + // Create local variables before the scf.while op. + rewriter.setInsertionPoint(loopOp); + auto alloc = rewriter.create( + condLoc, pointerType, spirv::StorageClass::Function, + /*initializer=*/nullptr); + + // Load the final result values after the scf.while op. + rewriter.setInsertionPointAfter(loopOp); + auto loadResult = rewriter.create(condLoc, alloc); + resultValues[i] = loadResult; + + // Store the current iteration's result value. + rewriter.setInsertionPointToEnd(&beforeBlock); + rewriter.create(condLoc, alloc, res); + } + + rewriter.setInsertionPointToEnd(&beforeBlock); + rewriter.replaceOpWithNewOp( + cond, conditionVal, &afterBlock, condArgs, &mergeBlock, llvm::None); + + // Convert the scf.yield op to a branch back to the header block. + rewriter.setInsertionPointToEnd(&afterBlock); + rewriter.replaceOpWithNewOp(yield, &beforeBlock, yieldArgs); + + rewriter.replaceOp(whileOp, resultValues); + return success(); +} + //===----------------------------------------------------------------------===// // Hooks //===----------------------------------------------------------------------===// @@ -321,6 +427,7 @@ void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter, ScfToSPIRVContext &scfToSPIRVContext, RewritePatternSet &patterns) { - patterns.add( - patterns.getContext(), typeConverter, scfToSPIRVContext.getImpl()); + patterns.add(patterns.getContext(), typeConverter, + scfToSPIRVContext.getImpl()); } diff --git a/mlir/test/Conversion/SCFToSPIRV/while.mlir b/mlir/test/Conversion/SCFToSPIRV/while.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SCFToSPIRV/while.mlir @@ -0,0 +1,72 @@ +// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-spirv %s -o - | FileCheck %s + +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, {}> +} { + +// CHECK-LABEL: @while_loop1 +func @while_loop1(%arg0: i32, %arg1: i32) -> i32 { + // CHECK-SAME: (%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32) + // CHECK: %[[INITVAR:.*]] = spv.Constant 2 : i32 + // CHECK: %[[VAR1:.*]] = spv.Variable : !spv.ptr + // CHECK: spv.mlir.loop { + // CHECK: spv.Branch ^[[HEADER:.*]](%[[ARG1]] : i32) + // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: i32): + // CHECK: %[[CMP:.*]] = spv.SLessThan %[[INDVAR1]], %[[ARG2]] : i32 + // CHECK: spv.Store "Function" %[[VAR1]], %[[INDVAR1]] : i32 + // CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]](%[[INDVAR1]] : i32), ^[[MERGE:.*]] + // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i32): + // CHECK: %[[UPDATED:.*]] = spv.IMul %[[INDVAR2]], %[[INITVAR]] : i32 + // CHECK: spv.Branch ^[[HEADER]](%[[UPDATED]] : i32) + // CHECK: ^[[MERGE]]: + // CHECK: spv.mlir.merge + // CHECK: } + %c2_i32 = arith.constant 2 : i32 + %0 = scf.while (%arg3 = %arg0) : (i32) -> (i32) { + %1 = arith.cmpi slt, %arg3, %arg1 : i32 + scf.condition(%1) %arg3 : i32 + } do { + ^bb0(%arg5: i32): + %1 = arith.muli %arg5, %c2_i32 : i32 + scf.yield %1 : i32 + } + // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR1]] : i32 + // CHECK: spv.ReturnValue %[[OUT]] : i32 + return %0 : i32 +} + +// ----- + +// CHECK-LABEL: @while_loop2 +func @while_loop2(%arg0: f32) -> i64 { + // CHECK-SAME: (%[[ARG:.*]]: f32) + // CHECK: %[[VAR:.*]] = spv.Variable : !spv.ptr + // CHECK: spv.mlir.loop { + // CHECK: spv.Branch ^[[HEADER:.*]](%[[ARG]] : f32) + // CHECK: ^[[HEADER]](%[[INDVAR1:.*]]: f32): + // CHECK: %[[SHARED:.*]] = "foo.shared_compute"(%[[INDVAR1]]) : (f32) -> i64 + // CHECK: %[[CMP:.*]] = "foo.evaluate_condition"(%[[INDVAR1]], %[[SHARED]]) : (f32, i64) -> i1 + // CHECK: spv.Store "Function" %[[VAR]], %[[SHARED]] : i64 + // CHECK: spv.BranchConditional %[[CMP]], ^[[BODY:.*]](%[[SHARED]] : i64), ^[[MERGE:.*]] + // CHECK: ^[[BODY]](%[[INDVAR2:.*]]: i64): + // CHECK: %[[UPDATED:.*]] = "foo.payload"(%[[INDVAR2]]) : (i64) -> f32 + // CHECK: spv.Branch ^[[HEADER]](%[[UPDATED]] : f32) + // CHECK: ^[[MERGE]]: + // CHECK: spv.mlir.merge + // CHECK: } + %res = scf.while (%arg1 = %arg0) : (f32) -> i64 { + %shared = "foo.shared_compute"(%arg1) : (f32) -> i64 + %condition = "foo.evaluate_condition"(%arg1, %shared) : (f32, i64) -> i1 + scf.condition(%condition) %shared : i64 + } do { + ^bb0(%arg2: i64): + %res = "foo.payload"(%arg2) : (i64) -> f32 + scf.yield %res : f32 + } + // CHECK: %[[OUT:.*]] = spv.Load "Function" %[[VAR]] : i64 + // CHECK: spv.ReturnValue %[[OUT]] : i64 + return %res : i64 +} + +} // end module