diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp @@ -57,6 +57,49 @@ } }; +struct IndexCastOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return false; + } + + OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand, + const BufferizationState &state) const { + return op->getResult(0); + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const BufferizationState &state) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + auto castOp = cast(op); + + Value source = *state.getBuffer(rewriter, op->getOpOperand(0) /*in*/); + auto sourceType = source.getType().cast(); + + // Result type should have same layout and address space as the source type. + MemRefLayoutAttrInterface layout = {}; + if (auto rankedMemRefType = sourceType.dyn_cast()) + layout = rankedMemRefType.getLayout(); + Type resultType = + getMemRefType(castOp.getType().cast(), state.getOptions(), + layout, sourceType.getMemorySpace()); + + replaceOpWithNewBufferizedOp(rewriter, op, source, + resultType); + return success(); + } +}; } // namespace arith_ext } // namespace comprehensive_bufferize } // namespace linalg @@ -65,4 +108,6 @@ void mlir::linalg::comprehensive_bufferize::arith_ext:: registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addOpInterface(); + registry + .addOpInterface(); } diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir @@ -96,3 +96,19 @@ } return %5: tensor } + +// ----- + +// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0)> +// CHECK-LABEL: func @index_cast( +// CHECK-SAME: %[[TENSOR:.*]]: tensor, %[[SCALAR:.*]]: i32 +func @index_cast(%tensor: tensor, %scalar: i32) -> (tensor, index) { + %index_tensor = arith.index_cast %tensor : tensor to tensor + %index_scalar = arith.index_cast %scalar : i32 to index + return %index_tensor, %index_scalar : tensor, index +} +// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[TENSOR]] : memref +// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = arith.index_cast %[[MEMREF]] +// CHECK-SAME: memref to memref +// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = bufferization.to_tensor %[[INDEX_MEMREF]] +// CHECK: return %[[INDEX_TENSOR]]