Index: mlir/include/mlir/Dialect/Affine/IR/AffineOps.td =================================================================== --- mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -24,6 +24,7 @@ let cppNamespace = "mlir"; let hasConstantMaterializer = 1; let dependentDialects = ["arith::ArithDialect"]; + let useFoldAPI = kEmitFoldAdaptorFolder; } // Base class for Affine dialect ops. Index: mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td =================================================================== --- mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td +++ mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td @@ -69,6 +69,7 @@ kEscapeAttrName = "bufferization.escape"; }]; let hasOperationAttrVerify = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // BUFFERIZATION_BASE Index: mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td =================================================================== --- mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td +++ mlir/include/mlir/Dialect/Complex/IR/ComplexBase.td @@ -22,6 +22,7 @@ let dependentDialects = ["arith::ArithDialect"]; let hasConstantMaterializer = 1; let useDefaultAttributePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // COMPLEX_BASE Index: mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td =================================================================== --- mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td +++ mlir/include/mlir/Dialect/EmitC/IR/EmitCBase.td @@ -31,6 +31,7 @@ let hasConstantMaterializer = 1; let useDefaultTypePrinterParser = 1; let useDefaultAttributePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // MLIR_DIALECT_EMITC_IR_EMITCBASE Index: mlir/include/mlir/Dialect/Func/IR/FuncOps.td =================================================================== --- mlir/include/mlir/Dialect/Func/IR/FuncOps.td +++ mlir/include/mlir/Dialect/Func/IR/FuncOps.td @@ -23,6 +23,7 @@ let cppNamespace = "::mlir::func"; let dependentDialects = ["cf::ControlFlowDialect"]; let hasConstantMaterializer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } // Base class for Func dialect ops. Index: mlir/include/mlir/Dialect/GPU/IR/GPUBase.td =================================================================== --- mlir/include/mlir/Dialect/GPU/IR/GPUBase.td +++ mlir/include/mlir/Dialect/GPU/IR/GPUBase.td @@ -56,6 +56,7 @@ let dependentDialects = ["arith::ArithDialect"]; let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } def GPU_AsyncToken : DialectType< Index: mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -31,6 +31,7 @@ let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; let hasOperationAttrVerify = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; let extraClassDeclaration = [{ /// Name of the data layout attributes. Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -46,6 +46,7 @@ let hasCanonicalizer = 1; let hasOperationAttrVerify = 1; let hasConstantMaterializer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; let extraClassDeclaration = [{ /// Attribute name used to to memoize indexing maps for named ops. constexpr const static ::llvm::StringLiteral Index: mlir/include/mlir/Dialect/Quant/QuantOpsBase.td =================================================================== --- mlir/include/mlir/Dialect/Quant/QuantOpsBase.td +++ mlir/include/mlir/Dialect/Quant/QuantOpsBase.td @@ -20,6 +20,7 @@ let cppNamespace = "::mlir::quant"; let useDefaultTypePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Dialect/SCF/IR/SCFOps.td =================================================================== --- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -25,6 +25,7 @@ let name = "scf"; let cppNamespace = "::mlir::scf"; let dependentDialects = ["arith::ArithDialect"]; + let useFoldAPI = kEmitFoldAdaptorFolder; } // Base class for SCF dialect ops. Index: mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td =================================================================== --- mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td +++ mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorBase.td @@ -83,6 +83,7 @@ let useDefaultAttributePrinterParser = 1; let useDefaultTypePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // SPARSETENSOR_BASE Index: mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td =================================================================== --- mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -23,6 +23,8 @@ "::mlir::pdl_interp::PDLInterpDialect", ]; + let useFoldAPI = kEmitFoldAdaptorFolder; + let extraClassDeclaration = [{ /// Returns the named PDL constraint functions available in the dialect /// as a map from their name to the function. Index: mlir/include/mlir/IR/BuiltinDialect.td =================================================================== --- mlir/include/mlir/IR/BuiltinDialect.td +++ mlir/include/mlir/IR/BuiltinDialect.td @@ -34,6 +34,8 @@ public: }]; + + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // BUILTIN_BASE Index: mlir/lib/Dialect/Affine/IR/AffineOps.cpp =================================================================== --- mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -562,7 +562,7 @@ }); } -OpFoldResult AffineApplyOp::fold(ArrayRef operands) { +OpFoldResult AffineApplyOp::fold(FoldAdaptor adaptor) { auto map = getAffineMap(); // Fold dims and symbols to existing values. @@ -574,7 +574,7 @@ // Otherwise, default to folding the map. SmallVector result; - if (failed(map.constantFold(operands, result))) + if (failed(map.constantFold(adaptor.getMapOperands(), result))) return {}; return result[0]; } @@ -2135,7 +2135,7 @@ return tripCount && *tripCount == 0; } -LogicalResult AffineForOp::fold(ArrayRef operands, +LogicalResult AffineForOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { bool folded = succeeded(foldLoopBounds(*this)); folded |= succeeded(canonicalizeLoopBounds(*this)); @@ -2723,7 +2723,7 @@ } /// Canonicalize an affine if op's conditional (integer set + operands). -LogicalResult AffineIfOp::fold(ArrayRef, +LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl &) { auto set = getIntegerSet(); SmallVector operands(getOperands()); @@ -2858,7 +2858,7 @@ results.add>(context); } -OpFoldResult AffineLoadOp::fold(ArrayRef cstOperands) { +OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) { /// load(memrefcast) -> load if (succeeded(memref::foldMemRefCast(*this))) return getResult(); @@ -2975,7 +2975,7 @@ results.add>(context); } -LogicalResult AffineStoreOp::fold(ArrayRef cstOperands, +LogicalResult AffineStoreOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { /// store(memrefcast) -> store return memref::foldMemRefCast(*this, getValueToStore()); @@ -3282,8 +3282,8 @@ // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) // -OpFoldResult AffineMinOp::fold(ArrayRef operands) { - return foldMinMaxOp(*this, operands); +OpFoldResult AffineMinOp::fold(FoldAdaptor adaptor) { + return foldMinMaxOp(*this, adaptor.getOperands()); } void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -3310,8 +3310,8 @@ // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) // -OpFoldResult AffineMaxOp::fold(ArrayRef operands) { - return foldMinMaxOp(*this, operands); +OpFoldResult AffineMaxOp::fold(FoldAdaptor adaptor) { + return foldMinMaxOp(*this, adaptor.getOperands()); } void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -3431,7 +3431,7 @@ results.add>(context); } -LogicalResult AffinePrefetchOp::fold(ArrayRef cstOperands, +LogicalResult AffinePrefetchOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { /// prefetch(memrefcast) -> prefetch return memref::foldMemRefCast(*this); @@ -3705,7 +3705,7 @@ return success(); } -LogicalResult AffineParallelOp::fold(ArrayRef operands, +LogicalResult AffineParallelOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { return canonicalizeLoopBounds(*this); } Index: mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp =================================================================== --- mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -458,7 +458,7 @@ SideEffects::DefaultResource::get()); } -OpFoldResult CloneOp::fold(ArrayRef operands) { +OpFoldResult CloneOp::fold(FoldAdaptor adaptor) { return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value(); } @@ -560,7 +560,7 @@ // ToTensorOp //===----------------------------------------------------------------------===// -OpFoldResult ToTensorOp::fold(ArrayRef) { +OpFoldResult ToTensorOp::fold(FoldAdaptor) { if (auto toMemref = getMemref().getDefiningOp()) // Approximate alias analysis by conservatively folding only when no there // is no interleaved operation. @@ -596,7 +596,7 @@ // ToMemrefOp //===----------------------------------------------------------------------===// -OpFoldResult ToMemrefOp::fold(ArrayRef) { +OpFoldResult ToMemrefOp::fold(FoldAdaptor) { if (auto memrefToTensor = getTensor().getDefiningOp()) if (memrefToTensor.getMemref().getType() == getType()) return memrefToTensor.getMemref(); Index: mlir/lib/Dialect/Complex/IR/ComplexOps.cpp =================================================================== --- mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -17,8 +17,7 @@ // ConstantOp //===----------------------------------------------------------------------===// -OpFoldResult ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } @@ -68,8 +67,7 @@ // CreateOp //===----------------------------------------------------------------------===// -OpFoldResult CreateOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary op takes two operands"); +OpFoldResult CreateOp::fold(FoldAdaptor adaptor) { // Fold complex.create(complex.re(op), complex.im(op)). if (auto reOp = getOperand(0).getDefiningOp()) { if (auto imOp = getOperand(1).getDefiningOp()) { @@ -85,9 +83,8 @@ // ImOp //===----------------------------------------------------------------------===// -OpFoldResult ImOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); +OpFoldResult ImOp::fold(FoldAdaptor adaptor) { + ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null(); if (arrayAttr && arrayAttr.size() == 2) return arrayAttr[1]; if (auto createOp = getOperand().getDefiningOp()) @@ -99,9 +96,8 @@ // ReOp //===----------------------------------------------------------------------===// -OpFoldResult ReOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - ArrayAttr arrayAttr = operands[0].dyn_cast_or_null(); +OpFoldResult ReOp::fold(FoldAdaptor adaptor) { + ArrayAttr arrayAttr = adaptor.getComplex().dyn_cast_or_null(); if (arrayAttr && arrayAttr.size() == 2) return arrayAttr[0]; if (auto createOp = getOperand().getDefiningOp()) @@ -113,9 +109,7 @@ // AddOp //===----------------------------------------------------------------------===// -OpFoldResult AddOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary op takes 2 operands"); - +OpFoldResult AddOp::fold(FoldAdaptor adaptor) { // complex.add(complex.sub(a, b), b) -> a if (auto sub = getLhs().getDefiningOp()) if (getRhs() == sub.getRhs()) @@ -142,9 +136,7 @@ // SubOp //===----------------------------------------------------------------------===// -OpFoldResult SubOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "binary op takes 2 operands"); - +OpFoldResult SubOp::fold(FoldAdaptor adaptor) { // complex.sub(complex.add(a, b), b) -> a if (auto add = getLhs().getDefiningOp()) if (getRhs() == add.getRhs()) @@ -166,9 +158,7 @@ // NegOp //===----------------------------------------------------------------------===// -OpFoldResult NegOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - +OpFoldResult NegOp::fold(FoldAdaptor adaptor) { // complex.neg(complex.neg(a)) -> a if (auto negOp = getOperand().getDefiningOp()) return negOp.getOperand(); @@ -180,9 +170,7 @@ // LogOp //===----------------------------------------------------------------------===// -OpFoldResult LogOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - +OpFoldResult LogOp::fold(FoldAdaptor adaptor) { // complex.log(complex.exp(a)) -> a if (auto expOp = getOperand().getDefiningOp()) return expOp.getOperand(); @@ -194,9 +182,7 @@ // ExpOp //===----------------------------------------------------------------------===// -OpFoldResult ExpOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - +OpFoldResult ExpOp::fold(FoldAdaptor adaptor) { // complex.exp(complex.log(a)) -> a if (auto logOp = getOperand().getDefiningOp()) return logOp.getOperand(); @@ -208,9 +194,7 @@ // ConjOp //===----------------------------------------------------------------------===// -OpFoldResult ConjOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "unary op takes 1 operand"); - +OpFoldResult ConjOp::fold(FoldAdaptor adaptor) { // complex.conj(complex.conj(a)) -> a if (auto conjOp = getOperand().getDefiningOp()) return conjOp.getOperand(); Index: mlir/lib/Dialect/EmitC/IR/EmitC.cpp =================================================================== --- mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -129,8 +129,7 @@ return success(); } -OpFoldResult emitc::ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); +OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } Index: mlir/lib/Dialect/Func/IR/FuncOps.cpp =================================================================== --- mlir/lib/Dialect/Func/IR/FuncOps.cpp +++ mlir/lib/Dialect/Func/IR/FuncOps.cpp @@ -201,8 +201,7 @@ return success(); } -OpFoldResult ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "constant has no operands"); +OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); } Index: mlir/lib/Dialect/GPU/IR/GPUDialect.cpp =================================================================== --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1286,12 +1286,12 @@ return success(); } -LogicalResult MemcpyOp::fold(ArrayRef operands, +LogicalResult MemcpyOp::fold(FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) { return memref::foldMemRefCast(*this); } -LogicalResult MemsetOp::fold(ArrayRef operands, +LogicalResult MemsetOp::fold(FoldAdaptor adaptor, SmallVectorImpl<::mlir::OpFoldResult> &results) { return memref::foldMemRefCast(*this); } Index: mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp =================================================================== --- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1441,7 +1441,7 @@ return llvmType; } -OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef operands) { +OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) { auto insertValueOp = getContainer().getDefiningOp(); OpFoldResult result = {}; while (insertValueOp) { @@ -2275,7 +2275,7 @@ } // Constant op constant-folds to its value. -OpFoldResult LLVM::ConstantOp::fold(ArrayRef) { return getValue(); } +OpFoldResult LLVM::ConstantOp::fold(FoldAdaptor) { return getValue(); } //===----------------------------------------------------------------------===// // Utility functions for parsing atomic ops @@ -2513,7 +2513,7 @@ // Folder for LLVM::BitcastOp //===----------------------------------------------------------------------===// -OpFoldResult LLVM::BitcastOp::fold(ArrayRef operands) { +OpFoldResult LLVM::BitcastOp::fold(FoldAdaptor adaptor) { // bitcast(x : T0, T0) -> x if (getArg().getType() == getType()) return getArg(); @@ -2528,7 +2528,7 @@ // Folder for LLVM::AddrSpaceCastOp //===----------------------------------------------------------------------===// -OpFoldResult LLVM::AddrSpaceCastOp::fold(ArrayRef operands) { +OpFoldResult LLVM::AddrSpaceCastOp::fold(FoldAdaptor adaptor) { // addrcast(x : T0, T0) -> x if (getArg().getType() == getType()) return getArg(); @@ -2543,9 +2543,9 @@ // Folder for LLVM::GEPOp //===----------------------------------------------------------------------===// -OpFoldResult LLVM::GEPOp::fold(ArrayRef operands) { +OpFoldResult LLVM::GEPOp::fold(FoldAdaptor adaptor) { GEPIndicesAdaptor> indices(getRawConstantIndicesAttr(), - operands.drop_front()); + adaptor.getDynamicIndices()); // gep %x:T, 0 -> %x if (getBase().getType() == getType() && indices.size() == 1) Index: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -980,8 +980,7 @@ results.add(context); } -LogicalResult GenericOp::fold(ArrayRef, - SmallVectorImpl &) { +LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl &) { return memref::foldMemRefCast(*this); } Index: mlir/lib/Dialect/Quant/IR/QuantOps.cpp =================================================================== --- mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -36,7 +36,7 @@ addBytecodeInterface(this); } -OpFoldResult StorageCastOp::fold(ArrayRef operands) { +OpFoldResult StorageCastOp::fold(FoldAdaptor adaptor) { // Matches x -> [scast -> scast] -> y, replacing the second scast with the // value of x if the casts invert each other. auto srcScastOp = getArg().getDefiningOp(); Index: mlir/lib/Dialect/SCF/IR/SCF.cpp =================================================================== --- mlir/lib/Dialect/SCF/IR/SCF.cpp +++ mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1598,7 +1598,7 @@ regions.push_back(RegionSuccessor(condition ? &getThenRegion() : elseRegion)); } -LogicalResult IfOp::fold(ArrayRef operands, +LogicalResult IfOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { // if (!c) then A() else B() -> if c then B() else A() if (getElseRegion().empty()) Index: mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp =================================================================== --- mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -467,7 +467,7 @@ return emitError("unexpected type in convert"); } -OpFoldResult ConvertOp::fold(ArrayRef operands) { +OpFoldResult ConvertOp::fold(FoldAdaptor adaptor) { Type dstType = getType(); // Fold trivial dense-to-dense convert and leave trivial sparse-to-sparse // convert for codegen to remove. This is because we use trivial @@ -531,7 +531,7 @@ return op.getSpecifier().template getDefiningOp(); } -OpFoldResult GetStorageSpecifierOp::fold(ArrayRef operands) { +OpFoldResult GetStorageSpecifierOp::fold(FoldAdaptor adaptor) { StorageSpecifierKind kind = getSpecifierKind(); std::optional dim = getDim(); for (auto op = getSpecifierSetDef(*this); op; op = getSpecifierSetDef(op)) Index: mlir/lib/Dialect/Transform/IR/TransformOps.cpp =================================================================== --- mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -463,7 +463,7 @@ // manipulation. } -OpFoldResult transform::MergeHandlesOp::fold(ArrayRef operands) { +OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) { if (getDeduplicate() || getHandles().size() != 1) return {}; Index: mlir/lib/IR/BuiltinDialect.cpp =================================================================== --- mlir/lib/IR/BuiltinDialect.cpp +++ mlir/lib/IR/BuiltinDialect.cpp @@ -190,7 +190,7 @@ //===----------------------------------------------------------------------===// LogicalResult -UnrealizedConversionCastOp::fold(ArrayRef attrOperands, +UnrealizedConversionCastOp::fold(FoldAdaptor adaptor, SmallVectorImpl &foldResults) { OperandRange operands = getInputs(); ResultRange results = getOutputs(); Index: mlir/test/lib/Dialect/Test/TestDialect.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.cpp +++ mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1099,32 +1099,31 @@ results.add(context); } -OpFoldResult TestOpWithRegionFold::fold(ArrayRef operands) { +OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { return getOperand(); } -OpFoldResult TestOpConstant::fold(ArrayRef operands) { +OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } LogicalResult TestOpWithVariadicResultsAndFolder::fold( - ArrayRef operands, SmallVectorImpl &results) { + FoldAdaptor adaptor, SmallVectorImpl &results) { for (Value input : this->getOperands()) { results.push_back(input); } return success(); } -OpFoldResult TestOpInPlaceFold::fold(ArrayRef operands) { - assert(operands.size() == 1); - if (operands.front()) { - (*this)->setAttr("attr", operands.front()); +OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { + if (adaptor.getOp()) { + (*this)->setAttr("attr", adaptor.getOp()); return getResult(); } return {}; } -OpFoldResult TestPassthroughFold::fold(ArrayRef operands) { +OpFoldResult TestPassthroughFold::fold(FoldAdaptor adaptor) { return getOperand(); } Index: mlir/test/lib/Dialect/Test/TestDialect.td =================================================================== --- mlir/test/lib/Dialect/Test/TestDialect.td +++ mlir/test/lib/Dialect/Test/TestDialect.td @@ -23,6 +23,7 @@ let hasNonDefaultDestructor = 1; let useDefaultTypePrinterParser = 0; let useDefaultAttributePrinterParser = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; let isExtensible = 1; let dependentDialects = ["::mlir::DLTIDialect"]; Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -1290,7 +1290,7 @@ let results = (outs Variadic); let hasFolder = 1; let extraClassDefinition = [{ - ::mlir::LogicalResult $cppClass::fold(ArrayRef operands, + ::mlir::LogicalResult $cppClass::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { return success(); } @@ -1315,11 +1315,7 @@ $op `,` `[` $variadic `]` `,` `{` $var_of_var `}` $body attr-dict-with-keyword }]; - let hasFolder = 0; - - let extraClassDeclaration = [{ - ::mlir::OpFoldResult fold(FoldAdaptor adaptor); - }]; + let hasFolder = 1; } // An op that always fold itself. Index: mlir/test/lib/Dialect/Test/TestTraits.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestTraits.cpp +++ mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -18,13 +18,13 @@ //===----------------------------------------------------------------------===// OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold( - ArrayRef operands) { + FoldAdaptor adaptor) { // This failure should cause the trait fold to run instead. return {}; } OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold( - ArrayRef operands) { + FoldAdaptor adaptor) { auto argumentOp = getOperand(); // The success case should cause the trait fold to be supressed. return argumentOp.getDefiningOp() ? argumentOp : OpFoldResult{}; Index: mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp =================================================================== --- mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -654,7 +654,7 @@ // Parameters: // {0}: Class name const char structuredOpFoldersFormat[] = R"FMT( -LogicalResult {0}::fold(ArrayRef, +LogicalResult {0}::fold(FoldAdaptor, SmallVectorImpl &) {{ return memref::foldMemRefCast(*this); }