diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h --- a/mlir/include/mlir/Transforms/Bufferize.h +++ b/mlir/include/mlir/Transforms/Bufferize.h @@ -148,10 +148,23 @@ /// /// This function should be called by all bufferization passes using /// BufferizeTypeConverter so that materializations work proprely. One exception -/// is bufferization passes doing "full" conversions, where it can be desirable -/// for even the materializations to remain illegal so that they are eliminated. +/// is "finalizing" bufferization passes (such as those doing "full" +/// conversions), where it can be desirable for even the materializations to +/// remain illegal so that they are eliminated. For "finalizing" passes, +/// populateBufferizeEliminateMaterializationsPatternsAndLegality is +/// recommended. void populateBufferizeMaterializationLegality(ConversionTarget &target); +/// Populate patterns and set up legality for eliminating materialization ops. +/// +/// Some bufferization passes are "finalizing". They expect that all tensors are +/// converted to memrefs, including function arguments, basic block arguments, +/// etc. In these cases, all materializations that may have been inserted for +/// structural legality of the IR need to be eliminated. +void populateBufferizeEliminateMaterializationsPatternsAndLegality( + MLIRContext *context, BufferizeTypeConverter &typeConverter, + OwningRewritePatternList &patterns, ConversionTarget &target); + /// Helper conversion pattern that encapsulates a BufferizeTypeConverter /// instance. template diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -76,6 +76,46 @@ target.addLegalOp(); }; +//===----------------------------------------------------------------------===// +// populateBufferizeEliminateMaterializationsPatternsAndLegality +//===----------------------------------------------------------------------===// + +namespace { +class BufferizeTensorLoadOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(TensorLoadOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TensorLoadOp::Adaptor adaptor(operands); + rewriter.replaceOp(op, adaptor.memref()); + return success(); + } +}; +} // namespace + +namespace { +class BufferizeTensorToMemrefOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(TensorToMemrefOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TensorToMemrefOp::Adaptor adaptor(operands); + rewriter.replaceOp(op, adaptor.tensor()); + return success(); + } +}; +} // namespace + +void mlir::populateBufferizeEliminateMaterializationsPatternsAndLegality( + MLIRContext *context, BufferizeTypeConverter &typeConverter, + OwningRewritePatternList &patterns, ConversionTarget &target) { + patterns.insert( + typeConverter, context); + target.addIllegalOp(); +} + //===----------------------------------------------------------------------===// // BufferizeFuncOpConverter //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -14,6 +14,7 @@ #include "TestDialect.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/IR/Function.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" @@ -37,124 +38,6 @@ TestBufferPlacementPreparationPass, OperationPass> { - /// Converts tensor-type generic linalg operations to memref ones using - /// bufferize. - /// TODO: Avoid the copy-pasta by exposing the pattern from BufferPlacement.h - /// This is limited by not wanting BufferPlacement to depend on Linalg. Fixing - /// this probably requires an OpConversionPattern over generic Operation*. For - /// now only RewritePattern but not ConversionPattern allow this. - - class GenericOpConverter - : public BufferizeOpConversionPattern { - public: - using BufferizeOpConversionPattern< - linalg::GenericOp>::BufferizeOpConversionPattern; - - LogicalResult - matchAndRewrite(linalg::GenericOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - linalg::GenericOpAdaptor adaptor(operands, - op.getOperation()->getAttrDictionary()); - - // All inputs need to be turned into buffers first. Until then, bail out. - if (llvm::any_of(adaptor.inputs(), [](Value in) { - return !in.getType().isa(); - })) - return failure(); - - // All init_tensors need to be turned into buffers first. Until then, bail - // out. - if (llvm::any_of(adaptor.init_tensors(), [](Value in) { - return !in.getType().isa(); - })) - return failure(); - - Location loc = op.getLoc(); - SmallVector newOutputBuffers; - newOutputBuffers.reserve(op.getNumOutputs()); - newOutputBuffers.append(adaptor.output_buffers().begin(), - adaptor.output_buffers().end()); - - // Update all types to memref types. - // Assume the init tensors fold onto the first results. - // TODO: update this assumption because the reality is more complex under - // linalg on tensor based transformations. - for (auto en : llvm::enumerate(op.getResultTypes())) { - auto type = en.value().cast(); - if (!type.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "dynamic shapes not currently supported"); - auto memrefType = - MemRefType::get(type.getShape(), type.getElementType()); - bool foldedInitTensor = en.index() < op.getNumInitTensors(); - if (foldedInitTensor) { - // Dealing with an init tensor requires distinguishing between 1-use - // and many-use cases which would create aliasing and WAR hazards. - Value initTensor = op.getInitTensor(en.index()); - Value initBuffer = adaptor.init_tensors()[en.index()]; - if (initTensor.hasOneUse()) { - newOutputBuffers.push_back(initBuffer); - continue; - } - auto alloc = rewriter.create(loc, memrefType); - rewriter.create(loc, initBuffer, alloc); - newOutputBuffers.push_back(alloc); - } else { - auto alloc = rewriter.create(loc, memrefType); - newOutputBuffers.push_back(alloc); - } - } - - // Generate a new linalg operation that works on buffers. - auto linalgOp = rewriter.create( - loc, - /*resultTensorTypes=*/ArrayRef{}, - /*inputs=*/adaptor.inputs(), - /*outputBuffers=*/newOutputBuffers, - /*initTensors=*/ValueRange{}, op.indexing_maps(), op.iterator_types(), - op.docAttr(), op.library_callAttr(), op.symbol_sourceAttr()); - - // Create a new block in the region of the new Generic Op. - Block &oldBlock = op.getRegion().front(); - Region &newRegion = linalgOp.region(); - Block *newBlock = rewriter.createBlock(&newRegion, newRegion.begin(), - oldBlock.getArgumentTypes()); - - // Add the result arguments that do not come from init_tensors to the new - // block. - // TODO: update this assumption because the reality is more complex under - // linalg on tensor based transformations. - for (Value v : ValueRange(newOutputBuffers) - .drop_front(adaptor.init_tensors().size())) - newBlock->addArgument(v.getType().cast().getElementType()); - - // Clone the body of the old block to the new block. - BlockAndValueMapping mapping; - for (unsigned i = 0; i < oldBlock.getNumArguments(); i++) - mapping.map(oldBlock.getArgument(i), newBlock->getArgument(i)); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToEnd(newBlock); - for (auto &op : oldBlock.getOperations()) { - Operation *clonedOp = rewriter.clone(op, mapping); - mapping.map(op.getResults(), clonedOp->getResults()); - } - - // Replace the results of the old op with the new output buffers. - rewriter.replaceOp(op, newOutputBuffers); - return success(); - } - }; - - void populateTensorLinalgToBufferLinalgConversionPattern( - MLIRContext *context, BufferizeTypeConverter &converter, - OwningRewritePatternList &patterns) { - populateWithBufferizeOpConversionPatterns( - context, converter, patterns); - patterns.insert(context, converter); - } - void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -230,8 +113,12 @@ }); OwningRewritePatternList patterns; - populateTensorLinalgToBufferLinalgConversionPattern(&context, converter, - patterns); + populateWithBufferizeOpConversionPatterns( + &context, converter, patterns); + linalg::populateLinalgBufferizePatterns(&context, converter, patterns); + populateBufferizeEliminateMaterializationsPatternsAndLegality( + &context, converter, patterns, target); if (failed(applyFullConversion(this->getOperation(), target, patterns))) this->signalPassFailure(); };