diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir --- a/mlir/test/Transforms/decompose-call-graph-types.mlir +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops" \ +// RUN: | FileCheck %s --check-prefix=CHECK-12N + // Test case: Most basic case of a 1:N decomposition, an identity function. // CHECK-LABEL: func @identity( @@ -9,6 +13,10 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @identity( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i32 func.func @identity(%arg0: tuple) -> tuple { return %arg0 : tuple } @@ -20,6 +28,9 @@ // CHECK-LABEL: func @identity_1_to_1_no_materializations( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 +// CHECK-12N-LABEL: func @identity_1_to_1_no_materializations( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { +// CHECK-12N: return %[[ARG0]] : i1 func.func @identity_1_to_1_no_materializations(%arg0: tuple) -> tuple { return %arg0 : tuple } @@ -31,6 +42,9 @@ // CHECK-LABEL: func @recursive_decomposition( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 +// CHECK-12N-LABEL: func @recursive_decomposition( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { +// CHECK-12N: return %[[ARG0]] : i1 func.func @recursive_decomposition(%arg0: tuple>>) -> tuple>> { return %arg0 : tuple>> } @@ -54,6 +68,10 @@ // CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple>) -> tuple // CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple) -> i2 // CHECK: return %[[V7]], %[[V10]] : i1, i2 +// CHECK-12N-LABEL: func @mixed_recursive_decomposition( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i2 func.func @mixed_recursive_decomposition(%arg0: tuple, tuple, tuple>>) -> tuple, tuple, tuple>> { return %arg0 : tuple, tuple, tuple>> } @@ -63,6 +81,7 @@ // Test case: Check decomposition of calls. // CHECK-LABEL: func private @callee(i1, i32) -> (i1, i32) +// CHECK-12N-LABEL: func private @callee(i1, i32) -> (i1, i32) func.func private @callee(tuple) -> tuple // CHECK-LABEL: func @caller( @@ -76,6 +95,11 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @caller( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32) +// CHECK-12N: return %[[V0]]#0, %[[V0]]#1 : i1, i32 func.func @caller(%arg0: tuple) -> tuple { %0 = call @callee(%arg0) : (tuple) -> tuple return %0 : tuple @@ -86,7 +110,12 @@ // Test case: Type that decomposes to nothing (that is, a 1:0 decomposition). // CHECK-LABEL: func private @callee() +// CHECK-12N-LABEL: func private @callee() func.func private @callee(tuple<>) -> tuple<> + +// CHECK-12N-LABEL: func @caller() { +// CHECK-12N: call @callee() : () -> () +// CHECK-12N: return // CHECK-LABEL: func @caller() { // CHECK: call @callee() : () -> () // CHECK: return @@ -105,6 +134,11 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @unconverted_op_result() -> (i1, i32) { +// CHECK-12N: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple +// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 +// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 +// CHECK-12N: return %[[RET0]], %[[RET1]] : i1, i32 func.func @unconverted_op_result() -> tuple { %0 = "test.source"() : () -> (tuple) return %0 : tuple @@ -125,6 +159,16 @@ // CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple // CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 // CHECK: return %[[V3]], %[[V5]] : i1, i32 +// CHECK-12N-LABEL: func @nested_unconverted_op_result( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple +// CHECK-12N: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple) -> tuple> +// CHECK-12N: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple>) -> tuple> +// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple>) -> i1 +// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple +// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 +// CHECK-12N: return %[[V3]], %[[V5]] : i1, i32 func.func @nested_unconverted_op_result(%arg: tuple>) -> tuple> { %0 = "test.op"(%arg) : (tuple>) -> (tuple>) return %0 : tuple> @@ -136,6 +180,7 @@ // This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions. // CHECK-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) +// CHECK-12N-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) func.func private @callee(tuple<>, i1, tuple, i3, tuple, i6) -> (tuple<>, i1, tuple, i3, tuple, i6) // CHECK-LABEL: func @caller( @@ -153,6 +198,15 @@ // CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 0 : i32} : (tuple) -> i4 // CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 1 : i32} : (tuple) -> i5 // CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 +// CHECK-12N-LABEL: func @caller( +// CHECK-12N-SAME: %[[I1:.*]]: i1, +// CHECK-12N-SAME: %[[I2:.*]]: i2, +// CHECK-12N-SAME: %[[I3:.*]]: i3, +// CHECK-12N-SAME: %[[I4:.*]]: i4, +// CHECK-12N-SAME: %[[I5:.*]]: i5, +// CHECK-12N-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) { +// CHECK-12N: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) +// CHECK-12N: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 func.func @caller(%arg0: tuple<>, %arg1: i1, %arg2: tuple, %arg3: i3, %arg4: tuple, %arg5: i6) -> (tuple<>, i1, tuple, i3, tuple, i6) { %0, %1, %2, %3, %4, %5 = call @callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tuple<>, i1, tuple, i3, tuple, i6) -> (tuple<>, i1, tuple, i3, tuple, i6) return %0, %1, %2, %3, %4, %5 : tuple<>, i1, tuple, i3, tuple, i6 diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt --- a/mlir/test/lib/Conversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(FuncToLLVM) +add_subdirectory(OneToNTypeConversion) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_library(MLIRTestOneToNTypeConversionPass + OneToNTypeConversion.cpp + OneToNTypeConversionFunc.cpp + TestOneToNTypeConversionPass.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRIR + MLIRTestDialect + MLIRTransformUtils + ) + +target_include_directories(MLIRTestOneToNTypeConversionPass + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test + ) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.h b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.h @@ -0,0 +1,248 @@ +//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utils for implementing (poor-man's) dialect conversion +// passes with 1:N type conversions. +// +// The main function first applies a set of RewritePatterns, which produce +// unrealized casts to convert the operands and results from and to the source +// types, and then replaces all newly added unrealized casts by user-provided +// materializations. For this to work, the main function requires a special +// TypeConverter and special RewritePatterns, respectively deriving from the +// provided classes, which extend their respective base classes for 1:N type +// converions. +// +// Note that this is much more simple-minded than the "real" dialect conversion, +// which checks for legality before applying patterns and does probably many +// other additional things. Ideally, some of the extensions here could be +// integrated there. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSION_H +#define TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSION_H + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +/// Extends `TypeConverter` with 1:N target materializations. Such +/// materializations have to provide the "reverse" of 1:N type conversions, +/// i.e., they need to materialize N values with target types into one value +/// with a source type (which isn't possible in the base class currently). +class OneToNTypeConverter : public TypeConverter { +public: + using OneToNMaterializationCallbackFn = + std::function>(OpBuilder &, TypeRange, + Value, Location)>; + + /// Creates the mapping of the given range of original types to target types + /// of the conversion and stores that mapping in the given (signature) + /// conversion. This function simply calls TypeConverter::convertSignatureArgs + /// and exists here with a different name to reflect the broader semantic. + LogicalResult computeTypeMapping(TypeRange types, + SignatureConversion &result) { + return convertSignatureArgs(types, result); + } + + /// Applies one of the user-provided 1:N target materializations (in LIFO + /// order). + std::optional> + materializeTargetConversion(OpBuilder &builder, Location loc, + TypeRange resultTypes, Value input) const; + + /// Adds a 1:N target materialization to the converter. Such materializations + /// build IR that converts N values with target types into 1 value of the + /// source type. + void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) { + oneToNTargetMaterializations.emplace_back(std::move(callback)); + } + +private: + SmallVector oneToNTargetMaterializations; +}; + +/// Stores a 1:N mapping of types and provides several useful accessors. This +/// class extends SignatureConversion, which already supports 1:N type mappings +/// but lacks some accessors into the mapping as well as access to the original +/// types. +class OneToNTypeMapping : public TypeConverter::SignatureConversion { +public: + OneToNTypeMapping(TypeRange originalTypes) + : TypeConverter::SignatureConversion(originalTypes.size()), + originalTypes(originalTypes) {} + + using TypeConverter::SignatureConversion::getConvertedTypes; + + /// Returns the list of types that corresponds to the original type at the + /// given index. + TypeRange getConvertedTypes(unsigned originalTypeNo) const; + + /// Returns the list of original types. + TypeRange getOriginalTypes() const { return originalTypes; } + + /// Returns the slice of converted values that corresponds the original value + /// at the given index. + ValueRange getConvertedValues(ValueRange convertedValues, + unsigned originalValueNo) const; + + /// Fills the given result vector with as many copies of the location of the + /// original value as the number of values it is converted to. + void convertLocation(Value originalValue, unsigned originalValueNo, + llvm::SmallVectorImpl &result) const; + + /// Fills the given result vector with as many copies of the lociation of each + /// original value as the number of values they are respectively converted to. + void convertLocations(ValueRange originalValues, + llvm::SmallVectorImpl &result) const; + + /// Returns true iff at least one type conversion maps an input type to a type + /// that is different from itself. + bool hasNonIdentityConversion() const; + +private: + llvm::SmallVector originalTypes; +}; + +/// Extends the basic RewritePattern with a type converter member and some +/// accessors to it. This is useful for patterns that are not ConversionPatterns +/// but still require access to a type converter. +class RewritePatternWithConverter : public mlir::RewritePattern { +public: + /// Construct a conversion pattern with the given converter, and forward the + /// remaining arguments to RewritePattern. + template + RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args) + : RewritePattern(std::forward(args)...), + typeConverter(&typeConverter) {} + + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + TypeConverter *getTypeConverter() const { return typeConverter; } + + template + std::enable_if_t::value, + ConverterTy *> + getTypeConverter() const { + return static_cast(typeConverter); + } + +protected: + /// A type converter for use by this pattern. + TypeConverter *const typeConverter; +}; + +/// Specialization of PatternRewriter that OneToNConversionPatterns use. The +/// class provides additional rewrite methods that are specific to 1:N type +/// conversions. +class OneToNPatternRewriter : public PatternRewriter { +public: + OneToNPatternRewriter(MLIRContext *context) : PatternRewriter(context) {} + + /// Replaces the results of the operation with the specified list of values + /// mapped back to the original types as specified in the provided type + /// mapping. That type mapping must match the replaced op (i.e., the original + /// types must be the same as the result types of the op) and the new values + /// (i.e., the converted types must be the same as the types of the new + /// values). + void replaceOp(Operation *op, ValueRange newValues, + const OneToNTypeMapping &resultMapping); + using PatternRewriter::replaceOp; + + /// Applies the given argument conversion to the given block. This consists of + /// replacing each original argument with N arguments as specified in the + /// argument conversion and inserting unrealized casts from the converted + /// values to the original types, which are then used in lieu of the original + /// ones. (Eventually, applyOneToNConversion replaces these casts with a + /// user-provided argument materialization if necessary.) This is similar to + /// ArgConverter::applySignatureConversion but (1) handles 1:N type conversion + /// properly and probably (2) doesn't handle many other edge cases. + Block *applySignatureConversion(Block *block, + OneToNTypeMapping &argumentConversion); +}; + +/// Base class for patterns with 1:N type conversions. Derived classes have to +/// overwrite the `matchAndRewrite`overlaod that provides additional information +/// for 1:N type conversions. +class OneToNConversionPattern : public RewritePatternWithConverter { +public: + using RewritePatternWithConverter::RewritePatternWithConverter; + + /// This function has to be implemented by base classes and is called from the + /// usual overloads. Like in normal DialectConversion, the function is + /// provided with the converted operands (which thus have target types). Since + /// 1:N conversion are supported, there is usually no 1:1 relationship between + /// the original and the converted operands. Instead, the provided + /// `operandMapping` can be used to access the converted operands that + /// correspond to a particular original operand. Similarly, `resultMapping` + /// is provided to help with assembling the result values (which may have 1:N + /// correspondences as well). The function is expted to return the converted + /// result values if the conversion succeeds and failure otherwise (in which + /// case any modifications of the IR have to be rolled back first). The + /// correspondance of original and converted result values needs to correspond + /// to `resultMapping`. For both the converted operands and results, the + /// calling overload inserts appropriate unrealized casts that produce and + /// consume them, and replaces the uses of the results with the results of the + /// casts. If the returned result values are the same as those of the original + /// op, an in-place update is assumed and the result values are left as is. + virtual LogicalResult + matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const = 0; + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final; +}; + +/// This class is a wrapper around OneToNConversionPattern for matching against +/// instances of a particular op class. +template +class OneToNOpConversionPattern : public OneToNConversionPattern { +public: + OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1, + ArrayRef generatedNames = {}) + : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(), + benefit, context, generatedNames) {} + + using OneToNConversionPattern::matchAndRewrite; + + /// Overload that derived classes have to override for their op type. + virtual LogicalResult + matchAndRewrite(SourceOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const = 0; + + LogicalResult + matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const final { + return matchAndRewrite(cast(op), rewriter, operandMapping, + resultMapping, convertedOperands); + } +}; + +/// Main function that 1:N conversion passes should call. The patterns are +/// expected to insert unrealized casts to maintain the types of operands and +/// results, which is done automatically if the derive from +/// OneToNConversionPattern. The function replaces those that do not fold away +/// until the end of pattern application with user-provided materializations +/// from the type converter, so those have to be provided if conversions from +/// source to target types are expected to remain. +LogicalResult applyOneToNConversion(Operation *op, + OneToNTypeConverter &typeConverter, + const FrozenRewritePatternSet &patterns); + +} // namespace mlir + +#endif // TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSION_H diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.cpp @@ -0,0 +1,392 @@ +//===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversion.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" + +using namespace llvm; +using namespace mlir; + +std::optional> +OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder, + Location loc, + TypeRange resultTypes, + Value input) const { + for (const OneToNMaterializationCallbackFn &fn : + llvm::reverse(oneToNTargetMaterializations)) { + if (std::optional> result = + fn(builder, resultTypes, input, loc)) + return *result; + } + return {}; +} + +TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const { + TypeRange convertedTypes = getConvertedTypes(); + if (auto mapping = getInputMapping(originalTypeNo)) + return convertedTypes.slice(mapping->inputNo, mapping->size); + return {}; +} + +ValueRange +OneToNTypeMapping::getConvertedValues(ValueRange convertedValues, + unsigned originalValueNo) const { + if (auto mapping = getInputMapping(originalValueNo)) + return convertedValues.slice(mapping->inputNo, mapping->size); + return {}; +} + +void OneToNTypeMapping::convertLocation( + Value originalValue, unsigned originalValueNo, + llvm::SmallVectorImpl &result) const { + if (auto mapping = getInputMapping(originalValueNo)) + result.append(mapping->size, originalValue.getLoc()); +} + +void OneToNTypeMapping::convertLocations( + ValueRange originalValues, llvm::SmallVectorImpl &result) const { + assert(originalValues.size() == getOriginalTypes().size()); + for (auto &[i, value] : llvm::enumerate(originalValues)) + convertLocation(value, i, result); +} + +static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) { + return convertedTypes.size() == 1 && convertedTypes[0] == originalType; +} + +bool OneToNTypeMapping::hasNonIdentityConversion() const { + // XXX: I think that the original types and the converted types are the same + // iff there was no non-identity type conversion. If that is true, the + // patterns could actually test whether there is anything useful to do + // without having access to the signature conversion. + for (auto [i, originalType] : llvm::enumerate(originalTypes)) { + TypeRange types = getConvertedTypes(i); + if (!isIdentityConversion(originalType, types)) { + assert(TypeRange(originalTypes) != getConvertedTypes()); + return true; + } + } + assert(TypeRange(originalTypes) == getConvertedTypes()); + return false; +} + +enum class CastKind { + // Casts block arguments in the target type back to the source type. (If + // necessary, this cast becomes an argument materialization.) + Argument, + + // Casts other values in the target type back to the source type. (If + // necessary, this cast becomes a source materialization.) + Source, + + // Casts values in the source type to the target type. (If necessary, this + // cast becomes a target materialization.) + Target +}; + +/// Mapping of enum values to string values. +static const std::unordered_map castKindNames = { + {CastKind::Argument, "argument"}, + {CastKind::Source, "source"}, + {CastKind::Target, "target"}}; + +/// Attribute name that is used to annotate inserted unrealized casts with their +/// kind (source, argument, or target). +static const char *const castKindAttrName = + "__one-to-n-type-conversion_cast-kind__"; + +/// Builds an UnrealizedConversionCastOp from the given inputs to the given +/// result types. Returns the result values of the cast. +static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes, + ValueRange inputs, CastKind kind) { + // Create cast. + Location loc = builder.getUnknownLoc(); + if (!inputs.empty()) + loc = inputs.front().getLoc(); + auto castOp = + builder.create(loc, resultTypes, inputs); + + // Store cast kind as attribute. + auto kindAttr = StringAttr::get(builder.getContext(), castKindNames.at(kind)); + castOp->setAttr(castKindAttrName, kindAttr); + + return castOp->getResults(); +} + +/// Builds one UnrealizedConversionCastOp for each of the given original values +/// using the respective target types given in the provided conversion mapping +/// and returns the results of these casts. If the conversion mapping of a value +/// maps a type to itself (i.e., is an identity conversion), then no cast is +/// inserted and the original value is returned instead. +/// Note that these unrealized are different from target materializations in +/// that they are *always* inserted, even if they immediately fold away, such +/// that patterns always see valid intermediate IR, whereas materializations are +/// only used in the places where the unrealized casts *don't* fold away. +static SmallVector +buildUnrealizedForwardCasts(ValueRange originalValues, + OneToNTypeMapping &conversion, + RewriterBase &rewriter, CastKind kind) { + + // Convert each operand one by one. + SmallVector convertedValues; + convertedValues.reserve(conversion.getConvertedTypes().size()); + for (auto [idx, originalValue] : llvm::enumerate(originalValues)) { + TypeRange convertedTypes = conversion.getConvertedTypes(idx); + + // Identity conversion: keep operand as is. + if (isIdentityConversion(originalValue.getType(), convertedTypes)) { + convertedValues.push_back(originalValue); + continue; + } + + // Non-identity conversion: materialize target types. + ValueRange castResult = + buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind); + convertedValues.append(castResult.begin(), castResult.end()); + } + + return convertedValues; +} + +/// Builds one UnrealizedConversionCastOp for each sequence of the given +/// original values to one value of the type they originated from, i.e., a +/// "reverse" conversion from N converted values back to one value of the +/// original type, using the given (forward) type conversion. If a given value +/// was mapped to a value of the same type (i.e., the conversion in the mapping +/// is an identity conversion), then the "converted" value is returned without +/// cast. +/// Note that these unrealized are different from source materializations in +/// that they are *always* inserted, even if they immediately fold away, such +/// that patterns always see valid intermediate IR, whereas materializations are +/// only used in the places where the unrealized casts *don't* fold away. +static SmallVector +buildUnrealizedBackwardsCasts(ValueRange convertedValues, + const OneToNTypeMapping &typeConversion, + RewriterBase &rewriter) { + assert(typeConversion.getConvertedTypes() == convertedValues.getTypes()); + + // Create unrealized cast op for each converted result of the op. + SmallVector recastValues; + TypeRange originalTypes = typeConversion.getOriginalTypes(); + recastValues.reserve(originalTypes.size()); + auto convertedValueIt = convertedValues.begin(); + for (auto [idx, originalType] : llvm::enumerate(originalTypes)) { + TypeRange convertedTypes = typeConversion.getConvertedTypes(idx); + size_t numConvertedValues = convertedTypes.size(); + if (isIdentityConversion(originalType, convertedTypes)) { + // Identity conversion: take result as is. + recastValues.push_back(*convertedValueIt); + } else { + // Non-identity conversion: cast back to source type. + ValueRange recastValue = buildUnrealizedCast( + rewriter, originalType, + ValueRange{convertedValueIt, convertedValueIt + numConvertedValues}, + CastKind::Source); + assert(recastValue.size() == 1); + recastValues.push_back(recastValue.front()); + } + convertedValueIt += numConvertedValues; + } + + return recastValues; +} + +void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues, + const OneToNTypeMapping &resultMapping) { + // Create a cast back to the original types and replace the results of the + // original op with those. + assert(newValues.size() == resultMapping.getConvertedTypes().size()); + assert(op->getResultTypes() == resultMapping.getOriginalTypes()); + SmallVector castResults = + buildUnrealizedBackwardsCasts(newValues, resultMapping, *this); + replaceOp(op, castResults); +} + +Block *OneToNPatternRewriter::applySignatureConversion( + Block *block, OneToNTypeMapping &argumentConversion) { + // Split the block at the beginning to get a new block to use for the + // updated signature. + SmallVector locs; + argumentConversion.convertLocations(block->getArguments(), locs); + Block *newBlock = + createBlock(block, argumentConversion.getConvertedTypes(), locs); + replaceAllUsesWith(block, newBlock); + + // Create necessary casts in new block. + SmallVector castResults; + for (auto [i, arg] : llvm::enumerate(block->getArguments())) { + TypeRange convertedTypes = argumentConversion.getConvertedTypes(i); + ValueRange newArgs = + argumentConversion.getConvertedValues(newBlock->getArguments(), i); + if (isIdentityConversion(arg.getType(), convertedTypes)) { + // Identity conversion: take argument as is. + assert(newArgs.size() == 1); + castResults.push_back(newArgs.front()); + } else { + // Non-identity conversion: cast the converted arguments to the original + // type. + PatternRewriter::InsertionGuard g(*this); + setInsertionPointToStart(newBlock); + ValueRange castResult = buildUnrealizedCast(*this, arg.getType(), newArgs, + CastKind::Argument); + assert(castResult.size() == 1); + castResults.push_back(castResult.front()); + } + } + + // Merge old block into new block such that we only have the latter with the + // new signature. + mergeBlocks(block, newBlock, castResults); + + return newBlock; +} + +LogicalResult +OneToNConversionPattern::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + auto *typeConverter = getTypeConverter(); + + // Construct conversion mapping for results. + Operation::result_type_range originalResultTypes = op->getResultTypes(); + OneToNTypeMapping resultMapping(originalResultTypes); + if (failed(typeConverter->computeTypeMapping(originalResultTypes, + resultMapping))) + return failure(); + + // Construct conversion mapping for operands. + Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); + OneToNTypeMapping operandMapping(originalOperandTypes); + if (failed(typeConverter->computeTypeMapping(originalOperandTypes, + operandMapping))) + return failure(); + + // Cast operands to target types. + SmallVector convertedOperands = buildUnrealizedForwardCasts( + op->getOperands(), operandMapping, rewriter, CastKind::Target); + + // Create a OneToNPatternRewriter for the pattern, which provides additional + // functionality. + // TODO(ingomueller): I guess it would be better to use only one rewriter + // throughout the whole pass, but that would require to + // drive the pattern application ourselves, which is a lot + // of additional boilerplate code. This seems to work fine, + // so I leave it like this for the time being. + OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext()); + oneToNPatternRewriter.restoreInsertionPoint(rewriter.saveInsertionPoint()); + oneToNPatternRewriter.setListener(rewriter.getListener()); + + // Apply actual pattern. + if (failed(matchAndRewrite(op, oneToNPatternRewriter, operandMapping, + resultMapping, convertedOperands))) + return failure(); + + return success(); +} + +namespace mlir { + +// This function applies the provided patterns using +// applyPatternsAndFoldGreedily and then replaces all newly inserted +// UnrealizedConversionCastOps that haven't folded away. ("Backward" casts from +// target to source types inserted by a OneToNConversionPattern normally fold +// away with the "forward" casts from source to target types inserted by the +// next pattern.) To understand which casts are "newly inserted", all casts +// inserted by this pass are annotated with a string attribute that also +// documents which kind of the cast (source, argument, or target). +LogicalResult applyOneToNConversion(Operation *op, + OneToNTypeConverter &typeConverter, + const FrozenRewritePatternSet &patterns) { +#ifndef NDEBUG + // Remember existing unrealized casts. This data structure is only used in + // asserts; building it only for that purpose may be an overkill. + SmallSet existingCasts; + op->walk([&](UnrealizedConversionCastOp castOp) { + assert(!castOp->hasAttr(castKindAttrName)); + existingCasts.insert(castOp); + }); +#endif // NDEBUG + + // Apply provided conversion patterns. + if (failed(applyPatternsAndFoldGreedily(op, patterns))) + return failure(); + + // Find all unrealized casts inserted by the pass that haven't folded away. + SmallVector worklist; + op->walk([&](UnrealizedConversionCastOp castOp) { + if (castOp->hasAttr(castKindAttrName)) { + assert(!existingCasts.contains(castOp)); + worklist.push_back(castOp); + } + }); + + // Replace new casts with user materializations. + IRRewriter rewriter(op->getContext()); + for (UnrealizedConversionCastOp castOp : worklist) { + TypeRange resultTypes = castOp->getResultTypes(); + ValueRange operands = castOp->getOperands(); + StringRef castKind = + castOp->getAttrOfType(castKindAttrName).getValue(); + rewriter.setInsertionPoint(castOp); + +#ifndef NDEBUG + // Determine whether operands or results are already legal to test some + // assumptions for the different kind of materializations. These properties + // are only used it asserts and it may be overkill to compute them. + bool areOperandTypesLegal = llvm::all_of( + operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); }); + bool areResultsTypesLegal = llvm::all_of( + resultTypes, [&](Type t) { return typeConverter.isLegal(t); }); +#endif // NDEBUG + + // Add materialization and remember materialized results. + SmallVector materializedResults; + if (castKind == castKindNames.at(CastKind::Target)) { + // Target materialization. + assert(!areOperandTypesLegal && areResultsTypesLegal && + operands.size() == 1 && "found unexpected target cast"); + std::optional> maybeResults = + typeConverter.materializeTargetConversion( + rewriter, castOp->getLoc(), resultTypes, operands.front()); + if (!maybeResults) + return failure(); + materializedResults = maybeResults.value(); + } else { + // Source and argument materializations. + assert(areOperandTypesLegal && !areResultsTypesLegal && + resultTypes.size() == 1 && "found unexpected cast"); + std::optional maybeResult; + if (castKind == castKindNames.at(CastKind::Source)) { + // Source materialization. + maybeResult = typeConverter.materializeSourceConversion( + rewriter, castOp->getLoc(), resultTypes.front(), + castOp.getOperands()); + } else { + // Argument materialization. + assert(castKind == castKindNames.at(CastKind::Argument) && + "unexpected value of cast kind attribute"); + assert(llvm::all_of(operands, + [&](Value v) { return v.isa(); })); + maybeResult = typeConverter.materializeArgumentConversion( + rewriter, castOp->getLoc(), resultTypes.front(), + castOp.getOperands()); + } + if (!maybeResult.has_value() || !maybeResult.value()) + return failure(); + materializedResults = {maybeResult.value()}; + } + + // Replace the cast with the result of the materialization. + rewriter.replaceOp(castOp, materializedResults); + } + + return success(); +} + +} // namespace mlir diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.h b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.h @@ -0,0 +1,26 @@ +//===- OneToNTypeConversionFunc.h - 1:N type conversion for Func-*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONFUNC_H +#define TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONFUNC_H + +namespace mlir { +class TypeConverter; +class RewritePatternSet; +} // namespace mlir + +namespace mlir { + +// Populates the provided pattern set with patterns that do 1:N type conversions +// on func ops. This is intended to be used with applyOneToNConversion. +void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONFUNC_H diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.cpp @@ -0,0 +1,131 @@ +//===-- OneToNTypeConversionFunc.cpp - Func 1:N type conversion -*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The patterns in this file are heavily inspired (and copied from) +// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the +// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N +// type conversions. +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversionFunc.h" + +#include "OneToNTypeConversion.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +using namespace mlir; +using namespace mlir::func; + +class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const override { + Location loc = op->getLoc(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!operandMapping.hasNonIdentityConversion() && + !resultMapping.hasNonIdentityConversion()) + return failure(); + + // Create new CallOp. + auto newOp = rewriter.create(loc, resultMapping.getConvertedTypes(), + convertedOperands); + newOp->setAttrs(op->getAttrs()); + + rewriter.replaceOp(op, SmallVector(newOp->getResults()), + resultMapping); + return success(); + } +}; + +class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite( + FuncOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping & /*operandMapping*/, + const OneToNTypeMapping & /*resultMapping*/, + const SmallVector & /*convertedOperands*/) const override { + auto *typeConverter = getTypeConverter(); + + // Construct mapping for function arguments. + OneToNTypeMapping argumentMapping(op.getArgumentTypes()); + if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(), + argumentMapping))) + return failure(); + + // Construct mapping for function results. + OneToNTypeMapping funcResultMapping(op.getResultTypes()); + if (failed(typeConverter->computeTypeMapping(op.getResultTypes(), + funcResultMapping))) + return failure(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!argumentMapping.hasNonIdentityConversion() && + !funcResultMapping.hasNonIdentityConversion()) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + argumentMapping.getConvertedTypes(), + funcResultMapping.getConvertedTypes()); + rewriter.updateRootInPlace(op, [&] { op.setType(newType); }); + + // Update block signatures. + if (!op.isExternal()) { + Region *region = &op.getBody(); + Block *block = ®ion->front(); + rewriter.applySignatureConversion(block, argumentMapping); + } + + return success(); + } +}; + +class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping & /*resultMapping*/, + const SmallVector &convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandMapping.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return success(); + } +}; + +namespace mlir { + +void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertTypesInFuncCallOp, + ConvertTypesInFuncFuncOp, + ConvertTypesInFuncReturnOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +} // namespace mlir diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -0,0 +1,239 @@ +//===- TestOneToNTypeConversion.cpp - Test 1:N type conversion utils ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversion.h" +#include "OneToNTypeConversionFunc.h" +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Test pass that exercises the (poor-man's) 1:N type conversion mechanisms +/// from this folder by converting built-in tuples to the elements they consist +/// of as well as some dummy ops operating on these tuples. +struct TestOneToNTypeConversionPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass) + + TestOneToNTypeConversionPass() = default; + TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { + return "test-one-to-n-type-conversion"; + } + + StringRef getDescription() const final { + return "Test pass for 1:N type conversion"; + } + + Option convertFuncOps{*this, "convert-func-ops", + llvm::cl::desc("Enable conversion on func ops"), + llvm::cl::init(false)}; + + Option convertTupleOps{*this, "convert-tuple-ops", + llvm::cl::desc("Enable conversion on tuple ops"), + llvm::cl::init(false)}; + + void runOnOperation() override; +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestOneToNTypeConversionPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir + +/// Test pattern on for the `make_tuple` op from the test dialect that converts +/// this kind of op into it's "decomposed" form, i.e., the elements of the tuple +/// that is being produced by `test.make_tuple`, which are really just the +/// operands of this op. +class ConvertMakeTupleOp + : public OneToNOpConversionPattern<::test::MakeTupleOp> { +public: + using OneToNOpConversionPattern< + ::test::MakeTupleOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(::test::MakeTupleOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const override { + // Simply replace the current op with the converted operands. + rewriter.replaceOp(op, convertedOperands, resultMapping); + return success(); + } +}; + +/// Test pattern on for the `get_tuple_element` op from the test dialect that +/// converts this kind of op into it's "decomposed" form, i.e., instead of +/// "physically" extracting one element from the tuple, we forward the one +/// element of the decomposed form that is being extracted (or the several +/// elements in case that element is a nested tuple). +class ConvertGetTupleElementOp + : public OneToNOpConversionPattern<::test::GetTupleElementOp> { +public: + using OneToNOpConversionPattern< + ::test::GetTupleElementOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(::test::GetTupleElementOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const override { + // Construct mapping for tuple element types. + auto stateType = op->getOperand(0).getType().cast(); + TypeRange originalElementTypes = stateType.getTypes(); + OneToNTypeMapping elementMapping(originalElementTypes); + if (failed(typeConverter->convertSignatureArgs(originalElementTypes, + elementMapping))) + return failure(); + + // Compute converted operands corresponding to original input tuple. + ValueRange convertedTuple = + operandMapping.getConvertedValues(convertedOperands, 0); + + // Got those converted operands that correspond to the index-th element of + // the original input tuple. + size_t index = op.getIndex(); + ValueRange extractedElement = + elementMapping.getConvertedValues(convertedTuple, index); + + rewriter.replaceOp(op, extractedElement, resultMapping); + + return success(); + } +}; + +static void populateDecomposeTuplesTestPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertMakeTupleOp, + ConvertGetTupleElementOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +/// Creates a sequence of `test.get_tuple_element` ops for all elements of a +/// given tuple value. If some tuple elements are, in turn, tuples, the elements +/// of those are extracted recursively such that the returned values have the +/// same types as `resultTypes.getFlattenedTypes()`. +/// +/// This function has been copied (with small adaptions) from +/// TestDecomposeCallGraphTypes.cpp. +static std::optional> +buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) { + TupleType inputType = input.getType().dyn_cast(); + if (!inputType) + return {}; + + SmallVector values; + for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) { + Value element = builder.create<::test::GetTupleElementOp>( + loc, elementType, input, builder.getI32IntegerAttr(idx)); + if (auto nestedTupleType = elementType.dyn_cast()) { + // Recurse if the current element is also a tuple. + SmallVector flatRecursiveTypes; + nestedTupleType.getFlattenedTypes(flatRecursiveTypes); + std::optional> resursiveValues = + buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc); + if (!resursiveValues.has_value()) + return {}; + values.append(resursiveValues.value()); + } else { + values.push_back(element); + } + } + return values; +} + +/// Creates a `test.make_tuple` op out of the given inputs building a tuple of +/// type `resultType`. If that type is nested, each nested tuple is built +/// recursively with another `test.make_tuple` op. +/// +/// This function has been copied (with small adaptions) from +/// TestDecomposeCallGraphTypes.cpp. +static std::optional buildMakeTupleOp(OpBuilder &builder, + TupleType resultType, + ValueRange inputs, Location loc) { + // Build one value for each element at this nesting level. + SmallVector elements; + elements.reserve(resultType.getTypes().size()); + ValueRange::iterator inputIt = inputs.begin(); + for (Type elementType : resultType.getTypes()) { + if (auto nestedTupleType = elementType.dyn_cast()) { + // Determine how many input values are needed for the nested elements of + // the nested TupleType and advance inputIt by that number. + // TODO: We only need the *number* of nested types, not the types itself. + // Maybe it's worth adding a more efficient overload? + SmallVector nestedFlattenedTypes; + nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); + size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); + ValueRange nestedFlattenedelements(inputIt, + inputIt + numNestedFlattenedTypes); + inputIt += numNestedFlattenedTypes; + + // Recurse on the values for the nested TupleType. + std::optional res = buildMakeTupleOp(builder, nestedTupleType, + nestedFlattenedelements, loc); + if (!res.has_value()) + return {}; + + // The tuple constructed by the conversion is the element value. + elements.push_back(res.value()); + } else { + // Base case: take one input as is. + elements.push_back(*inputIt++); + } + } + + // Assemble the tuple from the elements. + return builder.create<::test::MakeTupleOp>(loc, resultType, elements); +} + +void TestOneToNTypeConversionPass::runOnOperation() { + ModuleOp module = getOperation(); + auto *context = &getContext(); + + // Assemble type converter. + OneToNTypeConverter typeConverter; + + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](TupleType tupleType, SmallVectorImpl &types) { + tupleType.getFlattenedTypes(types); + return success(); + }); + + typeConverter.addArgumentMaterialization(buildMakeTupleOp); + typeConverter.addSourceMaterialization(buildMakeTupleOp); + typeConverter.addTargetMaterialization(buildGetTupleElementOps); + + // Assemble patterns. + RewritePatternSet patterns(context); + if (convertTupleOps) + populateDecomposeTuplesTestPatterns(typeConverter, patterns); + if (convertFuncOps) + populateFuncTypeConversionPatterns(typeConverter, patterns); + + // Run conversion. + if (failed(applyOneToNConversion(module, typeConverter, std::move(patterns)))) + return signalPassFailure(); +} diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -33,6 +33,7 @@ MLIRTestDialect MLIRTestDynDialect MLIRTestIR + MLIRTestOneToNTypeConversionPass MLIRTestPass MLIRTestPDLL MLIRTestReducer diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -106,6 +106,7 @@ void registerTestMathPolynomialApproximationPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); +void registerTestOneToNTypeConversionPass(); void registerTestOpaqueLoc(); void registerTestPadFusion(); void registerTestPDLByteCodePass(); @@ -215,6 +216,7 @@ mlir::test::registerTestMathPolynomialApproximationPass(); mlir::test::registerTestMemRefDependenceCheck(); mlir::test::registerTestMemRefStrideCalculation(); + mlir::test::registerTestOneToNTypeConversionPass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestPadFusion(); mlir::test::registerTestPDLByteCodePass();