diff --git a/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir @@ -0,0 +1,140 @@ +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-tuple-ops" \ +// RUN: | FileCheck --check-prefix=CHECK-TUP %s + +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops" \ +// RUN: | FileCheck --check-prefix=CHECK-FUNC %s + +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-tuple-ops" \ +// RUN: | FileCheck --check-prefix=CHECK-BOTH %s + +// Test case: Matching nested packs and unpacks just disappear. + +// CHECK-TUP-LABEL: func.func @pack_unpack( +// CHECK-TUP-SAME: %[[ARG0:.*]]: i1, +// CHECK-TUP-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-TUP-NEXT: return %[[ARG0]], %[[ARG1]] : i1, i2 +func.func @pack_unpack(%arg0: i1, %arg1: i2) -> (i1, i2) { + %0 = "test.make_tuple"() : () -> tuple<> + %1 = "test.make_tuple"(%arg1) : (i2) -> tuple + %2 = "test.make_tuple"(%1) : (tuple) -> tuple> + %3 = "test.make_tuple"(%0, %arg0, %2) : (tuple<>, i1, tuple>) -> tuple, i1, tuple>> + %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> + %5 = "test.get_tuple_element"(%3) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 + %6 = "test.get_tuple_element"(%3) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> + %7 = "test.get_tuple_element"(%6) {index = 0 : i32} : (tuple>) -> tuple + %8 = "test.get_tuple_element"(%7) {index = 0 : i32} : (tuple) -> i2 + return %5, %8 : i1, i2 +} + +// ----- + +// Test case: Appropriate materilizations are created depending on which ops are +// converted. + +// If we only convert the tuple ops, the original `get_tuple_element` ops will +// disappear but one target materialization will be inserted from the +// unconverted function arguments to each of the return values (which have +// redundancy among themselves). +// +// CHECK-TUP-LABEL: func.func @materializations( +// CHECK-TUP-SAME: %[[ARG0:.*]]: tuple, i1, tuple>>) -> (i1, i2) { +// CHECK-TUP-NEXT: %0 = "test.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> +// CHECK-TUP-NEXT: %1 = "test.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 +// CHECK-TUP-NEXT: %2 = "test.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> +// CHECK-TUP-NEXT: %3 = "test.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tuple +// CHECK-TUP-NEXT: %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple) -> i2 +// CHECK-TUP-NEXT: %5 = "test.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> +// CHECK-TUP-NEXT: %6 = "test.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 +// CHECK-TUP-NEXT: %7 = "test.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> +// CHECK-TUP-NEXT: %8 = "test.get_tuple_element"(%7) {index = 0 : i32} : (tuple>) -> tuple +// CHECK-TUP-NEXT: %9 = "test.get_tuple_element"(%8) {index = 0 : i32} : (tuple) -> i2 +// CHECK-TUP-NEXT: return %1, %9 : i1, i2 + +// If we only convert the func ops, argument materializations are created from +// the converted tuple elements back to the tuples that the `get_tuple_element` +// ops expect. +// +// CHECK-FUNC-LABEL: func.func @materializations( +// CHECK-FUNC-SAME: %[[ARG0:.*]]: i1, +// CHECK-FUNC-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-FUNC-NEXT: %0 = "test.make_tuple"() : () -> tuple<> +// CHECK-FUNC-NEXT: %1 = "test.make_tuple"(%arg1) : (i2) -> tuple +// CHECK-FUNC-NEXT: %2 = "test.make_tuple"(%1) : (tuple) -> tuple> +// CHECK-FUNC-NEXT: %3 = "test.make_tuple"(%0, %arg0, %2) : (tuple<>, i1, tuple>) -> tuple, i1, tuple>> +// CHECK-FUNC-NEXT: %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> +// CHECK-FUNC-NEXT: %5 = "test.get_tuple_element"(%3) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 +// CHECK-FUNC-NEXT: %6 = "test.get_tuple_element"(%3) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> +// CHECK-FUNC-NEXT: %7 = "test.get_tuple_element"(%6) {index = 0 : i32} : (tuple>) -> tuple +// CHECK-FUNC-NEXT: %8 = "test.get_tuple_element"(%7) {index = 0 : i32} : (tuple) -> i2 +// CHECK-FUNC-NEXT: return %5, %8 : i1, i2 + +// If we convert both tuple and func ops, basically everything disappears. +// +// CHECK-BOTH-LABEL: func.func @materializations( +// CHECK-BOTH-SAME: %[[ARG0:.*]]: i1, +// CHECK-BOTH-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-BOTH-NEXT: return %arg0, %arg1 : i1, i2 + +func.func @materializations(%arg0: tuple, i1, tuple>>) -> (i1, i2) { + %0 = "test.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> + %1 = "test.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 + %2 = "test.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> + %3 = "test.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tuple + %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple) -> i2 + return %1, %4 : i1, i2 +} +// ----- + +// Test case: Appropriate materilizations are created depending on which ops are +// converted. + +// If we only convert the tuple ops, the original `make_tuple` ops will +// disappear but a source materialization will be inserted from the result of +// conversion (which, for `make_tuple`, are the original ops that get forwarded) +// to the operands of the unconverted op with the original type (i.e., +// `return`). + +// CHECK-TUP-LABEL: func.func @materializations( +// CHECK-TUP-SAME: %[[ARG0:.*]]: i1, +// CHECK-TUP-SAME: %[[ARG1:.*]]: i2) -> tuple, i1, tuple>> { +// CHECK-TUP-NEXT: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> +// CHECK-TUP-NEXT: %[[V1:.*]] = "test.make_tuple"(%[[ARG1]]) : (i2) -> tuple +// CHECK-TUP-NEXT: %[[V2:.*]] = "test.make_tuple"(%[[V1]]) : (tuple) -> tuple> +// CHECK-TUP-NEXT: %[[V3:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]], %[[V2]]) : (tuple<>, i1, tuple>) -> tuple, i1, tuple>> +// CHECK-TUP-NEXT: return %[[V3]] : tuple, i1, tuple>> + +// If we only convert the func ops, target materializations are created from +// original tuples produced by `make_tuple` to its constituent elements that the +// converted op (i.e., `return`) expect. +// +// CHECK-FUNC-LABEL: func.func @materializations( +// CHECK-FUNC-SAME: %[[ARG0:.*]]: i1, +// CHECK-FUNC-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-FUNC-NEXT: %0 = "test.make_tuple"() : () -> tuple<> +// CHECK-FUNC-NEXT: %1 = "test.make_tuple"(%arg1) : (i2) -> tuple +// CHECK-FUNC-NEXT: %2 = "test.make_tuple"(%1) : (tuple) -> tuple> +// CHECK-FUNC-NEXT: %3 = "test.make_tuple"(%0, %arg0, %2) : (tuple<>, i1, tuple>) -> tuple, i1, tuple>> +// CHECK-FUNC-NEXT: %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> +// CHECK-FUNC-NEXT: %5 = "test.get_tuple_element"(%3) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 +// CHECK-FUNC-NEXT: %6 = "test.get_tuple_element"(%3) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> +// CHECK-FUNC-NEXT: %7 = "test.get_tuple_element"(%6) {index = 0 : i32} : (tuple>) -> tuple +// CHECK-FUNC-NEXT: %8 = "test.get_tuple_element"(%7) {index = 0 : i32} : (tuple) -> i2 +// CHECK-FUNC-NEXT: return %5, %8 : i1, i2 + +// If we convert both tuple and func ops, basically everything disappears. +// +// CHECK-BOTH-LABEL: func.func @materializations( +// CHECK-BOTH-SAME: %[[ARG0:.*]]: i1, +// CHECK-BOTH-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-BOTH-NEXT: return %arg0, %arg1 : i1, i2 + +func.func @materializations(%arg0: i1, %arg1: i2) -> tuple, i1, tuple>> { + %0 = "test.make_tuple"() : () -> tuple<> + %1 = "test.make_tuple"(%arg1) : (i2) -> tuple + %2 = "test.make_tuple"(%1) : (tuple) -> tuple> + %3 = "test.make_tuple"(%0, %arg0, %2) : (tuple<>, i1, tuple>) -> tuple, i1, tuple>> + return %3 : tuple, i1, tuple>> +} diff --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-type-conversion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-type-conversion.mlir @@ -0,0 +1,118 @@ +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-scf-ops" \ +// RUN: | FileCheck %s + +// Test case: Nested 1:N type conversion is carried through scf.if and +// scf.yield. + +// CHECK-LABEL: func.func @if_result( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i2, +// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) { +// CHECK-NEXT: %[[V0:.*]]:2 = scf.if %[[ARG2]] -> (i1, i2) { +// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]]#0, %[[V0]]#1 : i1, i2 +func.func @if_result(%arg0: tuple, i1, tuple>, %arg1: i1) -> tuple, i1, tuple> { + %0 = scf.if %arg1 -> (tuple, i1, tuple>) { + scf.yield %arg0 : tuple, i1, tuple> + } else { + scf.yield %arg0 : tuple, i1, tuple> + } + return %0 : tuple, i1, tuple> +} + +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.if and +// scf.yield and unconverted ops inside have propoer materilizations. + +// CHECK-LABEL: func.func @if_tuple_ops( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 { +// CHECK-NEXT: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> +// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]]) : (tuple<>, i1) -> tuple, i1> +// CHECK-NEXT: %[[V2:.*]] = scf.if %[[ARG1]] -> (i1) { +// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V1]]) : (tuple, i1>) -> tuple, i1> +// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V5]] : i1 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple, i1> +// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V8]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V2]] : i1 +func.func @if_tuple_ops(%arg0: tuple, i1>, %arg1: i1) -> tuple, i1> { + %0 = scf.if %arg1 -> (tuple, i1>) { + %1 = "test.op"(%arg0) : (tuple, i1>) -> tuple, i1> + scf.yield %1 : tuple, i1> + } else { + %1 = "test.source"() : () -> tuple, i1> + scf.yield %1 : tuple, i1> + } + return %0 : tuple, i1> +} +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.while, +// scf.condition, and scf.yield. + +// CHECK-LABEL: func.func @while_operands_results( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i2, +// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) { +// %[[V0:.*]]:2 = scf.while (%[[ARG3:.*]] = %[[ARG0]], %[[ARG4:.*]] = %[[ARG1]]) : (i1, i2) -> (i1, i2) { +// scf.condition(%arg2) %[[ARG3]], %[[ARG4]] : i1, i2 +// } do { +// ^bb0(%[[ARG5:.*]]: i1, %[[ARG6:.*]]: i2): +// scf.yield %[[ARG5]], %[[ARG4]] : i1, i2 +// } +// return %[[V0]]#0, %[[V0]]#1 : i1, i2 +func.func @while_operands_results(%arg0: tuple, i1, tuple>, %arg1: i1) -> tuple, i1, tuple> { + %0 = scf.while (%arg2 = %arg0) : (tuple, i1, tuple>) -> tuple, i1, tuple> { + scf.condition(%arg1) %arg2 : tuple, i1, tuple> + } do { + ^bb0(%arg2: tuple, i1, tuple>): + scf.yield %arg2 : tuple, i1, tuple> + } + return %0 : tuple, i1, tuple> +} + +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.while, +// scf.condition, and and unconverted ops inside have propoer materilizations. + +// CHECK-LABEL: func.func @while_tuple_ops( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 { +// CHECK-NEXT: %[[V0:.*]] = scf.while (%[[ARG2:.*]] = %[[ARG0]]) : (i1) -> i1 { +// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"() : () -> tuple<> +// CHECK-NEXT: %[[V2:.*]] = "test.make_tuple"(%[[V1]], %[[ARG2]]) : (tuple<>, i1) -> tuple, i1> +// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V2]]) : (tuple, i1>) -> tuple, i1> +// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.condition(%[[ARG1]]) %[[V5]] : i1 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: i1): +// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple, i1> +// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V8]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]] : i1 +func.func @while_tuple_ops(%arg0: tuple, i1>, %arg1: i1) -> tuple, i1> { + %0 = scf.while (%arg2 = %arg0) : (tuple, i1>) -> tuple, i1> { + %1 = "test.op"(%arg2) : (tuple, i1>) -> tuple, i1> + scf.condition(%arg1) %1 : tuple, i1> + } do { + ^bb0(%arg2: tuple, i1>): + %1 = "test.source"() : () -> tuple, i1> + scf.yield %1 : tuple, i1> + } + return %0 : tuple, i1> +} diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir --- a/mlir/test/Transforms/decompose-call-graph-types.mlir +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops" \ +// RUN: | FileCheck %s --check-prefix=CHECK-12N + // Test case: Most basic case of a 1:N decomposition, an identity function. // CHECK-LABEL: func @identity( @@ -9,6 +13,10 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @identity( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i32 func.func @identity(%arg0: tuple) -> tuple { return %arg0 : tuple } @@ -20,6 +28,9 @@ // CHECK-LABEL: func @identity_1_to_1_no_materializations( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 +// CHECK-12N-LABEL: func @identity_1_to_1_no_materializations( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { +// CHECK-12N: return %[[ARG0]] : i1 func.func @identity_1_to_1_no_materializations(%arg0: tuple) -> tuple { return %arg0 : tuple } @@ -31,6 +42,9 @@ // CHECK-LABEL: func @recursive_decomposition( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 +// CHECK-12N-LABEL: func @recursive_decomposition( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { +// CHECK-12N: return %[[ARG0]] : i1 func.func @recursive_decomposition(%arg0: tuple>>) -> tuple>> { return %arg0 : tuple>> } @@ -54,6 +68,10 @@ // CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple>) -> tuple // CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple) -> i2 // CHECK: return %[[V7]], %[[V10]] : i1, i2 +// CHECK-12N-LABEL: func @mixed_recursive_decomposition( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i2 func.func @mixed_recursive_decomposition(%arg0: tuple, tuple, tuple>>) -> tuple, tuple, tuple>> { return %arg0 : tuple, tuple, tuple>> } @@ -63,6 +81,7 @@ // Test case: Check decomposition of calls. // CHECK-LABEL: func private @callee(i1, i32) -> (i1, i32) +// CHECK-12N-LABEL: func private @callee(i1, i32) -> (i1, i32) func.func private @callee(tuple) -> tuple // CHECK-LABEL: func @caller( @@ -76,6 +95,11 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @caller( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32) +// CHECK-12N: return %[[V0]]#0, %[[V0]]#1 : i1, i32 func.func @caller(%arg0: tuple) -> tuple { %0 = call @callee(%arg0) : (tuple) -> tuple return %0 : tuple @@ -86,7 +110,12 @@ // Test case: Type that decomposes to nothing (that is, a 1:0 decomposition). // CHECK-LABEL: func private @callee() +// CHECK-12N-LABEL: func private @callee() func.func private @callee(tuple<>) -> tuple<> + +// CHECK-12N-LABEL: func @caller() { +// CHECK-12N: call @callee() : () -> () +// CHECK-12N: return // CHECK-LABEL: func @caller() { // CHECK: call @callee() : () -> () // CHECK: return @@ -105,6 +134,11 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @unconverted_op_result() -> (i1, i32) { +// CHECK-12N: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple +// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 +// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 +// CHECK-12N: return %[[RET0]], %[[RET1]] : i1, i32 func.func @unconverted_op_result() -> tuple { %0 = "test.source"() : () -> (tuple) return %0 : tuple @@ -125,6 +159,16 @@ // CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple // CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 // CHECK: return %[[V3]], %[[V5]] : i1, i32 +// CHECK-12N-LABEL: func @nested_unconverted_op_result( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple +// CHECK-12N: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple) -> tuple> +// CHECK-12N: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple>) -> tuple> +// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple>) -> i1 +// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple +// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 +// CHECK-12N: return %[[V3]], %[[V5]] : i1, i32 func.func @nested_unconverted_op_result(%arg: tuple>) -> tuple> { %0 = "test.op"(%arg) : (tuple>) -> (tuple>) return %0 : tuple> @@ -136,6 +180,7 @@ // This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions. // CHECK-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) +// CHECK-12N-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) func.func private @callee(tuple<>, i1, tuple, i3, tuple, i6) -> (tuple<>, i1, tuple, i3, tuple, i6) // CHECK-LABEL: func @caller( @@ -153,6 +198,15 @@ // CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 0 : i32} : (tuple) -> i4 // CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 1 : i32} : (tuple) -> i5 // CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 +// CHECK-12N-LABEL: func @caller( +// CHECK-12N-SAME: %[[I1:.*]]: i1, +// CHECK-12N-SAME: %[[I2:.*]]: i2, +// CHECK-12N-SAME: %[[I3:.*]]: i3, +// CHECK-12N-SAME: %[[I4:.*]]: i4, +// CHECK-12N-SAME: %[[I5:.*]]: i5, +// CHECK-12N-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) { +// CHECK-12N: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) +// CHECK-12N: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 func.func @caller(%arg0: tuple<>, %arg1: i1, %arg2: tuple, %arg3: i3, %arg4: tuple, %arg5: i6) -> (tuple<>, i1, tuple, i3, tuple, i6) { %0, %1, %2, %3, %4, %5 = call @callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tuple<>, i1, tuple, i3, tuple, i6) -> (tuple<>, i1, tuple, i3, tuple, i6) return %0, %1, %2, %3, %4, %5 : tuple<>, i1, tuple, i3, tuple, i6 diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt --- a/mlir/test/lib/Conversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(FuncToLLVM) +add_subdirectory(OneToNTypeConversion) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_library(MLIRTestOneToNTypeConversionPass + OneToNTypeConversion.cpp + OneToNTypeConversionFunc.cpp + OneToNTypeConversionSCF.cpp + TestOneToNTypeConversionPass.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRIR + MLIRSCFDialect + MLIRTestDialect + MLIRTransformUtils + ) + +target_include_directories(MLIRTestOneToNTypeConversionPass + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test + ) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.h b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.h @@ -0,0 +1,248 @@ +//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utils for implementing (poor-man's) dialect conversion +// passes with 1:N type conversions. +// +// The main function first applies a set of RewritePatterns, which produce +// unrealized casts to convert the operands and results from and to the source +// types, and then replaces all newly added unrealized casts by user-provided +// materializations. For this to work, the main function requires a special +// TypeConverter and special RewritePatterns, respectively deriving from the +// provided classes, which extend their respective base classes for 1:N type +// converions. +// +// Note that this is much more simple-minded than the "real" dialect conversion, +// which checks for legality before applying patterns and does probably many +// other additional things. Ideally, some of the extensions here could be +// integrated there. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSION_H +#define TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSION_H + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +/// Extends `TypeConverter` with 1:N target materializations. Such +/// materializations have to provide the "reverse" of 1:N type conversions, +/// i.e., they need to materialize N values with target types into one value +/// with a source type (which isn't possible in the base class currently). +class OneToNTypeConverter : public TypeConverter { +public: + using OneToNMaterializationCallbackFn = + std::function>(OpBuilder &, TypeRange, + Value, Location)>; + + /// Creates the mapping of the given range of original types to target types + /// of the conversion and stores that mapping in the given (signature) + /// conversion. This function simply calls TypeConverter::convertSignatureArgs + /// and exists here with a different name to reflect the broader semantic. + LogicalResult computeTypeMapping(TypeRange types, + SignatureConversion &result) { + return convertSignatureArgs(types, result); + } + + /// Applies one of the user-provided 1:N target materializations (in LIFO + /// order). + std::optional> + materializeTargetConversion(OpBuilder &builder, Location loc, + TypeRange resultTypes, Value input) const; + + /// Adds a 1:N target materialization to the converter. Such materializations + /// build IR that converts N values with target types into 1 value of the + /// source type. + void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) { + oneToNTargetMaterializations.emplace_back(std::move(callback)); + } + +private: + SmallVector oneToNTargetMaterializations; +}; + +/// Stores a 1:N mapping of types and provides several useful accessors. This +/// class extends SignatureConversion, which already supports 1:N type mappings +/// but lacks some accessors into the mapping as well as access to the original +/// types. +class OneToNTypeMapping : public TypeConverter::SignatureConversion { +public: + OneToNTypeMapping(TypeRange originalTypes) + : TypeConverter::SignatureConversion(originalTypes.size()), + originalTypes(originalTypes) {} + + using TypeConverter::SignatureConversion::getConvertedTypes; + + /// Returns the list of types that corresponds to the original type at the + /// given index. + TypeRange getConvertedTypes(unsigned originalTypeNo) const; + + /// Returns the list of original types. + TypeRange getOriginalTypes() const { return originalTypes; } + + /// Returns the slice of converted values that corresponds the original value + /// at the given index. + ValueRange getConvertedValues(ValueRange convertedValues, + unsigned originalValueNo) const; + + /// Fills the given result vector with as many copies of the location of the + /// original value as the number of values it is converted to. + void convertLocation(Value originalValue, unsigned originalValueNo, + llvm::SmallVectorImpl &result) const; + + /// Fills the given result vector with as many copies of the lociation of each + /// original value as the number of values they are respectively converted to. + void convertLocations(ValueRange originalValues, + llvm::SmallVectorImpl &result) const; + + /// Returns true iff at least one type conversion maps an input type to a type + /// that is different from itself. + bool hasNonIdentityConversion() const; + +private: + llvm::SmallVector originalTypes; +}; + +/// Extends the basic RewritePattern with a type converter member and some +/// accessors to it. This is useful for patterns that are not ConversionPatterns +/// but still require access to a type converter. +class RewritePatternWithConverter : public mlir::RewritePattern { +public: + /// Construct a conversion pattern with the given converter, and forward the + /// remaining arguments to RewritePattern. + template + RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args) + : RewritePattern(std::forward(args)...), + typeConverter(&typeConverter) {} + + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + TypeConverter *getTypeConverter() const { return typeConverter; } + + template + std::enable_if_t::value, + ConverterTy *> + getTypeConverter() const { + return static_cast(typeConverter); + } + +protected: + /// A type converter for use by this pattern. + TypeConverter *const typeConverter; +}; + +/// Specialization of PatternRewriter that OneToNConversionPatterns use. The +/// class provides additional rewrite methods that are specific to 1:N type +/// conversions. +class OneToNPatternRewriter : public PatternRewriter { +public: + OneToNPatternRewriter(MLIRContext *context) : PatternRewriter(context) {} + + /// Replaces the results of the operation with the specified list of values + /// mapped back to the original types as specified in the provided type + /// mapping. That type mapping must match the replaced op (i.e., the original + /// types must be the same as the result types of the op) and the new values + /// (i.e., the converted types must be the same as the types of the new + /// values). + void replaceOp(Operation *op, ValueRange newValues, + const OneToNTypeMapping &resultMapping); + using PatternRewriter::replaceOp; + + /// Applies the given argument conversion to the given block. This consists of + /// replacing each original argument with N arguments as specified in the + /// argument conversion and inserting unrealized casts from the converted + /// values to the original types, which are then used in lieu of the original + /// ones. (Eventually, applyOneToNConversion replaces these casts with a + /// user-provided argument materialization if necessary.) This is similar to + /// ArgConverter::applySignatureConversion but (1) handles 1:N type conversion + /// properly and probably (2) doesn't handle many other edge cases. + Block *applySignatureConversion(Block *block, + OneToNTypeMapping &argumentConversion); +}; + +/// Base class for patterns with 1:N type conversions. Derived classes have to +/// overwrite the `matchAndRewrite`overlaod that provides additional information +/// for 1:N type conversions. +class OneToNConversionPattern : public RewritePatternWithConverter { +public: + using RewritePatternWithConverter::RewritePatternWithConverter; + + /// This function has to be implemented by base classes and is called from the + /// usual overloads. Like in normal DialectConversion, the function is + /// provided with the converted operands (which thus have target types). Since + /// 1:N conversion are supported, there is usually no 1:1 relationship between + /// the original and the converted operands. Instead, the provided + /// `operandMapping` can be used to access the converted operands that + /// correspond to a particular original operand. Similarly, `resultMapping` + /// is provided to help with assembling the result values (which may have 1:N + /// correspondences as well). The function is expted to return the converted + /// result values if the conversion succeeds and failure otherwise (in which + /// case any modifications of the IR have to be rolled back first). The + /// correspondance of original and converted result values needs to correspond + /// to `resultMapping`. For both the converted operands and results, the + /// calling overload inserts appropriate unrealized casts that produce and + /// consume them, and replaces the uses of the results with the results of the + /// casts. If the returned result values are the same as those of the original + /// op, an in-place update is assumed and the result values are left as is. + virtual LogicalResult + matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const = 0; + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final; +}; + +/// This class is a wrapper around OneToNConversionPattern for matching against +/// instances of a particular op class. +template +class OneToNOpConversionPattern : public OneToNConversionPattern { +public: + OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1, + ArrayRef generatedNames = {}) + : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(), + benefit, context, generatedNames) {} + + using OneToNConversionPattern::matchAndRewrite; + + /// Overload that derived classes have to override for their op type. + virtual LogicalResult + matchAndRewrite(SourceOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const = 0; + + LogicalResult + matchAndRewrite(Operation *op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const final { + return matchAndRewrite(cast(op), rewriter, operandMapping, + resultMapping, convertedOperands); + } +}; + +/// Main function that 1:N conversion passes should call. The patterns are +/// expected to insert unrealized casts to maintain the types of operands and +/// results, which is done automatically if the derive from +/// OneToNConversionPattern. The function replaces those that do not fold away +/// until the end of pattern application with user-provided materializations +/// from the type converter, so those have to be provided if conversions from +/// source to target types are expected to remain. +LogicalResult applyOneToNConversion(Operation *op, + OneToNTypeConverter &typeConverter, + const FrozenRewritePatternSet &patterns); + +} // namespace mlir + +#endif // TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSION_H diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.cpp @@ -0,0 +1,374 @@ +//===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversion.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" + +using namespace llvm; +using namespace mlir; + +std::optional> +OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder, + Location loc, + TypeRange resultTypes, + Value input) const { + for (const OneToNMaterializationCallbackFn &fn : + llvm::reverse(oneToNTargetMaterializations)) { + if (std::optional> result = + fn(builder, resultTypes, input, loc)) + return *result; + } + return {}; +} + +TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const { + TypeRange convertedTypes = getConvertedTypes(); + if (auto mapping = getInputMapping(originalTypeNo)) + return convertedTypes.slice(mapping->inputNo, mapping->size); + return {}; +} + +ValueRange +OneToNTypeMapping::getConvertedValues(ValueRange convertedValues, + unsigned originalValueNo) const { + if (auto mapping = getInputMapping(originalValueNo)) + return convertedValues.slice(mapping->inputNo, mapping->size); + return {}; +} + +void OneToNTypeMapping::convertLocation( + Value originalValue, unsigned originalValueNo, + llvm::SmallVectorImpl &result) const { + if (auto mapping = getInputMapping(originalValueNo)) + result.append(mapping->size, originalValue.getLoc()); +} + +void OneToNTypeMapping::convertLocations( + ValueRange originalValues, llvm::SmallVectorImpl &result) const { + assert(originalValues.size() == getOriginalTypes().size()); + for (auto &[i, value] : llvm::enumerate(originalValues)) + convertLocation(value, i, result); +} + +static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) { + return convertedTypes.size() == 1 && convertedTypes[0] == originalType; +} + +bool OneToNTypeMapping::hasNonIdentityConversion() const { + // XXX: I think that the original types and the converted types are the same + // iff there was no non-identity type conversion. If that is true, the + // patterns could actually test whether there is anything useful to do + // without having access to the signature conversion. + for (auto [i, originalType] : llvm::enumerate(originalTypes)) { + TypeRange types = getConvertedTypes(i); + if (!isIdentityConversion(originalType, types)) { + assert(TypeRange(originalTypes) != getConvertedTypes()); + return true; + } + } + assert(TypeRange(originalTypes) == getConvertedTypes()); + return false; +} + +enum class CastKind { + // Casts block arguments in the target type back to the source type. (If + // necessary, this cast becomes an argument materialization.) + Argument, + + // Casts other values in the target type back to the source type. (If + // necessary, this cast becomes a source materialization.) + Source, + + // Casts values in the source type to the target type. (If necessary, this + // cast becomes a target materialization.) + Target +}; + +/// Mapping of enum values to string values. +static const std::unordered_map castKindNames = { + {CastKind::Argument, "argument"}, + {CastKind::Source, "source"}, + {CastKind::Target, "target"}}; + +/// Attribute name that is used to annotate inserted unrealized casts with their +/// kind (source, argument, or target). +static const char *const castKindAttrName = + "__one-to-n-type-conversion_cast-kind__"; + +/// Builds an UnrealizedConversionCastOp from the given inputs to the given +/// result types. Returns the result values of the cast. +static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes, + ValueRange inputs, CastKind kind) { + // Create cast. + Location loc = builder.getUnknownLoc(); + if (!inputs.empty()) + loc = inputs.front().getLoc(); + auto castOp = + builder.create(loc, resultTypes, inputs); + + // Store cast kind as attribute. + auto kindAttr = StringAttr::get(builder.getContext(), castKindNames.at(kind)); + castOp->setAttr(castKindAttrName, kindAttr); + + return castOp->getResults(); +} + +/// Builds one UnrealizedConversionCastOp for each of the given original values +/// using the respective target types given in the provided conversion mapping +/// and returns the results of these casts. If the conversion mapping of a value +/// maps a type to itself (i.e., is an identity conversion), then no cast is +/// inserted and the original value is returned instead. +/// Note that these unrealized are different from target materializations in +/// that they are *always* inserted, even if they immediately fold away, such +/// that patterns always see valid intermediate IR, whereas materializations are +/// only used in the places where the unrealized casts *don't* fold away. +static SmallVector +buildUnrealizedForwardCasts(ValueRange originalValues, + OneToNTypeMapping &conversion, + RewriterBase &rewriter, CastKind kind) { + + // Convert each operand one by one. + SmallVector convertedValues; + convertedValues.reserve(conversion.getConvertedTypes().size()); + for (auto [idx, originalValue] : llvm::enumerate(originalValues)) { + TypeRange convertedTypes = conversion.getConvertedTypes(idx); + + // Identity conversion: keep operand as is. + if (isIdentityConversion(originalValue.getType(), convertedTypes)) { + convertedValues.push_back(originalValue); + continue; + } + + // Non-identity conversion: materialize target types. + ValueRange castResult = + buildUnrealizedCast(rewriter, convertedTypes, originalValue, kind); + convertedValues.append(castResult.begin(), castResult.end()); + } + + return convertedValues; +} + +/// Builds one UnrealizedConversionCastOp for each sequence of the given +/// original values to one value of the type they originated from, i.e., a +/// "reverse" conversion from N converted values back to one value of the +/// original type, using the given (forward) type conversion. If a given value +/// was mapped to a value of the same type (i.e., the conversion in the mapping +/// is an identity conversion), then the "converted" value is returned without +/// cast. +/// Note that these unrealized are different from source materializations in +/// that they are *always* inserted, even if they immediately fold away, such +/// that patterns always see valid intermediate IR, whereas materializations are +/// only used in the places where the unrealized casts *don't* fold away. +static SmallVector +buildUnrealizedBackwardsCasts(ValueRange convertedValues, + const OneToNTypeMapping &typeConversion, + RewriterBase &rewriter) { + assert(typeConversion.getConvertedTypes() == convertedValues.getTypes()); + + // Create unrealized cast op for each converted result of the op. + SmallVector recastValues; + TypeRange originalTypes = typeConversion.getOriginalTypes(); + recastValues.reserve(originalTypes.size()); + auto convertedValueIt = convertedValues.begin(); + for (auto [idx, originalType] : llvm::enumerate(originalTypes)) { + TypeRange convertedTypes = typeConversion.getConvertedTypes(idx); + size_t numConvertedValues = convertedTypes.size(); + if (isIdentityConversion(originalType, convertedTypes)) { + // Identity conversion: take result as is. + recastValues.push_back(*convertedValueIt); + } else { + // Non-identity conversion: cast back to source type. + ValueRange recastValue = buildUnrealizedCast( + rewriter, originalType, + ValueRange{convertedValueIt, convertedValueIt + numConvertedValues}, + CastKind::Source); + assert(recastValue.size() == 1); + recastValues.push_back(recastValue.front()); + } + convertedValueIt += numConvertedValues; + } + + return recastValues; +} + +LogicalResult +OneToNConversionPattern::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + // Construct conversion mapping for results. + Operation::result_type_range originalResultTypes = op->getResultTypes(); + OneToNSignatureConversion resultConversion(originalResultTypes); + if (failed(typeConverter->convertSignatureArgs(originalResultTypes, + resultConversion))) + return failure(); + + // Construct conversion mapping for operands. + Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); + OneToNSignatureConversion operandConversion(originalOperandTypes); + if (failed(typeConverter->convertSignatureArgs(originalOperandTypes, + operandConversion))) + return failure(); + + // Cast operands to target types. + SmallVector convertedOperands = buildUnrealizedForwardCasts( + op->getOperands(), operandConversion, rewriter); + + // Apply actual pattern. + auto result = matchAndRewrite(op, rewriter, operandConversion, + resultConversion, convertedOperands); + + if (failed(result)) + return failure(); + SmallVector &replacementValues = result.value(); + + // If replacementValues consist of the results of the original op, assume + // in-place update. + // TODO: This isn't particularly elegant. Not sure how else to handle that + // case without tracking modifications through the rewriter, which + // would require a custom pattern application driver. + if (ValueRange{op->getResults()} == replacementValues) + return success(); + + // Cast op results back to the original types and use those. + SmallVector castResults = buildUnrealizedBackwardsCasts( + replacementValues, resultConversion, rewriter); + rewriter.replaceOp(op, castResults); + + return success(); +} + +namespace mlir { + +Block *applySignatureConversion(Block *block, + OneToNSignatureConversion &argumentConversion, + RewriterBase &rewriter) { + // Split the block at the beginning to get a new block to use for the + // updated signature. + Block *newBlock = rewriter.splitBlock(block, block->begin()); + rewriter.replaceAllUsesWith(block, newBlock); + + // Add block arguments to new block. + for (size_t i = 0; i < block->getNumArguments(); i++) { + BlockArgument arg = block->getArgument(i); + ArrayRef convertedTypes = argumentConversion.getConvertedTypes(i); + if (isIdentityConversion(arg.getType(), convertedTypes)) { + // Identity conversion: take argument as is. + BlockArgument newArg = newBlock->addArgument(arg.getType(), arg.getLoc()); + rewriter.replaceAllUsesWith(arg, newArg); + } else { + // Non-identity conversion: cast the converted arguments to the original + // type. + SmallVector locs(convertedTypes.size(), arg.getLoc()); + auto newArgsRange = newBlock->addArguments(convertedTypes, locs); + SmallVector newArgs(newArgsRange.begin(), newArgsRange.end()); + PatternRewriter::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(newBlock); + ValueRange castArgument = + buildUnrealizedCast(rewriter, arg.getType(), newArgs); + assert(castArgument.size() == 1); + rewriter.replaceAllUsesWith(arg, castArgument.front()); + } + } + + // Delete old (now empty) block. + rewriter.eraseBlock(block); + + return newBlock; +} + +// This function applies the provided patterns using +// applyPatternsAndFoldGreedily and then replaces all newly inserted +// UnrealizedConversionCastOps that haven't folded away. ("Backward" casts from +// target to source types inserted by a OneToNConversionPattern normally fold +// away with the "forward" casts from source to target types inserted by the +// next pattern.) To understand which casts are "newly inserted", we save a list +// of all casts existing before the patterns are applied and assume that all +// casts not in that list after the application are new. (This is probably not +// correct: It might be possible that an existing cast is folded away and a new +// cast happens to be allocated with exactly the same pointer. Dealing with that +// possiblity is an open TODO.) Also, we do not track which inserted casts are +// needed for source, target, or argument materialization, so we do some +// educated guessing to recover that information. Fixing both issues would +// require to use a PatternRewriter that overloads various `notify*` functions +// and similar and tracks all changes there. However, that would require a +// dedicated pattern application driver, which is currently also left as an open +// TODO.) +LogicalResult applyOneToNConversion(Operation *op, + OneToNTypeConverter &typeConverter, + const FrozenRewritePatternSet &patterns) { + // Remember existing unrealized casts. + SmallSet existingCasts; + op->walk( + [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); }); + + // Apply provided conversion patterns. + if (failed(applyPatternsAndFoldGreedily(op, patterns))) + return failure(); + + // Find all newly inserted unrealized casts (that haven't folded away). + SmallVector worklist; + op->walk([&](UnrealizedConversionCastOp castOp) { + if (!existingCasts.contains(castOp)) + worklist.push_back(castOp); + }); + + // Replace new casts with user materializations. + IRRewriter rewriter(op->getContext()); + for (UnrealizedConversionCastOp castOp : worklist) { + // Create user materialization. + TypeRange resultTypes = castOp->getResultTypes(); + rewriter.setInsertionPoint(castOp); + SmallVector materializedResults; + + // Determine whether operands or results are already legal to know which + // kind of materilization this is. + ValueRange operands = castOp.getOperands(); + bool areOperandTypesLegal = llvm::all_of( + operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); }); + bool areResultsTypesLegal = llvm::all_of( + resultTypes, [&](Type t) { return typeConverter.isLegal(t); }); + + if (!areOperandTypesLegal && areResultsTypesLegal && operands.size() == 1) { + // This is a target materilization. + std::optional> maybeResults = + typeConverter.materializeTargetConversion( + rewriter, castOp->getLoc(), resultTypes, operands.front()); + if (!maybeResults) + return failure(); + materializedResults = maybeResults.value(); + } else if (areOperandTypesLegal && !areResultsTypesLegal && + resultTypes.size() == 1) { + // This is a source or an argument materialization. + std::optional maybeResult; + if (llvm::all_of(operands, [&](Value v) { return v.getDefiningOp(); })) { + // This is an source materialization. + maybeResult = typeConverter.materializeArgumentConversion( + rewriter, castOp->getLoc(), resultTypes.front(), + castOp.getOperands()); + } else { + // This is an argument materialization. + maybeResult = typeConverter.materializeSourceConversion( + rewriter, castOp->getLoc(), resultTypes.front(), + castOp.getOperands()); + } + if (!maybeResult.has_value() || !maybeResult.value()) + return failure(); + materializedResults = {maybeResult.value()}; + } else { + assert(false && "unexpected cast inserted"); + } + + // Replace cast with materialization. + rewriter.replaceOp(castOp, materializedResults); + } + + return success(); +} + +} // namespace mlir diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.h b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.h @@ -0,0 +1,26 @@ +//===- OneToNTypeConversionFunc.h - 1:N type conversion for Func-*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONFUNC_H +#define TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONFUNC_H + +namespace mlir { +class TypeConverter; +class RewritePatternSet; +} // namespace mlir + +namespace mlir { + +// Populates the provided pattern set with patterns that do 1:N type conversions +// on func ops. This is intended to be used with applyOneToNConversion. +void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONFUNC_H diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.cpp @@ -0,0 +1,131 @@ +//===-- OneToNTypeConversionFunc.cpp - Func 1:N type conversion -*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The patterns in this file are heavily inspired (and copied from) +// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the +// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N +// type conversions. +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversionFunc.h" + +#include "OneToNTypeConversion.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +using namespace mlir; +using namespace mlir::func; + +class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(CallOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const override { + Location loc = op->getLoc(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!operandMapping.hasNonIdentityConversion() && + !resultMapping.hasNonIdentityConversion()) + return failure(); + + // Create new CallOp. + auto newOp = rewriter.create(loc, resultMapping.getConvertedTypes(), + convertedOperands); + newOp->setAttrs(op->getAttrs()); + + rewriter.replaceOp(op, SmallVector(newOp->getResults()), + resultMapping); + return success(); + } +}; + +class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite( + FuncOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping & /*operandMapping*/, + const OneToNTypeMapping & /*resultMapping*/, + const SmallVector & /*convertedOperands*/) const override { + auto *typeConverter = getTypeConverter(); + + // Construct mapping for function arguments. + OneToNTypeMapping argumentMapping(op.getArgumentTypes()); + if (failed(typeConverter->computeTypeMapping(op.getArgumentTypes(), + argumentMapping))) + return failure(); + + // Construct mapping for function results. + OneToNTypeMapping funcResultMapping(op.getResultTypes()); + if (failed(typeConverter->computeTypeMapping(op.getResultTypes(), + funcResultMapping))) + return failure(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!argumentMapping.hasNonIdentityConversion() && + !funcResultMapping.hasNonIdentityConversion()) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + argumentMapping.getConvertedTypes(), + funcResultMapping.getConvertedTypes()); + rewriter.updateRootInPlace(op, [&] { op.setType(newType); }); + + // Update block signatures. + if (!op.isExternal()) { + Region *region = &op.getBody(); + Block *block = ®ion->front(); + rewriter.applySignatureConversion(block, argumentMapping); + } + + return success(); + } +}; + +class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping & /*resultMapping*/, + const SmallVector &convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandMapping.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return success(); + } +}; + +namespace mlir { + +void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertTypesInFuncCallOp, + ConvertTypesInFuncFuncOp, + ConvertTypesInFuncReturnOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +} // namespace mlir diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.h b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.h @@ -0,0 +1,26 @@ +//===-- OneToNTypeConversionSCF.h - 1:N type conversion for scf -*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONSCF_H +#define TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONSCF_H + +namespace mlir { +class TypeConverter; +class RewritePatternSet; +} // namespace mlir + +namespace mlir { + +// Populates the provided pattern set with patterns that do 1:N type conversions +// on (some) SCF ops. This is intended to be used with applyOneToNConversion. +void populateSCFTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONSCF_H diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.cpp @@ -0,0 +1,159 @@ +//===-- OneToNTypeConversionSCF.cpp - SCF 1:N type conversion ---*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The patterns in this file are heavily inspired (and copied from) +// lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N +// type conversions. +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversionSCF.h" + +#include "OneToNTypeConversion.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +using namespace mlir; +using namespace mlir::scf; + +class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult matchAndRewrite( + IfOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping & /*operandMapping*/, + const OneToNTypeMapping &resultMapping, + const SmallVector & /*convertedOperands*/) const override { + Location loc = op->getLoc(); + + // Nothing to do if there is no non-identity conversion. + if (!resultMapping.hasNonIdentityConversion()) + return failure(); + + // Create new IfOp. + TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); + auto newOp = rewriter.create(loc, convertedResultTypes, + op.getCondition(), true); + newOp->setAttrs(op->getAttrs()); + + // We do not need the empty blocks created by rewriter. + rewriter.eraseBlock(newOp.elseBlock()); + rewriter.eraseBlock(newOp.thenBlock()); + + // Inlines block from the original operation. + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + rewriter.replaceOp(op, SmallVector(newOp->getResults()), + resultMapping); + return success(); + } +}; + +class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(WhileOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const override { + Location loc = op->getLoc(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!operandMapping.hasNonIdentityConversion() && + !resultMapping.hasNonIdentityConversion()) + return failure(); + + // Create new WhileOp. + TypeRange convertedResultTypes = resultMapping.getConvertedTypes(); + + auto newOp = + rewriter.create(loc, convertedResultTypes, convertedOperands); + newOp->setAttrs(op->getAttrs()); + + // Update block signatures. + std::array blockMappings = {operandMapping, + resultMapping}; + for (unsigned int i : {0u, 1u}) { + Region *region = &op.getRegion(i); + Block *block = ®ion->front(); + + rewriter.applySignatureConversion(block, blockMappings[i]); + + // Move updated region to new WhileOp. + Region &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + } + + rewriter.replaceOp(op, SmallVector(newOp->getResults()), + resultMapping); + return success(); + } +}; + +class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(YieldOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping & /*resultMapping*/, + const SmallVector &convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandMapping.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return success(); + } +}; + +class ConvertTypesInSCFConditionOp + : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(ConditionOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping & /*resultMapping*/, + const SmallVector &convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandMapping.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return success(); + } +}; + +namespace mlir { + +void populateSCFTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertTypesInSCFConditionOp, + ConvertTypesInSCFIfOp, + ConvertTypesInSCFWhileOp, + ConvertTypesInSCFYieldOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +} // namespace mlir diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -0,0 +1,228 @@ +//===- TestOneToNTypeConversion.cpp - Test 1:N type conversion utils ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversion.h" +#include "OneToNTypeConversionFunc.h" +#include "OneToNTypeConversionSCF.h" +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestOneToNTypeConversionPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass) + + TestOneToNTypeConversionPass() = default; + TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { + return "test-one-to-n-type-conversion"; + } + + StringRef getDescription() const final { + return "Test pass for 1:N type conversion"; + } + + Option convertFuncOps{*this, "convert-func-ops", + llvm::cl::desc("Enable conversion on func ops"), + llvm::cl::init(false)}; + + Option convertSCFOps{*this, "convert-scf-ops", + llvm::cl::desc("Enable conversion on scf ops"), + llvm::cl::init(false)}; + + Option convertTupleOps{*this, "convert-tuple-ops", + llvm::cl::desc("Enable conversion on tuple ops"), + llvm::cl::init(false)}; + + void runOnOperation() override; +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestOneToNTypeConversionPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir + +class ConvertMakeTupleOp + : public OneToNOpConversionPattern<::test::MakeTupleOp> { +public: + using OneToNOpConversionPattern< + ::test::MakeTupleOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(::test::MakeTupleOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const override { + // Simply replace the current op with the converted operands. + rewriter.replaceOp(op, convertedOperands, resultMapping); + return success(); + } +}; + +class ConvertGetTupleElementOp + : public OneToNOpConversionPattern<::test::GetTupleElementOp> { +public: + using OneToNOpConversionPattern< + ::test::GetTupleElementOp>::OneToNOpConversionPattern; + + LogicalResult + matchAndRewrite(::test::GetTupleElementOp op, OneToNPatternRewriter &rewriter, + const OneToNTypeMapping &operandMapping, + const OneToNTypeMapping &resultMapping, + const SmallVector &convertedOperands) const override { + // Construct mapping for tuple element types. + auto stateType = op->getOperand(0).getType().cast(); + TypeRange originalElementTypes = stateType.getTypes(); + OneToNTypeMapping elementMapping(originalElementTypes); + if (failed(typeConverter->convertSignatureArgs(originalElementTypes, + elementMapping))) + return failure(); + + // Compute converted operands corresponding to original input tuple. + ValueRange convertedTuple = + operandMapping.getConvertedValues(convertedOperands, 0); + + // Got those converted operands that correspond to the index-th element of + // the original input tuple. + size_t index = op.getIndex(); + ValueRange extractedElement = + elementMapping.getConvertedValues(convertedTuple, index); + + rewriter.replaceOp(op, extractedElement, resultMapping); + + return success(); + } +}; + +static void populateDecomposeTuplesPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertMakeTupleOp, + ConvertGetTupleElementOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +/// Creates a sequence of `test.get_tuple_element` ops for all elements of a +/// given tuple value. If some tuple elements are, in turn, tuples, the elements +/// of those are extracted recursively such that the returned values have the +/// same types as `resultTypes.getFlattenedTypes()`. +static std::optional> +buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) { + TupleType inputType = input.getType().dyn_cast(); + if (!inputType) + return {}; + + SmallVector values; + for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) { + Value element = builder.create<::test::GetTupleElementOp>( + loc, elementType, input, builder.getI32IntegerAttr(idx)); + if (auto nestedTupleType = elementType.dyn_cast()) { + // Recurse if the current element is also a tuple. + SmallVector flatRecursiveTypes; + nestedTupleType.getFlattenedTypes(flatRecursiveTypes); + std::optional> resursiveValues = + buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc); + if (!resursiveValues.has_value()) + return {}; + values.append(resursiveValues.value()); + } else { + values.push_back(element); + } + } + return values; +} + +/// Creates a `test.make_tuple` op out of the given inputs building a tuple of +/// type `resultType`. If that type is nested, each nested tuple is built +/// recursively with another `test.make_tuple` op. +static std::optional buildMakeTupleOp(OpBuilder &builder, + TupleType resultType, + ValueRange inputs, Location loc) { + // Build one value for each element at this nesting level. + SmallVector elements; + elements.reserve(resultType.getTypes().size()); + ValueRange::iterator inputIt = inputs.begin(); + for (Type elementType : resultType.getTypes()) { + if (auto nestedTupleType = elementType.dyn_cast()) { + // Determine how many input values are needed for the nested elements of + // the nested TupleType and advance inputIt by that number. + // TODO: We only need the *number* of nested types, not the types itself. + // Maybe it's worth adding a more efficient overload? + SmallVector nestedFlattenedTypes; + nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); + size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); + ValueRange nestedFlattenedelements(inputIt, + inputIt + numNestedFlattenedTypes); + inputIt += numNestedFlattenedTypes; + + // Recurse on the values for the nested TupleType. + std::optional res = buildMakeTupleOp(builder, nestedTupleType, + nestedFlattenedelements, loc); + if (!res.has_value()) + return {}; + + // The tuple constructed by the conversion is the element value. + elements.push_back(res.value()); + } else { + // Base case: take one input as is. + elements.push_back(*inputIt++); + } + } + + // Assemble the tuple from the elements. + return builder.create<::test::MakeTupleOp>(loc, resultType, elements); +} + +void TestOneToNTypeConversionPass::runOnOperation() { + ModuleOp module = getOperation(); + auto *context = &getContext(); + + // Assemble type converter. + OneToNTypeConverter typeConverter; + + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](TupleType tupleType, SmallVectorImpl &types) { + tupleType.getFlattenedTypes(types); + return success(); + }); + + typeConverter.addArgumentMaterialization(buildMakeTupleOp); + typeConverter.addSourceMaterialization(buildMakeTupleOp); + typeConverter.addTargetMaterialization(buildGetTupleElementOps); + + // Assemble patterns. + RewritePatternSet patterns(context); + if (convertTupleOps) + populateDecomposeTuplesPatterns(typeConverter, patterns); + if (convertFuncOps) + populateFuncTypeConversionPatterns(typeConverter, patterns); + if (convertSCFOps) + populateSCFTypeConversionPatterns(typeConverter, patterns); + + // Run conversion. + if (failed(applyOneToNConversion(module, typeConverter, std::move(patterns)))) + return signalPassFailure(); +} diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -33,6 +33,7 @@ MLIRTestDialect MLIRTestDynDialect MLIRTestIR + MLIRTestOneToNTypeConversionPass MLIRTestPass MLIRTestPDLL MLIRTestReducer diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -106,6 +106,7 @@ void registerTestMathPolynomialApproximationPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); +void registerTestOneToNTypeConversionPass(); void registerTestOpaqueLoc(); void registerTestPadFusion(); void registerTestPDLByteCodePass(); @@ -215,6 +216,7 @@ mlir::test::registerTestMathPolynomialApproximationPass(); mlir::test::registerTestMemRefDependenceCheck(); mlir::test::registerTestMemRefStrideCalculation(); + mlir::test::registerTestOneToNTypeConversionPass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestPadFusion(); mlir::test::registerTestPDLByteCodePass();