diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h @@ -20,6 +20,18 @@ #include "mlir/Dialect/Bufferization/IR/BufferizationOpsDialect.h.inc" +namespace mlir { +class RewritePatternSet; +class MLIRContext; + +namespace bufferization { +/// Populate patterns for folding to_memref and to_tensor ops. +/// Note: to_memref(to_tensor(x)) without type changes are handled by a folder. +void populateBufferizationOpFoldingPatterns(RewritePatternSet &patterns, + MLIRContext *context); +} // namespace bufferization +} // namespace mlir + //===----------------------------------------------------------------------===// // Bufferization Dialect Operations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -240,7 +240,8 @@ if (resultType.getShape()[i] != ShapedType::kDynamicSize) continue; auto index = rewriter.createOrFold(loc, i); - Value size = rewriter.create(loc, memrefToTensor, index); + Value size = + rewriter.create(loc, memrefToTensor.memref(), index); dynamicOperands.push_back(size); } // TODO: Use alloc/memcpy callback from BufferizationOptions if called via @@ -309,6 +310,11 @@ context); } +void bufferization::populateBufferizationOpFoldingPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(context); +} + LogicalResult ToMemrefOp::bufferize(RewriterBase &rewriter, const BufferizationState &state) { // Fold to_memref(to_tensor(x)) to x. Insert a cast if necessary. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -92,6 +92,7 @@ BufferizeTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add(typeConverter, patterns.getContext()); + populateBufferizationOpFoldingPatterns(patterns, patterns.getContext()); } namespace { diff --git a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir @@ -1,11 +1,11 @@ // RUN: mlir-opt %s -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s -// CHECK-LABEL: func @eliminate_materializations( -// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { -// CHECK: return %[[ARG]] : memref +// CHECK-LABEL: func @eliminate_materializations( +// CHECK-SAME: %[[ARG:.*]]: memref) -> memref { func @eliminate_materializations(%arg0: memref) -> memref { %0 = bufferization.to_tensor %arg0 : memref %1 = bufferization.to_memref %0 : memref + // CHECK: return %[[ARG]] : memref return %1 : memref } @@ -26,3 +26,37 @@ "test.sink"(%0) : (tensor) -> () return } + +// ----- + +// CHECK: #[[$MAP1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +#map1 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK-LABEL: func @insert_memref_cast( +// CHECK-SAME: %[[arg0:.*]]: memref +func @insert_memref_cast(%arg0: memref) -> memref { + %0 = bufferization.to_tensor %arg0 : memref + %1 = bufferization.to_memref %0 : memref + // CHECK: %[[r:.*]] = memref.cast %[[arg0]] : memref to memref + // CHECK: return %[[r]] + return %1 : memref +} + +// ----- + +// CHECK: #[[$MAP2:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +#map2 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK-LABEL: func @insert_buffer_copy( +// CHECK-SAME: %[[arg0:.*]]: memref +func @insert_buffer_copy(%arg0: memref) -> memref { + // CHECK: %[[c0:.*]] = arith.constant 0 : index + // CHECK: %[[dim0:.*]] = memref.dim %[[arg0]], %[[c0]] + // CHECK: %[[alloc:.*]] = memref.alloc(%[[dim0]]) : memref + // CHECK: memref.copy %[[arg0]], %[[alloc]] + %0 = bufferization.to_tensor %arg0 : memref + %1 = bufferization.to_memref %0 : memref + + // CHECK: return %[[alloc]] + return %1 : memref +}