diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1411,6 +1411,36 @@ (replaceWithValue $results__2), ConstantAttr)>; +def FooVariadicConsumer : TEST_Op<"foo_variadic_consume"> { + let arguments = (ins Variadic:$arg); + let results = (outs I32); +} +def FooBinaryConsumer : TEST_Op<"foo_binary_consume"> { + let arguments = (ins I32:$arg0, I32:$arg1); + let results = (outs I32); +} + +def Front: NativeCodeCall<"$0.front()", 1>; +def DropFront: NativeCodeCall<"$0.drop_front()", 1>; +class HasMoreValuesThan: Constraint " # n>, + "has more than " # n # " values">; +class HasExactValues: Constraint; +def : Pat<(FooVariadicConsumer $varg), + (FooBinaryConsumer (Front $varg), + (FooVariadicConsumer (DropFront $varg))), + [(HasMoreValuesThan<3> $varg)]>; +// This could also have been a canonicalization. +def : Pat<(FooVariadicConsumer $varg), + (FooBinaryConsumer (Front $varg), + (Front (DropFront $varg))), + [(HasExactValues<2> $varg)]>; +// This could be generalized (using C++ helper) to handle arbitrary number of +// operands including variadic ones. +def Concat3: NativeCodeCall<"SmallVector{$0, $1, $2}", 1>; +def : Pat<(FooBinaryConsumer $o1, (FooBinaryConsumer $o2, $o3)), + (FooVariadicConsumer (Concat3 $o1, $o2, $o3))>; + //===----------------------------------------------------------------------===// // Test Patterns (either) diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -533,6 +533,45 @@ return %0 : i32 } +//===----------------------------------------------------------------------===// +// Test variadic operands rewrites using native calls. +//===----------------------------------------------------------------------===// + +func @testVariadicToBinary1(%arg0: i32, %arg1: i32, %arg2: i32, %arg3: i32) { + %0 = "test.foo_variadic_consume"(%arg0, %arg1, %arg2, %arg3) : (i32, i32, i32, i32) -> (i32) + // CHECK-LABEL: func @testVariadicToBinary1( + // CHECK-SAME: %[[VAL_0:[^ ]*]]: i32 + // CHECK-SAME: %[[VAL_1:[^ ]*]]: i32 + // CHECK-SAME: %[[VAL_2:[^ ]*]]: i32 + // CHECK-SAME: %[[VAL_3:[^ ]*]]: i32 + // CHECK: %[[VAL_4:.*]] = "test.foo_variadic_consume"(%[[VAL_1]], %[[VAL_2]], %[[VAL_3]]) : (i32, i32, i32) -> i32 + // CHECK: %[[VAL_5:.*]] = "test.foo_binary_consume"(%[[VAL_0]], %[[VAL_4]]) : (i32, i32) -> i32 + return +} + +func @testVariadicToBinary2(%arg0: i32, %arg1: i32) { + %0 = "test.foo_variadic_consume"(%arg0, %arg1) : (i32, i32) -> (i32) + // CHECK-LABEL: func @testVariadicToBinary2( + // CHECK-SAME: %[[VAL_0:[^ ]*]]: i32 + // CHECK-SAME: %[[VAL_1:[^ ]*]]: i32 + // CHECK: %[[VAL_2:.*]] = "test.foo_binary_consume"(%[[VAL_0]], %[[VAL_1]]) : (i32, i32) -> i32 + return +} + +func @testBinaryToVariadic(%arg0: i32, %arg1: i32, %arg2: i32) { + %0 = "test.foo_binary_consume"(%arg1, %arg2) : (i32, i32) -> (i32) + %1 = "test.foo_binary_consume"(%arg0, %0) : (i32, i32) -> (i32) + // CHECK-LABEL: func @testBinaryToVariadic( + // CHECK-SAME: %[[VAL_0:[^ ]*]]: i32 + // CHECK-SAME: %[[VAL_1:[^ ]*]]: i32 + // CHECK-SAME: %[[VAL_2:[^ ]*]]: i32 + // The pattern doesn't delete the intermediate result, if side-effect + // free then DCE will remove next. + // CHECK: %[[VAL_3:.*]] = "test.foo_binary_consume"(%[[VAL_1]], %[[VAL_2]]) : (i32, i32) -> i32 + // CHECK: %[[VAL_4:.*]] = "test.foo_variadic_consume"(%[[VAL_0]], %[[VAL_1]], %[[VAL_2]]) : (i32, i32, i32) -> i32 + return +} + //===----------------------------------------------------------------------===// // Test either directive //===----------------------------------------------------------------------===//