diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -55,8 +55,7 @@ /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the /// to_memref op are different, a memref.cast is needed. LogicalResult foldToMemrefToTensorPair(RewriterBase &rewriter, - ToMemrefOp toMemref, - bool allowSameType = true); + ToMemrefOp toMemref); } // namespace bufferization } // namespace mlir diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -82,8 +82,9 @@ /// Try to fold to_memref(to_tensor(x)). If x's type and the result type of the /// to_memref op are different, a memref.cast is needed. -LogicalResult mlir::bufferization::foldToMemrefToTensorPair( - RewriterBase &rewriter, ToMemrefOp toMemref, bool allowSameType) { +LogicalResult +mlir::bufferization::foldToMemrefToTensorPair(RewriterBase &rewriter, + ToMemrefOp toMemref) { auto memrefToTensor = toMemref.getTensor().getDefiningOp(); if (!memrefToTensor) return failure(); @@ -93,9 +94,6 @@ // Directly rewrite if the type did not change. if (srcType == destType) { - // Function can be configured to only handle cases where a cast is needed. - if (!allowSameType) - return failure(); rewriter.replaceOp(toMemref, memrefToTensor.getMemref()); return success(); } @@ -501,6 +499,19 @@ } namespace { +/// Canonicalize bufferization.to_tensor + bufferization.to_memref. +struct ToTensorToMemrefFolding : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ToTensorOp toTensorOp, + PatternRewriter &rewriter) const final { + auto toMemrefOp = toTensorOp.getMemref().getDefiningOp(); + if (!toMemrefOp) + return failure(); + rewriter.replaceOp(toTensorOp, toMemrefOp.getTensor()); + return success(); + } +}; struct DimOfToTensorFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -516,12 +527,11 @@ return success(); } }; - } // namespace void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -561,17 +571,14 @@ } }; -/// Canonicalize bufferization.to_tensor + bufferization.to_memref to -/// memref.cast when type mismatches prevent `ToMemrefOp::fold` to kick in. -struct TensorLoadToMemref : public OpRewritePattern { +/// Canonicalize bufferization.to_tensor + bufferization.to_memref. Insert a +/// cast if necessary. +struct ToMemrefToTensorFolding : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToMemrefOp toMemref, PatternRewriter &rewriter) const final { - // Only handle cases where a cast is needed. The other case is handled by - // the folder. - return foldToMemrefToTensorPair(rewriter, toMemref, - /*allowSameType=*/false); + return foldToMemrefToTensorPair(rewriter, toMemref); } }; @@ -611,8 +618,8 @@ void ToMemrefOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add( - context); + results.add(context); } LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -787,8 +787,7 @@ } // CHECK-NEXT: %[[R0:.*]] = bufferization.to_tensor %[[M0]] : memref<128x128xf32> - // CHECK-NEXT: %[[R1:.*]] = bufferization.to_tensor %[[M1]] : memref<128x128xf32> - // CHECK-NEXT: return %[[R0]], %[[R1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> + // CHECK-NEXT: return %[[R0]], %[[T1]], %[[FOR_RES]] : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> return %0#0, %0#1, %0#2 : tensor<128x128xf32>, tensor<128x128xf32>, tensor<128x128xf32> } diff --git a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir @@ -109,8 +109,7 @@ // CHECK: scf.yield %[[VAL_84]] : f64 // CHECK: } // CHECK: memref.store %[[VAL_86:.*]], %[[VAL_15]][] : memref -// CHECK: %[[VAL_87:.*]] = bufferization.to_tensor %[[VAL_15]] : memref -// CHECK: return %[[VAL_87]] : tensor +// CHECK: return %[[VAL_0]] : tensor // CHECK: } func.func @sparse_matrix_sum(%argx: tensor {linalg.inplaceable = true}, %arga: tensor<64x32xf64, #SparseMatrix>,