diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -23,19 +23,6 @@ using namespace mlir; namespace { -class BufferizeDimOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(tensor::DimOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - tensor::DimOp::Adaptor adaptor(operands); - rewriter.replaceOpWithNewOp(op, adaptor.source(), - adaptor.index()); - return success(); - } -}; - class BufferizeIndexCastOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -70,8 +57,8 @@ void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); + patterns.add(typeConverter, + patterns.getContext()); } namespace { diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -35,6 +35,21 @@ }; } // namespace +namespace { +class BufferizeDimOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(tensor::DimOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + tensor::DimOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp(op, adaptor.source(), + adaptor.index()); + return success(); + } +}; +} // namespace + namespace { class BufferizeExtractOp : public OpConversionPattern { public: @@ -139,8 +154,9 @@ void mlir::populateTensorBufferizePatterns( BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add(typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); } namespace { diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -1,16 +1,5 @@ // RUN: mlir-opt %s -std-bufferize | FileCheck %s -// CHECK-LABEL: func @dim( -// CHECK-SAME: %[[TENSOR:.*]]: tensor, -// CHECK-SAME: %[[INDEX:.*]]: index) -> index { -// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref -// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref -// CHECK: return %[[EXTENT]] : index -func @dim(%arg0: tensor, %arg1: index) -> index { - %0 = tensor.dim %arg0, %arg1 : tensor - return %0 : index -} - // CHECK-LABEL: func @select( // CHECK-SAME: %[[PRED:.*]]: i1, // CHECK-SAME: %[[TRUE_VAL:.*]]: tensor, diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -1,5 +1,16 @@ // RUN: mlir-opt %s -tensor-bufferize | FileCheck %s +// CHECK-LABEL: func @dim( +// CHECK-SAME: %[[TENSOR:.*]]: tensor, +// CHECK-SAME: %[[INDEX:.*]]: index) -> index { +// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref +// CHECK: %[[EXTENT:.*]] = memref.dim %[[MEMREF]], %[[INDEX]] : memref +// CHECK: return %[[EXTENT]] : index +func @dim(%arg0: tensor, %arg1: index) -> index { + %0 = tensor.dim %arg0, %arg1 : tensor + return %0 : index +} + // CHECK-LABEL: func @tensor.cast( // CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor<2xindex> { // CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] @@ -67,7 +78,8 @@ // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { -// CHECK: %[[ELEM:.*]] = tensor.dim %[[ARG]], %[[I]] : tensor<*xf32> +// CHECK: %[[CASTED:.*]] = memref.buffer_cast %[[ARG]] : memref<*xf32> +// CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32> // CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref // CHECK: scf.yield // CHECK: }