diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -80,6 +80,12 @@ void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter, RewritePatternSet &patterns); +/// Create linalg op on buffers given the original tensor-based operation and +/// the buffers for the outputs. +LinalgOp createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, + LinalgOp linalgOp, ValueRange inputs, + ValueRange outputs); + /// Patterns to fold unit-extent dimensions in operands/results of linalg ops on /// tensors. void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -73,56 +73,44 @@ return success(); } -/// Specialization for `linalg::GenericOp`. -/// A pattern to convert Generic Linalg operations which work on tensors to -/// use buffers. BufferPlacement pass should be later used to move -/// Alloc operations to the correct positions and insert the missing Dealloc -/// operations in the correct places. -static void -finalizeBufferAllocationForGenericOp(ConversionPatternRewriter &rewriter, - GenericOp genericOp, ValueRange inputs, - ValueRange outputs) { - // Generate a new linalg operation that works on buffers. - auto newGenericOp = rewriter.create( - genericOp.getLoc(), - /*resultTensorTypes=*/llvm::None, - /*inputs=*/inputs, - /*outputs=*/outputs, genericOp.indexing_maps(), - genericOp.iterator_types(), genericOp.docAttr(), - genericOp.library_callAttr()); - - // Create a new block in the region of the new Generic Op. - Block *oldBlock = genericOp.getBody(); - Region &newRegion = newGenericOp.region(); - Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), - oldBlock->getArgumentTypes()); - - // Clone the body of the old block to the new block. - BlockAndValueMapping mapping; - mapping.map(oldBlock->getArguments(), newBlock->getArguments()); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToEnd(newBlock); - for (auto &op : oldBlock->getOperations()) { - Operation *clonedOp = rewriter.clone(op, mapping); - mapping.map(op.getResults(), clonedOp->getResults()); +/// Create linalg op on buffers given the original tensor-based operation and +/// the buffers for the outputs. +LinalgOp +mlir::linalg::createLinalgOpOnBuffers(ConversionPatternRewriter &rewriter, + LinalgOp linalgOp, ValueRange inputs, + ValueRange outputs) { + if (auto genericOp = mlir::dyn_cast(*linalgOp)) { + // Generate a new linalg operation that works on buffers. + auto newGenericOp = rewriter.create( + genericOp.getLoc(), + /*resultTensorTypes=*/llvm::None, + /*inputs=*/inputs, + /*outputs=*/outputs, genericOp.indexing_maps(), + genericOp.iterator_types(), genericOp.docAttr(), + genericOp.library_callAttr()); + + // Create a new block in the region of the new Generic Op. + Block *oldBlock = genericOp.getBody(); + Region &newRegion = newGenericOp.region(); + Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), + oldBlock->getArgumentTypes()); + + // Clone the body of the old block to the new block. + BlockAndValueMapping mapping; + mapping.map(oldBlock->getArguments(), newBlock->getArguments()); + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToEnd(newBlock); + for (auto &op : oldBlock->getOperations()) { + Operation *clonedOp = rewriter.clone(op, mapping); + mapping.map(op.getResults(), clonedOp->getResults()); + } + return newGenericOp; } - - // Replace the results of the old op with the new output buffers. - rewriter.replaceOp(genericOp, outputs); -} - -/// Specialization for all other `linalg::LinalgOp`. -static void finalizeBufferAllocation(ConversionPatternRewriter &rewriter, - linalg::LinalgOp linalgOp, - ValueRange inputs, ValueRange outputs) { - assert(!isa(linalgOp.getOperation())); SmallVector newOperands = inputs; newOperands.append(outputs.begin(), outputs.end()); - linalgOp.clone(rewriter, linalgOp.getLoc(), - /*resultTypes=*/ArrayRef{}, newOperands); - // Replace the results of the old op with the new output buffers. - rewriter.replaceOp(linalgOp, outputs); + return linalgOp.clone(rewriter, linalgOp.getLoc(), + /*resultTypes=*/ArrayRef{}, newOperands); } //===----------------------------------------------------------------------===// @@ -218,15 +206,9 @@ return op.emitOpError() << "Failed to allocate buffers for tensor results."; } - - // Delegate to the linalg generic pattern. - if (auto genericOp = dyn_cast(*op)) { - finalizeBufferAllocationForGenericOp(rewriter, genericOp, - adaptor.inputs(), newOutputBuffers); - return success(); - } - - finalizeBufferAllocation(rewriter, op, adaptor.inputs(), newOutputBuffers); + createLinalgOpOnBuffers(rewriter, op, adaptor.inputs(), newOutputBuffers); + // Replace the results of the old op with the new output buffers. + rewriter.replaceOp(op, newOutputBuffers); return success(); } };