diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -197,6 +197,16 @@ } //===----------------------------------------------------------------------===// +// ShapeToStandard +//===----------------------------------------------------------------------===// + +def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> { + let summary = "Convert operations from the shape dialect into the standard " + "dialect"; + let constructor = "mlir::createConvertShapeToStandardPass()"; +} + +//===----------------------------------------------------------------------===// // StandardToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h @@ -0,0 +1,28 @@ +//===- ShapeToStandard.h - Conversion utils from shape to std dialect -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_ +#define MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +class MLIRContext; +class ModuleOp; +template +class OperationPass; + +void populateShapeToStandardConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx); + +std::unique_ptr> createConvertShapeToStandardPass(); + +} // namespace mlir + +#endif // MLIR_CONVERSION_SHAPETOSTANDARD_SHAPETOSTANDARD_H_ diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -25,6 +25,7 @@ #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h" #include "mlir/Conversion/SCFToGPU/SCFToGPUPass.h" #include "mlir/Conversion/SCFToStandard/SCFToStandard.h" +#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -10,6 +10,7 @@ add_subdirectory(LinalgToStandard) add_subdirectory(SCFToGPU) add_subdirectory(SCFToStandard) +add_subdirectory(ShapeToStandard) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) add_subdirectory(VectorToLLVM) diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt @@ -0,0 +1,20 @@ +add_mlir_conversion_library(MLIRShapeToStandard + ShapeToStandard.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShapeToStandard + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIREDSC + MLIRIR + MLIRShape + MLIRPass + MLIRSCF + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -0,0 +1,106 @@ +//===- LinalgToStandard.cpp - conversion from Linalg to Standard dialect --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h" + +#include "../PassDetail.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace { + +/// Conversion patterns. +class SizeToIndexOpConversion + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(shape::SizeToIndexOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + shape::SizeToIndexOpOperandAdaptor transformed(operands); + rewriter.replaceOp(op.getOperation(), transformed.arg()); + return success(); + } +}; + +class IndexToSizeOpConversion + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(shape::IndexToSizeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + shape::IndexToSizeOpOperandAdaptor transformed(operands); + rewriter.replaceOp(op.getOperation(), transformed.arg()); + return success(); + } +}; + +/// Type conversions. +class ShapeTypeConverter : public TypeConverter { +public: + using TypeConverter::convertType; + + ShapeTypeConverter(MLIRContext *ctx) { + // Add default pass-through conversion. + addConversion([&](Type type) { return type; }); + addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); }); + } +}; + +/// Conversion pass. +class ConvertShapeToStandardPass + : public ConvertShapeToStandardBase { + + void runOnOperation() override { + + // Setup type conversion. + MLIRContext &ctx = getContext(); + ShapeTypeConverter typeConverter(&ctx); + + // Setup target legality. + ConversionTarget target(ctx); + target.addLegalDialect(); + target.addLegalOp(); + target.addDynamicallyLegalOp([&](FuncOp op) { + return typeConverter.isSignatureLegal(op.getType()); + }); + + // Setup conversion patterns. + OwningRewritePatternList patterns; + populateShapeToStandardConversionPatterns(patterns, &ctx); + populateFuncOpTypeConversionPattern(patterns, &ctx, typeConverter); + + // Apply conversion. + auto module = getOperation(); + if (failed(applyFullConversion(module, target, patterns, &typeConverter))) + signalPassFailure(); + } +}; + +} // namespace + +void populateShapeToStandardConversionPatterns( + OwningRewritePatternList &patterns, MLIRContext *ctx) { + // clang-format off + patterns.insert< + IndexToSizeOpConversion, + SizeToIndexOpConversion>(ctx); + // clang-format on +} + +std::unique_ptr> createConvertShapeToStandardPass() { + return std::make_unique(); +} + +} // namespace mlir diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt --split-input-file --convert-shape-to-std --verify-diagnostics %s | FileCheck %s --dump-input-on-failure + +// Convert `size` to `index` type. +// CHECK-LABEL: @size_id +// CHECK-SAME: (%[[SIZE:.*]]: index) +func @size_id(%size : !shape.size) -> !shape.size { + // CHECK: return %[[SIZE]] : index + return %size : !shape.size +} + +// ----- + +// Lower `size_to_index` conversion to no-op. +// CHECK-LABEL: @size_to_index +// CHECK-SAME: (%[[SIZE:.*]]: index) -> index +func @size_to_index(%size : !shape.size) -> index { + // CHECK-NEXT: return %[[SIZE]] : index + %index = shape.size_to_index %size + return %index : index +} + +// ----- + +// Lower `index_to_size` conversion to no-op. +// CHECK-LABEL: @index_to_size +// CHECK-SAME: (%[[INDEX:.*]]: index) -> index +func @index_to_size(%index : index) -> !shape.size { + // CHECK-NEXT: return %[[INDEX]] : index + %size = shape.index_to_size %index + return %size : !shape.size +}