diff --git a/mlir/docs/Passes.md b/mlir/docs/Passes.md --- a/mlir/docs/Passes.md +++ b/mlir/docs/Passes.md @@ -36,6 +36,10 @@ [include "QuantPasses.md"] +## `shape` Dialect Passes + +[include "ShapePasses.md"] + ## `spv` Dialect Passes [include "SPIRVPasses.md"] diff --git a/mlir/include/mlir/Dialect/Shape/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Shape/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Shape/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -327,7 +327,7 @@ let arguments = (ins Shape_ShapeType:$shape, Variadic:$initVals); let results = (outs Variadic:$result); - let regions = (region SizedRegion<1>:$body); + let regions = (region SizedRegion<1>:$region); let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Shape/Transforms/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls) +add_public_tablegen_target(MLIRShapeTransformsIncGen) + +add_mlir_doc(Passes -gen-pass-doc ShapePasses ./) diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h @@ -0,0 +1,30 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors in the +// shape transformation library. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ +#define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ + +#include + +namespace mlir { + +class Pass; + +/// Creates an instance of the ShapeToShapeLowering pass that legalizes Shape +/// dialect to be convertible to Standard. For example, `shape.num_elements` get +/// transformed to `shape.reduce`, which can be lowered to SCF and Standard. +std::unique_ptr createShapeToShapeLowering(); + +} // end namespace mlir + +#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES_H_ diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.td @@ -0,0 +1,19 @@ +//===-- Passes.td - ShapeOps pass definition file ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES +#define MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES + +include "mlir/Pass/PassBase.td" + +def ShapeToShapeLowering : FunctionPass<"shape-to-shape-lowering"> { + let summary = "Legalize Shape dialect to be convertible to Standard"; + let constructor = "mlir::createShapeToShapeLowering()"; +} + +#endif // MLIR_DIALECT_SHAPE_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -37,6 +37,7 @@ #include "mlir/Dialect/Quant/Passes.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SPIRV/Passes.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Transforms/LocationSnapshot.h" #include "mlir/Transforms/Passes.h" @@ -94,6 +95,10 @@ // Standard #define GEN_PASS_REGISTRATION #include "mlir/Dialect/StandardOps/Transforms/Passes.h.inc" + + // Shape +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/Shape/Transforms/Passes.h.inc" } } // namespace mlir diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/CMakeLists.txt @@ -18,3 +18,5 @@ MLIRIR MLIRSideEffectInterfaces ) + +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -504,7 +504,7 @@ static LogicalResult verify(ReduceOp op) { // Verify block arg types. - Block &block = op.body().front(); + Block &block = op.region().front(); auto blockArgsCount = op.initVals().size() + 2; if (block.getNumArguments() != blockArgsCount) @@ -560,7 +560,7 @@ p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() << ") "; p.printOptionalArrowTypeList(op.getResultTypes()); - p.printRegion(op.body()); + p.printRegion(op.region()); p.printOptionalAttrDict(op.getAttrs()); } diff --git a/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRShapeOpsTransforms + ShapeToShapeLowering.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ShapeOps/Transforms + + DEPENDS + MLIRShapeTransformsIncGen + ) + +target_link_libraries(MLIRShapeOpsTransforms + PUBLIC + MLIRIR + MLIRPass + MLIRShape + MLIRSupport + ) diff --git a/mlir/lib/Dialect/Shape/Transforms/PassDetail.h b/mlir/lib/Dialect/Shape/Transforms/PassDetail.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/PassDetail.h @@ -0,0 +1,21 @@ +//===- PassDetail.h - Shape Pass class details ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_ +#define DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { + +#define GEN_PASS_CLASSES +#include "mlir/Dialect/Shape/Transforms/Passes.h.inc" + +} // end namespace mlir + +#endif // DIALECT_SHAPE_TRANSFORMS_PASSDETAIL_H_ diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp @@ -0,0 +1,69 @@ +//===- ShapeToShapeLowering.cpp - Prepare for lowering to Standard --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::shape; + +namespace { +/// Converts `shape.num_elements` to `shape.reduce`. +struct NumElementsOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(NumElementsOp op, + PatternRewriter &rewriter) const final; +}; +} // namespace + +LogicalResult +NumElementsOpConverter::matchAndRewrite(NumElementsOp op, + PatternRewriter &rewriter) const { + auto loc = op.getLoc(); + Value init = rewriter.create(loc, rewriter.getIndexAttr(1)); + ReduceOp reduce = rewriter.create(loc, op.shape(), init); + + // Generate reduce operator. + Block *body = reduce.getBody(); + OpBuilder b = OpBuilder::atBlockEnd(body); + Value product = + b.create(loc, body->getArgument(1), body->getArgument(2)); + b.create(loc, product); + + rewriter.replaceOp(op, reduce.result()); + return success(); +} + +namespace { +struct ShapeToShapeLowering + : public ShapeToShapeLoweringBase { + void runOnFunction() override; +}; +} // namespace + +void ShapeToShapeLowering::runOnFunction() { + OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addIllegalOp(); + if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); +} + +std::unique_ptr mlir::createShapeToShapeLowering() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: func @num_elements_to_reduce( +// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] { +func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size { + %num_elements = shape.num_elements %shape + return %num_elements : !shape.size +} +// CHECK: [[C1:%.*]] = shape.const_size 1 +// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) -> [[SIZE_TY]] +// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: [[SIZE_TY]], [[ACC:%.*]]: [[SIZE_TY]] +// CHECK: [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]] +// CHECK: shape.yield [[NEW_ACC]] : [[SIZE_TY]] +// CHECK: } +// CHECK: return [[NUM_ELEMENTS]] : [[SIZE_TY]] +