diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -207,6 +207,15 @@ } //===----------------------------------------------------------------------===// +// ShapeToSCF +//===----------------------------------------------------------------------===// + +def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> { + let summary = "Convert operations from the shape dialect to the SCF dialect"; + let constructor = "mlir::createConvertShapeToSCFPass()"; +} + +//===----------------------------------------------------------------------===// // SPIRVToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h b/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ShapeToSCF/ShapeToSCF.h @@ -0,0 +1,27 @@ +//===- ShapeToSCF.h - Conversion utils from Shape to SCF dialect ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_ +#define MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_ + +#include + +namespace mlir { + +class MLIRContext; +class FunctionPass; +class OwningRewritePatternList; + +void populateShapeToSCFConversionPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx); + +std::unique_ptr createConvertShapeToSCFPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_SHAPETOSCF_SHAPETOSCF_H_ diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -26,6 +26,7 @@ #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" #include "mlir/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVMPass.h" +#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(LinalgToStandard) add_subdirectory(SCFToGPU) add_subdirectory(SCFToStandard) +add_subdirectory(ShapeToSCF) add_subdirectory(ShapeToStandard) add_subdirectory(SPIRVToLLVM) add_subdirectory(StandardToLLVM) diff --git a/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ShapeToSCF/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRShapeToSCF + ShapeToSCF.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToSCF + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRShape + MLIRPass + MLIRSCF + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp @@ -0,0 +1,99 @@ +//===- ShapeToSCF.cpp - conversion from Shape to SCF dialect --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ShapeToSCF/ShapeToSCF.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::shape; + +namespace { +/// Converts `shape.reduce` to `scf.for`. +struct ReduceOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReduceOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp, + PatternRewriter &rewriter) const { + auto loc = reduceOp.getLoc(); + + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + Value extentTensor = rewriter.create( + loc, + RankedTensorType::get({ShapedType::kDynamicSize}, + rewriter.getIndexType()), + reduceOp.shape()); + Value size = + rewriter.create(loc, rewriter.getIndexType(), extentTensor, zero); + + auto loop = rewriter.create( + loc, zero, size, one, reduceOp.initVals(), + [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { + Value indexExtent = b.create(loc, extentTensor, iv); + Value sizeExtent = b.create(loc, indexExtent); + + SmallVector mapped_values{iv, sizeExtent}; + mapped_values.append(args.begin(), args.end()); + + BlockAndValueMapping mapping; + Block *reduceBody = reduceOp.getBody(); + mapping.map(reduceBody->getArguments(), mapped_values); + for (auto &nested : reduceBody->without_terminator()) + b.clone(nested, mapping); + + SmallVector mappedResults; + for (auto result : reduceBody->getTerminator()->getOperands()) + mappedResults.push_back(mapping.lookup(result)); + b.create(loc, mappedResults); + }); + + rewriter.replaceOp(reduceOp, loop.getResults()); + return success(); +} + +namespace { +struct ConvertShapeToSCFPass + : public ConvertShapeToSCFBase { + void runOnFunction() override; +}; +} // namespace + +void ConvertShapeToSCFPass::runOnFunction() { + MLIRContext &ctx = getContext(); + + OwningRewritePatternList patterns; + populateShapeToSCFConversionPatterns(patterns, &ctx); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +void mlir::populateShapeToSCFConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + patterns.insert(ctx); +} + +std::unique_ptr mlir::createConvertShapeToSCFPass() { + return std::make_unique(); +} diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s + +// CHECK-LABEL: shape_reduce +// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size { +func @shape_reduce(%shape : !shape.shape) -> !shape.size { + %init = shape.const_size 1 + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size): + %new_acc = shape.mul %acc, %dim + shape.yield %new_acc : !shape.size + } + return %num_elements : !shape.size +} +// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1 +// CHECK-NEXT: [[C0:%.*]] = constant 0 : index +// CHECK-NEXT: [[C1:%.*]] = constant 1 : index + +// CHECK-NEXT: [[EXTENTS:%.*]] = "shape.to_extent_tensor"([[SHAPE]]) +// CHECK-NEXT: [[SIZE:%.*]] = dim [[EXTENTS]], [[C0]] : tensor + +// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[SIZE]] +// CHECK-SAME: step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]]) +// CHECK-NEXT: [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]] +// CHECK-NEXT: [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]] +// CHECK-NEXT: [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]] +// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size +// CHECK-NEXT: } +// CHECK-NEXT: return [[RESULT]] : !shape.size