diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -22,14 +22,21 @@ class OwningRewritePatternList; class Pass; -/// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape +/// Creates an instance of the `ShapeToShapeLoweringi` pass that legalizes Shape /// dialect to be convertible to Standard. For example, `shape.num_elements` get /// transformed to `shape.reduce`, which can be lowered to SCF and Standard. std::unique_ptr createShapeToShapeLowering(); +/// Creates an instance of the `SimplifyShapePass` that reduces common patterns +/// of `shape` operations to `std` operations. +/// The subsequent application of `shape.shape_of` and `shape.get_extent`, e.g., +/// will result in a `std.dim` operation. +std::unique_ptr createSimplifyShapePass(); + /// Collects a set of patterns to rewrite ops within the Shape dialect. void populateShapeRewritePatterns(MLIRContext *context, OwningRewritePatternList &patterns); + } // end namespace mlir #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -16,4 +16,9 @@ let constructor = "mlir::createShapeToShapeLowering()"; } +def SimplifyShape : FunctionPass<"simplify-shape"> { + let summary = "Simplify common patterns of shape operations"; + let constructor = "mlir::createSimplifyShapePass()"; +} + #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1408,7 +1408,9 @@ let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value memrefOrTensor, int64_t index"> + "Value memrefOrTensor, int64_t index">, + OpBuilder<"OpBuilder &builder, OperationState &result, " + "Value memrefOrTensor, Value index">, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,5 +1,10 @@ +set(LLVM_TARGET_DEFINITIONS SimplifyShapePatterns.td) +mlir_tablegen(SimplifyShapePatterns.inc -gen-rewriters) +add_public_tablegen_target(SimplifyShapePatternsIncGen) + add_mlir_dialect_library(MLIRShapeOpsTransforms ShapeToShapeLowering.cpp + SimplifyShape.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms @@ -13,6 +18,7 @@ MLIRIR MLIRPass MLIRShape + MLIRStandardOps MLIRSupport MLIRTransforms ) diff --git a/mlir/lib/Dialect/Shape/Transforms/SimplifyShape.cpp b/mlir/lib/Dialect/Shape/Transforms/SimplifyShape.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/SimplifyShape.cpp @@ -0,0 +1,45 @@ +//===- SimplifyShape.cpp - Simplify shape operations where possible -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::shape; + +namespace { + +// Use generated patterns. +#include "SimplifyShapePatterns.inc" + +struct SimplifyShapePass : public SimplifyShapeBase { + + void runOnFunction() override { + OwningRewritePatternList patterns; + populateWithGenerated(&getContext(), &patterns); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::createSimplifyShapePass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/Shape/Transforms/SimplifyShapePatterns.td b/mlir/lib/Dialect/Shape/Transforms/SimplifyShapePatterns.td new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/SimplifyShapePatterns.td @@ -0,0 +1,6 @@ +include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" + +def GetExtentShapeOfConversion : Pat< + (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx), + (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx)))>; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1273,8 +1273,13 @@ Value memrefOrTensor, int64_t index) { auto loc = result.location; Value indexValue = builder.create(loc, index); + build(builder, result, memrefOrTensor, indexValue); +} + +void DimOp::build(OpBuilder &builder, OperationState &result, + Value memrefOrTensor, Value index) { auto indexTy = builder.getIndexType(); - build(builder, result, indexTy, memrefOrTensor, indexValue); + build(builder, result, indexTy, memrefOrTensor, index); } Optional DimOp::getConstantIndex() { diff --git a/mlir/test/Dialect/Shape/simplify-shape.mlir b/mlir/test/Dialect/Shape/simplify-shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/simplify-shape.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt --simplify-shape --split-input-file %s | FileCheck %s + +// Express `get_extent` as `std.dim` when it relies directly on the outcome of a +// `shape_of` operation. +// CHECK-LABEL: @get_extent +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: !shape.size) -> !shape.size +func @get_extent(%arg : tensor<2x3xf32>, %idx : !shape.size) -> !shape.size { + // CHECK-DAG: %[[STD_IDX:.*]] = shape.size_to_index %[[IDX]] + // CHECK-DAG: %[[STD_RESULT:.*]] = dim %[[ARG]], %[[STD_IDX]] : tensor<2x3xf32> + // CHECK-DAG: %[[RESULT:.*]] = shape.index_to_size %[[STD_RESULT]] + // CHECK-DAG: return %[[RESULT]] : !shape.size + %shape = shape.shape_of %arg : tensor<2x3xf32> + %result = shape.get_extent %shape, %idx + return %result : !shape.size +} + +