diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -23,7 +23,7 @@ class OwningRewritePatternList; class Pass; -/// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape +/// Creates an instance of the `ShapeToShapeLowering` pass that legalizes Shape /// dialect to be convertible to Standard. For example, `shape.num_elements` get /// transformed to `shape.reduce`, which can be lowered to SCF and Standard. std::unique_ptr createShapeToShapeLowering(); @@ -32,6 +32,13 @@ void populateShapeRewritePatterns(MLIRContext *context, OwningRewritePatternList &patterns); +/// Canonicalize common patterns of `shape` operations to `std` operations. +/// The subsequent application of `shape.shape_of` and `shape.get_extent`, e.g., +/// will result in a single `std.dim` operation. +std::unique_ptr createCanonicalizeShapeToStandardPass(); +void populateCanonicalizeShapeToStandardPatterns( + OwningRewritePatternList *patterns, MLIRContext *ctx); + // Collects a set of patterns to replace all constraints with passing witnesses. // This is intended to then allow all ShapeConstraint related ops and data to // have no effects and allow them to be freely removed such as through diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -21,4 +21,10 @@ let constructor = "mlir::createShapeToShapeLowering()"; } +def CanonicalizeShapeToStandard : FunctionPass<"canonicalize-shape-to-standard"> { + let summary = "Canonicalize common patterns of shape operations with " + "standard operations"; + let constructor = "mlir::createCanonicalizeShapeToStandardPass()"; +} + #endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1408,7 +1408,9 @@ let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value memrefOrTensor, int64_t index"> + "Value memrefOrTensor, int64_t index">, + OpBuilder<"OpBuilder &builder, OperationState &result, " + "Value memrefOrTensor, Value index">, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -1,6 +1,11 @@ +set(LLVM_TARGET_DEFINITIONS CanonicalizeShapeToStandardPatterns.td) +mlir_tablegen(CanonicalizeShapeToStandardPatterns.inc -gen-rewriters) +add_public_tablegen_target(CanonicalizeShapeToStandardPatternsIncGen) + add_mlir_dialect_library(MLIRShapeOpsTransforms RemoveShapeConstraints.cpp ShapeToShapeLowering.cpp + CanonicalizeShapeToStandard.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms @@ -14,6 +19,7 @@ MLIRIR MLIRPass MLIRShape + MLIRStandardOps MLIRSupport MLIRTransforms ) diff --git a/mlir/lib/Dialect/Shape/Transforms/CanonicalizeShapeToStandard.cpp b/mlir/lib/Dialect/Shape/Transforms/CanonicalizeShapeToStandard.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/CanonicalizeShapeToStandard.cpp @@ -0,0 +1,51 @@ +//===- SimplifyShape.cpp - Simplify shape operations where possible -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::shape; + +namespace { + +// Use generated patterns. +#include "CanonicalizeShapeToStandardPatterns.inc" + +struct CanonicalizeShapeToStandardPass + : public CanonicalizeShapeToStandardBase { + + void runOnFunction() override { + OwningRewritePatternList patterns; + populateCanonicalizeShapeToStandardPatterns(&patterns, &getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr mlir::createCanonicalizeShapeToStandardPass() { + return std::make_unique(); +} + +void mlir::populateCanonicalizeShapeToStandardPatterns( + OwningRewritePatternList *patterns, MLIRContext *ctx) { + populateWithGenerated(ctx, patterns); +} diff --git a/mlir/lib/Dialect/Shape/Transforms/CanonicalizeShapeToStandardPatterns.td b/mlir/lib/Dialect/Shape/Transforms/CanonicalizeShapeToStandardPatterns.td new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/CanonicalizeShapeToStandardPatterns.td @@ -0,0 +1,6 @@ +include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" + +def GetExtentShapeOfConversion : Pat< + (Shape_GetExtentOp (Shape_ShapeOfOp $arg), $idx), + (Shape_IndexToSizeOp (DimOp $arg, (Shape_SizeToIndexOp $idx)))>; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1273,8 +1273,13 @@ Value memrefOrTensor, int64_t index) { auto loc = result.location; Value indexValue = builder.create(loc, index); + build(builder, result, memrefOrTensor, indexValue); +} + +void DimOp::build(OpBuilder &builder, OperationState &result, + Value memrefOrTensor, Value index) { auto indexTy = builder.getIndexType(); - build(builder, result, indexTy, memrefOrTensor, indexValue); + build(builder, result, indexTy, memrefOrTensor, index); } Optional DimOp::getConstantIndex() { diff --git a/mlir/test/Dialect/Shape/canonicalize-shape-to-standard.mlir b/mlir/test/Dialect/Shape/canonicalize-shape-to-standard.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/canonicalize-shape-to-standard.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt --canonicalize-shape-to-standard --split-input-file %s | FileCheck %s + +// Express `get_extent` as `std.dim` when it relies directly on the outcome of a +// `shape_of` operation. +// CHECK-LABEL: @get_extent +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>, %[[IDX:.*]]: !shape.size) -> !shape.size +func @get_extent(%arg : tensor<2x3xf32>, %idx : !shape.size) -> !shape.size { + // CHECK-DAG: %[[STD_IDX:.*]] = shape.size_to_index %[[IDX]] + // CHECK-DAG: %[[STD_RESULT:.*]] = dim %[[ARG]], %[[STD_IDX]] : tensor<2x3xf32> + // CHECK-DAG: %[[RESULT:.*]] = shape.index_to_size %[[STD_RESULT]] + // CHECK-DAG: return %[[RESULT]] : !shape.size + %shape = shape.shape_of %arg : tensor<2x3xf32> + %result = shape.get_extent %shape, %idx + return %result : !shape.size +} + +