diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -377,9 +377,6 @@ C++ object expected at the `NativeCodeCall` site (here it would be expecting an array attribute). Typically the string should be a function call. -Note that currently `NativeCodeCall` must return no more than one value or -attribute. This might change in the future. - ##### `NativeCodeCall` placeholders In `NativeCodeCall`, we can use placeholders like `$_builder`, `$N` and `$N...`. @@ -428,6 +425,30 @@ `SomeCall : NativeCodeCall<"someFn($1...)">` and use it like `(SomeCall $in0, $in1, $in2)`, then this will be translated into C++ call `someFn($in1, $in2)`. +##### `NativeCodeCall` binding multi-results + +To bind multi-results and access the N-th result with `$__N`, specify the +number of return values in the template. Note that only `Value` type is +supported for multiple results binding. For example, + +```tablegen + +def PackAttrs : NativeCodeCall<"packAttrs($0, $1)", 2>; +def : Pattern<(TwoResultOp $attr1, $attr2), + [(OneResultOp (PackAttr:$res__0, $attr1, $attr2)), + (OneResultOp $res__1)]>; + +``` + +Use `NativeCodeCallVoid` for case has no return value. + +The correct number of returned value specified in NativeCodeCall is important. +It will be used to verify the consistency of the number of result values. +Additionally, `mlir-tblgen` will try to capture the return value of +NativeCodeCall in the generated code so that it will trigger a later compilation +error if a NativeCodeCall that doesn't return a result isn't labeled with 0 +returns. + ##### Customizing entire op building `NativeCodeCall` is not only limited to transforming arguments for building an diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2565,11 +2565,20 @@ // If used as a DAG node, i.e., `(NativeCodeCall<"..."> , ..., )`, // then positional placeholders are also supported; placeholder `$N` in the // wrapped C++ expression will be replaced by ``. +// +// ## Bind multiple results +// +// To bind multi-results and access the N-th result with `$__N`, specify +// the number of return values in the template. Note that only `Value` type is +// supported for multiple results binding. -class NativeCodeCall { +class NativeCodeCall { string expression = expr; + int numReturns = returns; } +class NativeCodeCallVoid : NativeCodeCall; + def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult(0), m_Constant(&$0)))">; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -100,6 +100,11 @@ // Precondition: isNativeCodeCall() StringRef getNativeCodeTemplate() const; + // Returns the number of values will be returned by the native helper + // function. + // Precondition: isNativeCodeCall() + int getNumReturnsOfNativeCode() const; + // Returns the string associated with the leaf. // Precondition: isStringAttr() std::string getStringAttr() const; @@ -181,6 +186,11 @@ // Precondition: isNativeCodeCall() StringRef getNativeCodeTemplate() const; + // Returns the number of values will be returned by the native helper + // function. + // Precondition: isNativeCodeCall() + int getNumReturnsOfNativeCode() const; + void print(raw_ostream &os) const; private: @@ -242,30 +252,32 @@ // DagNode and DagLeaf are accessed by value which means it can't be used as // identifier here. Use an opaque pointer type instead. - using DagAndIndex = std::pair; + using DagAndConstant = std::pair; // What kind of entity this symbol represents: // * Attr: op attribute // * Operand: op operand // * Result: op result // * Value: a value not attached to an op (e.g., from NativeCodeCall) - enum class Kind : uint8_t { Attr, Operand, Result, Value }; + // * MultipleValues: a pack of values not attached to an op (e.g., from + // NativeCodeCall). This kind supports indexing. + enum class Kind : uint8_t { Attr, Operand, Result, Value, MultipleValues }; - // Creates a SymbolInfo instance. `dagAndIndex` is only used for `Attr` and - // `Operand` so should be llvm::None for `Result` and `Value` kind. + // Creates a SymbolInfo instance. `dagAndConstant` is only used for `Attr` + // and `Operand` so should be llvm::None for `Result` and `Value` kind. SymbolInfo(const Operator *op, Kind kind, - Optional dagAndIndex); + Optional dagAndConstant); // Static methods for creating SymbolInfo. static SymbolInfo getAttr(const Operator *op, int index) { - return SymbolInfo(op, Kind::Attr, DagAndIndex(nullptr, index)); + return SymbolInfo(op, Kind::Attr, DagAndConstant(nullptr, index)); } static SymbolInfo getAttr() { return SymbolInfo(nullptr, Kind::Attr, llvm::None); } static SymbolInfo getOperand(DagNode node, const Operator *op, int index) { return SymbolInfo(op, Kind::Operand, - DagAndIndex(node.getAsOpaquePointer(), index)); + DagAndConstant(node.getAsOpaquePointer(), index)); } static SymbolInfo getResult(const Operator *op) { return SymbolInfo(op, Kind::Result, llvm::None); @@ -273,6 +285,10 @@ static SymbolInfo getValue() { return SymbolInfo(nullptr, Kind::Value, llvm::None); } + static SymbolInfo getMultipleValues(int numValues) { + return SymbolInfo(nullptr, Kind::MultipleValues, + DagAndConstant(nullptr, numValues)); + } // Returns the number of static values this symbol corresponds to. // A static value is an operand/result declared in ODS. Normally a symbol @@ -298,13 +314,21 @@ std::string getAllRangeUse(StringRef name, int index, const char *fmt, const char *separator) const; + // The argument index (for `Attr` and `Operand` only) + int getArgIndex() const { return (*dagAndConstant).second; } + + // The number of values in the MultipleValue + int getSize() const { return (*dagAndConstant).second; } + const Operator *op; // The op where the bound entity belongs Kind kind; // The kind of the bound entity - // The pair of DagNode pointer and argument index (for `Attr` and `Operand` - // only). Note that operands may be bound to the same symbol, use the - // DagNode and index to distinguish them. For `Attr`, the Dag part will be - // nullptr. - Optional dagAndIndex; + + // The pair of DagNode pointer and constant value (for `Attr`, `Operand` and + // the size of MultipleValue symbol). Note that operands may be bound to the + // same symbol, use the DagNode and index to distinguish them. For `Attr` + // and MultipleValue, the Dag part will be nullptr. + Optional dagAndConstant; + // Alternative name for the symbol. It is used in case the name // is not unique. Applicable for `Operand` only. Optional alternativeName; @@ -331,10 +355,17 @@ // `symbol` is already bound. bool bindOpResult(StringRef symbol, const Operator &op); - // Registers the given `symbol` as bound to a value. Returns false if `symbol` - // is already bound. + // A helper function for dispatching target value binding functions. + bool bindValues(StringRef symbol, int numValues = 1); + + // Registers the given `symbol` as bound to the Value(s). Returns false if + // `symbol` is already bound. bool bindValue(StringRef symbol); + // Registers the given `symbol` as bound to a MultipleValue. Return false if + // `symbol` is already bound. + bool bindMultipleValues(StringRef symbol, int numValues); + // Registers the given `symbol` as bound to an attr. Returns false if `symbol` // is already bound. bool bindAttr(StringRef symbol); diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -83,6 +83,11 @@ return cast(def)->getDef()->getValueAsString("expression"); } +int DagLeaf::getNumReturnsOfNativeCode() const { + assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); + return cast(def)->getDef()->getValueAsInt("numReturns"); +} + std::string DagLeaf::getStringAttr() const { assert(isStringAttr() && "the DAG leaf must be string attribute"); return def->getAsUnquotedString(); @@ -119,6 +124,13 @@ ->getValueAsString("expression"); } +int DagNode::getNumReturnsOfNativeCode() const { + assert(isNativeCodeCall() && "the DAG leaf must be NativeCodeCall"); + return cast(node->getOperator()) + ->getDef() + ->getValueAsInt("numReturns"); +} + llvm::StringRef DagNode::getSymbol() const { return node->getNameStr(); } Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { @@ -193,8 +205,8 @@ } SymbolInfoMap::SymbolInfo::SymbolInfo(const Operator *op, SymbolInfo::Kind kind, - Optional dagAndIndex) - : op(op), kind(kind), dagAndIndex(dagAndIndex) {} + Optional dagAndConstant) + : op(op), kind(kind), dagAndConstant(dagAndConstant) {} int SymbolInfoMap::SymbolInfo::getStaticValueCount() const { switch (kind) { @@ -204,6 +216,8 @@ return 1; case Kind::Result: return op->getNumResults(); + case Kind::MultipleValues: + return getSize(); } llvm_unreachable("unknown kind"); } @@ -217,7 +231,7 @@ switch (kind) { case Kind::Attr: { if (op) { - auto type = op->getArg((*dagAndIndex).second) + auto type = op->getArg(getArgIndex()) .get() ->attr.getStorageType(); return std::string(formatv("{0} {1};\n", type, name)); @@ -235,6 +249,14 @@ case Kind::Value: { return std::string(formatv("::mlir::Value {0};\n", name)); } + case Kind::MultipleValues: { + // This is for the variable used in the source pattern. Each named value in + // source pattern will only be bound to a Value. The others in the result + // pattern may be associated with multiple Values as we will use `auto` to + // do the type inference. + return std::string(formatv( + "::mlir::Value {0}_raw; ::mlir::ValueRange {0}({0}_raw);\n", name)); + } case Kind::Result: { // Use the op itself for captured results. return std::string(formatv("{0} {1};\n", op->getQualCppClassName(), name)); @@ -255,8 +277,7 @@ } case Kind::Operand: { assert(index < 0); - auto *operand = - op->getArg((*dagAndIndex).second).get(); + auto *operand = op->getArg(getArgIndex()).get(); // If this operand is variadic, then return a range. Otherwise, return the // value itself. if (operand->isVariableLength()) { @@ -311,6 +332,21 @@ LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); return std::string(repl); } + case Kind::MultipleValues: { + assert(op == nullptr); + assert(index < getSize()); + if (index >= 0) { + std::string repl = + formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); + LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); + return repl; + } + // If it doesn't specify certain element, unpack them all. + auto repl = + formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); + LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); + return std::string(repl); + } } llvm_unreachable("unknown kind"); } @@ -353,6 +389,20 @@ LLVM_DEBUG(llvm::dbgs() << repl << " (Value)\n"); return std::string(repl); } + case Kind::MultipleValues: { + assert(op == nullptr); + assert(index < getSize()); + if (index >= 0) { + std::string repl = + formatv(fmt, std::string(formatv("{0}[{1}]", name, index))); + LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); + return repl; + } + auto repl = + formatv(fmt, std::string(formatv("{0}.begin(), {0}.end()", name))); + LLVM_DEBUG(llvm::dbgs() << repl << " (MultipleValues)\n"); + return std::string(repl); + } } llvm_unreachable("unknown kind"); } @@ -395,11 +445,25 @@ return symbolInfoMap.count(inserted->first) == 1; } +bool SymbolInfoMap::bindValues(StringRef symbol, int numValues) { + std::string name = getValuePackName(symbol).str(); + if (numValues > 1) + return bindMultipleValues(name, numValues); + return bindValue(name); +} + bool SymbolInfoMap::bindValue(StringRef symbol) { auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getValue()); return symbolInfoMap.count(inserted->first) == 1; } +bool SymbolInfoMap::bindMultipleValues(StringRef symbol, int numValues) { + std::string name = getValuePackName(symbol).str(); + auto inserted = + symbolInfoMap.emplace(name, SymbolInfo::getMultipleValues(numValues)); + return symbolInfoMap.count(inserted->first) == 1; +} + bool SymbolInfoMap::bindAttr(StringRef symbol) { auto inserted = symbolInfoMap.emplace(symbol.str(), SymbolInfo::getAttr()); return symbolInfoMap.count(inserted->first) == 1; @@ -423,11 +487,9 @@ const auto symbolInfo = SymbolInfo::getOperand(node, &op, argIndex); - for (auto it = range.first; it != range.second; ++it) { - if (it->second.dagAndIndex == symbolInfo.dagAndIndex) { + for (auto it = range.first; it != range.second; ++it) + if (it->second.dagAndConstant == symbolInfo.dagAndConstant) return it; - } - } return symbolInfoMap.end(); } @@ -633,7 +695,9 @@ if (!isSrcPattern) { LLVM_DEBUG(llvm::dbgs() << "found symbol bound to NativeCodeCall: " << treeName << '\n'); - verifyBind(infoMap.bindValue(treeName), treeName); + verifyBind( + infoMap.bindValues(treeName, tree.getNumReturnsOfNativeCode()), + treeName); } else { PrintFatalError(&def, formatv("binding symbol '{0}' to NativecodeCall in " diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -857,7 +857,7 @@ // Test that NativeCodeCall is not ignored if it is not used to directly // replace the matched root op. def : Pattern<(OpNativeCodeCall3 $input), - [(NativeCodeCall<"createOpI($_builder, $_loc, $0)"> $input), + [(NativeCodeCallVoid<"createOpI($_builder, $_loc, $0)"> $input), (OpK)]>; def OpNativeCodeCall4 : TEST_Op<"native_code_call4"> { @@ -874,6 +874,19 @@ def : Pat<(OpNativeCodeCall4 (GetFirstI32Result $ret)), (OpNativeCodeCall5 (BindNativeCodeCallResult:$native $ret), $native)>; +def OpNativeCodeCall6 : TEST_Op<"native_code_call6"> { + let arguments = (ins I32:$input1, I32:$input2); + let results = (outs I32:$output1, I32:$output2); +} +def OpNativeCodeCall7 : TEST_Op<"native_code_call7"> { + let arguments = (ins I32:$input); + let results = (outs I32); +} +def BindMultipleNativeCodeCallResult : NativeCodeCall<"bindMultipleNativeCodeCallResult($0, $1)", 2>; +def : Pattern<(OpNativeCodeCall6 $arg1, $arg2), + [(OpNativeCodeCall7 (BindMultipleNativeCodeCallResult:$native__0 $arg1, $arg2)), + (OpNativeCodeCall7 $native__1)]>; + // Test AllAttrConstraintsOf. def OpAllAttrConstraint1 : TEST_Op<"all_attr_constraint_of1"> { let arguments = (ins I64ArrayAttr:$attr); @@ -1033,7 +1046,7 @@ // Test that we can bind to an op without results and reference it later. def : Pat<(OpSymbolBindingNoResult:$op $operand), - (NativeCodeCall<"handleNoResultOp($_builder, $0)"> $op)>; + (NativeCodeCallVoid<"handleNoResultOp($_builder, $0)"> $op)>; //===----------------------------------------------------------------------===// // Test Patterns (Attributes) diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -44,6 +44,11 @@ static Value bindNativeCodeCallResult(Value value) { return value; } +static SmallVector bindMultipleNativeCodeCallResult(Value input1, + Value input2) { + return SmallVector({input2, input1}); +} + // Test that natives calls are only called once during rewrites. // OpM_Test will return Pi, increased by 1 for each subsequent calls. // This let us check the number of times OpM_Test was called by inspecting diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -102,6 +102,16 @@ return %1 : i32 } +// CHECK-LABEL: verifyMultipleNativeCodeCallBinding +func@verifyMultipleNativeCodeCallBinding(%arg0 : i32) -> (i32) { + %0 = "test.op_k"() : () -> (i32) + %1 = "test.op_k"() : () -> (i32) + // CHECK: %[[A:.*]] = "test.native_code_call7"(%1) : (i32) -> i32 + // CHECK: %[[A:.*]] = "test.native_code_call7"(%0) : (i32) -> i32 + %2, %3 = "test.native_code_call6"(%0, %1) : (i32, i32) -> (i32, i32) + return %2 : i32 +} + // CHECK-LABEL: verifyAllAttrConstraintOf func @verifyAllAttrConstraintOf() -> (i32, i32, i32) { // CHECK: "test.all_attr_constraint_of2" diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -754,7 +754,8 @@ // NativeCodeCall will only be materialized to `os` if it is used. Here // we are handling auxiliary patterns so we want the side effect even if // NativeCodeCall is not replacing matched root op's results. - if (resultTree.isNativeCodeCall()) + if (resultTree.isNativeCodeCall() && + resultTree.getNumReturnsOfNativeCode() == 0) os << val << ";\n"; } @@ -804,11 +805,8 @@ "location directive can only be used with op creation"); } - if (resultTree.isNativeCodeCall()) { - auto symbol = handleReplaceWithNativeCodeCall(resultTree, depth); - symbolInfoMap.bindValue(symbol); - return symbol; - } + if (resultTree.isNativeCodeCall()) + return handleReplaceWithNativeCodeCall(resultTree, depth); if (resultTree.isReplaceWithValue()) return handleReplaceWithValue(resultTree).str(); @@ -948,9 +946,39 @@ } std::string symbol = tgfmt(fmt, &fmtCtx.addSubst("_loc", locToUse), attrs); - if (!tree.getSymbol().empty()) { - os << formatv("auto {0} = {1};\n", tree.getSymbol(), symbol); - symbol = tree.getSymbol().str(); + + // In general, NativeCodeCall without naming binding don't need this. To + // ensure void helper function has been correctly labeled, i.e., use + // NativeCodeCallVoid, we cache the result to a local variable so that we will + // get a compilation error in the auto-generated file. + // Example. + // // In the td file + // Pat<(...), (NativeCodeCall ...)> + // + // --- + // + // // In the auto-generated .cpp + // ... + // // Causes compilation error if Foo() returns void. + // auto nativeVar = Foo(); + // ... + if (tree.getNumReturnsOfNativeCode() != 0) { + // Determine the local variable name for return value. + std::string varName = + SymbolInfoMap::getValuePackName(tree.getSymbol()).str(); + if (varName.empty()) { + varName = formatv("nativeVar_{0}", nextValueId++); + // Register the local variable for later uses. + symbolInfoMap.bindValues(varName, tree.getNumReturnsOfNativeCode()); + } + + // Catch the return value of helper function. + os << formatv("auto {0} = {1}; (void){0};\n", varName, symbol); + + if (!tree.getSymbol().empty()) + symbol = tree.getSymbol().str(); + else + symbol = varName; } return symbol; @@ -967,8 +995,10 @@ // Otherwise this is an unbound op; we will use all its results. return pattern.getDialectOp(node).getNumResults(); } - // TODO: This considers all NativeCodeCall as returning one - // value. Enhance if multi-value ones are needed. + + if (node.isNativeCodeCall()) + return node.getNumReturnsOfNativeCode(); + return 1; } @@ -1191,8 +1221,7 @@ if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); - os << formatv("/*{0}=*/{1}", opArgName, - handleReplaceWithNativeCodeCall(subTree, depth)); + os << formatv("/*{0}=*/{1}", opArgName, childNodeNames.lookup(argIndex)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. @@ -1233,8 +1262,7 @@ if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); - os << formatv(addAttrCmd, opArgName, - handleReplaceWithNativeCodeCall(subTree, depth + 1)); + os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex)); } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern.