diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2508,7 +2508,7 @@ string expression = expr; } -def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">; +def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult(0), m_Constant(&$0)))">; //===----------------------------------------------------------------------===// // Rewrite directives diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -232,7 +232,7 @@ getVarName(name))); } case Kind::Value: { - return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name)); + return std::string(formatv("::mlir::Value {0};\n", name)); } case Kind::Result: { // Use the op itself for captured results. @@ -626,11 +626,15 @@ if (tree.isNativeCodeCall()) { if (!treeName.empty()) { - PrintFatalError( - &def, - formatv( - "binding symbol '{0}' to native code call unsupported right now", - treeName)); + if (!isSrcPattern) { + LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " + << treeName << '\n'); + verifyBind(infoMap.bindValue(treeName), treeName); + } else + PrintFatalError(&def, + formatv("binding symbol '{0}' to NativecodeCall in " + "MatchPattern is not supported", + treeName)); } for (int i = 0; i != numTreeArgs; ++i) { @@ -649,24 +653,27 @@ // `$_` is a special symbol meaning ignore the current argument. if (!treeArgName.empty() && treeArgName != "_") { - if (tree.isNestedDagArg(i)) { - auto err = formatv("cannot bind '{0}' for nested native call arg", - treeArgName); - PrintFatalError(&def, err); - } - DagLeaf leaf = tree.getArgAsLeaf(i); - auto constraint = leaf.getAsConstraint(); - bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || - leaf.isConstantAttr() || - constraint.getKind() == Constraint::Kind::CK_Attr; - - if (isAttr) { - verifyBind(infoMap.bindAttr(treeArgName), treeArgName); - continue; - } - verifyBind(infoMap.bindValue(treeArgName), treeArgName); + // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), + if (leaf.isUnspecified()) + // This is case of $c, a Value without any constraints. + verifyBind(infoMap.bindValue(treeArgName), treeArgName); + else { + auto constraint = leaf.getAsConstraint(); + bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || + leaf.isConstantAttr() || + constraint.getKind() == Constraint::Kind::CK_Attr; + + if (isAttr) { + // This is case of $a, a binding to a certain attribute. + verifyBind(infoMap.bindAttr(treeArgName), treeArgName); + continue; + } + + // This is case of $b, a binding to a certain type. + verifyBind(infoMap.bindValue(treeArgName), treeArgName); + } } } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -837,6 +837,20 @@ [(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input), (OpK)]>; +def OpNativeCodeCall4 : TEST_Op<"native_code_call4"> { + let arguments = (ins I32:$input1); + let results = (outs I32:$output1, I32:$output2); +} +def OpNativeCodeCall5 : TEST_Op<"native_code_call5"> { + let arguments = (ins I32:$input1, I32:$input2); + let results = (outs I32:$output1, I32:$output2); +} + +def WriteReturnValueToArgument : NativeCodeCall<"success(writeReturnValueToArgument($_self, &$0))">; +def BindNativeCodeCallResult : NativeCodeCall<"bindNativeCodeCallResult($0)">; +def : Pat<(OpNativeCodeCall4 (WriteReturnValueToArgument $ret)), + (OpNativeCodeCall5 (BindNativeCodeCallResult:$native $ret), $native)>; + // Test AllAttrConstraintsOf. def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> { let arguments = (ins I64ArrayAttr:$attr); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -35,6 +35,13 @@ op.operand()); } +static bool writeReturnValueToArgument(Operation *op, Value *value) { + *value = op->getResult(0); + return true; +} + +static Value bindNativeCodeCallResult(Value value) { return value; } + // Test that natives calls are only called once during rewrites. // OpM_Test will return Pi, increased by 1 for each subsequent calls. // This let us check the number of times OpM_Test was called by inspecting diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -88,6 +88,15 @@ return %0 : i32 } +// CHECK-LABEL: verifyNativeCodeCallBinding +func @verifyNativeCodeCallBinding(%arg0 : i32) -> (i32) { + %0 = "test.op_k"() : () -> (i32) + // CHECK: %[[A:.*]], %[[B:.*]] = "test.native_code_call5"(%0, %0) + // CHECK: return %[[A]] + %1, %2 = "test.native_code_call4"(%0) : (i32) -> (i32, i32) + return %1 : i32 +} + // CHECK-LABEL: verifyAllAttrConstraintOf func @verifyAllAttrConstraintOf() -> (i32, i32, i32) { // CHECK: "test.all_attr_constraint_of2" diff --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td --- a/mlir/test/mlir-tblgen/rewriter-errors.td +++ b/mlir/test/mlir-tblgen/rewriter-errors.td @@ -1,5 +1,6 @@ // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR3 %s 2>&1 | FileCheck --check-prefix=ERROR3 %s include "mlir/IR/OpBase.td" @@ -16,14 +17,21 @@ #ifdef ERROR1 def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">; -// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now -def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg), +// ERROR1: [[@LINE+1]]:1: error: NativeCodeCall must have $_self as argument for passing the Operation of certain oprand +def : Pat<(OpA (NativeMatcher $val), AnyI32Attr:$arg), (OpB $val, $arg)>; #endif #ifdef ERROR2 -def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">; -// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for +def NativeMatcher : NativeCodeCall<"success(nativeCall($_self, $0, $1))">; +// ERROR2: [[@LINE+1]]:1: error: binding symbol 'error' to NativecodeCall in MatchPattern is not supported +def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg), + (OpB $val, $arg)>; +#endif + +#ifdef ERROR3 +def NativeMatcher : NativeCodeCall<"success(nativeCall($_self, $0, $1))">; +// ERROR3: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg), (OpB $val, $arg)>; #endif diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -252,7 +252,6 @@ // TODO(suderman): iterate through arguments, determine their types, output // names. SmallVector capture; - capture.push_back(opName.str()); raw_indented_ostream::DelimitedScope scope(os); @@ -265,8 +264,8 @@ auto leaf = tree.getArgAsLeaf(i); if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { os << "Attribute " << argName << ";\n"; - } else if (leaf.isOperandMatcher()) { - os << "Operation " << argName << ";\n"; + } else { + os << "Value " << argName << ";\n"; } } @@ -278,20 +277,25 @@ std::tie(hasLocationDirective, locToUse) = getLocation(tree); auto fmt = tree.getNativeCodeTemplate(); - auto nativeCodeCall = - std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), capture)); + if (fmt.count("$_self") != 1) { + PrintFatalError(loc, "NativeCodeCall must have $_self as argument for " + "passing the Operation of certain oprand"); + } + + auto nativeCodeCall = std::string(tgfmt( + fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture)); os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n"; for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto name = tree.getArgName(i); if (!name.empty() && name != "_") { - os << formatv("{0} = {1};\n", name, capture[i + 1]); + os << formatv("{0} = {1};\n", name, capture[i]); } } for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { - std::string argName = capture[i + 1]; + std::string argName = capture[i]; // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { @@ -302,9 +306,18 @@ } DagLeaf leaf = tree.getArgAsLeaf(i); + + // The parameter for native function doesn't bind any constraints. + if (leaf.isUnspecified()) + continue; + auto constraint = leaf.getAsConstraint(); - auto self = formatv("{0}", argName); + std::string self; + if (leaf.isAttrMatcher() || leaf.isConstantAttr()) + self = argName; + else + self = formatv("{0}.getType()", argName); emitMatchCheck( opName, tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), @@ -929,7 +942,13 @@ << " replacement: " << attrs[i] << "\n"); } - return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs)); + std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs); + if (!tree.getSymbol().empty()) { + os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol); + symbol = tree.getSymbol().str(); + } + + return symbol; } int PatternEmitter::getNodeValueCount(DagNode node) {