diff --git a/mlir/include/mlir/Transforms/BufferPlacement.h b/mlir/include/mlir/Transforms/BufferPlacement.h --- a/mlir/include/mlir/Transforms/BufferPlacement.h +++ b/mlir/include/mlir/Transforms/BufferPlacement.h @@ -24,34 +24,6 @@ namespace mlir { -/// Prepares a buffer placement phase. It can place (user-defined) alloc -/// nodes. This simplifies the integration of the actual buffer-placement -/// pass. Sample usage: -/// BufferAssignmentPlacer baHelper(regionOp); -/// -> determine alloc positions -/// auto allocPosition = baHelper.computeAllocPosition(value); -/// -> place alloc -/// allocBuilder.setInsertionPoint(positions.getAllocPosition()); -/// -/// Note: this class is intended to be used during legalization. In order -/// to move alloc and dealloc nodes into the right places you can use the -/// createBufferPlacementPass() function. -class BufferAssignmentPlacer { -public: - /// Creates a new assignment builder. - explicit BufferAssignmentPlacer(Operation *op); - - /// Returns the operation this analysis was constructed from. - Operation *getOperation() const { return operation; } - - /// Computes the actual position to place allocs for the given result. - OpBuilder::InsertPoint computeAllocPosition(OpResult result); - -private: - /// The operation this analysis was constructed from. - Operation *operation; -}; - /// A helper type converter class for using inside Buffer Assignment operation /// conversion patterns. The default constructor keeps all the types intact /// except for the ranked-tensor types which is converted to memref types. @@ -157,31 +129,20 @@ SmallVector decomposeTypeConversions; }; -/// Helper conversion pattern that encapsulates a BufferAssignmentPlacer -/// instance. Sample usage: -/// class CustomConversionPattern : public -/// BufferAssignmentOpConversionPattern -/// { -/// ... matchAndRewrite(...) { -/// -> Access stored BufferAssignmentPlacer -/// bufferAssignment->computeAllocPosition(resultOp); -/// } -/// }; +/// Helper conversion pattern that encapsulates a BufferAssignmentTypeConverter +/// instance. template class BufferAssignmentOpConversionPattern : public OpConversionPattern { public: explicit BufferAssignmentOpConversionPattern( - MLIRContext *context, BufferAssignmentPlacer *bufferAssignment = nullptr, - BufferAssignmentTypeConverter *converter = nullptr, + MLIRContext *context, BufferAssignmentTypeConverter *converter, PatternBenefit benefit = 1) - : OpConversionPattern(context, benefit), - bufferAssignment(bufferAssignment), converter(converter) { + : OpConversionPattern(context, benefit), converter(converter) { assert(converter && "The type converter has not been defined"); } protected: - BufferAssignmentPlacer *bufferAssignment; BufferAssignmentTypeConverter *converter; }; @@ -282,8 +243,7 @@ template static void populateWithBufferAssignmentOpConversionPatterns( - MLIRContext *context, BufferAssignmentPlacer *placer, - BufferAssignmentTypeConverter *converter, + MLIRContext *context, BufferAssignmentTypeConverter *converter, OwningRewritePatternList *patterns) { // clang-format off patterns->insert< @@ -291,7 +251,7 @@ BufferAssignmentFuncOpConverter, BufferAssignmentReturnOpConverter - >(context, placer, converter); + >(context, converter); // clang-format on } } // end namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp @@ -51,11 +51,6 @@ return rewriter.notifyMatchFailure( op, "dynamic shapes not currently supported"); auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); - - // Compute alloc position and insert a custom allocation node. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.restoreInsertionPoint( - bufferAssignment->computeAllocPosition(result)); auto alloc = rewriter.create(loc, memrefType); newArgs.push_back(alloc); newResults.push_back(alloc); @@ -99,13 +94,12 @@ /// Populate the given list with patterns to convert Linalg operations on /// tensors to buffers. static void populateConvertLinalgOnTensorsToBuffersPattern( - MLIRContext *context, BufferAssignmentPlacer *placer, - BufferAssignmentTypeConverter *converter, + MLIRContext *context, BufferAssignmentTypeConverter *converter, OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, - converter, patterns); - patterns->insert(context, placer, converter); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter, + patterns); + patterns->insert(context, converter); } /// Converts Linalg operations that work on tensor-type operands or results to @@ -119,6 +113,8 @@ // Mark all Standard operations legal. target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { @@ -144,16 +140,11 @@ converter.setResultConversionKind( BufferAssignmentTypeConverter::AppendToArgumentsList); - // Walk over all the functions to apply buffer assignment. - getOperation().walk([&](FuncOp function) -> WalkResult { - OwningRewritePatternList patterns; - BufferAssignmentPlacer placer(function); - populateConvertLinalgOnTensorsToBuffersPattern(&context, &placer, - &converter, &patterns); - - // Applying full conversion - return applyFullConversion(function, target, patterns); - }); + OwningRewritePatternList patterns; + populateConvertLinalgOnTensorsToBuffersPattern(&context, &converter, + &patterns); + if (failed(applyFullConversion(this->getOperation(), target, patterns))) + this->signalPassFailure(); } }; } // end anonymous namespace diff --git a/mlir/lib/Transforms/BufferPlacement.cpp b/mlir/lib/Transforms/BufferPlacement.cpp --- a/mlir/lib/Transforms/BufferPlacement.cpp +++ b/mlir/lib/Transforms/BufferPlacement.cpp @@ -681,20 +681,6 @@ } // end anonymous namespace -//===----------------------------------------------------------------------===// -// BufferAssignmentPlacer -//===----------------------------------------------------------------------===// - -/// Creates a new assignment placer. -BufferAssignmentPlacer::BufferAssignmentPlacer(Operation *op) : operation(op) {} - -/// Computes the actual position to place allocs for the given value. -OpBuilder::InsertPoint -BufferAssignmentPlacer::computeAllocPosition(OpResult result) { - Operation *owner = result.getOwner(); - return OpBuilder::InsertPoint(owner->getBlock(), Block::iterator(owner)); -} - //===----------------------------------------------------------------------===// // BufferAssignmentTypeConverter //===----------------------------------------------------------------------===// @@ -891,9 +877,6 @@ resultMapping.addMapping(newResultTypes.size() - 1); } else { // kind = BufferAssignmentTypeConverter::AppendToArgumentsList - OpBuilder::InsertionGuard guard(rewriter); - rewriter.restoreInsertionPoint( - bufferAssignment->computeAllocPosition(result.value())); MemRefType memref = converted.dyn_cast(); if (!memref) return callOp.emitError("Cannot allocate for a non-Memref type"); diff --git a/mlir/test/lib/Transforms/TestBufferPlacement.cpp b/mlir/test/lib/Transforms/TestBufferPlacement.cpp --- a/mlir/test/lib/Transforms/TestBufferPlacement.cpp +++ b/mlir/test/lib/Transforms/TestBufferPlacement.cpp @@ -65,11 +65,6 @@ op, "dynamic shapes not currently supported"); auto memrefType = MemRefType::get(type.getShape(), type.getElementType()); - - // Compute alloc position and insert a custom allocation node. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.restoreInsertionPoint( - bufferAssignment->computeAllocPosition(result)); auto alloc = rewriter.create(loc, memrefType); newArgs.push_back(alloc); newResults.push_back(alloc); @@ -110,13 +105,12 @@ }; void populateTensorLinalgToBufferLinalgConversionPattern( - MLIRContext *context, BufferAssignmentPlacer *placer, - BufferAssignmentTypeConverter *converter, + MLIRContext *context, BufferAssignmentTypeConverter *converter, OwningRewritePatternList *patterns) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, placer, - converter, patterns); - patterns->insert(context, placer, converter); + mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>(context, converter, + patterns); + patterns->insert(context, converter); } void getDependentDialects(DialectRegistry ®istry) const override { @@ -133,6 +127,8 @@ target.addLegalDialect(); target.addLegalOp(); target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); // Mark all Linalg operations illegal as long as they work on tensors. auto isLegalOperation = [&](Operation *op) { @@ -191,16 +187,11 @@ return success(); }); - // Walk over all the functions to apply buffer assignment. - this->getOperation().walk([&](FuncOp function) -> WalkResult { - OwningRewritePatternList patterns; - BufferAssignmentPlacer placer(function); - populateTensorLinalgToBufferLinalgConversionPattern( - &context, &placer, &converter, &patterns); - - // Applying full conversion - return applyFullConversion(function, target, patterns); - }); + OwningRewritePatternList patterns; + populateTensorLinalgToBufferLinalgConversionPattern(&context, &converter, + &patterns); + if (failed(applyFullConversion(this->getOperation(), target, patterns))) + this->signalPassFailure(); }; }; } // end anonymous namespace