diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1325,26 +1325,6 @@ results.add>(context); } -//===----------------------------------------------------------------------===// -// Common canonicalization pattern support logic -//===----------------------------------------------------------------------===// - -/// This is a common class used for patterns of the form -/// "someop(memrefcast) -> someop". It folds the source of any memref.cast -/// into the root operation directly. -static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) { - bool folded = false; - for (OpOperand &operand : op->getOpOperands()) { - auto cast = operand.get().getDefiningOp(); - if (cast && operand.get() != ignore && - !cast.getOperand().getType().isa()) { - operand.set(cast.getOperand()); - folded = true; - } - } - return success(folded); -} - //===----------------------------------------------------------------------===// // AffineDmaStartOp //===----------------------------------------------------------------------===// @@ -1511,7 +1491,7 @@ LogicalResult AffineDmaStartOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_start(memrefcast) -> dma_start - return foldMemRefCast(*this); + return memref::foldMemRefCast(*this); } //===----------------------------------------------------------------------===// @@ -1589,7 +1569,7 @@ LogicalResult AffineDmaWaitOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_wait(memrefcast) -> dma_wait - return foldMemRefCast(*this); + return memref::foldMemRefCast(*this); } //===----------------------------------------------------------------------===// @@ -2821,7 +2801,7 @@ OpFoldResult AffineLoadOp::fold(ArrayRef cstOperands) { /// load(memrefcast) -> load - if (succeeded(foldMemRefCast(*this))) + if (succeeded(memref::foldMemRefCast(*this))) return getResult(); // Fold load from a global constant memref. @@ -2939,7 +2919,7 @@ LogicalResult AffineStoreOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// store(memrefcast) -> store - return foldMemRefCast(*this, getValueToStore()); + return memref::foldMemRefCast(*this, getValueToStore()); } //===----------------------------------------------------------------------===// @@ -3392,7 +3372,7 @@ LogicalResult AffinePrefetchOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// prefetch(memrefcast) -> prefetch - return foldMemRefCast(*this); + return memref::foldMemRefCast(*this); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1266,29 +1266,14 @@ return success(); } -/// This is a common class used for patterns of the form -/// "someop(memrefcast) -> someop". It folds the source of any memref.cast -/// into the root operation directly. -static LogicalResult foldMemRefCast(Operation *op) { - bool folded = false; - for (OpOperand &operand : op->getOpOperands()) { - auto cast = operand.get().getDefiningOp(); - if (cast) { - operand.set(cast.getOperand()); - folded = true; - } - } - return success(folded); -} - LogicalResult MemcpyOp::fold(ArrayRef operands, SmallVectorImpl<::mlir::OpFoldResult> &results) { - return foldMemRefCast(*this); + return memref::foldMemRefCast(*this); } LogicalResult MemsetOp::fold(ArrayRef operands, SmallVectorImpl<::mlir::OpFoldResult> &results) { - return foldMemRefCast(*this); + return memref::foldMemRefCast(*this); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -254,23 +254,6 @@ // Region is elided. } -/// This is a common class used for patterns of the form -/// ``` -/// someop(memrefcast(%src)) -> someop(%src) -/// ``` -/// It folds the source of the memref.cast into the root operation directly. -static LogicalResult foldMemRefCast(Operation *op) { - bool folded = false; - for (OpOperand &operand : op->getOpOperands()) { - auto castOp = operand.get().getDefiningOp(); - if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { - operand.set(castOp.getOperand()); - folded = true; - } - } - return success(folded); -} - //===----------------------------------------------------------------------===// // Region builder helper. // TODO: Move this to a utility library. @@ -1290,7 +1273,7 @@ LogicalResult GenericOp::fold(ArrayRef, SmallVectorImpl &) { - return foldMemRefCast(*this); + return memref::foldMemRefCast(*this); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3085,35 +3085,6 @@ [&](Twine t) { return emitOpError(t); }); } -/// This is a common class used for patterns of the form -/// ``` -/// someop(memrefcast) -> someop -/// ``` -/// It folds the source of the memref.cast into the root operation directly. -static LogicalResult foldMemRefCast(Operation *op) { - bool folded = false; - for (OpOperand &operand : op->getOpOperands()) { - auto castOp = operand.get().getDefiningOp(); - if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) { - operand.set(castOp.getOperand()); - folded = true; - } - } - return success(folded); -} - -static LogicalResult foldTensorCast(Operation *op) { - bool folded = false; - for (OpOperand &operand : op->getOpOperands()) { - auto castOp = operand.get().getDefiningOp(); - if (castOp && tensor::canFoldIntoConsumerOp(castOp)) { - operand.set(castOp.getOperand()); - folded = true; - } - } - return success(folded); -} - template static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { // TODO: support more aggressive createOrFold on: @@ -3198,9 +3169,9 @@ /// transfer_read(memrefcast) -> transfer_read if (succeeded(foldTransferInBoundsAttribute(*this))) return getResult(); - if (succeeded(foldMemRefCast(*this))) + if (succeeded(memref::foldMemRefCast(*this))) return getResult(); - if (succeeded(foldTensorCast(*this))) + if (succeeded(tensor::foldTensorCast(*this))) return getResult(); return OpFoldResult(); } @@ -3648,7 +3619,7 @@ return success(); if (succeeded(foldTransferInBoundsAttribute(*this))) return success(); - return foldMemRefCast(*this); + return memref::foldMemRefCast(*this); } Optional> TransferWriteOp::getShapeForUnroll() { @@ -3948,7 +3919,7 @@ } OpFoldResult LoadOp::fold(ArrayRef) { - if (succeeded(foldMemRefCast(*this))) + if (succeeded(memref::foldMemRefCast(*this))) return getResult(); return OpFoldResult(); } @@ -3982,7 +3953,7 @@ LogicalResult StoreOp::fold(ArrayRef operands, SmallVectorImpl &results) { - return foldMemRefCast(*this); + return memref::foldMemRefCast(*this); } //===----------------------------------------------------------------------===// @@ -4034,7 +4005,7 @@ } OpFoldResult MaskedLoadOp::fold(ArrayRef) { - if (succeeded(foldMemRefCast(*this))) + if (succeeded(memref::foldMemRefCast(*this))) return getResult(); return OpFoldResult(); } @@ -4086,7 +4057,7 @@ LogicalResult MaskedStoreOp::fold(ArrayRef operands, SmallVectorImpl &results) { - return foldMemRefCast(*this); + return memref::foldMemRefCast(*this); } //===----------------------------------------------------------------------===// diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -655,7 +655,7 @@ const char structuredOpFoldersFormat[] = R"FMT( LogicalResult {0}::fold(ArrayRef, SmallVectorImpl &) {{ - return foldMemRefCast(*this); + return memref::foldMemRefCast(*this); } void {0}::getEffects(SmallVectorImpl< SideEffects::EffectInstance >&effects) {{