diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -23,6 +23,7 @@ namespace linalg { +struct LinalgBufferizeOptions; struct LinalgFusionOptions; struct LinalgTilingOptions; @@ -50,10 +51,25 @@ MLIRContext *context, SmallVectorImpl &patterns, ArrayRef tileSizes); +/// Callback function type used to perform the allocation for bufferization. +using BufferizeAllocCallbackFn = std::function; + +Value defaultBufferizationAllocFn(OpBuilder &b, Location loc, MemRefType type, + ValueRange allocOperands); + /// Populates the given list with patterns to bufferize linalg ops. -void populateLinalgBufferizePatterns(MLIRContext *context, - BufferizeTypeConverter &converter, - OwningRewritePatternList &patterns); +struct LinalgBufferizeOptions { + BufferizeAllocCallbackFn allocationFn = defaultBufferizationAllocFn; + LinalgBufferizeOptions &setAllocFn(BufferizeAllocCallbackFn const &allocFn) { + allocationFn = allocFn; + return *this; + } +}; +void populateLinalgBufferizePatterns( + MLIRContext *context, TypeConverter &converter, + OwningRewritePatternList &patterns, + LinalgBufferizeOptions options = LinalgBufferizeOptions()); /// Performs standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` @@ -154,6 +170,7 @@ /// smallest constant value for the size of the buffer needed for each /// dimension. If that is not possible, contains the dynamic size of the /// subview. The call back should return the buffer to use. +// TODO: unify API with using AllocBufferCallbackFn = std::function( OpBuilder &b, SubViewOp subView, ArrayRef boundingSubViewSize, OperationFolder *folder)>; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -21,8 +21,8 @@ using namespace ::mlir; using namespace ::mlir::linalg; -static SmallVector computeLoopRanges(Location loc, LinalgOp linalgOp, - OpBuilder &b) { +static SmallVector computeLoopRanges(OpBuilder &b, Location loc, + LinalgOp linalgOp) { auto indexingMaps = llvm::to_vector<4>( linalgOp.indexing_maps().getAsValueRange()); auto inputIndexingMaps = @@ -40,7 +40,8 @@ return b.create(loc, val, b.getIndexType()); } -static Value cloneMemref(Location loc, Value memref, OpBuilder &b) { +static Value cloneMemref(OpBuilder &b, Location loc, Value memref, + LinalgBufferizeOptions options) { auto memrefType = memref.getType().cast(); SmallVector dynOperands; for (auto dim : llvm::enumerate(memrefType.getShape())) { @@ -48,7 +49,7 @@ dynOperands.push_back(b.create(loc, memref, dim.index())); } } - auto alloc = b.create(loc, memrefType, dynOperands); + Value alloc = options.allocationFn(b, loc, memrefType, dynOperands); b.create(loc, memref, alloc); return alloc; } @@ -56,7 +57,8 @@ static LogicalResult allocateBuffersForResults(Location loc, LinalgOp linalgOp, linalg::GenericOpAdaptor &adaptor, - SmallVectorImpl &resultBuffers, OpBuilder &b) { + SmallVectorImpl &resultBuffers, OpBuilder &b, + LinalgBufferizeOptions options) { // Lazily compute loopRanges. SmallVector loopRanges; @@ -72,22 +74,24 @@ return failure(); } auto tensorShape = tensorType.getShape(); - auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); + auto memrefType = + MemRefType::get(tensorShape, tensorType.getElementType(), {}); // Allocate buffers for init tensors that are assumed to fold onto the first // results. // TODO: update this assumption because the reality is more complex // under linalg on tensor based transformations. + bool hasInitTensor = resultIndex < linalgOp.getNumInitTensors(); if (hasInitTensor) { resultBuffers.push_back( - cloneMemref(loc, adaptor.init_tensors()[resultIndex], b)); + cloneMemref(b, loc, adaptor.init_tensors()[resultIndex], options)); continue; } // Allocate buffers for statically-shaped results. if (memrefType.hasStaticShape()) { - resultBuffers.push_back(b.create(loc, memrefType)); + resultBuffers.push_back(options.allocationFn(b, loc, memrefType, {})); continue; } @@ -97,7 +101,7 @@ auto resultIndexingMap = linalgOp.getOutputIndexingMap(resultIndex); for (auto shapeElement : llvm::enumerate(tensorType.getShape())) { if (loopRanges.empty()) - loopRanges = computeLoopRanges(loc, linalgOp, b); + loopRanges = computeLoopRanges(b, loc, linalgOp); if (shapeElement.value() != ShapedType::kDynamicSize) continue; @@ -114,7 +118,8 @@ return failure(); } } - resultBuffers.push_back(b.create(loc, memrefType, dynOperands)); + resultBuffers.push_back( + options.allocationFn(b, loc, memrefType, dynOperands)); } return success(); } @@ -194,8 +199,10 @@ /// instantiating one pattern for each LinalgOp. class BufferizeAnyLinalgOp : public ConversionPattern { public: - BufferizeAnyLinalgOp(TypeConverter &typeConverter) - : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {} + BufferizeAnyLinalgOp(TypeConverter &typeConverter, + LinalgBufferizeOptions options) + : ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()), + options(options) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -214,8 +221,8 @@ SmallVector newOutputBuffers(adaptor.output_buffers().begin(), adaptor.output_buffers().end()); - if (failed(allocateBuffersForResults(loc, linalgOp, adaptor, - newOutputBuffers, rewriter))) { + if (failed(allocateBuffersForResults( + loc, linalgOp, adaptor, newOutputBuffers, rewriter, options))) { linalgOp.emitOpError() << "Failed to allocate buffers for tensor results."; return failure(); @@ -232,6 +239,8 @@ newOutputBuffers); return success(); } + + LinalgBufferizeOptions options; }; // Extract int64_t values from the assumed ArrayAttr of IntegerAttr. @@ -255,6 +264,11 @@ class SubTensorOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + SubTensorOpConverter(TypeConverter &typeConverter, + LinalgBufferizeOptions options, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), options(options) { + } LogicalResult matchAndRewrite(SubTensorOp op, ArrayRef operands, @@ -269,8 +283,8 @@ // op.sizes() capture exactly the dynamic alloc operands matching the // subviewMemRefType thanks to subview/subtensor canonicalization and // verification. - Value alloc = - rewriter.create(op.getLoc(), subviewMemRefType, op.sizes()); + Value alloc = options.allocationFn(rewriter, op.getLoc(), subviewMemRefType, + op.sizes()); Value subView = rewriter.create( op.getLoc(), sourceMemref, extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(op.static_sizes()), @@ -280,11 +294,13 @@ rewriter.replaceOp(op, alloc); return success(); } + + LinalgBufferizeOptions options; }; /// Convert `subtensor_insert %source into %dest [offsets][sizes][strides] -> /// %t` to an tensor_to_memref + subview + copy + tensor_load pattern. -/// tensor_to_memref and tensor_load are inserted automatically by the +/// tensor_to_memref and tensor_load are inserted automatically by the /// conversion infra: /// ``` /// %sv = subview %dest [offsets][sizes][strides] @@ -298,6 +314,11 @@ : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; + SubTensorInsertOpConverter(TypeConverter &typeConverter, + LinalgBufferizeOptions options, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), options(options) { + } LogicalResult matchAndRewrite(SubTensorInsertOp op, ArrayRef operands, @@ -310,7 +331,8 @@ // For now, be conservative and copy the converted input memref. // In general, the converted input memref here could be aliased or could // point into constant memory, so mutating it would lead to miscompilations. - Value destMemRef = cloneMemref(op.getLoc(), adaptor.dest(), rewriter); + Value destMemRef = + cloneMemref(rewriter, op.getLoc(), adaptor.dest(), options); assert(destMemRef.getType().isa()); // Take a subview to copy the small memref. @@ -324,10 +346,10 @@ rewriter.replaceOp(op, destMemRef); return success(); } + + LinalgBufferizeOptions options; }; -} // namespace -namespace { /// Converts Linalg operations that work on tensor-type operands or results to /// work on buffers. struct LinalgBufferizePass : public LinalgBufferizeBase { @@ -356,19 +378,25 @@ }; } // end anonymous namespace +Value mlir::linalg::defaultBufferizationAllocFn(OpBuilder &b, Location loc, + MemRefType type, + ValueRange allocOperands) { + return b.create(loc, type, allocOperands); +} + std::unique_ptr> mlir::createLinalgBufferizePass() { return std::make_unique(); } void mlir::linalg::populateLinalgBufferizePatterns( - MLIRContext *context, BufferizeTypeConverter &typeConverter, - OwningRewritePatternList &patterns) { - patterns.insert(typeConverter); + MLIRContext *context, TypeConverter &typeConverter, + OwningRewritePatternList &patterns, LinalgBufferizeOptions options) { + patterns.insert(typeConverter, options); // TODO: Drop this once tensor constants work in standard. patterns.insert< // clang-format off SubTensorOpConverter, SubTensorInsertOpConverter // clang-format on - >(typeConverter, context); + >(typeConverter, options, context); }