diff --git a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h --- a/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h +++ b/mlir/include/mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h @@ -53,8 +53,8 @@ /// This method registers a callback function that will be called to decompose /// a value of a certain type into 0, 1, or multiple values. - template ::template arg_t<2>> + template >::template arg_t<2>> void addDecomposeValueConversion(FnT &&callback) { decomposeValueConversions.emplace_back( wrapDecomposeValueConversionCallback(std::forward(callback))); diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir --- a/mlir/test/Transforms/decompose-call-graph-types.mlir +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -37,6 +37,29 @@ // ----- +// Test case: Type that needs to be recursively decomposed at different recursion depths. + +// CHECK-LABEL: func @mixed_recursive_decomposition( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> +// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]]) : (i1) -> tuple +// CHECK: %[[V2:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple +// CHECK: %[[V3:.*]] = "test.make_tuple"(%[[V2]]) : (tuple) -> tuple> +// CHECK: %[[V4:.*]] = "test.make_tuple"(%[[V0]], %[[V1]], %[[V3]]) : (tuple<>, tuple, tuple>) -> tuple, tuple, tuple>> +// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple, tuple, tuple>>) -> tuple<> +// CHECK: %[[V6:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 1 : i32} : (tuple, tuple, tuple>>) -> tuple +// CHECK: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple) -> i1 +// CHECK: %[[V8:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 2 : i32} : (tuple, tuple, tuple>>) -> tuple> +// CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple>) -> tuple +// CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple) -> i2 +// CHECK: return %[[V7]], %[[V10]] : i1, i2 +func.func @mixed_recursive_decomposition(%arg0: tuple, tuple, tuple>>) -> tuple, tuple, tuple>> { + return %arg0 : tuple, tuple, tuple>> +} + +// ----- + // Test case: Check decomposition of calls. // CHECK-LABEL: func private @callee(i1, i32) -> (i1, i32) @@ -89,6 +112,26 @@ // ----- +// Test case: Ensure decompositions are inserted properly around results of +// unconverted ops in the case of different nesting levels. + +// CHECK-LABEL: func @nested_unconverted_op_result( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple +// CHECK: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple) -> tuple> +// CHECK: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple>) -> tuple> +// CHECK: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple>) -> i1 +// CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple +// CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 +// CHECK: return %[[V3]], %[[V5]] : i1, i32 +func.func @nested_unconverted_op_result(%arg: tuple>) -> tuple> { + %0 = "test.op"(%arg) : (tuple>) -> (tuple>) + return %0 : tuple> +} + +// ----- + // Test case: Check mixed decomposed and non-decomposed args. // This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions. diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -16,6 +16,70 @@ using namespace mlir; namespace { +/// Creates a sequence of `test.get_tuple_element` ops for all elements of a +/// given tuple value. If some tuple elements are, in turn, tuples, the elements +/// of those are extracted recursively such that the returned values have the +/// same types as `resultTypes.getFlattenedTypes()`. +static LogicalResult buildDecomposeTuple(OpBuilder &builder, Location loc, + TupleType resultType, Value value, + SmallVectorImpl &values) { + for (unsigned i = 0, e = resultType.size(); i < e; ++i) { + Type elementType = resultType.getType(i); + Value element = builder.create( + loc, elementType, value, builder.getI32IntegerAttr(i)); + if (auto nestedTupleType = elementType.dyn_cast()) { + // Recurse if the current element is also a tuple. + if (failed(buildDecomposeTuple(builder, loc, nestedTupleType, element, + values))) + return failure(); + } else { + values.push_back(element); + } + } + return success(); +} + +/// Creates a `test.make_tuple` op out of the given inputs building a tuple of +/// type `resultType`. If that type is nested, each nested tuple is built +/// recursively with another `test.make_tuple` op. +static std::optional buildMakeTupleOp(OpBuilder &builder, + TupleType resultType, + ValueRange inputs, Location loc) { + // Build one value for each element at this nesting level. + SmallVector elements; + elements.reserve(resultType.getTypes().size()); + ValueRange::iterator inputIt = inputs.begin(); + for (Type elementType : resultType.getTypes()) { + if (auto nestedTupleType = elementType.dyn_cast()) { + // Determine how many input values are needed for the nested elements of + // the nested TupleType and advance inputIt by that number. + // TODO: We only need the *number* of nested types, not the types itself. + // Maybe it's worth adding a more efficient overload? + SmallVector nestedFlattenedTypes; + nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); + size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); + ValueRange nestedFlattenedelements(inputIt, + inputIt + numNestedFlattenedTypes); + inputIt += numNestedFlattenedTypes; + + // Recurse on the values for the nested TupleType. + std::optional res = buildMakeTupleOp(builder, nestedTupleType, + nestedFlattenedelements, loc); + if (!res.has_value()) + return {}; + + // The tuple constructed by the conversion is the element value. + elements.push_back(res.value()); + } else { + // Base case: take one input as is. + elements.push_back(*inputIt++); + } + } + + // Assemble the tuple from the elements. + return builder.create(loc, resultType, elements); +} + /// A pass for testing call graph type decomposition. /// /// This instantiates the patterns with a TypeConverter and ValueDecomposer @@ -39,7 +103,6 @@ auto *context = &getContext(); TypeConverter typeConverter; ConversionTarget target(*context); - ValueDecomposer decomposer; RewritePatternSet patterns(context); target.addLegalDialect(); @@ -59,27 +122,10 @@ tupleType.getFlattenedTypes(types); return success(); }); + typeConverter.addArgumentMaterialization(buildMakeTupleOp); - decomposer.addDecomposeValueConversion([](OpBuilder &builder, Location loc, - TupleType resultType, Value value, - SmallVectorImpl &values) { - for (unsigned i = 0, e = resultType.size(); i < e; ++i) { - Value res = builder.create( - loc, resultType.getType(i), value, builder.getI32IntegerAttr(i)); - values.push_back(res); - } - return success(); - }); - - typeConverter.addArgumentMaterialization( - [](OpBuilder &builder, TupleType resultType, ValueRange inputs, - Location loc) -> std::optional { - if (inputs.size() == 1) - return std::nullopt; - TupleType tuple = builder.getTupleType(inputs.getTypes()); - Value value = builder.create(loc, tuple, inputs); - return value; - }); + ValueDecomposer decomposer; + decomposer.addDecomposeValueConversion(buildDecomposeTuple); populateDecomposeCallGraphTypesPatterns(context, typeConverter, decomposer, patterns);