diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -602,6 +602,20 @@ def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>; def : Pat<(OpJ), (OpK)>; +// Test that natives calls are only called once during rewrites. +def OpM : TEST_Op<"op_m"> { + let arguments = (ins I32, OptionalAttr:$optional_attr); + let results = (outs I32); +} +// Pattern add the argument plus a increasing static number hidden in +// OpMTest function. That value is set into the optional argument. +// That way, we will know if operations is called once or twice. +def OpMGetNullAttr : NativeCodeCall<"Attribute()">; +def OpMAttributeIsNull : Constraint, "Attribute is null">; +def OpMVal : NativeCodeCall<"OpMTest($_builder, $0)">; +def : Pat<(OpM $attr, $optAttr), (OpM $attr, (OpMVal $attr) ), + [(OpMAttributeIsNull:$optAttr)]>; + // Test `$_` for ignoring op argument match. def TestIgnoreArgMatchSrcOp : TEST_Op<"ignore_arg_match_src"> { let arguments = (ins diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -32,6 +32,16 @@ op.operand()); } +// Test that natives calls are only called once during rewrites. +// OpM_Test will return Pi, increased by 1 for each subsequent calls. +// This let us check the number of times OpM_Test was called by inspecting +// the returned value in the MLIR output. +static int64_t opMIncreasingValue = 314159265; +static Attribute OpMTest(PatternRewriter &rewriter, Value val) { + int64_t i = opMIncreasingValue++; + return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); +} + namespace { #include "TestPatterns.inc" } // end anonymous namespace diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -359,3 +359,14 @@ %0 = "test.one_i32_out"() : () -> (i32) return %0 : i32 } + +//===----------------------------------------------------------------------===// +// Test that natives calls are only called once during rewrites. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: redundantTest +func @redundantTest(%arg0: i32) -> i32 { + %0 = "test.op_m"(%arg0) : (i32) -> i32 + // CHECK: "test.op_m"(%arg0) {optional_attr = 314159265 : i32} : (i32) -> i32 + return %0 : i32 +} diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1044,11 +1044,11 @@ os.indent(6) << formatv( "SmallVector tblgen_attrs; (void)tblgen_attrs;\n"); + const char *addAttrCmd = + "if (auto tmpAttr = {1}) " + "tblgen_attrs.emplace_back(rewriter.getIdentifier(\"{0}\"), tmpAttr);\n"; for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { if (resultOp.getArg(argIndex).is()) { - const char *addAttrCmd = "if ({1}) {{" - " tblgen_attrs.emplace_back(rewriter." - "getIdentifier(\"{0}\"), {1}); }\n"; // The argument in the op definition. auto opArgName = resultOp.getArgName(argIndex); if (auto subTree = node.getArgAsNestedDag(argIndex)) {