diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -39,10 +39,18 @@ return rewriter.notifyMatchFailure( op, "UnrankedMemRefType is not supported."); } + MemRefType memrefType = type.cast(); + MemRefLayoutAttrInterface layout; + auto allocType = + MemRefType::get(memrefType.getShape(), memrefType.getElementType(), + layout, memrefType.getMemorySpace()); + // Since this implementation always allocates, certain result types of the + // clone op cannot be lowered. + if (!memref::CastOp::areCastCompatible({allocType}, {memrefType})) + return failure(); // Transform a clone operation into alloc + copy operation and pay // attention to the shape dimensions. - MemRefType memrefType = type.cast(); Location loc = op->getLoc(); SmallVector dynamicOperands; for (int i = 0; i < memrefType.getRank(); ++i) { @@ -52,8 +60,14 @@ Value dim = rewriter.createOrFold(loc, op.input(), size); dynamicOperands.push_back(dim); } - Value alloc = rewriter.replaceOpWithNewOp(op, memrefType, - dynamicOperands); + + // Allocate a memref with identity layout. + Value alloc = rewriter.create(op->getLoc(), allocType, + dynamicOperands); + // Cast the allocation to the specified type if needed. + if (memrefType != allocType) + alloc = rewriter.create(op->getLoc(), memrefType, alloc); + rewriter.replaceOp(op, alloc); rewriter.create(loc, op.input(), alloc); return success(); } diff --git a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir --- a/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir +++ b/mlir/test/Conversion/BufferizationToMemRef/bufferization-to-memref.mlir @@ -2,9 +2,9 @@ // CHECK-LABEL: @conversion_static func @conversion_static(%arg0 : memref<2xf32>) -> memref<2xf32> { - %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32> - memref.dealloc %arg0 : memref<2xf32> - return %0 : memref<2xf32> + %0 = bufferization.clone %arg0 : memref<2xf32> to memref<2xf32> + memref.dealloc %arg0 : memref<2xf32> + return %0 : memref<2xf32> } // CHECK: %[[ALLOC:.*]] = memref.alloc @@ -16,9 +16,9 @@ // CHECK-LABEL: @conversion_dynamic func @conversion_dynamic(%arg0 : memref) -> memref { - %1 = bufferization.clone %arg0 : memref to memref - memref.dealloc %arg0 : memref - return %1 : memref + %1 = bufferization.clone %arg0 : memref to memref + memref.dealloc %arg0 : memref + return %1 : memref } // CHECK: %[[CONST:.*]] = arith.constant @@ -32,7 +32,40 @@ func @conversion_unknown(%arg0 : memref<*xf32>) -> memref<*xf32> { // expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}} - %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32> - memref.dealloc %arg0 : memref<*xf32> - return %1 : memref<*xf32> + %1 = bufferization.clone %arg0 : memref<*xf32> to memref<*xf32> + memref.dealloc %arg0 : memref<*xf32> + return %1 : memref<*xf32> +} + +// ----- + +// CHECK: #[[$MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +#map = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// CHECK-LABEL: func @conversion_with_layout_map( +// CHECK-SAME: %[[ARG:.*]]: memref +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[DIM:.*]] = memref.dim %[[ARG]], %[[C0]] +// CHECK: %[[ALLOC:.*]] = memref.alloc(%[[DIM]]) : memref +// CHECK: %[[CASTED:.*]] = memref.cast %[[ALLOC]] : memref to memref +// CHECK: memref.copy +// CHECK: memref.dealloc +// CHECK: return %[[CASTED]] +func @conversion_with_layout_map(%arg0 : memref) -> memref { + %1 = bufferization.clone %arg0 : memref to memref + memref.dealloc %arg0 : memref + return %1 : memref +} + +// ----- + +// This bufferization.clone cannot be lowered because a buffer with this layout +// map cannot be allocated (or casted to). + +#map2 = affine_map<(d0)[s0] -> (d0 * 10 + s0)> +func @conversion_with_invalid_layout_map(%arg0 : memref) + -> memref { +// expected-error@+1 {{failed to legalize operation 'bufferization.clone' that was explicitly marked illegal}} + %1 = bufferization.clone %arg0 : memref to memref + memref.dealloc %arg0 : memref + return %1 : memref }