diff --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h @@ -0,0 +1,32 @@ +//===------------ SCFToSPIRV.h - Pass entrypoint ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides patterns for lowering SCF ops to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_ +#define MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_ + +#include + +namespace mlir { +class MLIRContext; +class Pass; + +// Owning list of rewriting patterns. +class OwningRewritePatternList; +class SPIRVTypeConverter; + +/// Collects a set of patterns to lower from scf.for, scf.if, and +/// loop.terminator to CFG operations within the SPIR-V dialect. +void populateSCFToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns); +} // namespace mlir + +#endif // MLIR_CONVERSION_SCFTOSPIRV_SCFTOSPIRV_H_ diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(LinalgToSPIRV) add_subdirectory(LinalgToStandard) add_subdirectory(SCFToGPU) +add_subdirectory(SCFToSPIRV) add_subdirectory(SCFToStandard) add_subdirectory(ShapeToSCF) add_subdirectory(ShapeToStandard) diff --git a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToSPIRV/CMakeLists.txt @@ -14,6 +14,7 @@ MLIRGPU MLIRIR MLIRPass + MLIRSCFToSPIRV MLIRSPIRV MLIRStandardOps MLIRStandardToSPIRVTransforms diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -11,7 +11,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h" #include "mlir/Dialect/GPU/GPUDialect.h" -#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVLowering.h" #include "mlir/Dialect/SPIRV/SPIRVOps.h" @@ -20,41 +19,6 @@ using namespace mlir; namespace { - -/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. -class ForOpConversion final : public SPIRVOpLowering { -public: - using SPIRVOpLowering::SPIRVOpLowering; - - LogicalResult - matchAndRewrite(scf::ForOp forOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Pattern to convert a scf::IfOp within kernel functions into -/// spirv::SelectionOp. -class IfOpConversion final : public SPIRVOpLowering { -public: - using SPIRVOpLowering::SPIRVOpLowering; - - LogicalResult - matchAndRewrite(scf::IfOp IfOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - -/// Pattern to erase a scf::YieldOp. -class TerminatorOpConversion final : public SPIRVOpLowering { -public: - using SPIRVOpLowering::SPIRVOpLowering; - - LogicalResult - matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - rewriter.eraseOp(terminatorOp); - return success(); - } -}; - /// Pattern lowering GPU block/thread size/id to loading SPIR-V invocation /// builtin variables. template @@ -129,134 +93,6 @@ } // namespace //===----------------------------------------------------------------------===// -// scf::ForOp. -//===----------------------------------------------------------------------===// - -LogicalResult -ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - // scf::ForOp can be lowered to the structured control flow represented by - // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop - // latch and the merge block the exit block. The resulting spirv::LoopOp has a - // single back edge from the continue to header block, and a single exit from - // header to merge. - scf::ForOpAdaptor forOperands(operands); - auto loc = forOp.getLoc(); - auto loopControl = rewriter.getI32IntegerAttr( - static_cast(spirv::LoopControl::None)); - auto loopOp = rewriter.create(loc, loopControl); - loopOp.addEntryAndMergeBlock(); - - OpBuilder::InsertionGuard guard(rewriter); - // Create the block for the header. - auto header = new Block(); - // Insert the header. - loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); - - // Create the new induction variable to use. - BlockArgument newIndVar = - header->addArgument(forOperands.lowerBound().getType()); - Block *body = forOp.getBody(); - - // Apply signature conversion to the body of the forOp. It has a single block, - // with argument which is the induction variable. That has to be replaced with - // the new induction variable. - TypeConverter::SignatureConversion signatureConverter( - body->getNumArguments()); - signatureConverter.remapInput(0, newIndVar); - FailureOr newBody = rewriter.convertRegionTypes( - &forOp.getLoopBody(), typeConverter, &signatureConverter); - if (failed(newBody)) - return failure(); - body = *newBody; - - // Delete the loop terminator. - rewriter.eraseOp(body->getTerminator()); - - // Move the blocks from the forOp into the loopOp. This is the body of the - // loopOp. - rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(), - std::next(loopOp.body().begin(), 2)); - - // Branch into it from the entry. - rewriter.setInsertionPointToEnd(&(loopOp.body().front())); - rewriter.create(loc, header, forOperands.lowerBound()); - - // Generate the rest of the loop header. - rewriter.setInsertionPointToEnd(header); - auto mergeBlock = loopOp.getMergeBlock(); - auto cmpOp = rewriter.create( - loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); - rewriter.create( - loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); - - // Generate instructions to increment the step of the induction variable and - // branch to the header. - Block *continueBlock = loopOp.getContinueBlock(); - rewriter.setInsertionPointToEnd(continueBlock); - - // Add the step to the induction variable and branch to the header. - Value updatedIndVar = rewriter.create( - loc, newIndVar.getType(), newIndVar, forOperands.step()); - rewriter.create(loc, header, updatedIndVar); - - rewriter.eraseOp(forOp); - return success(); -} - -//===----------------------------------------------------------------------===// -// scf::IfOp. -//===----------------------------------------------------------------------===// - -LogicalResult -IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - // When lowering `scf::IfOp` we explicitly create a selection header block - // before the control flow diverges and a merge block where control flow - // subsequently converges. - scf::IfOpAdaptor ifOperands(operands); - auto loc = ifOp.getLoc(); - - // Create `spv.selection` operation, selection header block and merge block. - auto selectionControl = rewriter.getI32IntegerAttr( - static_cast(spirv::SelectionControl::None)); - auto selectionOp = rewriter.create(loc, selectionControl); - selectionOp.addMergeBlock(); - auto *mergeBlock = selectionOp.getMergeBlock(); - - OpBuilder::InsertionGuard guard(rewriter); - auto *selectionHeaderBlock = new Block(); - selectionOp.body().getBlocks().push_front(selectionHeaderBlock); - - // Inline `then` region before the merge block and branch to it. - auto &thenRegion = ifOp.thenRegion(); - auto *thenBlock = &thenRegion.front(); - rewriter.setInsertionPointToEnd(&thenRegion.back()); - rewriter.create(loc, mergeBlock); - rewriter.inlineRegionBefore(thenRegion, mergeBlock); - - auto *elseBlock = mergeBlock; - // If `else` region is not empty, inline that region before the merge block - // and branch to it. - if (!ifOp.elseRegion().empty()) { - auto &elseRegion = ifOp.elseRegion(); - elseBlock = &elseRegion.front(); - rewriter.setInsertionPointToEnd(&elseRegion.back()); - rewriter.create(loc, mergeBlock); - rewriter.inlineRegionBefore(elseRegion, mergeBlock); - } - - // Create a `spv.BranchConditional` operation for selection header block. - rewriter.setInsertionPointToEnd(selectionHeaderBlock); - rewriter.create(loc, ifOperands.condition(), - thenBlock, ArrayRef(), - elseBlock, ArrayRef()); - - rewriter.eraseOp(ifOp); - return success(); -} - -//===----------------------------------------------------------------------===// // Builtins. //===----------------------------------------------------------------------===// @@ -479,8 +315,7 @@ OwningRewritePatternList &patterns) { populateWithGenerated(context, &patterns); patterns.insert< - ForOpConversion, GPUFuncOpConversion, GPUModuleConversion, - GPUReturnOpConversion, IfOpConversion, + GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion, LaunchConfigConversion, LaunchConfigConversion, LaunchConfigConversion, SingleDimLaunchConfigConversion, - TerminatorOpConversion, WorkGroupSizeConversion>(context, typeConverter); + WorkGroupSizeConversion>(context, typeConverter); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h" #include "../PassDetail.h" #include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.h" +#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/SCF/SCF.h" @@ -59,6 +60,7 @@ SPIRVTypeConverter typeConverter(targetAttr); OwningRewritePatternList patterns; populateGPUToSPIRVPatterns(context, typeConverter, patterns); + populateSCFToSPIRVPatterns(context, typeConverter, patterns); populateStandardToSPIRVPatterns(context, typeConverter, patterns); if (failed(applyFullConversion(kernelModules, *target, patterns))) diff --git a/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SCFToSPIRV/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(MLIRSCFToSPIRV + SCFToSPIRV.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/SCFToSPIRV + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRAffineOps + MLIRAffineToStandard + MLIRSPIRV + MLIRIR + MLIRLinalgOps + MLIRPass + MLIRStandardOps + MLIRSupport + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -0,0 +1,191 @@ +//===- SCFToSPIRV.cpp - Convert SCF ops to SPIR-V dialect -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the conversion patterns from SCF ops to SPIR-V dialect. +// +//===----------------------------------------------------------------------===// +#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/SPIRV/SPIRVLowering.h" +#include "mlir/Dialect/SPIRV/SPIRVOps.h" +#include "mlir/IR/Module.h" + +using namespace mlir; + +namespace { + +/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp. +class ForOpConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(scf::ForOp forOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Pattern to convert a scf::IfOp within kernel functions into +/// spirv::SelectionOp. +class IfOpConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(scf::IfOp ifOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Pattern to erase a scf::YieldOp. +class TerminatorOpConversion final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(terminatorOp); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// scf::ForOp. +//===----------------------------------------------------------------------===// + +LogicalResult +ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // scf::ForOp can be lowered to the structured control flow represented by + // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop + // latch and the merge block the exit block. The resulting spirv::LoopOp has a + // single back edge from the continue to header block, and a single exit from + // header to merge. + scf::ForOpAdaptor forOperands(operands); + auto loc = forOp.getLoc(); + auto loopControl = rewriter.getI32IntegerAttr( + static_cast(spirv::LoopControl::None)); + auto loopOp = rewriter.create(loc, loopControl); + loopOp.addEntryAndMergeBlock(); + + OpBuilder::InsertionGuard guard(rewriter); + // Create the block for the header. + auto *header = new Block(); + // Insert the header. + loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); + + // Create the new induction variable to use. + BlockArgument newIndVar = + header->addArgument(forOperands.lowerBound().getType()); + Block *body = forOp.getBody(); + + // Apply signature conversion to the body of the forOp. It has a single block, + // with argument which is the induction variable. That has to be replaced with + // the new induction variable. + TypeConverter::SignatureConversion signatureConverter( + body->getNumArguments()); + signatureConverter.remapInput(0, newIndVar); + FailureOr newBody = rewriter.convertRegionTypes( + &forOp.getLoopBody(), typeConverter, &signatureConverter); + if (failed(newBody)) + return failure(); + body = *newBody; + + // Delete the loop terminator. + rewriter.eraseOp(body->getTerminator()); + + // Move the blocks from the forOp into the loopOp. This is the body of the + // loopOp. + rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(), + std::next(loopOp.body().begin(), 2)); + + // Branch into it from the entry. + rewriter.setInsertionPointToEnd(&(loopOp.body().front())); + rewriter.create(loc, header, forOperands.lowerBound()); + + // Generate the rest of the loop header. + rewriter.setInsertionPointToEnd(header); + auto *mergeBlock = loopOp.getMergeBlock(); + auto cmpOp = rewriter.create( + loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); + rewriter.create( + loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); + + // Generate instructions to increment the step of the induction variable and + // branch to the header. + Block *continueBlock = loopOp.getContinueBlock(); + rewriter.setInsertionPointToEnd(continueBlock); + + // Add the step to the induction variable and branch to the header. + Value updatedIndVar = rewriter.create( + loc, newIndVar.getType(), newIndVar, forOperands.step()); + rewriter.create(loc, header, updatedIndVar); + + rewriter.eraseOp(forOp); + return success(); +} + +//===----------------------------------------------------------------------===// +// scf::IfOp. +//===----------------------------------------------------------------------===// + +LogicalResult +IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + // When lowering `scf::IfOp` we explicitly create a selection header block + // before the control flow diverges and a merge block where control flow + // subsequently converges. + scf::IfOpAdaptor ifOperands(operands); + auto loc = ifOp.getLoc(); + + // Create `spv.selection` operation, selection header block and merge block. + auto selectionControl = rewriter.getI32IntegerAttr( + static_cast(spirv::SelectionControl::None)); + auto selectionOp = rewriter.create(loc, selectionControl); + selectionOp.addMergeBlock(); + auto *mergeBlock = selectionOp.getMergeBlock(); + + OpBuilder::InsertionGuard guard(rewriter); + auto *selectionHeaderBlock = new Block(); + selectionOp.body().getBlocks().push_front(selectionHeaderBlock); + + // Inline `then` region before the merge block and branch to it. + auto &thenRegion = ifOp.thenRegion(); + auto *thenBlock = &thenRegion.front(); + rewriter.setInsertionPointToEnd(&thenRegion.back()); + rewriter.create(loc, mergeBlock); + rewriter.inlineRegionBefore(thenRegion, mergeBlock); + + auto *elseBlock = mergeBlock; + // If `else` region is not empty, inline that region before the merge block + // and branch to it. + if (!ifOp.elseRegion().empty()) { + auto &elseRegion = ifOp.elseRegion(); + elseBlock = &elseRegion.front(); + rewriter.setInsertionPointToEnd(&elseRegion.back()); + rewriter.create(loc, mergeBlock); + rewriter.inlineRegionBefore(elseRegion, mergeBlock); + } + + // Create a `spv.BranchConditional` operation for selection header block. + rewriter.setInsertionPointToEnd(selectionHeaderBlock); + rewriter.create(loc, ifOperands.condition(), + thenBlock, ArrayRef(), + elseBlock, ArrayRef()); + + rewriter.eraseOp(ifOp); + return success(); +} + +void mlir::populateSCFToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + OwningRewritePatternList &patterns) { + patterns.insert( + context, typeConverter); +}