diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -942,11 +942,12 @@ /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref.cast /// into the root operation directly. -static LogicalResult foldMemRefCast(Operation *op) { +static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); - if (cast && !cast.getOperand().getType().isa()) { + if (cast && operand.get() != ignore && + !cast.getOperand().getType().isa()) { operand.set(cast.getOperand()); folded = true; } @@ -2270,7 +2271,7 @@ LogicalResult AffineStoreOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// store(memrefcast) -> store - return foldMemRefCast(*this); + return foldMemRefCast(*this, getValueToStore()); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -73,11 +73,12 @@ /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref.cast /// into the root operation directly. -static LogicalResult foldMemRefCast(Operation *op) { +static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); - if (cast && !cast.getOperand().getType().isa()) { + if (cast && operand.get() != inner && + !cast.getOperand().getType().isa()) { operand.set(cast.getOperand()); folded = true; } @@ -1425,7 +1426,7 @@ LogicalResult StoreOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// store(memrefcast) -> store - return foldMemRefCast(*this); + return foldMemRefCast(*this, getValueToStore()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -924,3 +924,15 @@ } return } + +// ----- + +// CHECK-LABEL: func @no_fold_of_store +// CHECK: %[[cst]] = memref.cast %arg to memref +// CHECK: affine.store %[[cst]] +func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref>) { + %0 = memref.cast %arg to memref + affine.store %0, %holder[] : memref + return +} + diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -206,4 +206,14 @@ return %1 : index } +// ----- + +// CHECK-LABEL: func @no_fold_of_store +// CHECK: %[[cst]] = memref.cast %arg to memref +// CHECK: memref.store %[[cst]] +func @no_fold_of_store(%arg : memref<32xi8>, %holder: memref>) { + %0 = memref.cast %arg to memref + memref.store %0, %holder[] : memref + return +}