diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -15,20 +15,49 @@ using namespace mlir; +/// Return `true` if the given MemRef type has a fully dynamic layout. +static bool hasFullyDynamicLayoutMap(MemRefType type) { + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(type, strides, offset))) + return false; + if (!llvm::all_of(strides, [](int64_t stride) { + return ShapedType::isDynamicStrideOrOffset(stride); + })) + return false; + if (!ShapedType::isDynamicStrideOrOffset(offset)) + return false; + return true; +} + +/// Return `true` if the given MemRef type has a static identity layout (i.e., +/// no layout). +static bool hasStaticIdentityLayout(MemRefType type) { + return type.getLayout().isIdentity(); +} + // Updates the func op and entry block. // // Any args appended to the entry block are added to `appendedEntryArgs`. -static void updateFuncOp(func::FuncOp func, - SmallVectorImpl &appendedEntryArgs) { +static LogicalResult +updateFuncOp(func::FuncOp func, + SmallVectorImpl &appendedEntryArgs) { auto functionType = func.getFunctionType(); // Collect information about the results will become appended arguments. SmallVector erasedResultTypes; BitVector erasedResultIndices(functionType.getNumResults()); for (const auto &resultType : llvm::enumerate(functionType.getResults())) { - if (resultType.value().isa()) { + if (auto memrefType = resultType.value().dyn_cast()) { + if (!hasStaticIdentityLayout(memrefType) && + !hasFullyDynamicLayoutMap(memrefType)) + // Only buffers with static identity layout can be allocated. These can + // be casted to memrefs with fully dynamic layout map. Other layout maps + // are not supported. + return func->emitError() + << "cannot create out param for result with unsupported layout"; erasedResultIndices.set(resultType.index()); - erasedResultTypes.push_back(resultType.value()); + erasedResultTypes.push_back(memrefType); } } @@ -51,10 +80,12 @@ // Add the new arguments to the entry block if the function is not external. if (func.isExternal()) - return; + return success(); Location loc = func.getLoc(); for (Type type : erasedResultTypes) appendedEntryArgs.push_back(func.front().addArgument(type, loc)); + + return success(); } // Updates all ReturnOps in the scope of the given func::FuncOp by either @@ -66,7 +97,7 @@ SmallVector copyIntoOutParams; SmallVector keepAsReturnOperands; for (Value operand : op.getOperands()) { - if (operand.getType().isa()) + if (operand.getType().isa()) copyIntoOutParams.push_back(operand); else keepAsReturnOperands.push_back(operand); @@ -88,7 +119,7 @@ SmallVector replaceWithNewCallResults; SmallVector replaceWithOutParams; for (OpResult result : op.getResults()) { - if (result.getType().isa()) + if (result.getType().isa()) replaceWithOutParams.push_back(result); else replaceWithNewCallResults.push_back(result); @@ -96,14 +127,24 @@ SmallVector outParams; OpBuilder builder(op); for (Value memref : replaceWithOutParams) { - if (!memref.getType().cast().hasStaticShape()) { + if (!memref.getType().cast().hasStaticShape()) { op.emitError() << "cannot create out param for dynamically shaped result"; didFail = true; return; } - Value outParam = builder.create( - op.getLoc(), memref.getType().cast()); + auto memrefType = memref.getType().cast(); + auto allocType = + MemRefType::get(memrefType.getShape(), memrefType.getElementType(), + AffineMap(), memrefType.getMemorySpaceAsInt()); + Value outParam = builder.create(op.getLoc(), allocType); + if (!hasStaticIdentityLayout(memrefType)) { + // Layout maps are already checked in `updateFuncOp`. + assert(hasFullyDynamicLayoutMap(memrefType) && + "layout map not supported"); + outParam = + builder.create(op.getLoc(), memrefType, outParam); + } memref.replaceAllUsesWith(outParam); outParams.push_back(outParam); } @@ -126,7 +167,8 @@ mlir::bufferization::promoteBufferResultsToOutParams(ModuleOp module) { for (auto func : module.getOps()) { SmallVector appendedEntryArgs; - updateFuncOp(func, appendedEntryArgs); + if (failed(updateFuncOp(func, appendedEntryArgs))) + return failure(); if (func.isExternal()) continue; updateReturnOps(func, appendedEntryArgs); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -285,19 +285,18 @@ &hasLeakingAllocs))) return failure(); - if (hasLeakingAllocs) { - // Promote returned buffers to "out" parameters. - // TODO: Pass options to support custom dealloc ops. - if (options.promoteBufferResultsToOutParams && isa(op) && - failed(promoteBufferResultsToOutParams(cast(op)))) - return failure(); - - // Create deallocation ops for all "leaking buffers" and all buffer - // allocations that were added during the above promotion process. - // TODO: Pass options to support custom dealloc ops. - if (options.createDeallocs && failed(deallocateBuffers(op))) - return failure(); - } + // Promote returned buffers to "out" parameters. + // TODO: Pass options to support custom dealloc ops. + if (options.promoteBufferResultsToOutParams && isa(op) && + failed(promoteBufferResultsToOutParams(cast(op)))) + return failure(); + + // Create deallocation ops for all "leaking buffers" and all buffer + // allocations that were added during the above promotion process. + // TODO: Pass options to support custom dealloc ops. + if (hasLeakingAllocs && options.createDeallocs && + failed(deallocateBuffers(op))) + return failure(); // Deallocate all remaining buffers at the end of their parent blocks. if (failed(createAllocDeallocOps(op, options))) diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir @@ -1,11 +1,16 @@ -// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs promote-buffer-results-to-out-params" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs promote-buffer-results-to-out-params function-boundary-type-conversion=fully-dynamic-layout-map" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs promote-buffer-results-to-out-params function-boundary-type-conversion=identity-layout-map" -split-input-file | FileCheck %s --check-prefix=CHECK-NO-LAYOUT +// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries allow-return-allocs function-boundary-type-conversion=infer-layout-map" -split-input-file | FileCheck %s --check-prefix=CHECK-BASELINE + +// Note: function-boundary-type-conversion=infer-layout-map with +// promote-buffer-results-to-out-params is an unsupported combination. // Note: This bufferization is not very efficient yet, but it works. // CHECK: #[[$map1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // CHECK-LABEL: func @callee( // CHECK-SAME: %[[arg0:.*]]: memref<5xf32, #[[$map1]]>, -// CHECK-SAME: %[[arg1:.*]]: memref<5xf32>) { +// CHECK-SAME: %[[arg1:.*]]: memref<5xf32, #[[$map1]]>) { // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> // CHECK: memref.copy %[[arg0]], %[[alloc]] // CHECK: memref.store %{{.*}}, %[[alloc]] @@ -13,21 +18,45 @@ // CHECK: memref.dealloc %[[alloc]] // CHECK: return // CHECK: } + +// CHECK-NO-LAYOUT-LABEL: func @callee(%{{.*}}: memref<5xf32>, +// CHECK-NO-LAYOUT-SAME: %{{.*}}: memref<5xf32>) { +// CHECK-NO-LAYOUT: memref.alloc +// CHECK-NO-LAYOUT: memref.copy +// CHECK-NO-LAYOUT: memref.store +// CHECK-NO-LAYOUT: memref.copy +// CHECK-NO-LAYOUT: memref.dealloc + +// CHECK-BASELINE: #[[$map1:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> +// CHECK-BASELINE-LABEL: func @callee( +// CHECK-BASELINE-SAME: %[[arg0:.*]]: memref<5xf32, #[[$map1]]>) -> memref<5xf32> { +// CHECK-BASELINE: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32> +// CHECK-BASELINE: memref.copy %[[arg0]], %[[alloc]] +// CHECK-BASELINE: memref.store {{.*}}, %[[alloc]] +// CHECK-BASELINE: return %[[alloc]] func.func @callee(%t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { %c0 = arith.constant 0 : index %cst = arith.constant 8.0 : f32 + // This must bufferize out-of-place. %1 = tensor.insert %cst into %t[%c0] : tensor<5xf32> + // Instead of returning %1, copy into new out param. %t will disappear + // entirely because the buffer is equivalent to a bbArg. return %t, %1 : tensor<5xf32>, tensor<5xf32> } // CHECK: func @main(%[[arg0:.*]]: memref<5xf32, #[[$map1]]>) -> (f32, f32) { // CHECK: %[[alloc:.*]] = memref.alloc() : memref<5xf32> -// CHECK: call @callee(%[[arg0]], %[[alloc]]) +// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<5xf32> to memref<5xf32, #[[$map1]]> +// CHECK: call @callee(%[[arg0]], %[[casted]]) // CHECK: %[[l1:.*]] = memref.load %[[arg0]] // CHECK: %[[l2:.*]] = memref.load %[[alloc]] // CHECK: memref.dealloc %[[alloc]] // CHECK: return %[[l1]], %[[l2]] // CHECK: } + +// CHECK-NO-LAYOUT-LABEL: func @main(%{{.*}}: memref<5xf32>) -> (f32, f32) { +// CHECK-NO-LAYOUT: %[[alloc:.*]] = memref.alloc() : memref<5xf32> +// CHECK-NO-LAYOUT: call @callee(%{{.*}}, %[[alloc]]) func.func @main(%t: tensor<5xf32>) -> (f32, f32) { %c0 = arith.constant 0 : index %0, %1 = func.call @callee(%t) @@ -37,3 +66,63 @@ return %2, %3 : f32, f32 } +// ----- + +// CHECK: #[[$map2a:.*]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// CHECK: #[[$map2b:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> +// CHECK-LABEL: func @callee( +// CHECK-SAME: %{{.*}}: index, +// CHECK-SAME: %[[r:.*]]: memref<2x5xf32, #[[$map2a]]>) { +// CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10x20xf32> +// CHECK: %[[subview:.*]] = memref.subview %[[alloc]]{{.*}} : memref<10x20xf32> to memref<2x5xf32, #[[$map2b]]> +// CHECK: memref.copy %[[subview]], %[[r]] +// CHECK: memref.dealloc %[[alloc]] + +// CHECK-NO-LAYOUT-LABEL: func @callee( +// CHECK-NO-LAYOUT-SAME: %{{.*}}: index, +// CHECK-NO-LAYOUT-SAME: %[[r:.*]]: memref<2x5xf32>) { +// CHECK-NO-LAYOUT: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10x20xf32> +// CHECK-NO-LAYOUT: %[[subview:.*]] = memref.subview %[[alloc]] +// Note: This alloc is not needed, but it is inserted before the returned buffer +// is promoted to an out param to reconcile mismatching layout maps on return +// value and function signature. +// CHECK-NO-LAYOUT: %[[alloc2:.*]] = memref.alloc() : memref<2x5xf32> +// CHECK-NO-LAYOUT: memref.copy %[[subview]], %[[alloc2]] +// CHECK-NO-LAYOUT: memref.dealloc %[[alloc]] +// CHECK-NO-LAYOUT: memref.copy %[[alloc2]], %[[r]] +// CHECK-NO-LAYOUT: memref.dealloc %[[alloc2]] + +// CHECK-BASELINE: #[[$map2:.*]] = affine_map<(d0, d1)[s0] -> (d0 * 20 + s0 + d1)> +// CHECK-BASELINE-LABEL: func @callee( +// CHECK-BASELINE-SAME: %{{.*}}: index) -> memref<2x5xf32, #[[$map2]]> { +// CHECK-BASELINE: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<10x20xf32> +// CHECK-BASELINE: %[[subview:.*]] = memref.subview %[[alloc]] +// CHECK-BASELINE: return %[[subview]] +func.func @callee(%idx: index) -> tensor<2x5xf32> { + %0 = linalg.init_tensor [10, 20] : tensor<10x20xf32> + %1 = tensor.extract_slice %0[%idx, %idx][2, 5][1, 1] : tensor<10x20xf32> to tensor<2x5xf32> + return %1 : tensor<2x5xf32> +} + +// CHECK: func @main( +// CHECK: %[[alloc:.*]] = memref.alloc() : memref<2x5xf32> +// CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<2x5xf32> to memref<2x5xf32, #[[$map2a]]> +// CHECK: call @callee(%{{.*}}, %[[casted]]) +// CHECK: memref.load %[[alloc]] +// CHECK: memref.dealloc %[[alloc]] + +// CHECK-NO-LAYOUT: func @main( +// CHECK-NO-LAYOUT: %[[alloc:.*]] = memref.alloc() : memref<2x5xf32> +// CHECK-NO-LAYOUT: call @callee(%{{.*}}, %[[alloc]]) +// CHECK-NO-LAYOUT: memref.load %[[alloc]] +// CHECK-NO-LAYOUT: memref.dealloc + +// CHECK-BASELINE: func @main( +// CHECK-BASELINE: %[[call:.*]] = call @callee +// CHECK-BASELINE: memref.load %[[call]] +func.func @main(%idx: index) -> f32 { + %c0 = arith.constant 0 : index + %0 = func.call @callee(%idx) : (index) -> (tensor<2x5xf32>) + %1 = tensor.extract %0[%c0, %c0] : tensor<2x5xf32> + return %1 : f32 +}