Index: mlir/docs/Canonicalization.md =================================================================== --- mlir/docs/Canonicalization.md +++ mlir/docs/Canonicalization.md @@ -119,13 +119,13 @@ "local" transformation, and can be invoked without the need for a pattern rewriter. -In [ODS](DefiningDialects/Operations.md), an operation can set the `hasFolder` bit to generate -a declaration for the `fold` method. This method takes on a different form, -depending on the structure of the operation. +In [ODS](DefiningDialects/Operations.md), an operation can set the `hasFolder` +field to `kEmitFoldAdaptorFolder` to generate a declaration for the `fold` method. +This method takes on a different form, depending on the structure of the operation. ```tablegen def MyOp : ... { - let hasFolder = 1; + let hasFolder = kEmitFoldAdaptorFolder; } ``` @@ -143,7 +143,7 @@ /// of the operation. The caller will remove the operation and use that /// result instead. /// -OpFoldResult MyOp::fold(ArrayRef operands) { +OpFoldResult MyOp::fold(FoldAdaptor adaptor) { ... } ``` @@ -165,19 +165,19 @@ /// the operation and use those results instead. /// /// Note that this mechanism cannot be used to remove 0-result operations. -LogicalResult MyOp::fold(ArrayRef operands, +LogicalResult MyOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { ... } ``` -In the above, for each method an `ArrayRef` is provided that -corresponds to the constant attribute value of each of the operands. These +In the above, for each method a `FoldAdaptor` is provided with getters for +each of the operands, returning the corresponding constant attribute. These operands are those that implement the `ConstantLike` trait. If any of the operands are non-constant, a null `Attribute` value is provided instead. For example, if MyOp provides three operands [`a`, `b`, `c`], but only `b` is -constant then `operands` will be of the form [Attribute(), b-value, -Attribute()]. +constant then `adaptor` will return Attribute() for `getA()` and `getC()`, +and b-value for `getB()`. Also above, is the use of `OpFoldResult`. This class represents the possible result of folding an operation result: either an SSA `Value`, or an Index: mlir/docs/DefiningDialects/Operations.md =================================================================== --- mlir/docs/DefiningDialects/Operations.md +++ mlir/docs/DefiningDialects/Operations.md @@ -1019,8 +1019,18 @@ ### `hasFolder` -This boolean field indicate whether general folding rules have been defined for -this operation. If it is `1`, then `::fold()` should be defined. +This int field indicate whether general folding rules have been defined for +this operation. + +There are currently 3 possible values that are allowed to be assigned to this +field: + +* `kEmitNoFolder` or `0`, the default, causes no `fold` method to be defined. +* `kEmitFoldAdaptorFolder` should be used to generate a `fold` method. Its + signature makes use of the op's `FoldAdaptor`. +* `kEmitRawAttributesFolder` or `1` generates the deprecated legacy `fold` + method, containing `ArrayRef` in the parameter list instead of + the op's `FoldAdaptor`. ### Extra declarations Index: mlir/docs/Tutorials/Toy/Ch-7.md =================================================================== --- mlir/docs/Tutorials/Toy/Ch-7.md +++ mlir/docs/Tutorials/Toy/Ch-7.md @@ -452,22 +452,22 @@ We have several `toy.struct_access` operations that access into a `toy.struct_constant`. As detailed in [chapter 3](Ch-3.md) (FoldConstantReshape), -we can add folders for these `toy` operations by setting the `hasFolder` bit -on the operation definition and providing a definition of the `*Op::fold` -method. +we can add folders for these `toy` operations by setting the `hasFolder` value +on the operation definition to `kEmitFoldAdaptorFolder` and providing a definition +of the `*Op::fold` method. ```c++ /// Fold constants. -OpFoldResult ConstantOp::fold(ArrayRef operands) { return value(); } +OpFoldResult ConstantOp::fold(FoldAdaptor) { return value(); } /// Fold struct constants. -OpFoldResult StructConstantOp::fold(ArrayRef operands) { +OpFoldResult StructConstantOp::fold(FoldAdaptor) { return value(); } /// Fold simple struct access operations that access into a constant. -OpFoldResult StructAccessOp::fold(ArrayRef operands) { - auto structAttr = operands.front().dyn_cast_or_null(); +OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) { + auto structAttr = adaptor.getInput().dyn_cast_or_null(); if (!structAttr) return nullptr; Index: mlir/examples/toy/Ch7/include/toy/Ops.td =================================================================== --- mlir/examples/toy/Ch7/include/toy/Ops.td +++ mlir/examples/toy/Ch7/include/toy/Ops.td @@ -107,7 +107,7 @@ let hasVerifier = 1; // Set the folder bit so that we can implement constant folders. - let hasFolder = 1; + let hasFolder = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// @@ -390,7 +390,7 @@ let hasVerifier = 1; // Set the folder bit so that we can fold constant accesses. - let hasFolder = 1; + let hasFolder = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// @@ -418,7 +418,7 @@ // Indicate that additional verification for this operation is necessary. let hasVerifier = 1; - let hasFolder = 1; + let hasFolder = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// Index: mlir/examples/toy/Ch7/mlir/ToyCombine.cpp =================================================================== --- mlir/examples/toy/Ch7/mlir/ToyCombine.cpp +++ mlir/examples/toy/Ch7/mlir/ToyCombine.cpp @@ -24,18 +24,14 @@ } // namespace /// Fold constants. -OpFoldResult ConstantOp::fold(ArrayRef operands) { - return getValue(); -} +OpFoldResult ConstantOp::fold(FoldAdaptor) { return getValue(); } /// Fold struct constants. -OpFoldResult StructConstantOp::fold(ArrayRef operands) { - return getValue(); -} +OpFoldResult StructConstantOp::fold(FoldAdaptor) { return getValue(); } /// Fold simple struct access operations that access into a constant. -OpFoldResult StructAccessOp::fold(ArrayRef operands) { - auto structAttr = operands.front().dyn_cast_or_null(); +OpFoldResult StructAccessOp::fold(FoldAdaptor adaptor) { + auto structAttr = adaptor.getInput().dyn_cast_or_null(); if (!structAttr) return nullptr; Index: mlir/include/mlir/IR/OpBase.td =================================================================== --- mlir/include/mlir/IR/OpBase.td +++ mlir/include/mlir/IR/OpBase.td @@ -2191,6 +2191,17 @@ list decorators = []> : OpVariable; +// "Enum" values for 'hasFolder' of 'Op'. +// Generate no 'fold' method (Default). +defvar kEmitNoFolder = 0; +// Generate 'fold' method with 'ArrayRef' parameter. +// New code should prefer using 'kEmitFoldAdaptorFolder' and +// consider 'kEmitRawAttributesFolder' deprecated and to be +// removed in the future. +defvar kEmitRawAttributesFolder = 1; +// Generate 'fold' method with 'FoldAdaptor' parameter. +defvar kEmitFoldAdaptorFolder = 2; + // Base class for all ops. class Op props = []> { // The dialect of the op. @@ -2291,7 +2302,7 @@ bit hasCanonicalizeMethod = 0; // Whether this op has a folder. - bit hasFolder = 0; + int hasFolder = kEmitNoFolder; // Op traits. // Note: The list of traits will be uniqued by ODS. Index: mlir/include/mlir/IR/OpDefinition.h =================================================================== --- mlir/include/mlir/IR/OpDefinition.h +++ mlir/include/mlir/IR/OpDefinition.h @@ -1699,6 +1699,23 @@ std::declval &>())); template using detect_has_fold = llvm::is_detected; + /// Trait to check if T provides a 'fold' method with a FoldAdaptor for a + /// single result op. + template + using has_fold_adaptor_single_result_fold = + decltype(std::declval().fold(std::declval())); + template + using detect_has_fold_adaptor_single_result_fold = + llvm::is_detected; + /// Trait to check if T provides a general 'fold' method with a FoldAdaptor. + template + using has_fold_adaptor_fold = decltype(std::declval().fold( + std::declval(), + std::declval &>())); + template + using detect_has_fold_adaptor_fold = + llvm::is_detected; + /// Trait to check if T provides a 'print' method. template using has_print = @@ -1747,13 +1764,17 @@ // If the operation is single result and defines a `fold` method. if constexpr (llvm::is_one_of, Traits...>::value && - detect_has_single_result_fold::value) + std::disjunction_v< + detect_has_single_result_fold, + detect_has_fold_adaptor_single_result_fold>) return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { return foldSingleResultHook(op, operands, results); }; // The operation is not single result and defines a `fold` method. - if constexpr (detect_has_fold::value) + if constexpr (std::disjunction_v< + detect_has_fold, + detect_has_fold_adaptor_fold>) return [](Operation *op, ArrayRef operands, SmallVectorImpl &results) { return foldHook(op, operands, results); @@ -1772,7 +1793,13 @@ static LogicalResult foldSingleResultHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - OpFoldResult result = cast(op).fold(operands); + OpFoldResult result; + if constexpr (detect_has_fold_adaptor_single_result_fold< + ConcreteOpT>::value) + result = cast(op).fold(typename ConcreteOpT::FoldAdaptor( + operands, op->getAttrDictionary(), op->getRegions())); + else + result = cast(op).fold(operands); // If the fold failed or was in-place, try to fold the traits of the // operation. @@ -1789,7 +1816,14 @@ template static LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - LogicalResult result = cast(op).fold(operands, results); + auto result = LogicalResult::failure(); + if constexpr (detect_has_fold_adaptor_fold::value) + result = cast(op).fold( + typename ConcreteOpT::FoldAdaptor(operands, op->getAttrDictionary(), + op->getRegions()), + results); + else + result = cast(op).fold(operands, results); // If the fold failed or was in-place, try to fold the traits of the // operation. Index: mlir/include/mlir/TableGen/Operator.h =================================================================== --- mlir/include/mlir/TableGen/Operator.h +++ mlir/include/mlir/TableGen/Operator.h @@ -314,6 +314,15 @@ /// Returns the remove name for the accessor of `name`. std::string getRemoverName(StringRef name) const; + enum class FolderAPI { + None = 0, /// No fold method should be emitted. + RawAttributes = 1, /// fold method with ArrayRef. + FolderAdaptor = 2, /// fold method with the operation's FoldAdaptor. + }; + + /// Returns the folder API that should be emitted for this operation. + FolderAPI getFolderAPI() const; + private: /// Populates the vectors containing operands, attributes, results and traits. void populateOpStructure(); Index: mlir/lib/TableGen/Operator.cpp =================================================================== --- mlir/lib/TableGen/Operator.cpp +++ mlir/lib/TableGen/Operator.cpp @@ -745,3 +745,12 @@ std::string Operator::getRemoverName(StringRef name) const { return "remove" + convertToCamelFromSnakeCase(name, /*capitalizeFirst=*/true); } + +Operator::FolderAPI Operator::getFolderAPI() const { + int64_t value = def.getValueAsInt("hasFolder"); + if (value < static_cast(FolderAPI::None) || + value > static_cast(FolderAPI::FolderAdaptor)) + llvm::PrintFatalError(def.getLoc(), "Invalid folder api value"); + + return static_cast(value); +} Index: mlir/test/IR/test-fold-adaptor.mlir =================================================================== --- /dev/null +++ mlir/test/IR/test-fold-adaptor.mlir @@ -0,0 +1,16 @@ +// RUN: mlir-opt %s -canonicalize -split-input-file | FileCheck %s + +func.func @test() -> i32 { + %c5 = arith.constant 5 : i32 + %c1 = arith.constant 1 : i32 + %c2 = arith.constant 2 : i32 + %c3 = arith.constant 3 : i32 + %res = test.fold_with_fold_adaptor %c5, [ %c1, %c2], { (%c3), (%c3) } { + %c0 = arith.constant 0 : i32 + } + return %res : i32 +} + +// CHECK-LABEL: func.func @test +// CHECK-NEXT: %[[C:.*]] = "test.constant"() {value = 33 : i32} +// CHECK-NEXT: return %[[C]] \ No newline at end of file Index: mlir/test/lib/Dialect/Test/TestDialect.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.cpp +++ mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -33,6 +33,8 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSwitch.h" +#include + // Include this before the using namespace lines below to // test that we don't have namespace dependencies. #include "TestOpsDialect.cpp.inc" @@ -1126,6 +1128,25 @@ return getOperand(); } +OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { + int64_t sum = 0; + if (auto value = dyn_cast_or_null(adaptor.getOp())) + sum += value.getValue().getSExtValue(); + + for (Attribute attr : adaptor.getVariadic()) + if (auto value = dyn_cast_or_null(attr)) + sum += 2 * value.getValue().getSExtValue(); + + for (ArrayRef attrs : adaptor.getVarOfVar()) + for (Attribute attr : attrs) + if (auto value = dyn_cast_or_null(attr)) + sum += 3 * value.getValue().getSExtValue(); + + sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end()); + + return IntegerAttr::get(getType(), sum); +} + LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( MLIRContext *, std::optional location, ValueRange operands, DictionaryAttr attributes, RegionRange regions, Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -1297,6 +1297,27 @@ }]; } +def TestOpFoldWithFoldAdaptor + : TEST_Op<"fold_with_fold_adaptor", + [AttrSizedOperandSegments, NoTerminator]> { + let arguments = (ins + I32:$op, + DenseI32ArrayAttr:$attr, + Variadic:$variadic, + VariadicOfVariadic:$var_of_var + ); + + let results = (outs I32:$res); + + let regions = (region AnyRegion:$body); + + let hasFolder = kEmitFoldAdaptorFolder; + + let assemblyFormat = [{ + $op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword + }]; +} + // An op that always fold itself. def TestPassthroughFold : TEST_Op<"passthrough_fold"> { let arguments = (ins AnyType:$op); Index: mlir/test/mlir-tblgen/has-fold-invalid-values.td =================================================================== --- /dev/null +++ mlir/test/mlir-tblgen/has-fold-invalid-values.td @@ -0,0 +1,14 @@ +// RUN: not mlir-tblgen -gen-op-decls -I %S/../../include %s 2>&1 | FileCheck %s + +include "mlir/IR/OpBase.td" + +def Test_Dialect : Dialect { + let name = "test"; + let cppNamespace = "NS"; +} + +def InvalidValue_Op : Op { + let hasFolder = 3; +} + +// CHECK: Invalid folder api value \ No newline at end of file Index: mlir/test/mlir-tblgen/op-decl-and-defs.td =================================================================== --- mlir/test/mlir-tblgen/op-decl-and-defs.td +++ mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -317,6 +317,23 @@ // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +def NS_MOp : NS_Op<"op_with_single_result_and_fold_adaptor_fold", []> { + let results = (outs AnyType:$res); + + let hasFolder = kEmitFoldAdaptorFolder; +} + +// CHECK-LABEL: class MOp : +// CHECK: ::mlir::OpFoldResult fold(FoldAdaptor adaptor); + +def NS_NOp : NS_Op<"op_with_multiple_results_and_fold_adaptor_fold", []> { + let results = (outs AnyType:$res1, AnyType:$res2); + + let hasFolder = kEmitFoldAdaptorFolder; +} + +// CHECK-LABEL: class NOp : +// CHECK: ::mlir::LogicalResult fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results); // Test that type defs have the proper namespaces when used as a constraint. // --- Index: mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp =================================================================== --- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2324,25 +2324,28 @@ } void OpEmitter::genFolderDecls() { + if (op.getFolderAPI() == Operator::FolderAPI::None) + return; + + SmallVector paramList; + if (op.getFolderAPI() == Operator::FolderAPI::RawAttributes) + paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands"); + else + paramList.emplace_back("FoldAdaptor", "adaptor"); + + StringRef retType; bool hasSingleResult = op.getNumResults() == 1 && op.getNumVariableLengthResults() == 0; - - if (def.getValueAsBit("hasFolder")) { - if (hasSingleResult) { - auto *m = opClass.declareMethod( - "::mlir::OpFoldResult", "fold", - MethodParameter("::llvm::ArrayRef<::mlir::Attribute>", "operands")); - ERROR_IF_PRUNED(m, "operands", op); - } else { - SmallVector paramList; - paramList.emplace_back("::llvm::ArrayRef<::mlir::Attribute>", "operands"); - paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", - "results"); - auto *m = opClass.declareMethod("::mlir::LogicalResult", "fold", - std::move(paramList)); - ERROR_IF_PRUNED(m, "fold", op); - } + if (hasSingleResult) + retType = "::mlir::OpFoldResult"; + else { + paramList.emplace_back("::llvm::SmallVectorImpl<::mlir::OpFoldResult> &", + "results"); + retType = "::mlir::LogicalResult"; } + + auto *m = opClass.declareMethod(retType, "fold", std::move(paramList)); + ERROR_IF_PRUNED(m, "fold", op); } void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) {