diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -592,6 +592,21 @@ def : Pattern<(OpNativeCodeCall3 $input), [(NativeCodeCall<"createOpI($_builder, $0)"> $input), (OpK)]>; +// Test that the $_op hook works. +def OpHook : TEST_Op<"op_hook"> { + let arguments = (ins + I32:$input, + I64Attr:$result_count + ); + let results = (outs I32); +} + +def CheckResultCount : ConstraintgetNumResults() == $0.getInt()">>; + +def : Pat<(OpHook:$op $input, $result_count), + (OpNativeCodeCall3 $input), + [(CheckResultCount $result_count)]>; + // Test AllAttrConstraintsOf. def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> { let arguments = (ins I64ArrayAttr:$attr); diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -76,6 +76,22 @@ return %0 : i32 } +// CHECK-LABEL: verifyOpHookValid +func @verifyOpHookValid(%arg0: i32) -> (i32) { + // CHECK: test.op_i + // CHECK: test.op_k + %0 = "test.op_hook"(%arg0) {result_count = 1} : (i32) -> (i32) + return %0 : i32 +} + +// CHECK-LABEL: verifyOpHookInvalid +func @verifyOpHookInvalid(%arg0: i32) -> (i32) { + // The result count is 1, not 2, so lowering constraint should fail. + // CHECK: test.op_hook + %0 = "test.op_hook"(%arg0) {result_count = 2} : (i32) -> (i32) + return %0 : i32 +} + // CHECK-LABEL: verifyAllAttrConstraintOf func @verifyAllAttrConstraintOf() -> (i32, i32, i32) { // CHECK: "test.all_attr_constraint_of2" diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -382,6 +382,8 @@ auto &constraint = appliedConstraint.constraint; auto &entities = appliedConstraint.entities; + fmtCtx.withOp("op"); + auto condition = constraint.getConditionTemplate(); auto cmd = "if (!({0})) return matchFailure();\n";