diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -457,16 +457,22 @@ Value memref = frontBlock.addArgument(memrefType, bbArg.getLoc()); OpBuilder b(funcOp->getContext()); b.setInsertionPointToStart(&frontBlock); - // Replace all uses of bbArg through a ToMemRefOp by a memref::CastOp. + // Replace all uses of bbArg through a ToMemRefOp. for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) { if (auto toMemrefOp = dyn_cast(use.getOwner())) { - assert(memref::CastOp::areCastCompatible( - memref.getType(), toMemrefOp.memref().getType()) && - "bufferizeFuncOpBoundary: cast incompatible"); - auto castOp = b.create( - funcOp.getLoc(), toMemrefOp.memref().getType(), memref); - toMemrefOp.memref().replaceAllUsesWith(castOp); + if (memref.getType() != toMemrefOp.memref().getType()) { + // Type has changed, insert a cast. + assert(memref::CastOp::areCastCompatible( + memref.getType(), toMemrefOp.memref().getType()) && + "bufferizeFuncOpBoundary: cast incompatible"); + auto castOp = b.create( + funcOp.getLoc(), toMemrefOp.memref().getType(), memref); + toMemrefOp.memref().replaceAllUsesWith(castOp); + } else { + // Type did not change, replace directly. + toMemrefOp.memref().replaceAllUsesWith(memref); + } } } // Replace all remaining uses by a to_tensor. diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -599,6 +599,35 @@ } }; +/// Bufferization of tensor.rank. Replace with memref.rank. +struct RankOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return OpResult(); + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + auto rankOp = cast(op); + Value v = *state.getBuffer(rewriter, rankOp->getOpOperand(0) /*source*/); + replaceOpWithNewBufferizedOp(rewriter, op, rankOp.getType(), + v); + return success(); + } +}; + } // namespace } // namespace tensor } // namespace mlir @@ -613,4 +642,5 @@ registry.addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); + registry.addOpInterface(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1389,3 +1389,14 @@ // CHECK: return %[[MEMREF]] return %0 : tensor<3x2xindex> } + +// ----- + +// CHECK-LABEL: func @tensor_rank( +// CHECK-SAME: %[[arg0:.*]]: memref<*xf32> +func @tensor_rank(%arg0: tensor<*xf32>) -> index { + // CHECK: %[[r:.*]] = memref.rank %[[arg0]] + %0 = tensor.rank %arg0 : tensor<*xf32> + // CHECK: return %[[r]] : index + return %0 : index +}