diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -34,9 +34,22 @@ return success(); } }; -} // namespace -namespace { +class BufferizeIndexCastOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(IndexCastOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + IndexCastOp::Adaptor adaptor(operands); + auto tensorType = op.getType().cast(); + rewriter.replaceOpWithNewOp( + op, adaptor.in(), + MemRefType::get(tensorType.getShape(), tensorType.getElementType())); + return success(); + } +}; + class BufferizeSelectOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -56,8 +69,8 @@ void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, - patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); } namespace { @@ -68,14 +81,15 @@ RewritePatternSet patterns(context); ConversionTarget target(*context); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); + target.addLegalDialect(); populateStdBufferizePatterns(typeConverter, patterns); // We only bufferize the case of tensor selected type and scalar condition, // as that boils down to a select over memref descriptors (don't need to // touch the data). + target.addDynamicallyLegalOp( + [&](IndexCastOp op) { return typeConverter.isLegal(op.getType()); }); target.addDynamicallyLegalOp([&](SelectOp op) { return typeConverter.isLegal(op.getType()) || !op.condition().getType().isa(); diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -24,3 +24,16 @@ %0 = select %arg0, %arg1, %arg2 : tensor return %0 : tensor } + +// CHECK-LABEL: func @index_cast( +// CHECK-SAME: %[[TENSOR:.*]]: tensor, %[[SCALAR:.*]]: i32 +func @index_cast(%tensor: tensor, %scalar: i32) -> (tensor, index) { + %index_tensor = index_cast %tensor : tensor to tensor + %index_scalar = index_cast %scalar : i32 to index + return %index_tensor, %index_scalar : tensor, index +} +// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref +// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = index_cast %[[MEMREF]] +// CHECK-SAME: memref to memref +// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = memref.tensor_load %[[INDEX_MEMREF]] +// CHECK: return %[[INDEX_TENSOR]]