diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -25,6 +25,11 @@ /// transformed to `shape.reduce`, which can be lowered to SCF and Standard. std::unique_ptr createShapeToShapeLowering(); +/// Creates an instance of the ShapeToSCFLowering pass that converts Shape ops +/// to SCF dialect. For example, `shape.reduce` gets rewritten as an `scf.for`. +/// The operations within the region of ReduceOp are cloned to the ForOp body. +std::unique_ptr createShapeToSCFLowering(); + } // end namespace mlir #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -11,9 +11,14 @@ include "mlir/Pass/PassBase.td" -def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> { +def ShapeToShapeLowering : FunctionPass<"shape-to-shape"> { let summary = "Legalize Shape dialect to be convertible to Standard"; let constructor = "mlir::createShapeToShapeLowering()"; } +def ShapeToSCFLowering : FunctionPass<"shape-to-scf"> { + let summary = "Convert Shape ops to SCF dialect"; + let constructor = "mlir::createShapeToSCFLowering()"; +} + #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms ShapeToShapeLowering.cpp + ShapeToSCFLowering.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToSCFLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToSCFLowering.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToSCFLowering.cpp @@ -0,0 +1,94 @@ +//===- ShapeToSCFLowering.cpp - Expand to SCF ForOp -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::shape; + +namespace { +/// Converts `shape.reduce` to `scf.for`. +struct ReduceOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReduceOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp, + PatternRewriter &rewriter) const { + auto loc = reduceOp.getLoc(); + + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + Value extentTensor = rewriter.create( + loc, + RankedTensorType::get({ShapedType::kDynamicSize}, + rewriter.getIndexType()), + reduceOp.shape()); + Value rank = rewriter.create(loc, extentTensor); + + auto loop = rewriter.create( + loc, zero, rank, one, reduceOp.initVals(), + [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { + Value indexExtent = b.create(loc, extentTensor, iv); + Value sizeExtent = b.create(loc, indexExtent); + + SmallVector mapped_values{iv, sizeExtent}; + mapped_values.append(args.begin(), args.end()); + + BlockAndValueMapping mapping; + Block *loopBody = reduceOp.getBody(); + mapping.map(loopBody->getArguments(), mapped_values); + for (auto &nested : loopBody->without_terminator()) { + Operation *clone = b.clone(nested, mapping); + mapping.map(nested.getResults(), clone->getResults()); + } + + SmallVector mappedResults; + for (auto result : loopBody->getTerminator()->getOperands()) + mappedResults.push_back(mapping.lookup(result)); + b.create(loc, mappedResults); + }); + + rewriter.replaceOp(reduceOp, loop.getResults()); + return success(); +} + +namespace { +struct ShapeToSCFLowering : public ShapeToSCFLoweringBase { + void runOnFunction() override; +}; +} // namespace + +void ShapeToSCFLowering::runOnFunction() { + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +std::unique_ptr mlir::createShapeToSCFLowering() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Shape/shape-to-scf.mlir b/mlir/test/Dialect/Shape/shape-to-scf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/shape-to-scf.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt -shape-to-scf -split-input-file %s | FileCheck %s + +// CHECK-LABEL: shape_reduce +// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size { +func @shape_reduce(%shape : !shape.shape) -> !shape.size { + %init = shape.const_size 1 + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size): + %new_acc = shape.mul %acc, %dim + shape.yield %new_acc : !shape.size + } + return %num_elements : !shape.size +} +// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1 +// CHECK-NEXT: [[C0:%.*]] = constant 0 : index +// CHECK-NEXT: [[C1:%.*]] = constant 1 : index + +// CHECK-NEXT: [[EXTENTS:%.*]] = "shape.to_extent_tensor"([[SHAPE]]) +// CHECK-NEXT: [[RANK:%.*]] = rank [[EXTENTS]] : tensor + +// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[RANK]] +// CHECK-SAME: step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]]) +// CHECK-NEXT: [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]] +// CHECK-NEXT: [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]] +// CHECK-NEXT: [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]] +// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size +// CHECK-NEXT: } +// CHECK-NEXT: return [[RESULT]] : !shape.size diff --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir --- a/mlir/test/Dialect/Shape/shape-to-shape.mlir +++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s +// RUN: mlir-opt -shape-to-shape -split-input-file %s | FileCheck %s // CHECK-LABEL: func @num_elements_to_reduce( // CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] {