Index: mlir/docs/Canonicalization.md =================================================================== --- mlir/docs/Canonicalization.md +++ mlir/docs/Canonicalization.md @@ -156,7 +156,7 @@ /// of the operation. The caller will remove the operation and use that /// result instead. /// -OpFoldResult MyOp::fold(ArrayRef operands) { +OpFoldResult MyOp::fold(FoldAdaptor adaptor) { ... } ``` @@ -178,19 +178,19 @@ /// the operation and use those results instead. /// /// Note that this mechanism cannot be used to remove 0-result operations. -LogicalResult MyOp::fold(ArrayRef operands, +LogicalResult MyOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { ... } ``` -In the above, for each method an `ArrayRef` is provided that -corresponds to the constant attribute value of each of the operands. These +In the above, for each method a `FoldAdaptor` is provided with getters for +each of the operands, returning the corresponding constant attribute. These operands are those that implement the `ConstantLike` trait. If any of the operands are non-constant, a null `Attribute` value is provided instead. For example, if MyOp provides three operands [`a`, `b`, `c`], but only `b` is -constant then `operands` will be of the form [Attribute(), b-value, -Attribute()]. +constant then `adaptor` will return Attribute() for `getA()` and `getC()`, +and b-value for `getB()`. Also above, is the use of `OpFoldResult`. This class represents the possible result of folding an operation result: either an SSA `Value`, or an Index: mlir/docs/DefiningDialects/_index.md =================================================================== --- mlir/docs/DefiningDialects/_index.md +++ mlir/docs/DefiningDialects/_index.md @@ -255,6 +255,31 @@ unsigned argIndex, NamedAttribute attribute); ``` +#### `useFoldAPI` + +There are currently two possible values that are allowed to be assigned to this +field: +* `kEmitFoldAdaptorFolder` generates a `fold` method making use of the op's + `FoldAdaptor` to allow access of operands via convenient getter. + + Generated code example: + ```cpp + OpFoldResult fold(FoldAdaptor adaptor); + // or + LogicalResult fold(FoldAdaptor adaptor, SmallVectorImpl& results); + ``` +* `kEmitRawAttributesFolder` generates the deprecated legacy `fold` + method, containing `ArrayRef` in the parameter list instead of + the op's `FoldAdaptor`. This API is scheduled for removal and should not be + used by new dialects. + + Generated code example: + ```cpp + OpFoldResult fold(ArrayRef operands); + // or + LogicalResult fold(ArrayRef operands, SmallVectorImpl& results); + ``` + ### Operation Interface Fallback Some dialects have an open ecosystem and don't register all of the possible operations. In such Index: mlir/docs/Tutorials/Toy/Ch-7.md =================================================================== --- mlir/docs/Tutorials/Toy/Ch-7.md +++ mlir/docs/Tutorials/Toy/Ch-7.md @@ -458,16 +458,16 @@ ```c++ /// Fold constants. -OpFoldResult ConstantOp::fold(ArrayRef operands) { return value(); } +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return value(); } /// Fold struct constants. -OpFoldResult StructConstantOp::fold(ArrayRef operands) { +OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return value(); } /// Fold simple struct access operations that access into a constant. -OpFoldResult StructAccessOp::fold(ArrayRef operands) { - auto structAttr = operands.front().dyn_cast_or_null(); +OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) { + auto structAttr = adaptor.getInput().dyn_cast_or_null(); if (!structAttr) return nullptr; Index: mlir/examples/toy/Ch7/include/toy/Ops.td =================================================================== --- mlir/examples/toy/Ch7/include/toy/Ops.td +++ mlir/examples/toy/Ch7/include/toy/Ops.td @@ -33,6 +33,8 @@ // We set this bit to generate the declarations for the dialect's type parsing // and printing hooks. let useDefaultTypePrinterParser = 1; + + let useFoldAPI = kEmitFoldAdaptorFolder; } // Base class for toy dialect operations. This operation inherits from the base Index: mlir/examples/toy/Ch7/mlir/ToyCombine.cpp =================================================================== --- mlir/examples/toy/Ch7/mlir/ToyCombine.cpp +++ mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -24,18 +24,14 @@ } // namespace /// Fold constants. -OpFoldResult ConstantOp::fold(ArrayRef operands) { - return getValue(); -} +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } /// Fold struct constants. -OpFoldResult StructConstantOp::fold(ArrayRef operands) { - return getValue(); -} +OpFoldResult StructConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } /// Fold simple struct access operations that access into a constant. -OpFoldResult StructAccessOp::fold(ArrayRef operands) { - auto structAttr = operands.front().dyn_cast_or_null(); +OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) { + auto structAttr = adaptor.getInput().dyn_cast_or_null(); if (!structAttr) return nullptr; Index: mlir/include/mlir/IR/DialectBase.td =================================================================== --- mlir/include/mlir/IR/DialectBase.td +++ mlir/include/mlir/IR/DialectBase.td @@ -17,6 +17,14 @@ // Dialect definitions //===----------------------------------------------------------------------===// +// Generate 'fold' method with 'ArrayRef' parameter. +// New code should prefer using 'kEmitFoldAdaptorFolder' and +// consider 'kEmitRawAttributesFolder' deprecated and to be +// removed in the future. +defvar kEmitRawAttributesFolder = 0; +// Generate 'fold' method with 'FoldAdaptor' parameter. +defvar kEmitFoldAdaptorFolder = 1; + class Dialect { // The name of the dialect. string name = ?; @@ -85,6 +93,9 @@ // If this dialect can be extended at runtime with new operations or types. bit isExtensible = 0; + + // Fold API to use for operations in this dialect. + int useFoldAPI = kEmitRawAttributesFolder; } #endif // DIALECTBASE_TD Index: mlir/include/mlir/IR/OpDefinition.h =================================================================== --- mlir/include/mlir/IR/OpDefinition.h +++ mlir/include/mlir/IR/OpDefinition.h @@ -1687,18 +1687,35 @@ private: /// Trait to check if T provides a 'fold' method for a single result op. template - using has_single_result_fold = + using has_single_result_fold_t = decltype(std::declval().fold(std::declval>())); template - using detect_has_single_result_fold = - llvm::is_detected; + constexpr static bool has_single_result_fold_v = + llvm::is_detected::value; /// Trait to check if T provides a general 'fold' method. template - using has_fold = decltype(std::declval().fold( + using has_fold_t = decltype(std::declval().fold( std::declval>(), std::declval &>())); template - using detect_has_fold = llvm::is_detected; + constexpr static bool has_fold_v = llvm::is_detected::value; + /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a + /// single result op. + template + using has_fold_adaptor_single_result_fold_t = + decltype(std::declval().fold(std::declval())); + template + constexpr static bool has_fold_adaptor_single_result_v = + llvm::is_detected::value; + /// Trait to check if T provides a general 'fold' method with a FoldAdaptor. + template + using has_fold_adaptor_fold_t = decltype(std::declval().fold( + std::declval(), + std::declval &>())); + template + constexpr static bool has_fold_adaptor_v = + llvm::is_detected::value; + /// Trait to check if T provides a 'print' method. template using has_print = @@ -1747,13 +1764,14 @@ // If the operation is single result and defines a `fold` method. if constexpr (llvm::is_one_of, Traits...>::value && - detect_has_single_result_fold::value) + (has_single_result_fold_v || + has_fold_adaptor_single_result_v)) return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { return foldSingleResultHook(op, operands, results); }; // The operation is not single result and defines a `fold` method. - if constexpr (detect_has_fold::value) + if constexpr (has_fold_v || has_fold_adaptor_v) return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { return foldHook(op, operands, results); @@ -1772,7 +1790,12 @@ static LogicalResult foldSingleResultHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - OpFoldResult result = cast(op).fold(operands); + OpFoldResult result; + if constexpr (has_fold_adaptor_single_result_v) + result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( + operands, op->getAttrDictionary(), op->getRegions())); + else + result = cast(op).fold(operands); // If the fold failed or was in-place, try to fold the traits of the // operation. @@ -1789,7 +1812,14 @@ template static LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - LogicalResult result = cast(op).fold(operands, results); + auto result = LogicalResult::failure(); + if constexpr (has_fold_adaptor_v) + result = cast(op).fold( + typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(), + op->getRegions()), + results); + else + result = cast(op).fold(operands, results); // If the fold failed or was in-place, try to fold the traits of the // operation. Index: mlir/include/mlir/TableGen/Dialect.h =================================================================== --- mlir/include/mlir/TableGen/Dialect.h +++ mlir/include/mlir/TableGen/Dialect.h @@ -86,6 +86,15 @@ /// operations or types. bool isExtensible() const; + enum class FolderAPI { + RawAttributes = 0, /// fold method with ArrayRef. + FolderAdaptor = 1, /// fold method with the operation's FoldAdaptor. + }; + + /// Returns the folder API that should be emitted for operations in this + /// dialect. + FolderAPI getFolderAPI() const; + // Returns whether two dialects are equal by checking the equality of the // underlying record. bool operator==(const Dialect &other) const; Index: mlir/include/mlir/TableGen/Operator.h =================================================================== --- mlir/include/mlir/TableGen/Operator.h +++ mlir/include/mlir/TableGen/Operator.h @@ -314,6 +314,8 @@ /// Returns the remove name for the accessor of `name`. std::string getRemoverName(StringRef name) const; + bool hasFolder() const; + private: /// Populates the vectors containing operands, attributes, results and traits. void populateOpStructure(); Index: mlir/lib/TableGen/Dialect.cpp =================================================================== --- mlir/lib/TableGen/Dialect.cpp +++ mlir/lib/TableGen/Dialect.cpp @@ -102,6 +102,15 @@ return def->getValueAsBit("isExtensible"); } +Dialect::FolderAPI Dialect::getFolderAPI() const { + int64_t value = def->getValueAsInt("useFoldAPI"); + if (value < static_cast(FolderAPI::RawAttributes) || + value > static_cast(FolderAPI::FolderAdaptor)) + llvm::PrintFatalError(def->getLoc(), "Invalid fold api value"); + + return static_cast(value); +} + bool Dialect::operator==(const Dialect &other) const { return def == other.def; } Index: mlir/lib/TableGen/Operator.cpp =================================================================== --- mlir/lib/TableGen/Operator.cpp +++ mlir/lib/TableGen/Operator.cpp @@ -745,3 +745,7 @@ std::string Operator::getRemoverName(StringRef name) const { return "remove" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); } + +bool Operator::hasFolder() const { + return def.getValueAsBit("hasFolder"); +} Index: mlir/test/IR/test-fold-adaptor.mlir =================================================================== --- /dev/null +++ mlir/test/IR/test-fold-adaptor.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s + +func.func @test() -> i32 { + %c5 = "test.constant"() {value = 5 : i32} : () -> i32 + %c1 = "test.constant"() {value = 1 : i32} : () -> i32 + %c2 = "test.constant"() {value = 2 : i32} : () -> i32 + %c3 = "test.constant"() {value = 3 : i32} : () -> i32 + %res = test.fold_with_fold_adaptor %c5, [ %c1, %c2], { (%c3), (%c3) } { + %c0 = "test.constant"() {value = 0 : i32} : () -> i32 + } + return %res : i32 +} + +// CHECK-LABEL: func.func @test +// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 33 : i32} +// CHECK-NEXT: return %[[C]] Index: mlir/test/lib/Dialect/Test/TestDialect.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.cpp +++ mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -33,6 +33,8 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" +#include + // Include this before the using namespace lines below to // test that we don't have namespace dependencies. #include "TestOpsDialect.cpp.inc" @@ -1126,6 +1128,25 @@ return getOperand(); } +OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { + int64_t sum = 0; + if (auto value = dyn_cast_or_null(adaptor.getOp())) + sum += value.getValue().getSExtValue(); + + for (Attribute attr : adaptor.getVariadic()) + if (auto value = dyn_cast_or_null(attr)) + sum += 2 * value.getValue().getSExtValue(); + + for (ArrayRef attrs : adaptor.getVarOfVar()) + for (Attribute attr : attrs) + if (auto value = dyn_cast_or_null(attr)) + sum += 3 * value.getValue().getSExtValue(); + + sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end()); + + return IntegerAttr::get(getType(), sum); +} + LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, std::optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -1297,6 +1297,31 @@ }]; } +def TestOpFoldWithFoldAdaptor + : TEST_Op<"fold_with_fold_adaptor", + [AttrSizedOperandSegments, NoTerminator]> { + let arguments = (ins + I32:$op, + DenseI32ArrayAttr:$attr, + Variadic:$variadic, + VariadicOfVariadic:$var_of_var + ); + + let results = (outs I32:$res); + + let regions = (region AnyRegion:$body); + + let assemblyFormat = [{ + $op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword + }]; + + let hasFolder = 0; + + let extraClassDeclaration = [{ + ::mlir::OpFoldResult fold(FoldAdaptor adaptor); + }]; +} + // An op that always fold itself. def TestPassthroughFold : TEST_Op<"passthrough_fold"> { let arguments = (ins AnyType:$op); Index: mlir/test/mlir-tblgen/has-fold-invalid-values.td =================================================================== --- /dev/null +++ mlir/test/mlir-tblgen/has-fold-invalid-values.td @@ -0,0 +1,15 @@ +// RUN: not mlir-tblgen -gen-op-decls -I %S/../../include %s 2>&1 | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; + let cppNamespace = "NS"; + let useFoldAPI = 3; +} + +def InvalidValue_Op : Op { + let hasFolder = 1; +} + +// CHECK: Invalid fold api value Index: mlir/test/mlir-tblgen/op-decl-and-defs.td =================================================================== --- mlir/test/mlir-tblgen/op-decl-and-defs.td +++ mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -317,6 +317,29 @@ // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +def TestWithNewFold_Dialect : Dialect { + let name = "test"; + let cppNamespace = "::mlir::testWithFold"; + let useFoldAPI = kEmitFoldAdaptorFolder; +} + +def NS_MOp : Op { + let results = (outs AnyType:$res); + + let hasFolder = 1; +} + +// CHECK-LABEL: class MOp : +// CHECK: ::mlir::OpFoldResult fold(FoldAdaptor adaptor); + +def NS_NOp : Op { + let results = (outs AnyType:$res1, AnyType:$res2); + + let hasFolder = 1; +} + +// CHECK-LABEL: class NOp : +// CHECK: ::mlir::LogicalResult fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results); // Test that type defs have the proper namespaces when used as a constraint. // --- Index: mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2323,25 +2323,29 @@ } void OpEmitter::genFolderDecls() { + if (!op.hasFolder()) + return; + + Dialect::FolderAPI folderApi = op.getDialect().getFolderAPI(); + SmallVector paramList; + if (folderApi == Dialect::FolderAPI::RawAttributes) + paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands"); + else + paramList.emplace_back("FoldAdaptor", "adaptor"); + + StringRef retType; bool hasSingleResult = op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0; - - if (def.getValueAsBit("hasFolder")) { - if (hasSingleResult) { - auto *m = opClass.declareMethod( - "::mlir::OpFoldResult", "fold", - MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands")); - ERROR_IF_PRUNED(m, "operands", op); - } else { - SmallVector paramList; - paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands"); - paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", - "results"); - auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold", - std::move(paramList)); - ERROR_IF_PRUNED(m, "fold", op); - } + if (hasSingleResult) + retType = "::mlir::OpFoldResult"; + else { + paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", + "results"); + retType = "::mlir::LogicalResult"; } + + auto *m = opClass.declareMethod(retType, "fold", std::move(paramList)); + ERROR_IF_PRUNED(m, "fold", op); } void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {