diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2603,15 +2603,49 @@ // Rewrite directives //===----------------------------------------------------------------------===// -// Directive used in result pattern to specify the location of the generated -// op. This directive must be used as the last argument to the op creation -// DAG construct. The arguments to location must be previously captured symbol. -def location; - // Directive used in result pattern to indicate that no new op are generated, // so to replace the matched DAG with an existing SSA value. def replaceWithValue; +// Directive used in result patterns to specify the location of the generated +// op. This directive must be used as a trailing argument to op creation or +// native code calls. +// +// Usage: +// * Create a named location: `(location "myLocation")` +// * Copy the location of a captured symbol: `(location $arg)` +// * Create a fused location: `(location "metadata", $arg0, $arg1)` + +def location; + +// Directive used in result patterns to specify return types for a created op. +// This allows ops to be created without relying on type inference with +// `OpTraits` or an op builder with deduction. +// +// This directive must be used as a trailing argument to op creation. +// +// Specify one return type with a string literal: +// +// ``` +// (AnOp $val, (returnType "$_builder.getI32Type()")) +// ``` +// +// Pass a captured value to copy its return type: +// +// ``` +// (AnOp $val, (returnType $val)); +// ``` +// +// Pass a native code call inside a DAG to create a new type with arguments. +// +// ``` +// (AnOp $val, +// (returnType (NativeCodeCall<"$_builder.getTupleType({$0})"> $val))); +// ``` +// +// Specify multiple return types with multiple of any of the above. + +def returnType; //===----------------------------------------------------------------------===// // Attribute and Type generation diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -176,6 +176,9 @@ // Returns whether this DAG represents the location of an op creation. bool isLocationDirective() const; + // Returns whether this DAG is a return type specifier. + bool isReturnTypeDirective() const; + // Returns true if this DAG node is wrapping native code call. bool isNativeCodeCall() const; diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -54,9 +54,7 @@ return isSubClassOf("EnumAttrCaseInfo"); } -bool DagLeaf::isStringAttr() const { - return isa(def); -} +bool DagLeaf::isStringAttr() const { return isa(def); } Constraint DagLeaf::getAsConstraint() const { assert((isOperandMatcher() || isAttrMatcher()) && @@ -114,7 +112,8 @@ } bool DagNode::isOperation() const { - return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); + return !isNativeCodeCall() && !isReplaceWithValue() && + !isLocationDirective() && !isReturnTypeDirective(); } llvm::StringRef DagNode::getNativeCodeTemplate() const { @@ -180,6 +179,11 @@ return dagOpDef->getName() == "location"; } +bool DagNode::isReturnTypeDirective() const { + auto *dagOpDef = cast(node->getOperator())->getDef(); + return dagOpDef->getName() == "returnType"; +} + void DagNode::print(raw_ostream &os) const { if (node) node->print(os); @@ -753,14 +757,18 @@ auto &op = getDialectOp(tree); auto numOpArgs = op.getNumArgs(); - // The pattern might have the last argument specifying the location. - bool hasLocDirective = false; - if (numTreeArgs != 0) { - if (auto lastArg = tree.getArgAsNestedDag(numTreeArgs - 1)) - hasLocDirective = lastArg.isLocationDirective(); + // The pattern might have trailing directives. + int numDirectives = 0; + for (int i = numTreeArgs - 1; i >= 0; --i) { + if (auto dagArg = tree.getArgAsNestedDag(i)) { + if (dagArg.isLocationDirective() || dagArg.isReturnTypeDirective()) + ++numDirectives; + else + break; + } } - if (numOpArgs != numTreeArgs - hasLocDirective) { + if (numOpArgs != numTreeArgs - numDirectives) { auto err = formatv("op '{0}' argument number mismatch: " "{1} in pattern vs. {2} in definition", op.getOperationName(), numTreeArgs, numOpArgs); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1317,6 +1317,64 @@ (location "named")), (location "fused", $res2, $res3))>; +//===----------------------------------------------------------------------===// +// Test Patterns (Type Builders) + +def SourceOp : TEST_Op<"source_op"> { + let arguments = (ins AnyInteger:$arg, AnyI32Attr:$tag); + let results = (outs AnyInteger); +} + +// An op without return type deduction. +def OpX : TEST_Op<"op_x"> { + let arguments = (ins AnyInteger:$input); + let results = (outs AnyInteger); +} + +// Test that ops without built-in type deduction can be created in the +// replacement DAG with an explicitly specified type. +def : Pat<(SourceOp $val, ConstantAttr:$attr), + (OpX (OpX $val, (returnType "$_builder.getI32Type()")))>; +// Test NativeCodeCall type builder can accept arguments. +def SameTypeAs : NativeCodeCall<"$0.getType()">; + +def : Pat<(SourceOp $val, ConstantAttr:$attr), + (OpX (OpX $val, (returnType (SameTypeAs $val))))>; + +// Test multiple return types. +def MakeI64Type : NativeCodeCall<"$_builder.getI64Type()">; +def MakeI32Type : NativeCodeCall<"$_builder.getI32Type()">; + +def OneToTwo : TEST_Op<"one_to_two"> { + let arguments = (ins AnyInteger); + let results = (outs AnyInteger, AnyInteger); +} + +def TwoToOne : TEST_Op<"two_to_one"> { + let arguments = (ins AnyInteger, AnyInteger); + let results = (outs AnyInteger); +} + +def : Pat<(SourceOp $val, ConstantAttr:$attr), + (TwoToOne (OpX (OneToTwo:$res__0 $val, (returnType (MakeI64Type), (MakeI32Type))), (returnType (MakeI32Type))), + (OpX $res__1, (returnType (MakeI64Type))))>; + +// Test copy value return type. +def : Pat<(SourceOp $val, ConstantAttr:$attr), + (OpX (OpX $val, (returnType $val)))>; + +// Test create multiple return types with different methods. +def : Pat<(SourceOp $val, ConstantAttr:$attr), + (TwoToOne (OneToTwo:$res__0 $val, (returnType $val, "$_builder.getI64Type()")), $res__1)>; + +//===----------------------------------------------------------------------===// +// Test Patterns (Trailing Directives) + +// Test that we can specify both `location` and `returnType` directives. +def : Pat<(SourceOp $val, ConstantAttr:$attr), + (TwoToOne (OpX $val, (returnType $val), (location "loc1")), + (OpX $val, (location "loc2"), (returnType $val)))>; + //===----------------------------------------------------------------------===// // Test Legalization //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -530,3 +530,56 @@ // CHECK: "test.op_m"(%arg0) {optional_attr = 314159265 : i32} : (i32) -> i32 return %0 : i32 } + +//===----------------------------------------------------------------------===// +// Test that ops without type deduction can be created with type builders. +//===----------------------------------------------------------------------===// + +func @explicitReturnTypeTest(%arg0 : i64) -> i8 { + %0 = "test.source_op"(%arg0) {tag = 11 : i32} : (i64) -> i8 + // CHECK: "test.op_x"(%arg0) : (i64) -> i32 + // CHECK: "test.op_x"(%0) : (i32) -> i8 + return %0 : i8 +} + +func @returnTypeBuilderTest(%arg0 : i1) -> i8 { + %0 = "test.source_op"(%arg0) {tag = 22 : i32} : (i1) -> i8 + // CHECK: "test.op_x"(%arg0) : (i1) -> i1 + // CHECK: "test.op_x"(%0) : (i1) -> i8 + return %0 : i8 +} + +func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 { + %0 = "test.source_op"(%arg0) {tag = 33 : i32} : (i1) -> i1 + // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i64, i32) + // CHECK: "test.op_x"(%0#0) : (i64) -> i32 + // CHECK: "test.op_x"(%0#1) : (i32) -> i64 + // CHECK: "test.two_to_one"(%1, %2) : (i32, i64) -> i1 + return %0 : i1 +} + +func @copyValueType(%arg0 : i8) -> i32 { + %0 = "test.source_op"(%arg0) {tag = 44 : i32} : (i8) -> i32 + // CHECK: "test.op_x"(%arg0) : (i8) -> i8 + // CHECK: "test.op_x"(%0) : (i8) -> i32 + return %0 : i32 +} + +func @multipleReturnTypeDifferent(%arg0 : i1) -> i64 { + %0 = "test.source_op"(%arg0) {tag = 55 : i32} : (i1) -> i64 + // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i1, i64) + // CHECK: "test.two_to_one"(%0#0, %0#1) : (i1, i64) -> i64 + return %0 : i64 +} + +//===----------------------------------------------------------------------===// +// Test that multiple trailing directives can be mixed in patterns. +//===----------------------------------------------------------------------===// + +func @returnTypeAndLocation(%arg0 : i32) -> i1 { + %0 = "test.source_op"(%arg0) {tag = 66 : i32} : (i32) -> i1 + // CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc1") + // CHECK: "test.op_x"(%arg0) : (i32) -> i32 loc("loc2") + // CHECK: "test.two_to_one"(%0, %1) : (i32, i32) -> i1 + return %0 : i1 +} diff --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td --- a/mlir/test/mlir-tblgen/rewriter-errors.td +++ b/mlir/test/mlir-tblgen/rewriter-errors.td @@ -1,6 +1,8 @@ // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR3 %s 2>&1 | FileCheck --check-prefix=ERROR3 %s +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR4 %s 2>&1 | FileCheck --check-prefix=ERROR4 %s +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR5 %s 2>&1 | FileCheck --check-prefix=ERROR5 %s include "mlir/IR/OpBase.td" @@ -35,3 +37,15 @@ def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg), (OpB $val, $arg)>; #endif + +#ifdef ERROR4 +// Check trying to pass op as DAG node inside ReturnTypeFunc fails. +// ERROR4: [[@LINE+1]]:1: error: nested DAG in `returnType` must be a native code +def : Pat<(OpB $val, AnyI32Attr:$attr), (OpA (OpA $val, $val, (returnType (OpA $val, $val))), $val)>; +#endif + +#ifdef ERROR5 +// Check that trying to specify explicit types at the root node fails. +// ERROR5: [[@LINE+1]]:1: error: Cannot specify explicit return types in an op +def : Pat<(OpB $val, AnyI32Attr:$attr), (OpA $val, $val, (returnType "someType()"))>; +#endif diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td --- a/mlir/test/mlir-tblgen/rewriter-indexing.td +++ b/mlir/test/mlir-tblgen/rewriter-indexing.td @@ -90,3 +90,13 @@ // CHECK: foo(rewriter, (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin())) def test5 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10), (NativeCodeCall<[{ foo($_builder, $3...) }]> $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>; + +// Check Pattern with return type builder. +def SameTypeAs : NativeCodeCall<"$0.getType()">; +// CHECK: struct test6 : public ::mlir::RewritePattern { +// CHECK: tblgen_types.push_back((*v2.begin()).getType()) +// CHECK: tblgen_types.push_back(rewriter.getI32Type()) +// CHECK: nativeVar_1 = doSomething((*v3.begin())) +// CHECK: tblgen_types.push_back(nativeVar_1) +def test6 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10), + (AOp (AOp $v1, (returnType $v2, "$_builder.getI32Type()", (NativeCodeCall<"doSomething($0)"> $v3))))>; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -127,12 +127,31 @@ // Returns the symbol of the old value serving as the replacement. StringRef handleReplaceWithValue(DagNode tree); + // Trailing directives are used at the end of DAG node argument lists to + // specify additional behaviour for op matchers and creators, etc. + struct TrailingDirectives { + // DAG node containing the `location` directive. Null if there is none. + DagNode location; + + // DAG node containing the `returnType` directive. Null if there is none. + DagNode returnType; + + // Number of found trailing directives. + int numDirectives; + }; + + // Collect any trailing directives. + TrailingDirectives getTrailingDirectives(DagNode tree); + // Returns the location value to use. - std::pair getLocation(DagNode tree); + std::string getLocation(TrailingDirectives &tail); // Returns the location value to use. std::string handleLocationDirective(DagNode tree); + // Emit return type argument. + std::string handleReturnTypeArg(DagNode returnType, int i, int depth); + // Emits the C++ statement to build a new op out of the given DAG `tree` and // returns the variable name that this op is assigned to. If the root op in // DAG `tree` has a specified name, the created op will be assigned to a @@ -271,9 +290,10 @@ capture.push_back(std::move(argName)); } - bool hasLocationDirective; - std::string locToUse; - std::tie(hasLocationDirective, locToUse) = getLocation(tree); + auto tail = getTrailingDirectives(tree); + if (tail.returnType) + PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier"); + auto locToUse = getLocation(tail); auto fmt = tree.getNativeCodeTemplate(); if (fmt.count("$_self") != 1) @@ -286,14 +306,14 @@ emitMatchCheck(opName, formatv("!failed({0})", nativeCodeCall), formatv("\"{0} return failure\"", nativeCodeCall)); - for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { auto name = tree.getArgName(i); if (!name.empty() && name != "_") { os << formatv("{0} = {1};\n", name, capture[i]); } } - for (int i = 0, e = tree.getNumArgs(); i != e; ++i) { + for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { std::string argName = capture[i]; // Handle nested DAG construct first @@ -884,6 +904,24 @@ return os.str(); } +std::string PatternEmitter::handleReturnTypeArg(DagNode returnType, int i, + int depth) { + // Nested NativeCodeCall. + if (auto dagNode = returnType.getArgAsNestedDag(i)) { + if (!dagNode.isNativeCodeCall()) + PrintFatalError(loc, "nested DAG in `returnType` must be a native code " + "call"); + return handleReplaceWithNativeCodeCall(dagNode, depth); + } + // String literal. + auto dagLeaf = returnType.getArgAsLeaf(i); + if (dagLeaf.isStringAttr()) + return tgfmt(dagLeaf.getStringAttr(), &fmtCtx); + return tgfmt( + "$0.getType()", &fmtCtx, + handleOpArgument(returnType.getArgAsLeaf(i), returnType.getArgName(i))); +} + std::string PatternEmitter::handleOpArgument(DagLeaf leaf, StringRef patArgName) { if (leaf.isStringAttr()) @@ -929,11 +967,12 @@ SmallVector attrs; - bool hasLocationDirective; - std::string locToUse; - std::tie(hasLocationDirective, locToUse) = getLocation(tree); + auto tail = getTrailingDirectives(tree); + if (tail.returnType) + PrintFatalError(loc, "`NativeCodeCall` cannot have return type specifier"); + auto locToUse = getLocation(tail); - for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { + for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { if (tree.isNestedDagArg(i)) { attrs.push_back( handleResultPattern(tree.getArgAsNestedDag(i), i, depth + 1)); @@ -1002,18 +1041,49 @@ return 1; } -std::pair PatternEmitter::getLocation(DagNode tree) { - auto numPatArgs = tree.getNumArgs(); +PatternEmitter::TrailingDirectives +PatternEmitter::getTrailingDirectives(DagNode tree) { + TrailingDirectives tail = {DagNode(nullptr), DagNode(nullptr), 0}; - if (numPatArgs != 0) { - if (auto lastArg = tree.getArgAsNestedDag(numPatArgs - 1)) - if (lastArg.isLocationDirective()) { - return std::make_pair(true, handleLocationDirective(lastArg)); - } + // Look backwards through the arguments. + auto numPatArgs = tree.getNumArgs(); + for (int i = numPatArgs - 1; i >= 0; --i) { + auto dagArg = tree.getArgAsNestedDag(i); + // A leaf is not a directive. Stop looking. + if (!dagArg) + break; + + auto isLocation = dagArg.isLocationDirective(); + auto isReturnType = dagArg.isReturnTypeDirective(); + // If encountered a DAG node that isn't a trailing directive, stop looking. + if (!(isLocation || isReturnType)) + break; + // Save the directive, but error if one of the same type was already + // found. + ++tail.numDirectives; + if (isLocation) { + if (tail.location) + PrintFatalError(loc, "`location` directive can only be specified " + "once"); + tail.location = dagArg; + } else if (isReturnType) { + if (tail.returnType) + PrintFatalError(loc, "`returnType` directive can only be specified " + "once"); + tail.returnType = dagArg; + } } + return tail; +} + +std::string +PatternEmitter::getLocation(PatternEmitter::TrailingDirectives &tail) { + if (tail.location) + return handleLocationDirective(tail.location); + // If no explicit location is given, use the default, all fused, location. - return std::make_pair(false, "odsLoc"); + return "odsLoc"; } std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, @@ -1026,11 +1096,10 @@ auto numOpArgs = resultOp.getNumArgs(); auto numPatArgs = tree.getNumArgs(); - bool hasLocationDirective; - std::string locToUse; - std::tie(hasLocationDirective, locToUse) = getLocation(tree); + auto tail = getTrailingDirectives(tree); + auto locToUse = getLocation(tail); - auto inPattern = numPatArgs - hasLocationDirective; + auto inPattern = numPatArgs - tail.numDirectives; if (numOpArgs != inPattern) { PrintFatalError(loc, formatv("resultant op '{0}' argument number mismatch: " @@ -1045,7 +1114,7 @@ // First go through all the child nodes who are nested DAG constructs to // create ops for them and remember the symbol names for them, so that we can // use the results in the current node. This happens in a recursive manner. - for (int i = 0, e = tree.getNumArgs() - hasLocationDirective; i != e; ++i) { + for (int i = 0, e = tree.getNumArgs() - tail.numDirectives; i != e; ++i) { if (auto child = tree.getArgAsNestedDag(i)) childNodeNames[i] = handleResultPattern(child, i, depth + 1); } @@ -1080,7 +1149,7 @@ bool useFirstAttr = resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); - if (isSameOperandsAndResultType || useFirstAttr) { + if (!tail.returnType && (isSameOperandsAndResultType || useFirstAttr)) { // We know how to deduce the result type for ops with these traits and we've // generated builders taking aggregate parameters. Use those builders to // create the ops. @@ -1097,7 +1166,7 @@ bool usePartialResults = valuePackName != resultValue; - if (usePartialResults || depth > 0 || resultIndex < 0) { + if (!tail.returnType && (usePartialResults || depth > 0 || resultIndex < 0)) { // For these cases (broadcastable ops, op results used both as auxiliary // values and replacement values, ops in nested patterns, auxiliary ops), we // still need to supply the result types when building the op. But because @@ -1115,10 +1184,14 @@ return resultValue; } - // If depth == 0 and resultIndex >= 0, it means we are replacing the values - // generated from the source pattern root op. Then we can use the source - // pattern's value types to determine the value type of the generated op - // here. + // If we are provided explicit return types, use them to build the op. + // However, if depth == 0 and resultIndex >= 0, it means we are replacing + // the values generated from the source pattern root op. Then we must use the + // source pattern's value types to determine the value type of the generated + // op here. + if (depth == 0 && resultIndex >= 0 && tail.returnType) + PrintFatalError(loc, "Cannot specify explicit return types in an op whose " + "return values replace the source pattern's root op"); // First prepare local variables for op arguments used in builder call. createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); @@ -1128,11 +1201,20 @@ os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " "(void)tblgen_types;\n"); int numResults = resultOp.getNumResults(); - if (numResults != 0) { - for (int i = 0; i < numResults; ++i) - os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" - " tblgen_types.push_back(v.getType());\n}\n", - resultIndex + i); + if (tail.returnType) { + auto numRetTys = tail.returnType.getNumArgs(); + for (int i = 0; i < numRetTys; ++i) { + auto varName = handleReturnTypeArg(tail.returnType, i, depth + 1); + os << "tblgen_types.push_back(" << varName << ");\n"; + } + } else { + if (numResults != 0) { + // Copy the result types from the source pattern. + for (int i = 0; i < numResults; ++i) + os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" + " tblgen_types.push_back(v.getType());\n}\n", + resultIndex + i); + } } os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " "tblgen_values, tblgen_attrs);\n",