diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -172,16 +172,13 @@ : public OpConversionPattern { public: explicit BufferAssignmentOpConversionPattern( - MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr, - BufferAssignmentTypeConverter *converter = nullptr, + MLIRContext *context, BufferAssignmentTypeConverter *converter = nullptr, PatternBenefit benefit = 1) - : OpConversionPattern(context, benefit), - bufferAssignment(bufferAssignment), converter(converter) { + : OpConversionPattern(context, benefit), converter(converter) { assert(converter && "The type converter has not been defined"); } protected: - BufferAssignmentPlacer *bufferAssignment; BufferAssignmentTypeConverter *converter; }; @@ -282,8 +279,7 @@ template static void populateWithBufferAssignmentOpConversionPatterns( - MLIRContext *context, BufferAssignmentPlacer *placer, - BufferAssignmentTypeConverter *converter, + MLIRContext *context, BufferAssignmentTypeConverter *converter, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< @@ -291,7 +287,7 @@ BufferAssignmentFuncOpConverter, BufferAssignmentReturnOpConverter - >(context, placer, converter); + >(context, converter); // clang-format on } } // end namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -44,6 +44,7 @@ newResults.reserve(results.size()); // Update all types to memref types. + rewriter.setInsertionPoint(op); for (auto result : results) { auto type = result.getType().cast(); assert(type && "tensor to buffer conversion expects ranked results"); @@ -51,11 +52,6 @@ return rewriter.notifyMatchFailure( op, "dynamic shapes not currently supported"); auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); - - // Compute alloc position and insert a custom allocation node. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.restoreInsertionPoint( - bufferAssignment->computeAllocPosition(result)); auto alloc = rewriter.create(loc, memrefType); newArgs.push_back(alloc); newResults.push_back(alloc); @@ -99,13 +95,12 @@ /// Populate the given list with patterns to convert Linalg operations on /// tensors to buffers. static void populateConvertLinalgOnTensorsToBuffersPattern( - MLIRContext *context, BufferAssignmentPlacer *placer, - BufferAssignmentTypeConverter *converter, + MLIRContext *context, BufferAssignmentTypeConverter *converter, OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, - converter, patterns); - patterns->insert(context, placer, converter); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter, + patterns); + patterns->insert(context, converter); } /// Converts Linalg operations that work on tensor-type operands or results to @@ -119,6 +114,8 @@ // Mark all Standard operations legal. target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { @@ -144,16 +141,11 @@ converter.setResultConversionKind( BufferAssignmentTypeConverter::AppendToArgumentsList); - // Walk over all the functions to apply buffer assignment. - getOperation().walk([&](FuncOp function) -> WalkResult { - OwningRewritePatternList patterns; - BufferAssignmentPlacer placer(function); - populateConvertLinalgOnTensorsToBuffersPattern(&context, &placer, - &converter, &patterns); - - // Applying full conversion - return applyFullConversion(function, target, patterns); - }); + OwningRewritePatternList patterns; + populateConvertLinalgOnTensorsToBuffersPattern(&context, &converter, + &patterns); + if (failed(applyFullConversion(this->getOperation(), target, patterns))) + this->signalPassFailure(); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -877,6 +877,7 @@ SmallVector newResultTypes; SmallVector mappings; mappings.resize(callOp.getNumResults()); + rewriter.setInsertionPoint(callOp); for (auto result : llvm::enumerate(callOp.getResults())) { SmallVector originTypes; converter->tryDecomposeType(result.value().getType(), originTypes); @@ -891,9 +892,6 @@ resultMapping.addMapping(newResultTypes.size() - 1); } else { // kind = BufferAssignmentTypeConverter::AppendToArgumentsList - OpBuilder::InsertionGuard guard(rewriter); - rewriter.restoreInsertionPoint( - bufferAssignment->computeAllocPosition(result.value())); MemRefType memref = converted.dyn_cast(); if (!memref) return callOp.emitError("Cannot allocate for a non-Memref type"); diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -56,6 +56,7 @@ newResults.reserve(results.size()); // Update all types to memref types. + rewriter.setInsertionPoint(op); for (auto result : results) { ShapedType type = result.getType().cast(); assert(type && "Generic operations with non-shaped typed results are " @@ -65,11 +66,6 @@ op, "dynamic shapes not currently supported"); auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); - - // Compute alloc position and insert a custom allocation node. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.restoreInsertionPoint( - bufferAssignment->computeAllocPosition(result)); auto alloc = rewriter.create(loc, memrefType); newArgs.push_back(alloc); newResults.push_back(alloc); @@ -110,13 +106,12 @@ }; void populateTensorLinalgToBufferLinalgConversionPattern( - MLIRContext *context, BufferAssignmentPlacer *placer, - BufferAssignmentTypeConverter *converter, + MLIRContext *context, BufferAssignmentTypeConverter *converter, OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, - converter, patterns); - patterns->insert(context, placer, converter); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter, + patterns); + patterns->insert(context, converter); } void getDependentDialects(DialectRegistry ®istry) const override { @@ -133,6 +128,8 @@ target.addLegalDialect(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { @@ -191,16 +188,11 @@ return success(); }); - // Walk over all the functions to apply buffer assignment. - this->getOperation().walk([&](FuncOp function) -> WalkResult { - OwningRewritePatternList patterns; - BufferAssignmentPlacer placer(function); - populateTensorLinalgToBufferLinalgConversionPattern( - &context, &placer, &converter, &patterns); - - // Applying full conversion - return applyFullConversion(function, target, patterns); - }); + OwningRewritePatternList patterns; + populateTensorLinalgToBufferLinalgConversionPattern(&context, &converter, + &patterns); + if (failed(applyFullConversion(this->getOperation(), target, patterns))) + this->signalPassFailure(); }; }; } // end anonymous namespace