diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -392,26 +392,31 @@ * `$_builder` will be replaced by the current `mlir::PatternRewriter`. * `$_loc` will be replaced by the fused location or custom location (as determined by location directive). -* `$_self` will be replaced with the entity `NativeCodeCall` is attached to. +* `$_self` is a required placeholder and can only be used in source pattern as + a placeholder for passing the defining operation of certain value. We have seen how `$_builder` can be used in the above; it allows us to pass a `mlir::Builder` (`mlir::PatternRewriter` is a subclass of `mlir::OpBuilder`, which is a subclass of `mlir::Builder`) to the C++ helper function to use the handy methods on `mlir::Builder`. -`$_self` is useful when we want to write something in the form of -`NativeCodeCall<"...">:$symbol`. For example, if we want to reverse the previous -example and decompose the array attribute into two attributes: +Here's an example how we should use `$_self` in source pattern, ```tablegen -class getNthAttr : NativeCodeCall<"$_self[" # n # "]">; -def : Pat<(OneAttrOp $attr), - (TwoAttrOp (getNthAttr<0>:$attr), (getNthAttr<1>:$attr)>; +def : Pat<(OneAttrOp (NativeCodeCall<"Foo($_self, $0)"> I32Attr:$val)), + (TwoAttrOp $val, $val)>; ``` -In the above, `$_self` is substituted by the attribute bound by `$attr`, which -is `OneAttrOp`'s array attribute. +In the above, `$_self` is substituted by the defining operation of the first +operand of OneAttrOp. Note that we don't support binding name to NativeCodeCall +in source pattern. To get the return value from the helper function, you can put +the names in the parameter list and `mlir-tblgen` will pass the reference to the +variables with corresponding type. In this case, `$val` has attribute constraint +then it will be a variable with `Attribute` type, so the type of `$0` is +supposed to be `&Attribute` and you can assign the desired attribute to `$val` +in the helper function. Other than attribute constraint, others will be +associated with `Value` type. Positional placeholders will be substituted by the `dag` object parameters at the `NativeCodeCall` use site. For example, if we define `SomeCall : diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2496,9 +2496,10 @@ // the wrapped expression can take special placeholders listed below: // // * `$_builder` will be replaced by the current `mlir::PatternRewriter`. -// * `$_self` will be replaced with the entity this transformer is attached to. -// E.g., with the definition `def transform : NativeCodeCall<"$_self...">`, -// `$_self` in `transform:$attr` will be replaced by the value for `$attr`. +// * `$_self` is a required placeholder and can only be used in `pattern` for +// the operand we are going to match. E.g., +// `NativeCodeCall<"Foo($_self, $0)> I32Attr:$attr)>`, `$_self` will be +// replaced with the defining operation of the first operand of OneArgOp. // // If used as a DAG node, i.e., `(NativeCodeCall<"..."> , ..., )`, // then positional placeholders are also supported; placeholder `$N` in the @@ -2508,7 +2509,7 @@ string expression = expr; } -def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($0->getResult(0), m_Constant(&$1)))">; +def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult(0), m_Constant(&$0)))">; //===----------------------------------------------------------------------===// // Rewrite directives diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -232,7 +232,7 @@ getVarName(name))); } case Kind::Value: { - return std::string(formatv("::llvm::ArrayRef<::mlir::Value> {0};\n", name)); + return std::string(formatv("::mlir::Value {0};\n", name)); } case Kind::Result: { // Use the op itself for captured results. @@ -626,11 +626,15 @@ if (tree.isNativeCodeCall()) { if (!treeName.empty()) { - PrintFatalError( - &def, - formatv( - "binding symbol '{0}' to native code call unsupported right now", - treeName)); + if (!isSrcPattern) { + LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " + << treeName << '\n'); + verifyBind(infoMap.bindValue(treeName), treeName); + } else + PrintFatalError(&def, + formatv("binding symbol '{0}' to NativecodeCall in " + "MatchPattern is not supported", + treeName)); } for (int i = 0; i != numTreeArgs; ++i) { @@ -649,24 +653,27 @@ // `$_` is a special symbol meaning ignore the current argument. if (!treeArgName.empty() && treeArgName != "_") { - if (tree.isNestedDagArg(i)) { - auto err = formatv("cannot bind '{0}' for nested native call arg", - treeArgName); - PrintFatalError(&def, err); - } - DagLeaf leaf = tree.getArgAsLeaf(i); - auto constraint = leaf.getAsConstraint(); - bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || - leaf.isConstantAttr() || - constraint.getKind() == Constraint::Kind::CK_Attr; - - if (isAttr) { - verifyBind(infoMap.bindAttr(treeArgName), treeArgName); - continue; - } - verifyBind(infoMap.bindValue(treeArgName), treeArgName); + // In (NativeCodeCall<"Foo($_self, $0, $1, $2)"> I8Attr:$a, I8:$b, $c), + if (leaf.isUnspecified()) + // This is case of $c, a Value without any constraints. + verifyBind(infoMap.bindValue(treeArgName), treeArgName); + else { + auto constraint = leaf.getAsConstraint(); + bool isAttr = leaf.isAttrMatcher() || leaf.isEnumAttrCase() || + leaf.isConstantAttr() || + constraint.getKind() == Constraint::Kind::CK_Attr; + + if (isAttr) { + // This is case of $a, a binding to a certain attribute. + verifyBind(infoMap.bindAttr(treeArgName), treeArgName); + continue; + } + + // This is case of $b, a binding to a certain type. + verifyBind(infoMap.bindValue(treeArgName), treeArgName); + } } } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -837,6 +837,20 @@ [(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input), (OpK)]>; +def OpNativeCodeCall4 : TEST_Op<"native_code_call4"> { + let arguments = (ins AnyType:$input1); + let results = (outs I32:$output1, I32:$output2); +} +def OpNativeCodeCall5 : TEST_Op<"native_code_call5"> { + let arguments = (ins I32:$input1, I32:$input2); + let results = (outs I32:$output1, I32:$output2); +} + +def GetFirstI32Result : NativeCodeCall<"success(getFirstI32Result($_self, &$0))">; +def BindNativeCodeCallResult : NativeCodeCall<"bindNativeCodeCallResult($0)">; +def : Pat<(OpNativeCodeCall4 (GetFirstI32Result $ret)), + (OpNativeCodeCall5 (BindNativeCodeCallResult:$native $ret), $native)>; + // Test AllAttrConstraintsOf. def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> { let arguments = (ins I64ArrayAttr:$attr); diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -35,6 +35,15 @@ op.operand()); } +static bool getFirstI32Result(Operation *op, Value *value) { + if (!Type(op->getResult(0).getType()).isSignlessInteger(32)) + return false; + *value = op->getResult(0); + return true; +} + +static Value bindNativeCodeCallResult(Value value) { return value; } + // Test that natives calls are only called once during rewrites. // OpM_Test will return Pi, increased by 1 for each subsequent calls. // This let us check the number of times OpM_Test was called by inspecting diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -88,6 +88,20 @@ return %0 : i32 } +// CHECK-LABEL: verifyNativeCodeCallBinding +func @verifyNativeCodeCallBinding(%arg0 : i32) -> (i32) { + %0 = "test.op_k"() : () -> (i32) + // CHECK: %[[A:.*]], %[[B:.*]] = "test.native_code_call5" + %1, %2 = "test.native_code_call4"(%0) : (i32) -> (i32, i32) + %3 = "test.constant"() {value = 1 : i8} : () -> i8 + // %3 is i8 so it'll fail at GetFirstI32Result match. The operation should + // keep the same form. + // CHECK: %{{.*}}, %{{.*}} = "test.native_code_call4"({{%.*}}) : (i8) -> (i32, i32) + %4, %5 = "test.native_code_call4"(%3) : (i8) -> (i32, i32) + // CHECK: return %[[A]] + return %1 : i32 +} + // CHECK-LABEL: verifyAllAttrConstraintOf func @verifyAllAttrConstraintOf() -> (i32, i32, i32) { // CHECK: "test.all_attr_constraint_of2" diff --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td --- a/mlir/test/mlir-tblgen/rewriter-errors.td +++ b/mlir/test/mlir-tblgen/rewriter-errors.td @@ -1,5 +1,6 @@ // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR3 %s 2>&1 | FileCheck --check-prefix=ERROR3 %s include "mlir/IR/OpBase.td" @@ -16,14 +17,21 @@ #ifdef ERROR1 def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">; -// ERROR1: [[@LINE+1]]:1: error: binding symbol 'error' to native code call unsupported right now -def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg), +// ERROR1: [[@LINE+1]]:1: error: NativeCodeCall must have $_self as argument for passing the defining Operation of certain operand +def : Pat<(OpA (NativeMatcher $val), AnyI32Attr:$arg), (OpB $val, $arg)>; #endif #ifdef ERROR2 -def NativeMatcher : NativeCodeCall<"success(nativeCall($0, $1))">; -// ERROR2: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for +def NativeMatcher : NativeCodeCall<"success(nativeCall($_self, $0))">; +// ERROR2: [[@LINE+1]]:1: error: binding symbol 'error' to NativecodeCall in MatchPattern is not supported +def : Pat<(OpA (NativeMatcher:$error $val), AnyI32Attr:$arg), + (OpB $val, $arg)>; +#endif + +#ifdef ERROR3 +def NativeMatcher : NativeCodeCall<"success(nativeCall($_self, $0, $1))">; +// ERROR3: [[@LINE+1]]:1: error: Matching nested tree in NativeCodecall not support for def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg), (OpB $val, $arg)>; #endif diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -252,7 +252,6 @@ // TODO(suderman): iterate through arguments, determine their types, output // names. SmallVector capture; - capture.push_back(opName.str()); raw_indented_ostream::DelimitedScope scope(os); @@ -265,8 +264,8 @@ auto leaf = tree.getArgAsLeaf(i); if (leaf.isAttrMatcher() || leaf.isConstantAttr()) { os << "Attribute " << argName << ";\n"; - } else if (leaf.isOperandMatcher()) { - os << "Operation " << argName << ";\n"; + } else { + os << "Value " << argName << ";\n"; } } @@ -278,20 +277,25 @@ std::tie(hasLocationDirective, locToUse) = getLocation(tree); auto fmt = tree.getNativeCodeTemplate(); - auto nativeCodeCall = - std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), capture)); + if (fmt.count("$_self") != 1) { + PrintFatalError(loc, "NativeCodeCall must have $_self as argument for " + "passing the defining Operation of certain operand"); + } + + auto nativeCodeCall = std::string(tgfmt( + fmt, &fmtCtx.addSubst("_loc", locToUse).withSelf(opName.str()), capture)); os << "if (failed(" << nativeCodeCall << ")) return ::mlir::failure();\n"; for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { auto name = tree.getArgName(i); if (!name.empty() && name != "_") { - os << formatv("{0} = {1};\n", name, capture[i + 1]); + os << formatv("{0} = {1};\n", name, capture[i]); } } for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { - std::string argName = capture[i + 1]; + std::string argName = capture[i]; // Handle nested DAG construct first if (DagNode argTree = tree.getArgAsNestedDag(i)) { @@ -302,9 +306,18 @@ } DagLeaf leaf = tree.getArgAsLeaf(i); + + // The parameter for native function doesn't bind any constraints. + if (leaf.isUnspecified()) + continue; + auto constraint = leaf.getAsConstraint(); - auto self = formatv("{0}", argName); + std::string self; + if (leaf.isAttrMatcher() || leaf.isConstantAttr()) + self = argName; + else + self = formatv("{0}.getType()", argName); emitMatchCheck( opName, tgfmt(constraint.getConditionTemplate(), &fmtCtx.withSelf(self)), @@ -929,7 +942,13 @@ << " replacement: " << attrs[i] << "\n"); } - return std::string(tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs)); + std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs); + if (!tree.getSymbol().empty()) { + os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol); + symbol = tree.getSymbol().str(); + } + + return symbol; } int PatternEmitter::getNodeValueCount(DagNode node) {