diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -812,13 +812,12 @@ def SCFToSPIRV : Pass<"convert-scf-to-spirv"> { let summary = "Convert SCF dialect to SPIR-V dialect."; let description = [{ - This pass converts SCF ops into SPIR-V structured control flow ops. - SPIR-V structured control flow ops does not support yielding values. + Converts SCF ops into SPIR-V structured control flow ops. + SPIR-V structured control flow ops do not support yielding values. So for SCF ops yielding values, SPIR-V variables are created for holding the values and load/store operations are emitted for updating them. }]; - let constructor = "mlir::createConvertSCFToSPIRVPass()"; let dependentDialects = ["spirv::SPIRVDialect"]; } diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -41,16 +41,57 @@ /// StoreOp cannot be created earlier as they may use a different type than /// yield operands. ScfToSPIRVContext::ScfToSPIRVContext() { - impl = std::make_unique(); + impl = std::make_unique<::ScfToSPIRVContextImpl>(); } ScfToSPIRVContext::~ScfToSPIRVContext() = default; +namespace { + //===----------------------------------------------------------------------===// -// Pattern Declarations +// Helper Functions +//===----------------------------------------------------------------------===// + +/// Replaces SCF op outputs with SPIR-V variable loads. +/// We create VariableOp to handle the results value of the control flow region. +/// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right +/// after the loop we load the value from the allocation and use it as the SCF +/// op result. +template +void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, + ConversionPatternRewriter &rewriter, + ScfToSPIRVContextImpl *scfToSPIRVContext, + ArrayRef returnTypes) { + + Location loc = scfOp.getLoc(); + auto &allocas = scfToSPIRVContext->outputVars[newOp]; + // Clearing the allocas is necessary in case a dialect conversion path failed + // previously, and this is the second attempt of this conversion. + allocas.clear(); + SmallVector resultValue; + for (Type convertedType : returnTypes) { + auto pointerType = + spirv::PointerType::get(convertedType, spirv::StorageClass::Function); + rewriter.setInsertionPoint(newOp); + auto alloc = rewriter.create( + loc, pointerType, spirv::StorageClass::Function, + /*initializer=*/nullptr); + allocas.push_back(alloc); + rewriter.setInsertionPointAfter(newOp); + Value loadResult = rewriter.create(loc, alloc); + resultValue.push_back(loadResult); + } + rewriter.replaceOp(scfOp, resultValue); +} + +Region::iterator getBlockIt(Region ®ion, unsigned index) { + return std::next(region.begin(), index); +} + +//===----------------------------------------------------------------------===// +// Conversion Patterns //===----------------------------------------------------------------------===// -namespace { /// Common class for all vector to GPU patterns. template class SCFToSPIRVPattern : public OpConversionPattern { @@ -79,356 +120,306 @@ SPIRVTypeConverter &typeConverter; }; +//===----------------------------------------------------------------------===// +// scf::ForOp +//===----------------------------------------------------------------------===// + /// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. -class ForOpConversion final : public SCFToSPIRVPattern { -public: - using SCFToSPIRVPattern::SCFToSPIRVPattern; +struct ForOpConversion final : SCFToSPIRVPattern { + using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; + ConversionPatternRewriter &rewriter) const override { + // scf::ForOp can be lowered to the structured control flow represented by + // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop + // latch and the merge block the exit block. The resulting spirv::LoopOp has + // a single back edge from the continue to header block, and a single exit + // from header to merge. + auto loc = forOp.getLoc(); + auto loopOp = rewriter.create(loc, spirv::LoopControl::None); + loopOp.addEntryAndMergeBlock(); + + OpBuilder::InsertionGuard guard(rewriter); + // Create the block for the header. + auto *header = new Block(); + // Insert the header. + loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1), + header); + + // Create the new induction variable to use. + Value adapLowerBound = adaptor.getLowerBound(); + BlockArgument newIndVar = + header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc()); + for (Value arg : adaptor.getInitArgs()) + header->addArgument(arg.getType(), arg.getLoc()); + Block *body = forOp.getBody(); + + // Apply signature conversion to the body of the forOp. It has a single + // block, with argument which is the induction variable. That has to be + // replaced with the new induction variable. + TypeConverter::SignatureConversion signatureConverter( + body->getNumArguments()); + signatureConverter.remapInput(0, newIndVar); + for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) + signatureConverter.remapInput(i, header->getArgument(i)); + body = rewriter.applySignatureConversion(&forOp.getLoopBody(), + signatureConverter); + + // Move the blocks from the forOp into the loopOp. This is the body of the + // loopOp. + rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(), + getBlockIt(loopOp.getBody(), 2)); + + SmallVector args(1, adaptor.getLowerBound()); + args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); + // Branch into it from the entry. + rewriter.setInsertionPointToEnd(&(loopOp.getBody().front())); + rewriter.create(loc, header, args); + + // Generate the rest of the loop header. + rewriter.setInsertionPointToEnd(header); + auto *mergeBlock = loopOp.getMergeBlock(); + auto cmpOp = rewriter.create( + loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound()); + + rewriter.create( + loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); + + // Generate instructions to increment the step of the induction variable and + // branch to the header. + Block *continueBlock = loopOp.getContinueBlock(); + rewriter.setInsertionPointToEnd(continueBlock); + + // Add the step to the induction variable and branch to the header. + Value updatedIndVar = rewriter.create( + loc, newIndVar.getType(), newIndVar, adaptor.getStep()); + rewriter.create(loc, header, updatedIndVar); + + // Infer the return types from the init operands. Vector type may get + // converted to CooperativeMatrix or to Vector type, to avoid having complex + // extra logic to figure out the right type we just infer it from the Init + // operands. + SmallVector initTypes; + for (auto arg : adaptor.getInitArgs()) + initTypes.push_back(arg.getType()); + replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, + initTypes); + return success(); + } }; +//===----------------------------------------------------------------------===// +// scf::IfOp +//===----------------------------------------------------------------------===// + /// Pattern to convert a scf::IfOp within kernel functions into /// spirv::SelectionOp. -class IfOpConversion final : public SCFToSPIRVPattern { -public: - using SCFToSPIRVPattern::SCFToSPIRVPattern; +struct IfOpConversion : SCFToSPIRVPattern { + using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; - -class TerminatorOpConversion final : public SCFToSPIRVPattern { -public: - using SCFToSPIRVPattern::SCFToSPIRVPattern; + ConversionPatternRewriter &rewriter) const override { + // When lowering `scf::IfOp` we explicitly create a selection header block + // before the control flow diverges and a merge block where control flow + // subsequently converges. + auto loc = ifOp.getLoc(); + + // Create `spirv.selection` operation, selection header block and merge + // block. + auto selectionOp = + rewriter.create(loc, spirv::SelectionControl::None); + auto *mergeBlock = rewriter.createBlock(&selectionOp.getBody(), + selectionOp.getBody().end()); + rewriter.create(loc); + + OpBuilder::InsertionGuard guard(rewriter); + auto *selectionHeaderBlock = + rewriter.createBlock(&selectionOp.getBody().front()); + + // Inline `then` region before the merge block and branch to it. + auto &thenRegion = ifOp.getThenRegion(); + auto *thenBlock = &thenRegion.front(); + rewriter.setInsertionPointToEnd(&thenRegion.back()); + rewriter.create(loc, mergeBlock); + rewriter.inlineRegionBefore(thenRegion, mergeBlock); + + auto *elseBlock = mergeBlock; + // If `else` region is not empty, inline that region before the merge block + // and branch to it. + if (!ifOp.getElseRegion().empty()) { + auto &elseRegion = ifOp.getElseRegion(); + elseBlock = &elseRegion.front(); + rewriter.setInsertionPointToEnd(&elseRegion.back()); + rewriter.create(loc, mergeBlock); + rewriter.inlineRegionBefore(elseRegion, mergeBlock); + } - LogicalResult - matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; + // Create a `spirv.BranchConditional` operation for selection header block. + rewriter.setInsertionPointToEnd(selectionHeaderBlock); + rewriter.create(loc, adaptor.getCondition(), + thenBlock, ArrayRef(), + elseBlock, ArrayRef()); + + SmallVector returnTypes; + for (auto result : ifOp.getResults()) { + auto convertedType = typeConverter.convertType(result.getType()); + if (!convertedType) + return rewriter.notifyMatchFailure( + loc, + llvm::formatv("failed to convert type '{0}'", result.getType())); + + returnTypes.push_back(convertedType); + } + replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, + returnTypes); + return success(); + } }; -class WhileOpConversion final : public SCFToSPIRVPattern { +//===----------------------------------------------------------------------===// +// scf::YieldOp +//===----------------------------------------------------------------------===// + +struct TerminatorOpConversion final : SCFToSPIRVPattern { public: - using SCFToSPIRVPattern::SCFToSPIRVPattern; + using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult - matchAndRewrite(scf::WhileOp forOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override; -}; -} // namespace - -/// Helper function to replaces SCF op outputs with SPIR-V variable loads. -/// We create VariableOp to handle the results value of the control flow region. -/// spirv.mlir.loop/spirv.mlir.selection currently don't yield value. Right -/// after the loop we load the value from the allocation and use it as the SCF -/// op result. -template -static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp, - ConversionPatternRewriter &rewriter, - ScfToSPIRVContextImpl *scfToSPIRVContext, - ArrayRef returnTypes) { - - Location loc = scfOp.getLoc(); - auto &allocas = scfToSPIRVContext->outputVars[newOp]; - // Clearing the allocas is necessary in case a dialect conversion path failed - // previously, and this is the second attempt of this conversion. - allocas.clear(); - SmallVector resultValue; - for (Type convertedType : returnTypes) { - auto pointerType = - spirv::PointerType::get(convertedType, spirv::StorageClass::Function); - rewriter.setInsertionPoint(newOp); - auto alloc = rewriter.create( - loc, pointerType, spirv::StorageClass::Function, - /*initializer=*/nullptr); - allocas.push_back(alloc); - rewriter.setInsertionPointAfter(newOp); - Value loadResult = rewriter.create(loc, alloc); - resultValue.push_back(loadResult); + matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange operands = adaptor.getOperands(); + + // If the region is return values, store each value into the associated + // VariableOp created during lowering of the parent region. + if (!operands.empty()) { + auto &allocas = + scfToSPIRVContext->outputVars[terminatorOp->getParentOp()]; + if (allocas.size() != operands.size()) + return failure(); + + auto loc = terminatorOp.getLoc(); + for (unsigned i = 0, e = operands.size(); i < e; i++) + rewriter.create(loc, allocas[i], operands[i]); + if (isa(terminatorOp->getParentOp())) { + // For loops we also need to update the branch jumping back to the + // header. + auto br = cast( + rewriter.getInsertionBlock()->getTerminator()); + SmallVector args(br.getBlockArguments()); + args.append(operands.begin(), operands.end()); + rewriter.setInsertionPoint(br); + rewriter.create(terminatorOp.getLoc(), br.getTarget(), + args); + rewriter.eraseOp(br); + } + } + rewriter.eraseOp(terminatorOp); + return success(); } - rewriter.replaceOp(scfOp, resultValue); -} - -static Region::iterator getBlockIt(Region ®ion, unsigned index) { - return std::next(region.begin(), index); -} +}; //===----------------------------------------------------------------------===// -// scf::ForOp +// scf::WhileOp //===----------------------------------------------------------------------===// -LogicalResult -ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // scf::ForOp can be lowered to the structured control flow represented by - // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop - // latch and the merge block the exit block. The resulting spirv::LoopOp has a - // single back edge from the continue to header block, and a single exit from - // header to merge. - auto loc = forOp.getLoc(); - auto loopOp = rewriter.create(loc, spirv::LoopControl::None); - loopOp.addEntryAndMergeBlock(); - - OpBuilder::InsertionGuard guard(rewriter); - // Create the block for the header. - auto *header = new Block(); - // Insert the header. - loopOp.getBody().getBlocks().insert(getBlockIt(loopOp.getBody(), 1), header); - - // Create the new induction variable to use. - Value adapLowerBound = adaptor.getLowerBound(); - BlockArgument newIndVar = - header->addArgument(adapLowerBound.getType(), adapLowerBound.getLoc()); - for (Value arg : adaptor.getInitArgs()) - header->addArgument(arg.getType(), arg.getLoc()); - Block *body = forOp.getBody(); - - // Apply signature conversion to the body of the forOp. It has a single block, - // with argument which is the induction variable. That has to be replaced with - // the new induction variable. - TypeConverter::SignatureConversion signatureConverter( - body->getNumArguments()); - signatureConverter.remapInput(0, newIndVar); - for (unsigned i = 1, e = body->getNumArguments(); i < e; i++) - signatureConverter.remapInput(i, header->getArgument(i)); - body = rewriter.applySignatureConversion(&forOp.getLoopBody(), - signatureConverter); - - // Move the blocks from the forOp into the loopOp. This is the body of the - // loopOp. - rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.getBody(), - getBlockIt(loopOp.getBody(), 2)); - - SmallVector args(1, adaptor.getLowerBound()); - args.append(adaptor.getInitArgs().begin(), adaptor.getInitArgs().end()); - // Branch into it from the entry. - rewriter.setInsertionPointToEnd(&(loopOp.getBody().front())); - rewriter.create(loc, header, args); - - // Generate the rest of the loop header. - rewriter.setInsertionPointToEnd(header); - auto *mergeBlock = loopOp.getMergeBlock(); - auto cmpOp = rewriter.create( - loc, rewriter.getI1Type(), newIndVar, adaptor.getUpperBound()); - - rewriter.create( - loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); - - // Generate instructions to increment the step of the induction variable and - // branch to the header. - Block *continueBlock = loopOp.getContinueBlock(); - rewriter.setInsertionPointToEnd(continueBlock); - - // Add the step to the induction variable and branch to the header. - Value updatedIndVar = rewriter.create( - loc, newIndVar.getType(), newIndVar, adaptor.getStep()); - rewriter.create(loc, header, updatedIndVar); - - // Infer the return types from the init operands. Vector type may get - // converted to CooperativeMatrix or to Vector type, to avoid having complex - // extra logic to figure out the right type we just infer it from the Init - // operands. - SmallVector initTypes; - for (auto arg : adaptor.getInitArgs()) - initTypes.push_back(arg.getType()); - replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes); - return success(); -} +struct WhileOpConversion final : SCFToSPIRVPattern { + using SCFToSPIRVPattern::SCFToSPIRVPattern; -//===----------------------------------------------------------------------===// -// scf::IfOp -//===----------------------------------------------------------------------===// + LogicalResult + matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = whileOp.getLoc(); + auto loopOp = rewriter.create(loc, spirv::LoopControl::None); + loopOp.addEntryAndMergeBlock(); -LogicalResult -IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - // When lowering `scf::IfOp` we explicitly create a selection header block - // before the control flow diverges and a merge block where control flow - // subsequently converges. - auto loc = ifOp.getLoc(); - - // Create `spirv.selection` operation, selection header block and merge block. - auto selectionOp = - rewriter.create(loc, spirv::SelectionControl::None); - auto *mergeBlock = - rewriter.createBlock(&selectionOp.getBody(), selectionOp.getBody().end()); - rewriter.create(loc); - - OpBuilder::InsertionGuard guard(rewriter); - auto *selectionHeaderBlock = - rewriter.createBlock(&selectionOp.getBody().front()); - - // Inline `then` region before the merge block and branch to it. - auto &thenRegion = ifOp.getThenRegion(); - auto *thenBlock = &thenRegion.front(); - rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create(loc, mergeBlock); - rewriter.inlineRegionBefore(thenRegion, mergeBlock); - - auto *elseBlock = mergeBlock; - // If `else` region is not empty, inline that region before the merge block - // and branch to it. - if (!ifOp.getElseRegion().empty()) { - auto &elseRegion = ifOp.getElseRegion(); - elseBlock = &elseRegion.front(); - rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create(loc, mergeBlock); - rewriter.inlineRegionBefore(elseRegion, mergeBlock); - } + OpBuilder::InsertionGuard guard(rewriter); - // Create a `spirv.BranchConditional` operation for selection header block. - rewriter.setInsertionPointToEnd(selectionHeaderBlock); - rewriter.create(loc, adaptor.getCondition(), - thenBlock, ArrayRef(), - elseBlock, ArrayRef()); + Region &beforeRegion = whileOp.getBefore(); + Region &afterRegion = whileOp.getAfter(); - SmallVector returnTypes; - for (auto result : ifOp.getResults()) { - auto convertedType = typeConverter.convertType(result.getType()); - if (!convertedType) - return rewriter.notifyMatchFailure( - loc, llvm::formatv("failed to convert type '{0}'", result.getType())); + Block &entryBlock = *loopOp.getEntryBlock(); + Block &beforeBlock = beforeRegion.front(); + Block &afterBlock = afterRegion.front(); + Block &mergeBlock = *loopOp.getMergeBlock(); - returnTypes.push_back(convertedType); - } - replaceSCFOutputValue(ifOp, selectionOp, rewriter, scfToSPIRVContext, - returnTypes); - return success(); -} + auto cond = cast(beforeBlock.getTerminator()); + SmallVector condArgs; + if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs))) + return failure(); -//===----------------------------------------------------------------------===// -// scf::YieldOp -//===----------------------------------------------------------------------===// + Value conditionVal = rewriter.getRemappedValue(cond.getCondition()); + if (!conditionVal) + return failure(); -/// Yield is lowered to stores to the VariableOp created during lowering of the -/// parent region. For loops we also need to update the branch looping back to -/// the header with the loop carried values. -LogicalResult TerminatorOpConversion::matchAndRewrite( - scf::YieldOp terminatorOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - ValueRange operands = adaptor.getOperands(); - - // If the region is return values, store each value into the associated - // VariableOp created during lowering of the parent region. - if (!operands.empty()) { - auto &allocas = scfToSPIRVContext->outputVars[terminatorOp->getParentOp()]; - if (allocas.size() != operands.size()) + auto yield = cast(afterBlock.getTerminator()); + SmallVector yieldArgs; + if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs))) return failure(); - auto loc = terminatorOp.getLoc(); - for (unsigned i = 0, e = operands.size(); i < e; i++) - rewriter.create(loc, allocas[i], operands[i]); - if (isa(terminatorOp->getParentOp())) { - // For loops we also need to update the branch jumping back to the header. - auto br = - cast(rewriter.getInsertionBlock()->getTerminator()); - SmallVector args(br.getBlockArguments()); - args.append(operands.begin(), operands.end()); - rewriter.setInsertionPoint(br); - rewriter.create(terminatorOp.getLoc(), br.getTarget(), - args); - rewriter.eraseOp(br); + // Move the while before block as the initial loop header block. + rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(), + getBlockIt(loopOp.getBody(), 1)); + + // Move the while after block as the initial loop body block. + rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(), + getBlockIt(loopOp.getBody(), 2)); + + // Jump from the loop entry block to the loop header block. + rewriter.setInsertionPointToEnd(&entryBlock); + rewriter.create(loc, &beforeBlock, adaptor.getInits()); + + auto condLoc = cond.getLoc(); + + SmallVector resultValues(condArgs.size()); + + // For other SCF ops, the scf.yield op yields the value for the whole SCF + // op. So we use the scf.yield op as the anchor to create/load/store SPIR-V + // local variables. But for the scf.while op, the scf.yield op yields a + // value for the before region, which may not matching the whole op's + // result. Instead, the scf.condition op returns values matching the whole + // op's results. So we need to create/load/store variables according to + // that. + for (const auto &it : llvm::enumerate(condArgs)) { + auto res = it.value(); + auto i = it.index(); + auto pointerType = + spirv::PointerType::get(res.getType(), spirv::StorageClass::Function); + + // Create local variables before the scf.while op. + rewriter.setInsertionPoint(loopOp); + auto alloc = rewriter.create( + condLoc, pointerType, spirv::StorageClass::Function, + /*initializer=*/nullptr); + + // Load the final result values after the scf.while op. + rewriter.setInsertionPointAfter(loopOp); + auto loadResult = rewriter.create(condLoc, alloc); + resultValues[i] = loadResult; + + // Store the current iteration's result value. + rewriter.setInsertionPointToEnd(&beforeBlock); + rewriter.create(condLoc, alloc, res); } - } - rewriter.eraseOp(terminatorOp); - return success(); -} -//===----------------------------------------------------------------------===// -// scf::WhileOp -//===----------------------------------------------------------------------===// - -LogicalResult -WhileOpConversion::matchAndRewrite(scf::WhileOp whileOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - auto loc = whileOp.getLoc(); - auto loopOp = rewriter.create(loc, spirv::LoopControl::None); - loopOp.addEntryAndMergeBlock(); - - OpBuilder::InsertionGuard guard(rewriter); - - Region &beforeRegion = whileOp.getBefore(); - Region &afterRegion = whileOp.getAfter(); - - Block &entryBlock = *loopOp.getEntryBlock(); - Block &beforeBlock = beforeRegion.front(); - Block &afterBlock = afterRegion.front(); - Block &mergeBlock = *loopOp.getMergeBlock(); - - auto cond = cast(beforeBlock.getTerminator()); - SmallVector condArgs; - if (failed(rewriter.getRemappedValues(cond.getArgs(), condArgs))) - return failure(); - - Value conditionVal = rewriter.getRemappedValue(cond.getCondition()); - if (!conditionVal) - return failure(); - - auto yield = cast(afterBlock.getTerminator()); - SmallVector yieldArgs; - if (failed(rewriter.getRemappedValues(yield.getResults(), yieldArgs))) - return failure(); - - // Move the while before block as the initial loop header block. - rewriter.inlineRegionBefore(beforeRegion, loopOp.getBody(), - getBlockIt(loopOp.getBody(), 1)); - - // Move the while after block as the initial loop body block. - rewriter.inlineRegionBefore(afterRegion, loopOp.getBody(), - getBlockIt(loopOp.getBody(), 2)); - - // Jump from the loop entry block to the loop header block. - rewriter.setInsertionPointToEnd(&entryBlock); - rewriter.create(loc, &beforeBlock, adaptor.getInits()); - - auto condLoc = cond.getLoc(); - - SmallVector resultValues(condArgs.size()); - - // For other SCF ops, the scf.yield op yields the value for the whole SCF op. - // So we use the scf.yield op as the anchor to create/load/store SPIR-V local - // variables. But for the scf.while op, the scf.yield op yields a value for - // the before region, which may not matching the whole op's result. Instead, - // the scf.condition op returns values matching the whole op's results. So we - // need to create/load/store variables according to that. - for (const auto &it : llvm::enumerate(condArgs)) { - auto res = it.value(); - auto i = it.index(); - auto pointerType = - spirv::PointerType::get(res.getType(), spirv::StorageClass::Function); - - // Create local variables before the scf.while op. - rewriter.setInsertionPoint(loopOp); - auto alloc = rewriter.create( - condLoc, pointerType, spirv::StorageClass::Function, - /*initializer=*/nullptr); - - // Load the final result values after the scf.while op. - rewriter.setInsertionPointAfter(loopOp); - auto loadResult = rewriter.create(condLoc, alloc); - resultValues[i] = loadResult; - - // Store the current iteration's result value. rewriter.setInsertionPointToEnd(&beforeBlock); - rewriter.create(condLoc, alloc, res); - } - - rewriter.setInsertionPointToEnd(&beforeBlock); - rewriter.replaceOpWithNewOp( - cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt); + rewriter.replaceOpWithNewOp( + cond, conditionVal, &afterBlock, condArgs, &mergeBlock, std::nullopt); - // Convert the scf.yield op to a branch back to the header block. - rewriter.setInsertionPointToEnd(&afterBlock); - rewriter.replaceOpWithNewOp(yield, &beforeBlock, yieldArgs); + // Convert the scf.yield op to a branch back to the header block. + rewriter.setInsertionPointToEnd(&afterBlock); + rewriter.replaceOpWithNewOp(yield, &beforeBlock, + yieldArgs); - rewriter.replaceOp(whileOp, resultValues); - return success(); -} + rewriter.replaceOp(whileOp, resultValues); + return success(); + } +}; +} // namespace //===----------------------------------------------------------------------===// -// Hooks +// Public API //===----------------------------------------------------------------------===// void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter, diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp @@ -56,7 +56,3 @@ if (failed(applyPartialConversion(op, *target, std::move(patterns)))) return signalPassFailure(); } - -std::unique_ptr> mlir::createConvertSCFToSPIRVPass() { - return std::make_unique(); -}