Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -602,6 +602,20 @@ def OpK : TEST_Op<"op_k">, Arguments<(ins)>, Results<(outs I32)>; def : Pat<(OpJ), (OpK)>; +// Test patterns for redundant computation of operands +def OpM : TEST_Op<"op_m"> { + let arguments = (ins I32, OptionalAttr:$optional_attr); + let results = (outs I32); +} +// Pattern add the argument plus a increasing static number hidden in +// OpMTest function. That value is set into the optional argument. +// That way, we will know if operations is called once or twice. +def OpM_GetNullAttr : NativeCodeCall<"Attribute()">; +def OpM_AttributeIsNull : Constraint, "Attribute is null">; +def OpM_Val : NativeCodeCall<"OpM_Test($_builder, $0)">; +def : Pat<(OpM $attr, $optAttr), (OpM $attr, (OpM_Val $attr) ), + [(OpM_AttributeIsNull:$optAttr)]>; + // Test `$_` for ignoring op argument match. def TestIgnoreArgMatchSrcOp : TEST_Op<"ignore_arg_match_src"> { let arguments = (ins Index: mlir/test/lib/Dialect/Test/TestPatterns.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -32,6 +32,16 @@ op.operand()); } +// Native function for testing redundant calls to Native calls. +// OpM_Test will return Pi, increased by 1 for each subsequent calls. +// This let us check the number of times OpM_Test was called by inspecting +// the returned value in the MLIR output. +static int64_t OpM_IncrasingValue = 314159265; +static Attribute OpM_Test(PatternRewriter &rewriter, Value val){ + int64_t i = OpM_IncrasingValue++; + return rewriter.getIntegerAttr(rewriter.getIntegerType(32), i); +} + namespace { #include "TestPatterns.inc" } // end anonymous namespace Index: mlir/test/mlir-tblgen/pattern.mlir =================================================================== --- mlir/test/mlir-tblgen/pattern.mlir +++ mlir/test/mlir-tblgen/pattern.mlir @@ -359,3 +359,14 @@ %0 = "test.one_i32_out"() : () -> (i32) return %0 : i32 } + +//===----------------------------------------------------------------------===// +// Test Rewrite rules calling native calls only once +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: redundantTest +func @redundantTest(%arg0: i32) -> i32 { + %0 = "test.op_m"(%arg0) : (i32) -> i32 + // CHECK: "test.op_m"(%arg0) {optional_attr = 314159265 : i32} : (i32) -> i32 + return %0 : i32 +} Index: mlir/tools/mlir-tblgen/RewriterGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/RewriterGen.cpp +++ mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1046,9 +1046,9 @@ for (int argIndex = 0, e = resultOp.getNumArgs(); argIndex < e; ++argIndex) { if (resultOp.getArg(argIndex).is()) { - const char *addAttrCmd = "if ({1}) {{" - " tblgen_attrs.emplace_back(rewriter." - "getIdentifier(\"{0}\"), {1}); }\n"; + const char *addAttrCmd = "{ if (auto tmpAttr = {1}) " + "tblgen_attrs.emplace_back(rewriter." + "getIdentifier(\"{0}\"), tmpAttr); }\n"; // The argument in the op definition. auto opArgName = resultOp.getArgName(argIndex); if (auto subTree = node.getArgAsNestedDag(argIndex)) {