diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -25,6 +25,11 @@ /// transformed to `shape.reduce`, which can be lowered to SCF and Standard. std::unique_ptr createShapeToShapeLowering(); +/// Creates an instance of the ExpandReduce pass that rewrites Shape ReduceOp as +/// an SCF ForOp. The operations within the region of ReduceOp are cloned to the +/// body of ForOp. +std::unique_ptr createExpandReducePass(); + } // end namespace mlir #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -11,9 +11,14 @@ include "mlir/Pass/PassBase.td" -def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> { +def ShapeToShapeLowering : FunctionPass<"shape-to-shape"> { let summary = "Legalize Shape dialect to be convertible to Standard"; let constructor = "mlir::createShapeToShapeLowering()"; } +def ExpandReduce : FunctionPass<"expand-reduce"> { + let summary = "Rewrites `shape.reduce` as an `scf.for` loop"; + let constructor = "mlir::createExpandReducePass()"; +} + #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,4 +1,5 @@ add_mlir_dialect_library(MLIRShapeOpsTransforms + ExpandReduce.cpp ShapeToShapeLowering.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Shape/Transforms/ExpandReduce.cpp b/mlir/lib/Dialect/Shape/Transforms/ExpandReduce.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/ExpandReduce.cpp @@ -0,0 +1,94 @@ +//===- ExpandReduce.cpp - Expand to SCF ForOp -----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::shape; + +namespace { +/// Converts `shape.reduce` to `scf.for`. +struct ReduceOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReduceOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +ReduceOpConverter::matchAndRewrite(ReduceOp reduceOp, + PatternRewriter &rewriter) const { + auto loc = reduceOp.getLoc(); + + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + Value extent_tensor = rewriter.create( + loc, + RankedTensorType::get({ShapedType::kDynamicSize}, + rewriter.getIndexType()), + reduceOp.shape()); + Value rank = rewriter.create(loc, extent_tensor); + + auto loop = rewriter.create( + loc, zero, rank, one, reduceOp.initVals(), + [&](OpBuilder &b, Location nestedLoc, Value iv, ValueRange args) { + Value index_extent = b.create(loc, extent_tensor, iv); + Value size_extent = rewriter.create(loc, index_extent); + + SmallVector mapped_values{iv, size_extent}; + mapped_values.append(args.begin(), args.end()); + + BlockAndValueMapping mapping; + Block *loopBody = reduceOp.getBody(); + mapping.map(loopBody->getArguments(), mapped_values); + for (auto &nested : loopBody->without_terminator()) { + Operation *clone = b.clone(nested, mapping); + mapping.map(nested.getResults(), clone->getResults()); + } + + SmallVector mapped_results; + for (auto result : loopBody->getTerminator()->getOperands()) + mapped_results.push_back(mapping.lookup(result)); + b.create(loc, mapped_results); + }); + + rewriter.replaceOp(reduceOp, loop.getResults()); + return success(); +} + +namespace { +struct ExpandReduce : public ExpandReduceBase { + void runOnFunction() override; +}; +} // namespace + +void ExpandReduce::runOnFunction() { + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +std::unique_ptr mlir::createExpandReducePass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Shape/expand-reduce.mlir b/mlir/test/Dialect/Shape/expand-reduce.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/expand-reduce.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt -expand-reduce -split-input-file %s | FileCheck %s + +// CHECK-LABEL: shape_reduce +// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size { +func @shape_reduce(%shape : !shape.shape) -> !shape.size { + %init = shape.const_size 1 + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size): + %new_acc = shape.mul %acc, %dim + shape.yield %new_acc : !shape.size + } + return %num_elements : !shape.size +} +// CHECK-NEXT: [[SHAPE_C1:%.*]] = shape.const_size 1 +// CHECK-NEXT: [[C0:%.*]] = constant 0 : index +// CHECK-NEXT: [[C1:%.*]] = constant 1 : index + +// CHECK-NEXT: [[EXTENTS:%.*]] = "shape.to_extent_tensor"([[SHAPE]]) +// CHECK-NEXT: [[RANK:%.*]] = rank [[EXTENTS]] : tensor + +// CHECK-NEXT: [[RESULT:%.*]] = scf.for [[I:%.*]] = [[C0]] to [[RANK]] +// CHECK-SAME: step [[C1]] iter_args([[ACC:%.*]] = [[SHAPE_C1]]) +// CHECK-NEXT: [[EXTENT_INDEX:%.*]] = extract_element [[EXTENTS]]{{\[}}[[I]]] +// CHECK-NEXT: [[EXTENT:%.*]] = shape.index_to_size [[EXTENT_INDEX]] +// CHECK-NEXT: [[NEW_ACC:%.*]] = shape.mul [[ACC]], [[EXTENT]] +// CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size +// CHECK-NEXT: } +// CHECK-NEXT: return [[RESULT]] : !shape.size diff --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir --- a/mlir/test/Dialect/Shape/shape-to-shape.mlir +++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s +// RUN: mlir-opt -shape-to-shape -split-input-file %s | FileCheck %s // CHECK-LABEL: func @num_elements_to_reduce( // CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] {