diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td @@ -21,6 +21,7 @@ }]; let dependentDialects = ["arith::ArithDialect"]; let hasConstantMaterializer = 1; + let useFoldAPI = kEmitFoldAdaptorFolder; } #endif // MEMREF_BASE diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -808,7 +808,7 @@ return false; } -OpFoldResult CastOp::fold(ArrayRef operands) { +OpFoldResult CastOp::fold(FoldAdaptor adaptor) { return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } @@ -883,7 +883,7 @@ results.add(context); } -LogicalResult CopyOp::fold(ArrayRef cstOperands, +LogicalResult CopyOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { /// copy(memrefcast) -> copy bool folded = false; @@ -902,7 +902,7 @@ // DeallocOp //===----------------------------------------------------------------------===// -LogicalResult DeallocOp::fold(ArrayRef cstOperands, +LogicalResult DeallocOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { /// dealloc(memrefcast) -> dealloc return foldMemRefCast(*this); @@ -1056,9 +1056,9 @@ return *unusedDims; } -OpFoldResult DimOp::fold(ArrayRef operands) { +OpFoldResult DimOp::fold(FoldAdaptor adaptor) { // All forms of folding require a known index. - auto index = operands[1].dyn_cast_or_null(); + auto index = adaptor.getIndex().dyn_cast_or_null(); if (!index) return {}; @@ -1322,7 +1322,7 @@ return success(); } -LogicalResult DmaStartOp::fold(ArrayRef cstOperands, +LogicalResult DmaStartOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { /// dma_start(memrefcast) -> dma_start return foldMemRefCast(*this); @@ -1332,7 +1332,7 @@ // DmaWaitOp // --------------------------------------------------------------------------- -LogicalResult DmaWaitOp::fold(ArrayRef cstOperands, +LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { /// dma_wait(memrefcast) -> dma_wait return foldMemRefCast(*this); @@ -1433,7 +1433,7 @@ } LogicalResult -ExtractStridedMetadataOp::fold(ArrayRef cstOperands, +ExtractStridedMetadataOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { OpBuilder builder(*this); @@ -1677,7 +1677,7 @@ return success(); } -OpFoldResult LoadOp::fold(ArrayRef cstOperands) { +OpFoldResult LoadOp::fold(FoldAdaptor adaptor) { /// load(memrefcast) -> load if (succeeded(foldMemRefCast(*this))) return getResult(); @@ -1747,7 +1747,7 @@ return success(); } -LogicalResult PrefetchOp::fold(ArrayRef cstOperands, +LogicalResult PrefetchOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { // prefetch(memrefcast) -> prefetch return foldMemRefCast(*this); @@ -1757,7 +1757,7 @@ // RankOp //===----------------------------------------------------------------------===// -OpFoldResult RankOp::fold(ArrayRef operands) { +OpFoldResult RankOp::fold(FoldAdaptor adaptor) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); auto shapedType = type.dyn_cast(); @@ -1881,7 +1881,7 @@ return success(); } -OpFoldResult ReinterpretCastOp::fold(ArrayRef /*operands*/) { +OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) { Value src = getSource(); auto getPrevSrc = [&]() -> Value { // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x). @@ -2465,12 +2465,14 @@ CollapseShapeOpMemRefCastFolder>(context); } -OpFoldResult ExpandShapeOp::fold(ArrayRef operands) { - return foldReshapeOp(*this, operands); +OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) { + return foldReshapeOp(*this, + adaptor.getOperands()); } -OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { - return foldReshapeOp(*this, operands); +OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) { + return foldReshapeOp(*this, + adaptor.getOperands()); } //===----------------------------------------------------------------------===// @@ -2522,7 +2524,7 @@ return success(); } -LogicalResult StoreOp::fold(ArrayRef cstOperands, +LogicalResult StoreOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { /// store(memrefcast) -> store return foldMemRefCast(*this, getValueToStore()); @@ -3101,7 +3103,7 @@ SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context); } -OpFoldResult SubViewOp::fold(ArrayRef operands) { +OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { auto resultShapedType = getResult().getType().cast(); auto sourceShapedType = getSource().getType().cast(); @@ -3217,7 +3219,7 @@ return success(); } -OpFoldResult TransposeOp::fold(ArrayRef) { +OpFoldResult TransposeOp::fold(FoldAdaptor) { if (succeeded(foldMemRefCast(*this))) return getResult(); return {}; @@ -3393,7 +3395,7 @@ return success(); } -OpFoldResult AtomicRMWOp::fold(ArrayRef operands) { +OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) { /// atomicrmw(memrefcast) -> atomicrmw if (succeeded(foldMemRefCast(*this, getValue()))) return getResult();