diff --git a/mlir/include/mlir/Conversion/StandardToStandard/StandardToStandard.h b/mlir/include/mlir/Conversion/StandardToStandard/StandardToStandard.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/StandardToStandard/StandardToStandard.h @@ -0,0 +1,31 @@ +//===- DialectConversion.h - MLIR dialect conversion pass -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This files contains patterns for lowering within the Standard dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_STANDARDTOSTANDARD_STANDARDTOSTANDARD_H_ +#define MLIR_CONVERSION_STANDARDTOSTANDARD_STANDARDTOSTANDARD_H_ + +namespace mlir { + +// Forward declarations. +class MLIRContext; +class OwningRewritePatternList; +class TypeConverter; + +/// Add a pattern to the given pattern list to convert the result types of a +/// CallOp with the given type converter. +void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns, + MLIRContext *ctx, + TypeConverter &converter); + +} // end namespace mlir + +#endif // MLIR_CONVERSION_STANDARDTOSTANDARD_STANDARDTOSTANDARD_H_ diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -10,5 +10,6 @@ add_subdirectory(LoopToStandard) add_subdirectory(StandardToLLVM) add_subdirectory(StandardToSPIRV) +add_subdirectory(StandardToStandard) add_subdirectory(VectorToLLVM) add_subdirectory(VectorToLoops) diff --git a/mlir/lib/Conversion/StandardToStandard/CMakeLists.txt b/mlir/lib/Conversion/StandardToStandard/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/StandardToStandard/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_conversion_library(MLIRStandardToStandard + DialectConversion.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/StandardToStandard + + DEPENDS + intrinsics_gen + ) +target_link_libraries(MLIRStandardToStandard + PUBLIC + MLIRIR + MLIRPass + MLIRStandardOps + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/StandardToStandard/DialectConversion.cpp b/mlir/lib/Conversion/StandardToStandard/DialectConversion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/StandardToStandard/DialectConversion.cpp @@ -0,0 +1,51 @@ +//===- DialectConversion.cpp - Test dialect pattern driver ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/StandardToStandard/StandardToStandard.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +/// Create a default conversion pattern that rewrites the result type of a +/// CallOp. +namespace { + +struct CallOpSignatureConversion : public OpConversionPattern { + CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) + : OpConversionPattern(ctx), converter(converter) {} + + /// Hook for derived classes to implement combined matching and rewriting. + PatternMatchResult + matchAndRewrite(CallOp callOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + FunctionType type = callOp.getCalleeType(); + + // Convert the original function results. + SmallVector convertedResults; + if (failed(converter.convertTypes(type.getResults(), convertedResults))) + return matchFailure(); + + // Substitute with the new result types from the corresponding FuncType + // conversion. + rewriter.replaceOpWithNewOp(callOp, callOp.callee(), + convertedResults, operands); + return matchSuccess(); + } + + /// The type converter to use when rewriting the signature. + TypeConverter &converter; +}; + +} // end anonymous namespace + +void mlir::populateCallOpTypeConversionPattern( + OwningRewritePatternList &patterns, MLIRContext *ctx, + TypeConverter &converter) { + patterns.insert(ctx, converter); +} diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -23,6 +23,13 @@ "test.invalid"(%arg0) : (i64) -> () } +// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64) +func @remap_call_1_to_1(%arg0: i64) { + // CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> () + call @remap_input_1_to_1(%arg0) : (i64) -> () + return +} + // CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16) func @remap_input_1_to_N(%arg0: f32) -> f32 { // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> () diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "mlir/Conversion/StandardToStandard/StandardToStandard.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -381,6 +382,8 @@ patterns.insert(&getContext(), converter); mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), converter); + mlir::populateCallOpTypeConversionPattern(patterns, &getContext(), + converter); // Define the conversion target used for the test. ConversionTarget target(getContext()); diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -45,6 +45,7 @@ MLIRLoopOps MLIRGPU MLIRPass + MLIRStandardToStandard MLIRTestDialect MLIRTransformUtils MLIRVectorToLoops