diff --git a/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OneToNTypeConversion/one-to-n-type-conversion.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-tuple-ops" \ +// RUN: | FileCheck --check-prefix=CHECK-TUP %s + +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops" \ +// RUN: | FileCheck --check-prefix=CHECK-FUNC %s + +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-tuple-ops" \ +// RUN: | FileCheck --check-prefix=CHECK-BOTH %s + +// Test case: Matching nested packs and unpacks just disappear. + +// CHECK-TUP-LABEL: func.func @pack_unpack( +// CHECK-TUP-SAME: %[[ARG0:.*]]: i1, +// CHECK-TUP-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-TUP-NEXT: return %[[ARG0]], %[[ARG1]] : i1, i2 +func.func @pack_unpack(%arg0: i1, %arg1: i2) -> (i1, i2) { + %0 = "test.make_tuple"() : () -> tuple<> + %1 = "test.make_tuple"(%arg1) : (i2) -> tuple + %2 = "test.make_tuple"(%1) : (tuple) -> tuple> + %3 = "test.make_tuple"(%0, %arg0, %2) : (tuple<>, i1, tuple>) -> tuple, i1, tuple>> + %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> + %5 = "test.get_tuple_element"(%3) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 + %6 = "test.get_tuple_element"(%3) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> + %7 = "test.get_tuple_element"(%6) {index = 0 : i32} : (tuple>) -> tuple + %8 = "test.get_tuple_element"(%7) {index = 0 : i32} : (tuple) -> i2 + return %5, %8 : i1, i2 +} + +// ----- + +// Test case: Appropriate materilizations are created depending on which ops are +// converted. + +// If we only convert the tuple ops, the original `get_tuple_element` ops will +// disappear but one target materialization will be inserted from the +// unconverted function arguments to the return values (which have redundancy +// among themselves). +// +// CHECK-TUP-LABEL: func.func @materializations( +// CHECK-TUP-SAME: %[[ARG0:.*]]: tuple, i1, tuple>>) -> (i1, i2) { +// CHECK-TUP-NEXT: %0 = "test.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> +// CHECK-TUP-NEXT: %1 = "test.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 +// CHECK-TUP-NEXT: %2 = "test.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> +// CHECK-TUP-NEXT: %3 = "test.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tuple +// CHECK-TUP-NEXT: %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple) -> i2 +// CHECK-TUP-NEXT: %5 = "test.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> +// CHECK-TUP-NEXT: %6 = "test.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 +// CHECK-TUP-NEXT: %7 = "test.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> +// CHECK-TUP-NEXT: %8 = "test.get_tuple_element"(%7) {index = 0 : i32} : (tuple>) -> tuple +// CHECK-TUP-NEXT: %9 = "test.get_tuple_element"(%8) {index = 0 : i32} : (tuple) -> i2 +// CHECK-TUP-NEXT: return %1, %9 : i1, i2 + +// If we only convert the func ops, argument materializations are created from +// the converted tuple elements back to the tuples that the `get_tuple_element` +// ops expect. +// +// CHECK-FUNC-LABEL: func.func @materializations( +// CHECK-FUNC-SAME: %[[ARG0:.*]]: i1, +// CHECK-FUNC-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-FUNC-NEXT: %0 = "test.make_tuple"() : () -> tuple<> +// CHECK-FUNC-NEXT: %1 = "test.make_tuple"(%arg1) : (i2) -> tuple +// CHECK-FUNC-NEXT: %2 = "test.make_tuple"(%1) : (tuple) -> tuple> +// CHECK-FUNC-NEXT: %3 = "test.make_tuple"(%0, %arg0, %2) : (tuple<>, i1, tuple>) -> tuple, i1, tuple>> +// CHECK-FUNC-NEXT: %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> +// CHECK-FUNC-NEXT: %5 = "test.get_tuple_element"(%3) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 +// CHECK-FUNC-NEXT: %6 = "test.get_tuple_element"(%3) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> +// CHECK-FUNC-NEXT: %7 = "test.get_tuple_element"(%6) {index = 0 : i32} : (tuple>) -> tuple +// CHECK-FUNC-NEXT: %8 = "test.get_tuple_element"(%7) {index = 0 : i32} : (tuple) -> i2 +// CHECK-FUNC-NEXT: return %5, %8 : i1, i2 + +// If we convert both tuple and func ops, basically everything disappears. +// +// CHECK-BOTH-LABEL: func.func @materializations( +// CHECK-BOTH-SAME: %[[ARG0:.*]]: i1, +// CHECK-BOTH-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-BOTH-NEXT: return %arg0, %arg1 : i1, i2 + +func.func @materializations(%arg0: tuple, i1, tuple>>) -> (i1, i2) { + %0 = "test.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, i1, tuple>>) -> tuple<> + %1 = "test.get_tuple_element"(%arg0) {index = 1 : i32} : (tuple, i1, tuple>>) -> i1 + %2 = "test.get_tuple_element"(%arg0) {index = 2 : i32} : (tuple, i1, tuple>>) -> tuple> + %3 = "test.get_tuple_element"(%2) {index = 0 : i32} : (tuple>) -> tuple + %4 = "test.get_tuple_element"(%3) {index = 0 : i32} : (tuple) -> i2 + return %1, %4 : i1, i2 +} diff --git a/mlir/test/Conversion/OneToNTypeConversion/scf-structural-type-conversion.mlir b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-type-conversion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/OneToNTypeConversion/scf-structural-type-conversion.mlir @@ -0,0 +1,118 @@ +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops convert-scf-ops" \ +// RUN: | FileCheck %s + +// Test case: Nested 1:N type conversion is carried through scf.if and +// scf.yield. + +// CHECK-LABEL: func.func @if_result( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i2, +// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) { +// CHECK-NEXT: %[[V0:.*]]:2 = scf.if %[[ARG2]] -> (i1, i2) { +// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2 +// CHECK-NEXT: } else { +// CHECK-NEXT: scf.yield %[[ARG0]], %[[ARG1]] : i1, i2 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]]#0, %[[V0]]#1 : i1, i2 +func.func @if_result(%arg0: tuple, i1, tuple>, %arg1: i1) -> tuple, i1, tuple> { + %0 = scf.if %arg1 -> (tuple, i1, tuple>) { + scf.yield %arg0 : tuple, i1, tuple> + } else { + scf.yield %arg0 : tuple, i1, tuple> + } + return %0 : tuple, i1, tuple> +} + +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.if and +// scf.yield and unconverted ops inside have propoer materilizations. + +// CHECK-LABEL: func.func @if_tuple_ops( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 { +// CHECK-NEXT: %[[V0:.*]] = "test.make_tuple"() : () -> tuple<> +// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"(%[[V0]], %[[ARG0]]) : (tuple<>, i1) -> tuple, i1> +// CHECK-NEXT: %[[V2:.*]] = scf.if %[[ARG1]] -> (i1) { +// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V1]]) : (tuple, i1>) -> tuple, i1> +// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V5]] : i1 +// CHECK-NEXT: } else { +// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple, i1> +// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V8]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V2]] : i1 +func.func @if_tuple_ops(%arg0: tuple, i1>, %arg1: i1) -> tuple, i1> { + %0 = scf.if %arg1 -> (tuple, i1>) { + %1 = "test.op"(%arg0) : (tuple, i1>) -> tuple, i1> + scf.yield %1 : tuple, i1> + } else { + %1 = "test.source"() : () -> tuple, i1> + scf.yield %1 : tuple, i1> + } + return %0 : tuple, i1> +} +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.while, +// scf.condition, and scf.yield. + +// CHECK-LABEL: func.func @while_operands_results( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i2, +// CHECK-SAME: %[[ARG2:.*]]: i1) -> (i1, i2) { +// %[[V0:.*]]:2 = scf.while (%[[ARG3:.*]] = %[[ARG0]], %[[ARG4:.*]] = %[[ARG1]]) : (i1, i2) -> (i1, i2) { +// scf.condition(%arg2) %[[ARG3]], %[[ARG4]] : i1, i2 +// } do { +// ^bb0(%[[ARG5:.*]]: i1, %[[ARG6:.*]]: i2): +// scf.yield %[[ARG5]], %[[ARG4]] : i1, i2 +// } +// return %[[V0]]#0, %[[V0]]#1 : i1, i2 +func.func @while_operands_results(%arg0: tuple, i1, tuple>, %arg1: i1) -> tuple, i1, tuple> { + %0 = scf.while (%arg2 = %arg0) : (tuple, i1, tuple>) -> tuple, i1, tuple> { + scf.condition(%arg1) %arg2 : tuple, i1, tuple> + } do { + ^bb0(%arg2: tuple, i1, tuple>): + scf.yield %arg2 : tuple, i1, tuple> + } + return %0 : tuple, i1, tuple> +} + +// ----- + +// Test case: Nested 1:N type conversion is carried through scf.while, +// scf.condition, and and unconverted ops inside have propoer materilizations. + +// CHECK-LABEL: func.func @while_tuple_ops( +// CHECK-SAME: %[[ARG0:.*]]: i1, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> i1 { +// CHECK-NEXT: %[[V0:.*]] = scf.while (%[[ARG2:.*]] = %[[ARG0]]) : (i1) -> i1 { +// CHECK-NEXT: %[[V1:.*]] = "test.make_tuple"() : () -> tuple<> +// CHECK-NEXT: %[[V2:.*]] = "test.make_tuple"(%[[V1]], %[[ARG2]]) : (tuple<>, i1) -> tuple, i1> +// CHECK-NEXT: %[[V3:.*]] = "test.op"(%[[V2]]) : (tuple, i1>) -> tuple, i1> +// CHECK-NEXT: %[[V4:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V5:.*]] = "test.get_tuple_element"(%[[V3]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.condition(%[[ARG1]]) %[[V5]] : i1 +// CHECK-NEXT: } do { +// CHECK-NEXT: ^bb0(%[[ARG3:.*]]: i1): +// CHECK-NEXT: %[[V6:.*]] = "test.source"() : () -> tuple, i1> +// CHECK-NEXT: %[[V7:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 0 : i32} : (tuple, i1>) -> tuple<> +// CHECK-NEXT: %[[V8:.*]] = "test.get_tuple_element"(%[[V6]]) {index = 1 : i32} : (tuple, i1>) -> i1 +// CHECK-NEXT: scf.yield %[[V8]] : i1 +// CHECK-NEXT: } +// CHECK-NEXT: return %[[V0]] : i1 +func.func @while_tuple_ops(%arg0: tuple, i1>, %arg1: i1) -> tuple, i1> { + %0 = scf.while (%arg2 = %arg0) : (tuple, i1>) -> tuple, i1> { + %1 = "test.op"(%arg2) : (tuple, i1>) -> tuple, i1> + scf.condition(%arg1) %1 : tuple, i1> + } do { + ^bb0(%arg2: tuple, i1>): + %1 = "test.source"() : () -> tuple, i1> + scf.yield %1 : tuple, i1> + } + return %0 : tuple, i1> +} diff --git a/mlir/test/Transforms/decompose-call-graph-types.mlir b/mlir/test/Transforms/decompose-call-graph-types.mlir --- a/mlir/test/Transforms/decompose-call-graph-types.mlir +++ b/mlir/test/Transforms/decompose-call-graph-types.mlir @@ -1,5 +1,9 @@ // RUN: mlir-opt %s -split-input-file -test-decompose-call-graph-types | FileCheck %s +// RUN: mlir-opt %s -split-input-file \ +// RUN: -test-one-to-n-type-conversion="convert-func-ops" \ +// RUN: | FileCheck %s --check-prefix=CHECK-12N + // Test case: Most basic case of a 1:N decomposition, an identity function. // CHECK-LABEL: func @identity( @@ -9,6 +13,10 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[ARG_MATERIALIZED]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @identity( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i32 func.func @identity(%arg0: tuple) -> tuple { return %arg0 : tuple } @@ -20,6 +28,9 @@ // CHECK-LABEL: func @identity_1_to_1_no_materializations( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 +// CHECK-12N-LABEL: func @identity_1_to_1_no_materializations( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { +// CHECK-12N: return %[[ARG0]] : i1 func.func @identity_1_to_1_no_materializations(%arg0: tuple) -> tuple { return %arg0 : tuple } @@ -31,6 +42,9 @@ // CHECK-LABEL: func @recursive_decomposition( // CHECK-SAME: %[[ARG0:.*]]: i1) -> i1 { // CHECK: return %[[ARG0]] : i1 +// CHECK-12N-LABEL: func @recursive_decomposition( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1) -> i1 { +// CHECK-12N: return %[[ARG0]] : i1 func.func @recursive_decomposition(%arg0: tuple>>) -> tuple>> { return %arg0 : tuple>> } @@ -54,6 +68,10 @@ // CHECK: %[[V9:.*]] = "test.get_tuple_element"(%[[V8]]) {index = 0 : i32} : (tuple>) -> tuple // CHECK: %[[V10:.*]] = "test.get_tuple_element"(%[[V9]]) {index = 0 : i32} : (tuple) -> i2 // CHECK: return %[[V7]], %[[V10]] : i1, i2 +// CHECK-12N-LABEL: func @mixed_recursive_decomposition( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i2) -> (i1, i2) { +// CHECK-12N: return %[[ARG0]], %[[ARG1]] : i1, i2 func.func @mixed_recursive_decomposition(%arg0: tuple, tuple, tuple>>) -> tuple, tuple, tuple>> { return %arg0 : tuple, tuple, tuple>> } @@ -63,6 +81,7 @@ // Test case: Check decomposition of calls. // CHECK-LABEL: func private @callee(i1, i32) -> (i1, i32) +// CHECK-12N-LABEL: func private @callee(i1, i32) -> (i1, i32) func.func private @callee(tuple) -> tuple // CHECK-LABEL: func @caller( @@ -76,6 +95,11 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[CALL_RESULT_RECOMPOSED]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @caller( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: %[[V0:.*]]:2 = call @callee(%[[ARG0]], %[[ARG1]]) : (i1, i32) -> (i1, i32) +// CHECK-12N: return %[[V0]]#0, %[[V0]]#1 : i1, i32 func.func @caller(%arg0: tuple) -> tuple { %0 = call @callee(%arg0) : (tuple) -> tuple return %0 : tuple @@ -86,7 +110,12 @@ // Test case: Type that decomposes to nothing (that is, a 1:0 decomposition). // CHECK-LABEL: func private @callee() +// CHECK-12N-LABEL: func private @callee() func.func private @callee(tuple<>) -> tuple<> + +// CHECK-12N-LABEL: func @caller() { +// CHECK-12N: call @callee() : () -> () +// CHECK-12N: return // CHECK-LABEL: func @caller() { // CHECK: call @callee() : () -> () // CHECK: return @@ -105,6 +134,11 @@ // CHECK: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 // CHECK: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 // CHECK: return %[[RET0]], %[[RET1]] : i1, i32 +// CHECK-12N-LABEL: func @unconverted_op_result() -> (i1, i32) { +// CHECK-12N: %[[UNCONVERTED_VALUE:.*]] = "test.source"() : () -> tuple +// CHECK-12N: %[[RET0:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 0 : i32} : (tuple) -> i1 +// CHECK-12N: %[[RET1:.*]] = "test.get_tuple_element"(%[[UNCONVERTED_VALUE]]) {index = 1 : i32} : (tuple) -> i32 +// CHECK-12N: return %[[RET0]], %[[RET1]] : i1, i32 func.func @unconverted_op_result() -> tuple { %0 = "test.source"() : () -> (tuple) return %0 : tuple @@ -125,6 +159,16 @@ // CHECK: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple // CHECK: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 // CHECK: return %[[V3]], %[[V5]] : i1, i32 +// CHECK-12N-LABEL: func @nested_unconverted_op_result( +// CHECK-12N-SAME: %[[ARG0:.*]]: i1, +// CHECK-12N-SAME: %[[ARG1:.*]]: i32) -> (i1, i32) { +// CHECK-12N: %[[V0:.*]] = "test.make_tuple"(%[[ARG1]]) : (i32) -> tuple +// CHECK-12N: %[[V1:.*]] = "test.make_tuple"(%[[ARG0]], %[[V0]]) : (i1, tuple) -> tuple> +// CHECK-12N: %[[V2:.*]] = "test.op"(%[[V1]]) : (tuple>) -> tuple> +// CHECK-12N: %[[V3:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 0 : i32} : (tuple>) -> i1 +// CHECK-12N: %[[V4:.*]] = "test.get_tuple_element"(%[[V2]]) {index = 1 : i32} : (tuple>) -> tuple +// CHECK-12N: %[[V5:.*]] = "test.get_tuple_element"(%[[V4]]) {index = 0 : i32} : (tuple) -> i32 +// CHECK-12N: return %[[V3]], %[[V5]] : i1, i32 func.func @nested_unconverted_op_result(%arg: tuple>) -> tuple> { %0 = "test.op"(%arg) : (tuple>) -> (tuple>) return %0 : tuple> @@ -136,6 +180,7 @@ // This makes sure to test the cases if 1:0, 1:1, and 1:N decompositions. // CHECK-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) +// CHECK-12N-LABEL: func private @callee(i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) func.func private @callee(tuple<>, i1, tuple, i3, tuple, i6) -> (tuple<>, i1, tuple, i3, tuple, i6) // CHECK-LABEL: func @caller( @@ -153,6 +198,15 @@ // CHECK: %[[RET_TUPLE_0:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 0 : i32} : (tuple) -> i4 // CHECK: %[[RET_TUPLE_1:.*]] = "test.get_tuple_element"(%[[RET_TUPLE]]) {index = 1 : i32} : (tuple) -> i5 // CHECK: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[RET_TUPLE_0]], %[[RET_TUPLE_1]], %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 +// CHECK-12N-LABEL: func @caller( +// CHECK-12N-SAME: %[[I1:.*]]: i1, +// CHECK-12N-SAME: %[[I2:.*]]: i2, +// CHECK-12N-SAME: %[[I3:.*]]: i3, +// CHECK-12N-SAME: %[[I4:.*]]: i4, +// CHECK-12N-SAME: %[[I5:.*]]: i5, +// CHECK-12N-SAME: %[[I6:.*]]: i6) -> (i1, i2, i3, i4, i5, i6) { +// CHECK-12N: %[[CALL:.*]]:6 = call @callee(%[[I1]], %[[I2]], %[[I3]], %[[I4]], %[[I5]], %[[I6]]) : (i1, i2, i3, i4, i5, i6) -> (i1, i2, i3, i4, i5, i6) +// CHECK-12N: return %[[CALL]]#0, %[[CALL]]#1, %[[CALL]]#2, %[[CALL]]#3, %[[CALL]]#4, %[[CALL]]#5 : i1, i2, i3, i4, i5, i6 func.func @caller(%arg0: tuple<>, %arg1: i1, %arg2: tuple, %arg3: i3, %arg4: tuple, %arg5: i6) -> (tuple<>, i1, tuple, i3, tuple, i6) { %0, %1, %2, %3, %4, %5 = call @callee(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (tuple<>, i1, tuple, i3, tuple, i6) -> (tuple<>, i1, tuple, i3, tuple, i6) return %0, %1, %2, %3, %4, %5 : tuple<>, i1, tuple, i3, tuple, i6 diff --git a/mlir/test/lib/Conversion/CMakeLists.txt b/mlir/test/lib/Conversion/CMakeLists.txt --- a/mlir/test/lib/Conversion/CMakeLists.txt +++ b/mlir/test/lib/Conversion/CMakeLists.txt @@ -1 +1,2 @@ add_subdirectory(FuncToLLVM) +add_subdirectory(OneToNTypeConversion) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/CMakeLists.txt @@ -0,0 +1,21 @@ +add_mlir_library(MLIRTestOneToNTypeConversionPass + OneToNTypeConversion.cpp + OneToNTypeConversionFunc.cpp + OneToNTypeConversionSCF.cpp + TestOneToNTypeConversionPass.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRIR + MLIRSCFDialect + MLIRTestDialect + MLIRTransformUtils + ) + +target_include_directories(MLIRTestOneToNTypeConversionPass + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test + ) diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.h b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.h @@ -0,0 +1,212 @@ +//===-- OneToNTypeConversion.h - Utils for 1:N type conversion --*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides utils for implementing (poor-man's) dialect conversion +// passes with 1:N type conversions. +// +// The main function first applies a set of RewritePatterns, which produce +// unrealized casts to convert the operands and results from and to the source +// types, and then replaces all newly added unrealized casts by user-provided +// materializations. For this to work, the main function requires a special +// TypeConverter and special RewritePatterns, respectively deriving from the +// provided classes, which extend their respective base classes for 1:N type +// converions. +// +// Note that this is much more simple-minded than the "real" dialect conversion, +// which checks for legality before applying patterns and does probably many +// other additional things. Ideally, some of the extensions here could be +// integrated there. +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSION_H +#define TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSION_H + +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallVector.h" + +namespace mlir { + +/// Extends `TypeConverter` with 1:N target materializations. Such +/// materializations have to provide the "reverse" of 1:N type conversions, +/// i.e., they need to materialize N values with target types into one value +/// with a source type (which isn't possible in the base class currently). +class OneToNTypeConverter : public TypeConverter { +public: + using OneToNMaterializationCallbackFn = + std::function>(OpBuilder &, TypeRange, + Value, Location)>; + + /// Applies one of the user-provided 1:N target materializations (in LIFO + /// order). + std::optional> + materializeTargetConversion(OpBuilder &builder, Location loc, + TypeRange resultTypes, Value input) const; + + /// Adds a 1:N target materialization to the converter. Such materializations + /// build IR that converts N values with target types into 1 value of the + /// source type. + void addTargetMaterialization(OneToNMaterializationCallbackFn &&callback) { + oneToNTargetMaterializations.emplace_back(std::move(callback)); + } + +private: + SmallVector oneToNTargetMaterializations; +}; + +/// This class extends SignatureConversion with several for writing 1:N +/// conversion patterns. SignatureConversion provides a 1:N mapping of types; +/// the extensions provides additional accessor into the mapping as well as +/// access to the original types. +class OneToNSignatureConversion : public TypeConverter::SignatureConversion { +public: + OneToNSignatureConversion(TypeRange originalTypes) + : TypeConverter::SignatureConversion(originalTypes.size()), + originalTypes(originalTypes) {} + + using TypeConverter::SignatureConversion::getConvertedTypes; + + /// Returns the list of types that corresponds to the original type at the + /// given index. + ArrayRef getConvertedTypes(unsigned originalTypeNo) const; + + /// Returns the list of original types. + ArrayRef getOriginalTypes() const { return originalTypes; } + + /// Returns the slice of converted values that corresponds the original value + /// at the given index. + ArrayRef getConvertedValues(ArrayRef convertedValues, + unsigned originalValueNo) const; + + /// Returns true iff at least one type conversion maps an input type to a type + /// that is different from itself. + bool hasNonIdentityConversion() const; + +private: + llvm::SmallVector originalTypes; +}; + +/// Extends the basic RewritePattern with a type converter member and some +/// accessors to it. This is useful for patterns that are not ConversionPatterns +/// but still require access to a type converter. +class RewritePatternWithConverter : public mlir::RewritePattern { +public: + /// Construct a conversion pattern with the given converter, and forward the + /// remaining arguments to RewritePattern. + template + RewritePatternWithConverter(TypeConverter &typeConverter, Args &&...args) + : RewritePattern(std::forward(args)...), + typeConverter(&typeConverter) {} + + /// Return the type converter held by this pattern, or nullptr if the pattern + /// does not require type conversion. + TypeConverter *getTypeConverter() const { return typeConverter; } + + template + std::enable_if_t::value, + ConverterTy *> + getTypeConverter() const { + return static_cast(typeConverter); + } + +protected: + /// A type converter for use by this pattern. + TypeConverter *const typeConverter; +}; + +/// Base class for patterns with 1:N type conversions. Derived classes have to +/// overwrite the `matchAndRewrite`overlaod that provides additional information +/// for 1:N type conversions. +class OneToNConversionPattern : public RewritePatternWithConverter { +public: + using RewritePatternWithConverter::RewritePatternWithConverter; + + /// This function has to be implemented by base classes and is called from the + /// usual overloads. Like in normal DialectConversion, the function is + /// provided with the converted operands (which thus have target types). Since + /// 1:N conversion are supported, there is usually no 1:1 relationship between + /// the original and the converted operands. Instead, the provided + /// `operandConversion` can be used to access the converted operands that + /// correspond to a particular original operand. Similarly, `resultConversion` + /// is provided to help with assembling the result values (which may have 1:N + /// correspondences as well). The function is expted to return the converted + /// result values if the conversion succeeds and failuare otherwise (in which + /// case any modifications of the IR have to be rolled back first). The + /// correspondance of original and converted result values needs to correspond + /// to `resultConversion`. For both the converted operands and results, the + /// calling overload inserts appropriate unrealized casts that produce and + /// consume them, and replaces the uses of the results with the results of the + /// casts. If the returned result values are the same as those of the original + /// op, an in-place update is assumed and the result values are left as is. + virtual FailureOr> + matchAndRewrite(Operation *op, PatternRewriter &rewriter, + const OneToNSignatureConversion &operandConversion, + const OneToNSignatureConversion &resultConversion, + const SmallVector &convertedOperands) const = 0; + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final; +}; + +/// This class is a wrapper around OneToNConversionPattern for matching against +/// instances of a particular op class. +template +class OneToNOpConversionPattern : public OneToNConversionPattern { +public: + OneToNOpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1, + ArrayRef generatedNames = {}) + : OneToNConversionPattern(typeConverter, SourceOp::getOperationName(), + benefit, context, generatedNames) {} + + using OneToNConversionPattern::matchAndRewrite; + + /// Overload that derived classes have to override for their op type. + virtual FailureOr> + matchAndRewrite(SourceOp op, PatternRewriter &rewriter, + const OneToNSignatureConversion &operandConversion, + const OneToNSignatureConversion &resultConversion, + const SmallVector &convertedOperands) const = 0; + + FailureOr> + matchAndRewrite(Operation *op, PatternRewriter &rewriter, + const OneToNSignatureConversion &operandConversion, + const OneToNSignatureConversion &resultConversion, + const SmallVector &convertedOperands) const final { + return matchAndRewrite(cast(op), rewriter, operandConversion, + resultConversion, convertedOperands); + } +}; + +/// Applies the given argument conversion to the given block. This consists of +/// replacing each original argument with N arguments as specified in the +/// argument conversion and inserting unrealized casts from the converted values +/// to the original types, which are then used in lieu of the original ones. +/// (Eventually, applyOneToNConversion replaces these casts with a +/// user-provided argument materialization if necessary.) This is similar to +/// ArgConverter::applySignatureConversion but (1) handles 1:N type conversion +/// properly and probably (2) doesn't handle many other edge cases. +Block *applySignatureConversion(Block *block, + OneToNSignatureConversion &argumentConversion, + RewriterBase &rewriter); + +/// Main function that 1:N conversion passes should call. The patterns are +/// expected to insert unrealized casts to maintain the types of operands and +/// results, which is done automatically if the derive from +/// OneToNConversionPattern. The function replaces those that do not fold away +/// until the end of pattern application with user-provided materializations +/// from the type converter, so those have to be provided if conversions from +/// source to target types are expected to remain. +LogicalResult applyOneToNConversion(Operation *op, + OneToNTypeConverter &typeConverter, + const FrozenRewritePatternSet &patterns); + +} // namespace mlir + +#endif // TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSION_H diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversion.cpp @@ -0,0 +1,321 @@ +//===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversion.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/SmallSet.h" + +using namespace llvm; +using namespace mlir; + +std::optional> +OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder, + Location loc, + TypeRange resultTypes, + Value input) const { + for (const OneToNMaterializationCallbackFn &fn : + llvm::reverse(oneToNTargetMaterializations)) { + if (std::optional> result = + fn(builder, resultTypes, input, loc)) + return *result; + } + return {}; +} + +ArrayRef +OneToNSignatureConversion::getConvertedTypes(unsigned originalTypeNo) const { + ArrayRef convertedTypes = getConvertedTypes(); + if (auto mapping = getInputMapping(originalTypeNo)) + return convertedTypes.slice(mapping->inputNo, mapping->size); + return {}; +} + +ArrayRef +OneToNSignatureConversion::getConvertedValues(ArrayRef convertedValues, + unsigned originalValueNo) const { + if (auto mapping = getInputMapping(originalValueNo)) + return convertedValues.slice(mapping->inputNo, mapping->size); + return {}; +} + +static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) { + return convertedTypes.size() == 1 && convertedTypes[0] == originalType; +} + +bool OneToNSignatureConversion::hasNonIdentityConversion() const { + // XXX: I think that the original types and the converted types are the same + // iff there was no non-identity type conversion. If that is true, the + // patterns could actually test whether there is anything useful to do + // without having access to the signature conversion. + for (size_t i = 0; i < originalTypes.size(); i++) { + ArrayRef types = getConvertedTypes(i); + if (!isIdentityConversion(originalTypes[i], types)) { + assert(ArrayRef(originalTypes) != getConvertedTypes()); + return true; + } + } + assert(ArrayRef(originalTypes) == getConvertedTypes()); + return false; +} + +/// Builds an UnrealizedConversionCastOp from the given inputs to the given +/// result types. Returns the result values of the cast. +static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes, + ValueRange inputs) { + Location loc = builder.getUnknownLoc(); + if (!inputs.empty()) + loc = inputs.front().getLoc(); + auto castOp = + builder.create(loc, resultTypes, inputs); + return castOp->getResults(); +} + +/// Builds one UnrealizedConversionCastOp for each of the given original values +/// using the respective target types given in the provided conversion mapping +/// and returns the results of these casts. If the conversion mapping of a value +/// maps a type to itself (i.e., is an identity conversion), then no cast is +/// inserted and the original value is returned instead. +static SmallVector +buildUnrealizedForwardCasts(ValueRange originalValues, + OneToNSignatureConversion &conversion, + RewriterBase &rewriter) { + + // Convert each operand one by one. + SmallVector convertedValues; + convertedValues.reserve(conversion.getConvertedTypes().size()); + for (auto [idx, originalValue] : llvm::enumerate(originalValues)) { + ArrayRef convertedTypes = conversion.getConvertedTypes(idx); + + // Identity conversion: keep operand as is. + if (isIdentityConversion(originalValue.getType(), convertedTypes)) { + convertedValues.push_back(originalValue); + continue; + } + + // Non-identity conversion: materialize target types. + ValueRange castResult = + buildUnrealizedCast(rewriter, convertedTypes, originalValue); + convertedValues.append(castResult.begin(), castResult.end()); + } + + return convertedValues; +} + +/// Builds one UnrealizedConversionCastOp for each sequence of the given +/// original values to one value of the type they originated from, i.e., a +/// "reverse" conversion from N converted values back to one value of the +/// original type, using the given (forward) type conversion. If a given value +/// was mapped to a value of the same type (i.e., the conversion in the mapping +/// is an identity conversion), then the "converted" value is returned without +/// cast. +static SmallVector +buildUnrealizedBackwardsCasts(ValueRange convertedValues, + const OneToNSignatureConversion &typeConversion, + RewriterBase &rewriter) { + assert(typeConversion.getConvertedTypes() == convertedValues.getTypes()); + + // Create unrealized cast op for each converted result of the op. + SmallVector recastValues; + ArrayRef originalTypes = typeConversion.getOriginalTypes(); + recastValues.reserve(originalTypes.size()); + auto convertedValueIt = convertedValues.begin(); + for (auto [idx, originalType] : llvm::enumerate(originalTypes)) { + ArrayRef convertedTypes = typeConversion.getConvertedTypes(idx); + size_t numConvertedValues = convertedTypes.size(); + if (isIdentityConversion(originalType, convertedTypes)) { + // Identity conversion: take result as is. + recastValues.push_back(*convertedValueIt); + } else { + // Non-identity conversion: cast back to source type. + ValueRange recastValue = buildUnrealizedCast( + rewriter, originalType, + ValueRange{convertedValueIt, convertedValueIt + numConvertedValues}); + assert(recastValue.size() == 1); + recastValues.push_back(recastValue.front()); + } + convertedValueIt += numConvertedValues; + } + + return recastValues; +} + +LogicalResult +OneToNConversionPattern::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { + // Construct conversion mapping for results. + Operation::result_type_range originalResultTypes = op->getResultTypes(); + OneToNSignatureConversion resultConversion(originalResultTypes); + if (failed(typeConverter->convertSignatureArgs(originalResultTypes, + resultConversion))) + return failure(); + + // Construct conversion mapping for operands. + Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); + OneToNSignatureConversion operandConversion(originalOperandTypes); + if (failed(typeConverter->convertSignatureArgs(originalOperandTypes, + operandConversion))) + return failure(); + + // Cast operands to target types. + SmallVector convertedOperands = buildUnrealizedForwardCasts( + op->getOperands(), operandConversion, rewriter); + + // Apply actual pattern. + auto result = matchAndRewrite(op, rewriter, operandConversion, + resultConversion, convertedOperands); + + if (failed(result)) + return failure(); + SmallVector &replacementValues = result.value(); + + // If replacementValues consist of the results of the original op, assume + // in-place update. + // TODO: This isn't particularly elegant. Not sure how else to handle that + // case without tracking modifications through the rewriter, which + // would require a custom pattern application driver. + if (ValueRange{op->getResults()} == replacementValues) + return success(); + + // Cast op results back to the original types and use those. + SmallVector castResults = buildUnrealizedBackwardsCasts( + replacementValues, resultConversion, rewriter); + rewriter.replaceOp(op, castResults); + + return success(); +} + +namespace mlir { + +Block *applySignatureConversion(Block *block, + OneToNSignatureConversion &argumentConversion, + RewriterBase &rewriter) { + // Split the block at the beginning to get a new block to use for the + // updated signature. + Block *newBlock = rewriter.splitBlock(block, block->begin()); + rewriter.replaceAllUsesWith(block, newBlock); + + // Add block arguments to new block. + for (size_t i = 0; i < block->getNumArguments(); i++) { + BlockArgument arg = block->getArgument(i); + ArrayRef convertedTypes = argumentConversion.getConvertedTypes(i); + if (isIdentityConversion(arg.getType(), convertedTypes)) { + // Identity conversion: take argument as is. + BlockArgument newArg = newBlock->addArgument(arg.getType(), arg.getLoc()); + rewriter.replaceAllUsesWith(arg, newArg); + } else { + // Non-identity conversion: cast the converted arguments to the original + // type. + SmallVector locs(convertedTypes.size(), arg.getLoc()); + auto newArgsRange = newBlock->addArguments(convertedTypes, locs); + SmallVector newArgs(newArgsRange.begin(), newArgsRange.end()); + PatternRewriter::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(newBlock); + ValueRange castArgument = + buildUnrealizedCast(rewriter, arg.getType(), newArgs); + assert(castArgument.size() == 1); + rewriter.replaceAllUsesWith(arg, castArgument.front()); + } + } + + // Delete old (now empty) block. + rewriter.eraseBlock(block); + + return newBlock; +} + +// This function applies the provided patterns using +// applyPatternsAndFoldGreedily and then replaces all newly inserted +// UnrealizedConversionCastOps that haven't folded away. ("Backward" casts from +// target to source types inserted by a OneToNConversionPattern normally fold +// away with the "forward" casts from source to target types inserted by the +// next pattern.) To understand which casts are "newly inserted", we save a list +// of all casts existing before the patterns are applied and assume that all +// casts not in that list after the application are new. (This is probably not +// correct: It might be possible that an existing cast is folded away and a new +// cast happens to be allocated with exactly the same pointer. Dealing with that +// possiblity is an open TODO.) Also, we do not track which inserted casts are +// needed for source, target, or argument materialization, so we do some +// educated guessing to recover that information. Fixing both issues would +// require to use a PatternRewriter that overloads various `notify*` functions +// and similar and tracks all changes there. However, that would require a +// dedicated pattern application driver, which is currently also left as an open +// TODO.) +LogicalResult applyOneToNConversion(Operation *op, + OneToNTypeConverter &typeConverter, + const FrozenRewritePatternSet &patterns) { + // Remember existing unrealized casts. + SmallSet existingCasts; + op->walk( + [&](UnrealizedConversionCastOp castOp) { existingCasts.insert(castOp); }); + + // Apply provided conversion patterns. + if (failed(applyPatternsAndFoldGreedily(op, patterns))) + return failure(); + + // Find all newly inserted unrealized casts (that haven't folded away). + SmallVector worklist; + op->walk([&](UnrealizedConversionCastOp castOp) { + if (!existingCasts.contains(castOp)) + worklist.push_back(castOp); + }); + + // Replace new casts with user materializations. + IRRewriter rewriter(op->getContext()); + for (UnrealizedConversionCastOp castOp : worklist) { + // Create user materialization. + TypeRange resultTypes = castOp->getResultTypes(); + rewriter.setInsertionPoint(castOp); + SmallVector materializedResults; + + // Determine whether operands or results are already legal to know which + // kind of materilization this is. + ValueRange operands = castOp.getOperands(); + bool areOperandTypesLegal = llvm::all_of( + operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); }); + bool areResultsTypesLegal = llvm::all_of( + resultTypes, [&](Type t) { return typeConverter.isLegal(t); }); + + if (!areOperandTypesLegal && areResultsTypesLegal && operands.size() == 1) { + // This is a target materilization. + std::optional> maybeResults = + typeConverter.materializeTargetConversion( + rewriter, castOp->getLoc(), resultTypes, operands.front()); + if (!maybeResults) + return failure(); + materializedResults = maybeResults.value(); + } else if (areOperandTypesLegal && !areResultsTypesLegal && + resultTypes.size() == 1) { + // This is a source or an argument materialization. + std::optional maybeResult; + if (llvm::all_of(operands, [&](Value v) { return v.getDefiningOp(); })) { + // This is an source materialization. + maybeResult = typeConverter.materializeArgumentConversion( + rewriter, castOp->getLoc(), resultTypes.front(), + castOp.getOperands()); + } else { + // This is an argument materialization. + maybeResult = typeConverter.materializeSourceConversion( + rewriter, castOp->getLoc(), resultTypes.front(), + castOp.getOperands()); + } + if (!maybeResult.has_value() || !maybeResult.value()) + return failure(); + materializedResults = {maybeResult.value()}; + } else { + assert(false && "unexpected cast inserted"); + } + + // Replace cast with materialization. + rewriter.replaceOp(castOp, materializedResults); + } + + return success(); +} + +} // namespace mlir diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.h b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.h @@ -0,0 +1,26 @@ +//===- OneToNTypeConversionFunc.h - 1:N type conversion for Func-*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONFUNC_H +#define TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONFUNC_H + +namespace mlir { +class TypeConverter; +class RewritePatternSet; +} // namespace mlir + +namespace mlir { + +// Populates the provided pattern set with patterns that do 1:N type conversions +// on func ops. This is intended to be used with applyOneToNConversion. +void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONFUNC_H diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionFunc.cpp @@ -0,0 +1,127 @@ +//===-- OneToNTypeConversionFunc.cpp - Func 1:N type conversion -*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The patterns in this file are heavily inspired (and copied from) +// convertFuncOpTypes in lib/Transforms/Utils/DialectConversion.cpp and the +// patterns in lib/Dialect/Func/Transforms/FuncConversions.cpp but work for 1:N +// type conversions. +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversionFunc.h" + +#include "OneToNTypeConversion.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +using namespace mlir; +using namespace mlir::func; + +class ConvertTypesInFuncCallOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + FailureOr> + matchAndRewrite(CallOp op, PatternRewriter &rewriter, + const OneToNSignatureConversion &operandConversion, + const OneToNSignatureConversion &resultConversion, + const SmallVector &convertedOperands) const override { + Location loc = op->getLoc(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!operandConversion.hasNonIdentityConversion() && + !resultConversion.hasNonIdentityConversion()) + return failure(); + + // Create new CallOp. + auto newOp = rewriter.create( + loc, resultConversion.getConvertedTypes(), convertedOperands); + newOp->setAttrs(op->getAttrs()); + + return SmallVector(newOp->getResults()); + } +}; + +class ConvertTypesInFuncFuncOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + FailureOr> matchAndRewrite( + FuncOp op, PatternRewriter &rewriter, + const OneToNSignatureConversion & /*operandConversion*/, + const OneToNSignatureConversion & /*resultConversion*/, + const SmallVector & /*convertedOperands*/) const override { + // Construct conversion mapping for arguments. + OneToNSignatureConversion argumentConversion(op.getArgumentTypes()); + if (failed(typeConverter->convertSignatureArgs(op.getArgumentTypes(), + argumentConversion))) + return failure(); + + // Construct conversion mapping for arguments. + OneToNSignatureConversion funcResultConversion(op.getResultTypes()); + if (failed(typeConverter->convertSignatureArgs(op.getResultTypes(), + funcResultConversion))) + return failure(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!argumentConversion.hasNonIdentityConversion() && + !funcResultConversion.hasNonIdentityConversion()) + return failure(); + + // Update the function signature in-place. + auto newType = FunctionType::get(rewriter.getContext(), + argumentConversion.getConvertedTypes(), + funcResultConversion.getConvertedTypes()); + rewriter.updateRootInPlace(op, [&] { op.setType(newType); }); + + // Update block signatures. + if (!op.isExternal()) { + Region *region = &op.getBody(); + Block *block = ®ion->front(); + applySignatureConversion(block, argumentConversion, rewriter); + } + + return SmallVector(op->getResults()); + } +}; + +class ConvertTypesInFuncReturnOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + FailureOr> + matchAndRewrite(ReturnOp op, PatternRewriter &rewriter, + const OneToNSignatureConversion &operandConversion, + const OneToNSignatureConversion & /*resultConversion*/, + const SmallVector &convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandConversion.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return SmallVector(op->getResults()); + } +}; + +namespace mlir { + +void populateFuncTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertTypesInFuncCallOp, + ConvertTypesInFuncFuncOp, + ConvertTypesInFuncReturnOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +} // namespace mlir diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.h b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.h @@ -0,0 +1,26 @@ +//===-- OneToNTypeConversionSCF.h - 1:N type conversion for scf -*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONSCF_H +#define TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONSCF_H + +namespace mlir { +class TypeConverter; +class RewritePatternSet; +} // namespace mlir + +namespace mlir { + +// Populates the provided pattern set with patterns that do 1:N type conversions +// on (some) SCF ops. This is intended to be used with applyOneToNConversion. +void populateSCFTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns); + +} // namespace mlir + +#endif // TEST_LIB_CONVERSION_ONETONTYPECONVERSION_ONETONTYPECONVERSIONSCF_H diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/OneToNTypeConversionSCF.cpp @@ -0,0 +1,155 @@ +//===-- OneToNTypeConversionSCF.cpp - SCF 1:N type conversion ---*- C++ -*-===// +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// The patterns in this file are heavily inspired (and copied from) +// lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp but work for 1:N +// type conversions. +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversionSCF.h" + +#include "OneToNTypeConversion.h" +#include "mlir/Dialect/SCF/IR/SCF.h" + +using namespace mlir; +using namespace mlir::scf; + +class ConvertTypesInSCFIfOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + FailureOr> matchAndRewrite( + IfOp op, PatternRewriter &rewriter, + const OneToNSignatureConversion & /*operandConversion*/, + const OneToNSignatureConversion &resultConversion, + const SmallVector & /*convertedOperands*/) const override { + Location loc = op->getLoc(); + + // Nothing to do if there is no non-identity conversion. + if (!resultConversion.hasNonIdentityConversion()) + return failure(); + + // Create new IfOp. + ArrayRef convertedResultTypes = resultConversion.getConvertedTypes(); + auto newOp = rewriter.create(loc, convertedResultTypes, + op.getCondition(), true); + newOp->setAttrs(op->getAttrs()); + + // We do not need the empty blocks created by rewriter. + rewriter.eraseBlock(newOp.elseBlock()); + rewriter.eraseBlock(newOp.thenBlock()); + + // Inlines block from the original operation. + rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(), + newOp.getThenRegion().end()); + rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(), + newOp.getElseRegion().end()); + + return SmallVector(newOp->getResults()); + } +}; + +class ConvertTypesInSCFWhileOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + FailureOr> + matchAndRewrite(WhileOp op, PatternRewriter &rewriter, + const OneToNSignatureConversion &operandConversion, + const OneToNSignatureConversion &resultConversion, + const SmallVector &convertedOperands) const override { + Location loc = op->getLoc(); + + // Nothing to do if the op doesn't have any non-identity conversions for its + // operands or results. + if (!operandConversion.hasNonIdentityConversion() && + !resultConversion.hasNonIdentityConversion()) + return failure(); + + // Create new WhileOp. + ArrayRef convertedResultTypes = resultConversion.getConvertedTypes(); + + auto newOp = + rewriter.create(loc, convertedResultTypes, convertedOperands); + newOp->setAttrs(op->getAttrs()); + + // Update block signatures. + std::array blockConversions = { + operandConversion, resultConversion}; + for (unsigned int i : {0u, 1u}) { + Region *region = &op.getRegion(i); + Block *block = ®ion->front(); + + applySignatureConversion(block, blockConversions[i], rewriter); + + // Move updated region to new WhileOp. + Region &dstRegion = newOp.getRegion(i); + rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end()); + } + + return SmallVector(newOp->getResults()); + } +}; + +class ConvertTypesInSCFYieldOp : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + FailureOr> + matchAndRewrite(YieldOp op, PatternRewriter &rewriter, + const OneToNSignatureConversion &operandConversion, + const OneToNSignatureConversion & /*resultConversion*/, + const SmallVector &convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandConversion.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return SmallVector(op->getResults()); + } +}; + +class ConvertTypesInSCFConditionOp + : public OneToNOpConversionPattern { +public: + using OneToNOpConversionPattern::OneToNOpConversionPattern; + + FailureOr> + matchAndRewrite(ConditionOp op, PatternRewriter &rewriter, + const OneToNSignatureConversion &operandConversion, + const OneToNSignatureConversion & /*resultConversion*/, + const SmallVector &convertedOperands) const override { + // Nothing to do if there is no non-identity conversion. + if (!operandConversion.hasNonIdentityConversion()) + return failure(); + + // Convert operands. + rewriter.updateRootInPlace(op, [&] { op->setOperands(convertedOperands); }); + + return SmallVector(op->getResults()); + } +}; + +namespace mlir { + +void populateSCFTypeConversionPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertTypesInSCFConditionOp, + ConvertTypesInSCFIfOp, + ConvertTypesInSCFWhileOp, + ConvertTypesInSCFYieldOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +} // namespace mlir diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -0,0 +1,225 @@ +//===- TestOneToNTypeConversion.cpp - Test 1:N type conversion utils ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "OneToNTypeConversion.h" +#include "OneToNTypeConversionFunc.h" +#include "OneToNTypeConversionSCF.h" +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestOneToNTypeConversionPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneToNTypeConversionPass) + + TestOneToNTypeConversionPass() = default; + TestOneToNTypeConversionPass(const TestOneToNTypeConversionPass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { + return "test-one-to-n-type-conversion"; + } + + StringRef getDescription() const final { + return "Test pass for 1:N type conversion"; + } + + Option convertFuncOps{*this, "convert-func-ops", + llvm::cl::desc("Enable conversion on func ops"), + llvm::cl::init(false)}; + + Option convertSCFOps{*this, "convert-scf-ops", + llvm::cl::desc("Enable conversion on scf ops"), + llvm::cl::init(false)}; + + Option convertTupleOps{*this, "convert-tuple-ops", + llvm::cl::desc("Enable conversion on tuple ops"), + llvm::cl::init(false)}; + + void runOnOperation() override; +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestOneToNTypeConversionPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir + +class ConvertMakeTupleOp + : public OneToNOpConversionPattern<::test::MakeTupleOp> { +public: + using OneToNOpConversionPattern< + ::test::MakeTupleOp>::OneToNOpConversionPattern; + + FailureOr> + matchAndRewrite(::test::MakeTupleOp op, PatternRewriter &rewriter, + const OneToNSignatureConversion &operandConversion, + const OneToNSignatureConversion & /*resultConversion*/, + const SmallVector &convertedOperands) const override { + // Simply forward converted operands. + return convertedOperands; + } +}; + +class ConvertGetTupleElementOp + : public OneToNOpConversionPattern<::test::GetTupleElementOp> { +public: + using OneToNOpConversionPattern< + ::test::GetTupleElementOp>::OneToNOpConversionPattern; + + FailureOr> + matchAndRewrite(::test::GetTupleElementOp op, PatternRewriter &rewriter, + const OneToNSignatureConversion &operandConversion, + const OneToNSignatureConversion & /*resultConversion*/, + const SmallVector &convertedOperands) const override { + // Construct conversion mapping for field types. + auto stateType = op->getOperand(0).getType().cast(); + TypeRange originalElementTypes = stateType.getTypes(); + OneToNSignatureConversion elementConversion(originalElementTypes); + if (failed(typeConverter->convertSignatureArgs(originalElementTypes, + elementConversion))) + return failure(); + + // Compute converted operands corresponding to original input tuple. + ArrayRef convertedTuple = + operandConversion.getConvertedValues(convertedOperands, 0); + + // Got those converted operands that correspond to the index-th element of + // the original input tuple. + size_t index = op.getIndex(); + ValueRange extractedElement = + elementConversion.getConvertedValues(convertedTuple, index); + + return SmallVector(extractedElement); + } +}; + +static void populateDecomposeTuplesPatterns(TypeConverter &typeConverter, + RewritePatternSet &patterns) { + patterns.add< + // clang-format off + ConvertMakeTupleOp, + ConvertGetTupleElementOp + // clang-format on + >(typeConverter, patterns.getContext()); +} + +/// Creates a sequence of `test.get_tuple_element` ops for all elements of a +/// given tuple value. If some tuple elements are, in turn, tuples, the elements +/// of those are extracted recursively such that the returned values have the +/// same types as `resultTypes.getFlattenedTypes()`. +static std::optional> +buildGetTupleElementOps(OpBuilder &builder, TypeRange resultTypes, Value input, + Location loc) { + TupleType inputType = input.getType().dyn_cast(); + if (!inputType) + return {}; + + SmallVector values; + for (auto [idx, elementType] : llvm::enumerate(inputType.getTypes())) { + Value element = builder.create<::test::GetTupleElementOp>( + loc, elementType, input, builder.getI32IntegerAttr(idx)); + if (auto nestedTupleType = elementType.dyn_cast()) { + // Recurse if the current element is also a tuple. + SmallVector flatRecursiveTypes; + nestedTupleType.getFlattenedTypes(flatRecursiveTypes); + std::optional> resursiveValues = + buildGetTupleElementOps(builder, flatRecursiveTypes, element, loc); + if (!resursiveValues.has_value()) + return {}; + values.append(resursiveValues.value()); + } else { + values.push_back(element); + } + } + return values; +} + +/// Creates a `test.make_tuple` op out of the given inputs building a tuple of +/// type `resultType`. If that type is nested, each nested tuple is built +/// recursively with another `test.make_tuple` op. +static std::optional buildMakeTupleOp(OpBuilder &builder, + TupleType resultType, + ValueRange inputs, Location loc) { + // Build one value for each element at this nesting level. + SmallVector elements; + elements.reserve(resultType.getTypes().size()); + ValueRange::iterator inputIt = inputs.begin(); + for (Type elementType : resultType.getTypes()) { + if (auto nestedTupleType = elementType.dyn_cast()) { + // Determine how many input values are needed for the nested elements of + // the nested TupleType and advance inputIt by that number. + // TODO: We only need the *number* of nested types, not the types itself. + // Maybe it's worth adding a more efficient overload? + SmallVector nestedFlattenedTypes; + nestedTupleType.getFlattenedTypes(nestedFlattenedTypes); + size_t numNestedFlattenedTypes = nestedFlattenedTypes.size(); + ValueRange nestedFlattenedelements(inputIt, + inputIt + numNestedFlattenedTypes); + inputIt += numNestedFlattenedTypes; + + // Recurse on the values for the nested TupleType. + std::optional res = buildMakeTupleOp(builder, nestedTupleType, + nestedFlattenedelements, loc); + if (!res.has_value()) + return {}; + + // The tuple constructed by the conversion is the element value. + elements.push_back(res.value()); + } else { + // Base case: take one input as is. + elements.push_back(*inputIt++); + } + } + + // Assemble the tuple from the elements. + return builder.create<::test::MakeTupleOp>(loc, resultType, elements); +} + +void TestOneToNTypeConversionPass::runOnOperation() { + ModuleOp module = getOperation(); + auto *context = &getContext(); + + // Assemble type converter. + OneToNTypeConverter typeConverter; + + typeConverter.addConversion([](Type type) { return type; }); + typeConverter.addConversion( + [](TupleType tupleType, SmallVectorImpl &types) { + tupleType.getFlattenedTypes(types); + return success(); + }); + + typeConverter.addArgumentMaterialization(buildMakeTupleOp); + typeConverter.addSourceMaterialization(buildMakeTupleOp); + typeConverter.addTargetMaterialization(buildGetTupleElementOps); + + // Assemble patterns. + RewritePatternSet patterns(context); + if (convertTupleOps) + populateDecomposeTuplesPatterns(typeConverter, patterns); + if (convertFuncOps) + populateFuncTypeConversionPatterns(typeConverter, patterns); + if (convertSCFOps) + populateSCFTypeConversionPatterns(typeConverter, patterns); + + // Run conversion. + if (failed(applyOneToNConversion(module, typeConverter, std::move(patterns)))) + return signalPassFailure(); +} diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -33,6 +33,7 @@ MLIRTestDialect MLIRTestDynDialect MLIRTestIR + MLIRTestOneToNTypeConversionPass MLIRTestPass MLIRTestPDLL MLIRTestReducer diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -106,6 +106,7 @@ void registerTestMathPolynomialApproximationPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); +void registerTestOneToNTypeConversionPass(); void registerTestOpaqueLoc(); void registerTestPadFusion(); void registerTestPDLByteCodePass(); @@ -215,6 +216,7 @@ mlir::test::registerTestMathPolynomialApproximationPass(); mlir::test::registerTestMemRefDependenceCheck(); mlir::test::registerTestMemRefStrideCalculation(); + mlir::test::registerTestOneToNTypeConversionPass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestPadFusion(); mlir::test::registerTestPDLByteCodePass();