diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/rewriter-indexing.td @@ -0,0 +1,50 @@ +// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} +class NS_Op traits> : + Op; + +// Tests dag operand indexing for ops with mixed attr and operand. +// --- + +def AOp : NS_Op<"a_op", []> { + let arguments = (ins + AnyInteger:$any_integer + ); + + let results = (outs AnyInteger); +} + +def BOp : NS_Op<"b_op", []> { + let arguments = (ins + AnyAttr: $any_attr, + AnyInteger + ); +} + +def COp : NS_Op<"c_op", []> { + let arguments = (ins + AnyAttr: $any_attr1, + AnyInteger, + AnyAttr: $any_attr2, + AnyInteger + ); +} + +// Only operand 0 should be addressed during matching. +// CHECK: struct test1 : public ::mlir::RewritePattern { +// CHECK: castedOp0.getODSOperands(0).begin()).getDefiningOp() +def test1 : Pat<(BOp $attr, (AOp $input)), + (BOp $attr, $input)>; + +// Only operand 0 and 1 should be addressed during matching. + +// CHECK: struct test2 : public ::mlir::RewritePattern { +// CHECK: castedOp0.getODSOperands(0); +// CHECK: castedOp0.getODSOperands(1).begin()).getDefiningOp() +def test2 : Pat<(COp $attr1, $op1, $attr2, (AOp $op2)), + (BOp $attr1, $op2)>; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -349,7 +349,7 @@ if (!name.empty()) os << formatv("{0} = {1};\n", name, castedName); - for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + for (int i = 0, e = tree.getNumArgs(), nextOperand = 0; i != e; ++i) { auto opArg = op.getArg(i); std::string argName = formatv("op{0}", depth + 1); @@ -365,10 +365,11 @@ } os << "{\n"; + // Attributes don't count for getODSOperands. os.indent() << formatv( "auto *{0} = " "(*{1}.getODSOperands({2}).begin()).getDefiningOp();\n", - argName, castedName, i); + argName, castedName, nextOperand++); emitMatch(argTree, argName, depth + 1); os << formatv("tblgen_ops[{0}] = {1};\n", ++opCounter, argName); os.unindent() << "}\n"; @@ -377,7 +378,9 @@ // Next handle DAG leaf: operand or attribute if (opArg.is()) { + // emitOperandMatch's argument indexing counts attributes. emitOperandMatch(tree, castedName, i, depth); + ++nextOperand; } else if (opArg.is()) { emitAttributeMatch(tree, opName, i, depth); } else {