diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2599,6 +2599,47 @@ def ConstantLikeMatcher : NativeCodeCall<"success(matchPattern($_self->getResult(0), m_Constant(&$0)))">; +// Create an op in the replacement pattern with explicitly specified return +// types or type builders. This allows ops to be created with relying on type +// inference with OpTraits or an op builder with type deduction. +// +// ``` +// def : Pat<(SourceOp $val), +// (NewOp (OpWithResultType $val) +// )>; +// ``` +// +// Specify multiple explicit return types with `OpWithResultTypes`. +// +// ## Type Builders with NativeCodeCall +// +// By passing native code calls instead of a string, you can manipulate and/or +// create types using bound values in the pattern. +// +// ``` +// def SameTypeAs : NativeCodeCall<"$0.getType()">; +// +// def : Pat<(SourceOp $val), +// (NewOp (OpWithResultTypeBuilder $val) +// )>; +// ``` + +class OpWithResultTypeBuilders types> { + Op operation = op; + list typeBuilders = types; +} + +class OpWithResultTypeBuilder : + OpWithResultTypeBuilders; + +class OpWithResultTypes types> : + OpWithResultTypeBuilders))>; + +class OpWithResultType : + OpWithResultTypes; + //===----------------------------------------------------------------------===// // Rewrite directives //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/TableGen/Pattern.h b/mlir/include/mlir/TableGen/Pattern.h --- a/mlir/include/mlir/TableGen/Pattern.h +++ b/mlir/include/mlir/TableGen/Pattern.h @@ -147,6 +147,9 @@ // Returns the operator wrapper object corresponding to the dialect op matched // by this DAG. The operator wrapper will be queried from the given `mapper` // and created in it if not existing. + // + // If this DAG node is an OpWithResultTypeBuilders, return the wrapped + // operator object. Operator &getDialectOp(RecordOperatorMap *mapper) const; // Returns the number of operations recursively involved in the DAG tree @@ -179,7 +182,11 @@ // Returns true if this DAG node is wrapping native code call. bool isNativeCodeCall() const; - // Returns true if this DAG node is an operation. + // Returns true if this DAG node is an operation with a result type builder. + bool isOpWithResultTypeBuilders() const; + + // Returns true if this DAG node is an operation or an operation wrapped in a + // OpWithResultTypeBuilders object. bool isOperation() const; // Returns the native code call template inside this DAG node. @@ -191,6 +198,11 @@ // Precondition: isNativeCodeCall() int getNumReturnsOfNativeCode() const; + // Return the explicit type builders for the op in this DAG node. The type + // builders are native code calls that may have arguments. + // Precondition: isOpWithResultTypeBuilders() + llvm::SmallVector getOpResultTypeBuilders() const; + void print(raw_ostream &os) const; private: diff --git a/mlir/lib/TableGen/Pattern.cpp b/mlir/lib/TableGen/Pattern.cpp --- a/mlir/lib/TableGen/Pattern.cpp +++ b/mlir/lib/TableGen/Pattern.cpp @@ -54,9 +54,7 @@ return isSubClassOf("EnumAttrCaseInfo"); } -bool DagLeaf::isStringAttr() const { - return isa(def); -} +bool DagLeaf::isStringAttr() const { return isa(def); } Constraint DagLeaf::getAsConstraint() const { assert((isOperandMatcher() || isAttrMatcher()) && @@ -113,8 +111,18 @@ return false; } +bool DagNode::isOpWithResultTypeBuilders() const { + if (auto *defInit = dyn_cast_or_null(node->getOperator())) + return defInit->getDef()->isSubClassOf("OpWithResultTypeBuilders"); + return false; +} + bool DagNode::isOperation() const { - return !isNativeCodeCall() && !isReplaceWithValue() && !isLocationDirective(); + // Treat OpWithResultTypeBuilders as an Operator so that it inherits all the + // logic of using Operators inside patterns. + return isOpWithResultTypeBuilders() || + (!isNativeCodeCall() && !isReplaceWithValue() && + !isLocationDirective()); } llvm::StringRef DagNode::getNativeCodeTemplate() const { @@ -135,6 +143,11 @@ Operator &DagNode::getDialectOp(RecordOperatorMap *mapper) const { llvm::Record *opDef = cast(node->getOperator())->getDef(); + + // Retrieve the wrapped operator object from an OpWithResultTypeBuilders + if (isOpWithResultTypeBuilders()) + opDef = opDef->getValueAsDef("operation"); + auto it = mapper->find(opDef); if (it != mapper->end()) return *it->second; @@ -142,6 +155,18 @@ .first->second; } +SmallVector DagNode::getOpResultTypeBuilders() const { + assert(isOpWithResultTypeBuilders() && + "the DAG node must be OpWithResultTypeBuilders"); + llvm::ListInit *builderDefs = cast(node->getOperator()) + ->getDef() + ->getValueAsListInit("typeBuilders"); + SmallVector builders; + for (auto *builderDef : *builderDefs) + builders.push_back(DagNode(cast(builderDef))); + return builders; +} + int DagNode::getNumOps() const { int count = isReplaceWithValue() ? 0 : 1; for (int i = 0, e = getNumArgs(); i != e; ++i) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1317,6 +1317,54 @@ (location "named")), (location "fused", $res2, $res3))>; +//===----------------------------------------------------------------------===// +// Test Patterns (Type Builders) + +def SourceOp : TEST_Op<"source_op"> { + let arguments = (ins AnyInteger:$arg, AnyI32Attr:$tag); + let results = (outs AnyInteger); +} + +def OpWithoutTypeDeduction : TEST_Op<"op_without_type_deduction"> { + let arguments = (ins AnyInteger:$input); + let results = (outs AnyInteger); +} + +// Test that ops without built-in type deduction can be created in the +// replacement DAG with an explicitly specified type. +def : Pat<(SourceOp $val, ConstantAttr:$attr), + (OpWithoutTypeDeduction + (OpWithResultType $val))>; + +// Test NativeCodeCall type builder can accept arguments. +def SameTypeAs : NativeCodeCall<"$0.getType()">; + +def : Pat<(SourceOp $val, ConstantAttr:$attr), + (OpWithoutTypeDeduction + (OpWithResultTypeBuilder $val))>; + +// Test multiple return types. +def MakeI64Type : NativeCodeCall<"$_builder.getI64Type()">; +def MakeI32Type : NativeCodeCall<"$_builder.getI32Type()">; + +def OneToTwo : TEST_Op<"one_to_two"> { + let arguments = (ins AnyInteger); + let results = (outs AnyInteger, AnyInteger); +} + +def TwoToOne : TEST_Op<"two_to_one"> { + let arguments = (ins AnyInteger, AnyInteger); + let results = (outs AnyInteger); +} + +def : Pat<(SourceOp $val, ConstantAttr:$attr), + (TwoToOne + (OpWithResultTypeBuilder + (OpWithResultTypeBuilders:$res__0 $val)), + (OpWithResultTypeBuilder $res__1))>; + //===----------------------------------------------------------------------===// // Test Legalization //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -530,3 +530,30 @@ // CHECK: "test.op_m"(%arg0) {optional_attr = 314159265 : i32} : (i32) -> i32 return %0 : i32 } + +//===----------------------------------------------------------------------===// +// Test that ops without type deduction can be created with type builders. +//===----------------------------------------------------------------------===// + +func @explicitReturnTypeTest(%arg0 : i64) -> i8 { + %0 = "test.source_op"(%arg0) {tag = 11 : i32} : (i64) -> i8 + // CHECK: "test.op_without_type_deduction"(%arg0) : (i64) -> i32 + // CHECK: "test.op_without_type_deduction"(%0) : (i32) -> i8 + return %0 : i8 +} + +func @returnTypeBuilderTest(%arg0 : i1) -> i8 { + %0 = "test.source_op"(%arg0) {tag = 22 : i32} : (i1) -> i8 + // CHECK: "test.op_without_type_deduction"(%arg0) : (i1) -> i1 + // CHECK: "test.op_without_type_deduction"(%0) : (i1) -> i8 + return %0 : i8 +} + +func @multipleReturnTypeBuildTest(%arg0 : i1) -> i1 { + %0 = "test.source_op"(%arg0) {tag = 33 : i32} : (i1) -> i1 + // CHECK: "test.one_to_two"(%arg0) : (i1) -> (i64, i32) + // CHECK: "test.op_without_type_deduction"(%0#0) : (i64) -> i32 + // CHECK: "test.op_without_type_deduction"(%0#1) : (i32) -> i64 + // CHECK: "test.two_to_one"(%1, %2) : (i32, i64) -> i1 + return %0 : i1 +} diff --git a/mlir/test/mlir-tblgen/rewriter-errors.td b/mlir/test/mlir-tblgen/rewriter-errors.td --- a/mlir/test/mlir-tblgen/rewriter-errors.td +++ b/mlir/test/mlir-tblgen/rewriter-errors.td @@ -1,6 +1,7 @@ // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR1 %s 2>&1 | FileCheck --check-prefix=ERROR1 %s // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR2 %s 2>&1 | FileCheck --check-prefix=ERROR2 %s // RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR3 %s 2>&1 | FileCheck --check-prefix=ERROR3 %s +// RUN: not mlir-tblgen -gen-rewriters -I %S/../../include -DERROR4 %s 2>&1 | FileCheck --check-prefix=ERROR4 %s include "mlir/IR/OpBase.td" @@ -35,3 +36,10 @@ def : Pat<(OpA (NativeMatcher (OpB $val, $unused)), AnyI32Attr:$arg), (OpB $val, $arg)>; #endif + +#ifdef ERROR4 +// Check trying to pass op as DAG node inside OpWithResultTypeBuilder fails. +class OpAWithTypeBuilder : OpWithResultTypeBuilder; +// ERROR4: [[@LINE+1]]:1: error: expected NativeCodeCall as type builder +def : Pat<(OpB $val, AnyI32Attr:$attr), (OpA (OpAWithTypeBuilder<(OpA $val, $val)> $val, $val), $val)>; +#endif diff --git a/mlir/test/mlir-tblgen/rewriter-indexing.td b/mlir/test/mlir-tblgen/rewriter-indexing.td --- a/mlir/test/mlir-tblgen/rewriter-indexing.td +++ b/mlir/test/mlir-tblgen/rewriter-indexing.td @@ -90,3 +90,11 @@ // CHECK: foo(rewriter, (*v4.begin()), (*v5.begin()), (*v6.begin()), (*v7.begin()), (*v8.begin()), (*v9.begin()), (*v10.begin())) def test5 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10), (NativeCodeCall<[{ foo($_builder, $3...) }]> $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10)>; + +// Check Pattern with return type builder. +def SameTypeAs : NativeCodeCall<"$0.getType()">; +// CHECK: struct test6 : public ::mlir::RewritePattern { +// CHECK: (*v2.begin()).getType() +// CHECK: tblgen_types.push_back(nativeVar_1); +def test6 : Pat<(DOp $v1, $v2, $v3, $v4, $v5, $v6, $v7, $v8, $v9, $v10), + (AOp (OpWithResultTypeBuilder $v1))>; diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1080,7 +1080,8 @@ bool useFirstAttr = resultOp.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType"); - if (isSameOperandsAndResultType || useFirstAttr) { + if (!tree.isOpWithResultTypeBuilders() && + (isSameOperandsAndResultType || useFirstAttr)) { // We know how to deduce the result type for ops with these traits and we've // generated builders taking aggregate parameters. Use those builders to // create the ops. @@ -1097,7 +1098,8 @@ bool usePartialResults = valuePackName != resultValue; - if (usePartialResults || depth > 0 || resultIndex < 0) { + if (!tree.isOpWithResultTypeBuilders() && + (usePartialResults || depth > 0 || resultIndex < 0)) { // For these cases (broadcastable ops, op results used both as auxiliary // values and replacement values, ops in nested patterns, auxiliary ops), we // still need to supply the result types when building the op. But because @@ -1115,10 +1117,11 @@ return resultValue; } - // If depth == 0 and resultIndex >= 0, it means we are replacing the values - // generated from the source pattern root op. Then we can use the source - // pattern's value types to determine the value type of the generated op - // here. + // If we are provided explicit return types, use them to build the op. + // Otherwise, if depth == 0 and resultIndex >= 0, it means we are replacing + // the values generated from the source pattern root op. Then we can use the + // source pattern's value types to determine the value type of the generated + // op here. // First prepare local variables for op arguments used in builder call. createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); @@ -1128,11 +1131,24 @@ os.indent() << formatv("::mlir::SmallVector<::mlir::Type, 4> tblgen_types; " "(void)tblgen_types;\n"); int numResults = resultOp.getNumResults(); - if (numResults != 0) { - for (int i = 0; i < numResults; ++i) - os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" - " tblgen_types.push_back(v.getType());\n}\n", - resultIndex + i); + if (tree.isOpWithResultTypeBuilders()) { + // Use the explicitly specified types. + for (auto &builder : tree.getOpResultTypeBuilders()) { + if (!builder.isNativeCodeCall()) + PrintFatalError(loc, "expected NativeCodeCall as type builder in " + "DAG node OpWithResultTypeBuilders"); + + auto typeVar = handleReplaceWithNativeCodeCall(builder, depth + 1); + os << formatv("tblgen_types.push_back({0});\n", typeVar); + } + } else { + if (numResults != 0) { + // Copy the result types from the source pattern. + for (int i = 0; i < numResults; ++i) + os << formatv("for (auto v: castedOp0.getODSResults({0})) {{\n" + " tblgen_types.push_back(v.getType());\n}\n", + resultIndex + i); + } } os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " "tblgen_values, tblgen_attrs);\n",