diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -36,8 +36,8 @@ /// Options for ComprehensiveBufferize. struct BufferizationOptions { - using AllocationFn = std::function( - OpBuilder &, Location, MemRefType, ArrayRef)>; + using AllocationFn = std::function(OpBuilder &, Location, + MemRefType, ValueRange)>; using DeallocationFn = std::function; using MemCpyFn = @@ -298,15 +298,23 @@ MemRefType getDynamicMemRefType(RankedTensorType tensorType, unsigned addressSpace = 0); -/// Creates a memref allocation. +/// Creates a memref allocation with the given type and dynamic extents. FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape, + ValueRange dynShape, + const BufferizationOptions &options); + +/// Creates a memref allocation with the given type and dynamic extents. If +/// `createDealloc`, a deallocation op is inserted at the point where the +/// allocation goes out of scope. +FailureOr createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape, bool deallocMemref, const BufferizationOptions &options); /// Creates a memref allocation for the given shaped value. This function may /// perform additional optimizations such as buffer allocation hoisting. If /// `createDealloc`, a deallocation op is inserted at the point where the /// allocation goes out of scope. +// TODO: Allocation hoisting should be a cleanup pass. FailureOr createAlloc(OpBuilder &b, Location loc, Value shapedValue, bool deallocMemref, const BufferizationOptions &options); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -433,10 +433,10 @@ return casted; } -/// Create a memref allocation. +/// Create a memref allocation with the given type and dynamic extents. FailureOr bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape, + ValueRange dynShape, const BufferizationOptions &options) { if (options.allocationFn) return (*options.allocationFn)(b, loc, type, dynShape); @@ -447,6 +447,28 @@ return allocated; } +/// Create a memref allocation with the given type and dynamic extents. May also +/// deallocate the memref again. +FailureOr +bufferization::createAlloc(OpBuilder &b, Location loc, MemRefType type, + ValueRange dynShape, bool deallocMemref, + const BufferizationOptions &options) { + OpBuilder::InsertionGuard g(b); + + FailureOr alloc = createAlloc(b, loc, type, dynShape, options); + if (failed(alloc)) + return failure(); + + if (deallocMemref) { + // Dealloc at the end of the block. + b.setInsertionPoint(alloc.getValue().getParentBlock()->getTerminator()); + if (failed(createDealloc(b, loc, *alloc, options))) + return failure(); + } + + return alloc; +} + /// Create a memref deallocation. LogicalResult bufferization::createDealloc(OpBuilder &b, Location loc, Value allocatedBuffer, diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -64,6 +64,7 @@ MLIRBufferizableOpInterface MLIRIR MLIRMemRef + MLIRSCF MLIRTensor ) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" @@ -231,6 +232,65 @@ } }; +/// Bufferization of tensor.generate. +struct GenerateOpInterface + : public BufferizableOpInterface::ExternalModel { + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationState &state) const { + auto generateOp = cast(op); + + // Allocate memory. + Location loc = op->getLoc(); + MemRefType memrefType = + getContiguousMemRefType(generateOp.getType().cast()); + FailureOr maybeResult = + createAlloc(rewriter, loc, memrefType, generateOp.dynamicExtents(), + /*deallocMemref=*/state.getOptions().createDeallocs, + state.getOptions()); + if (failed(maybeResult)) + return failure(); + Value result = *maybeResult; + + // Collect loop bounds. + int64_t rank = memrefType.getRank(); + Value zero = rewriter.create(loc, 0); + Value one = rewriter.create(loc, 1); + SmallVector lowerBounds(rank, zero); + SmallVector steps(rank, one); + SmallVector upperBounds; + int nextDynamicIndex = 0; + for (int i = 0; i < rank; i++) { + Value upperBound = memrefType.isDynamicDim(i) + ? generateOp.dynamicExtents()[nextDynamicIndex++] + : rewriter.create( + loc, memrefType.getDimSize(i)); + upperBounds.push_back(upperBound); + } + + // Generate tensor elements with a parallel loop that stores into + // each element of the resulting memref. We use mergeBlockBefore to "move" + // this op's body into the scf.parallel's body. + auto parallel = + rewriter.create(loc, lowerBounds, upperBounds, steps); + Block *parallelBody = parallel.getBody(); + rewriter.mergeBlockBefore(generateOp.getBody(), + parallelBody->getTerminator(), + parallelBody->getArguments()); + // Replace the inlined yield op with a store op. The scf.parallel's builder + // already populated an scf.yield at the end, so we don't need to worry + // about creating that. + Operation *elementYield = parallelBody->getTerminator()->getPrevNode(); + rewriter.setInsertionPointAfter(elementYield); + rewriter.replaceOpWithNewOp( + elementYield, elementYield->getOperands()[0], result, + parallelBody->getArguments()); + + replaceOpWithBufferizedValues(rewriter, op, result); + return success(); + } +}; + /// Bufferization of tensor.insert. Replace with memref.store. struct InsertOpInterface : public BufferizableOpInterface::ExternalModel(); registry.addOpInterface(); + registry + .addOpInterface(); registry.addOpInterface(); registry.addOpInterface(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -73,7 +73,7 @@ static FailureOr allocationFnUsingAlloca(OpBuilder &b, Location loc, MemRefType type, - ArrayRef dynShape) { + ValueRange dynShape) { Value allocated = b.create( loc, type, dynShape, b.getI64IntegerAttr(kBufferAlignments)); return allocated; diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1348,3 +1348,23 @@ // CHECK: return %[[f]], %[[select]] return %f, %w : f32, tensor } + +// ----- + +// CHECK-LABEL: func @tensor_generate_static_and_dynamic( +// CHECK-SAME: %[[arg0:.*]]: index +func @tensor_generate_static_and_dynamic(%arg0: index) -> tensor<16x?xindex> { + // CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[c16:.*]] = arith.constant 16 : index + // CHECK: %[[alloc:.*]] = memref.alloc(%[[arg0]]) {{.*}} : memref<16x?xindex> + // CHECK: scf.parallel (%[[arg1:.*]], %[[arg2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c16]], %[[arg0]]) {{.*}} { + %result = tensor.generate %arg0 { + ^bb0(%i: index, %j: index): + %sum = arith.addi %i, %j : index + // CHECK: memref.store {{.*}}, %[[alloc]][%[[arg1]], %[[arg2]]] + // CHECK: scf.yield + tensor.yield %sum : index + } : tensor<16x?xindex> + // CHECK: } + return %result : tensor<16x?xindex> +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6697,6 +6697,7 @@ ":BufferizableOpInterface", ":IR", ":MemRefDialect", + ":SCFDialect", ":Support", ":TensorDialect", "//llvm:Support",