diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -285,6 +285,7 @@ // CHECK: [[V0]] return %2 : i32 } + // CHECK-LABEL: testConstOpMatchFailure func @testConstOpMatchFailure() -> (i64) { // CHECK-DAG: [[C0:%.+]] = constant 1 @@ -300,6 +301,20 @@ return %2 : i64 } +// CHECK-LABEL: testConstOpMatchNonConst +func @testConstOpMatchNonConst(%arg0 : i32) -> (i32) { + // CHECK-DAG: [[C0:%.+]] = constant 1 + %0 = "test.constant"() {value = 1 : i32} : () -> i32 + + // CHECK: [[V0:%.+]] = "test.op_r"([[C0]], %arg0) + %1 = "test.op_r"(%0, %arg0) : (i32, i32) -> i32 + + // CHECK: [[V0]] + return %1 : i32 +} + + + //===----------------------------------------------------------------------===// // Test Enum Attributes //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -259,6 +259,7 @@ raw_indented_ostream::DelimitedScope scope(os); + os << "if(!" << opName << ") return failure();\n"; for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { std::string argName = formatv("arg{0}_{1}", depth, i); if (DagNode argTree = tree.getArgAsNestedDag(i)) {