diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h @@ -84,6 +84,10 @@ Operation *opToBufferize, DenseSet &usesRead, DenseSet &usesWrite, const DominanceInfo &domInfo) const; + /// Return true if bufferizing `opResult` inplace would create a write to a + /// non-writable buffer. + bool wouldCreateWriteToNonWritableBuffer(OpResult opResult) const; + /// Assume that result bufferizes in-place with one of the operation's /// operands. Return true if it is possible to find an inplace write W (resp. /// a read R) among the uses of `aliasInfo[result]`, and a read R (resp. an diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -972,6 +972,34 @@ usesWrite, domInfo); } +/// Return true if bufferizing `opResult` inplace would create a write to a +/// non-writable buffer. +bool BufferizationAliasInfo::wouldCreateWriteToNonWritableBuffer( + OpResult opResult) const { + Optional maybeAliasingOperand = getAliasingOpOperand(opResult); + if (!maybeAliasingOperand || !*maybeAliasingOperand) + return false; + + // Certain buffers are not writeable: + // 1. A function bbArg that is not inplaceable or + // 2. A constant op. + assert(!aliasesNonWritableBuffer(opResult) && + "expected that opResult does not alias non-writable buffer"); + bool nonWritable = aliasesNonWritableBuffer((*maybeAliasingOperand)->get()); + if (!nonWritable) + return false; + + // This is a problem only if the buffer is written to via some alias. + bool hasWrite = aliasesInPlaceWrite(opResult) || + aliasesInPlaceWrite((*maybeAliasingOperand)->get()) || + bufferizesToMemoryWrite(**maybeAliasingOperand); + if (!hasWrite) + return false; + + LDBG("->the corresponding buffer is not writeable\n"); + return true; +} + /// Return true if the source of a `insertSliceOp` bufferizes to an /// equivalent ExtractSliceOp that bufferizes inplace. bool BufferizationAliasInfo::isSourceEquivalentToAMatchingInplaceExtractSliceOp( @@ -2231,23 +2259,14 @@ LDBG("Inplace analysis for extract_slice: " << printOperationInfo(extractSliceOp) << '\n'); - // If `extractSliceOp` were to be bufferized inplace, it cannot end up - // aliasing a write into a non-writable buffer. - bool wouldCreateAliasingWriteToNonWritableBuffer = - aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) && - aliasInfo.aliasesNonWritableBuffer(extractSliceOp.source()); - - if (wouldCreateAliasingWriteToNonWritableBuffer) - LDBG("->the corresponding buffer is not writable\n"); - else - LDBG("->bufferizes to writable inplace buffer\n"); - - // In any of extractSliceOp.result's aliases, can we find 2 such that we hit - // an interfering write? OpResult r = extractSliceOp->getResult(0); OpOperand &s = extractSliceOp->getOpOperand(0); bool foundInterference = - wouldCreateAliasingWriteToNonWritableBuffer || + /* If `extractSliceOp` were to be bufferized inplace, it cannot end up + aliasing a write into a non-writable buffer.*/ + aliasInfo.wouldCreateWriteToNonWritableBuffer(r) || + /* In any of extractSliceOp.result's aliases, can we find 2 such that we + hit an interfering write? */ aliasInfo.wouldCreateReadAfterWriteInterference(r, domInfo); if (foundInterference) aliasInfo.bufferizeOutOfPlace(r); @@ -2282,20 +2301,8 @@ << operand.getOperandNumber() << " in " << printValueInfo(result) << '\n'); - // `result` must bufferize to a writable buffer to be a candidate. - // This means the operand must not alias either: - // 1. a function bbArg that is not inplaceable or - // 2. a constant op. - // to be considered for inplace bufferization - bool wouldCreateAliasingWriteToNonWritableBuffer = - aliasInfo.aliasesNonWritableBuffer(operand.get()); - if (wouldCreateAliasingWriteToNonWritableBuffer) - LDBG("->the corresponding buffer is not writable\n"); - else - LDBG("->bufferizes to writable inplace buffer\n"); - bool foundInterference = - wouldCreateAliasingWriteToNonWritableBuffer || + aliasInfo.wouldCreateWriteToNonWritableBuffer(result) || aliasInfo.wouldCreateReadAfterWriteInterference(result, domInfo); if (foundInterference) diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -620,10 +620,28 @@ return } +// ----- + //===----------------------------------------------------------------------===// // Transitive cases through extract_slice. //===----------------------------------------------------------------------===// +// CHECK-LABEL: func @write_into_constant_via_alias +func @write_into_constant_via_alias(%v : vector<5xi32>, + %s1 : index, %s2 : index, + %s3 : index) -> tensor { + %A = constant dense<[1, 2, 3, 4]> : tensor<4xi32> + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_results_attr__ = ["false"]} + %b = tensor.extract_slice %A[%s1][%s2][1] : tensor<4xi32> to tensor + // CHECK: vector.transfer_write + // CHECK-SAME: {__inplace_results_attr__ = ["true"]} + %r = vector.transfer_write %v, %b[%s3] : vector<5xi32>, tensor + return %r : tensor +} + +// ----- + builtin.func @matmul_on_tensors( %arg0: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false}, %arg1: tensor<518x518xf32> {linalg.buffer_layout = affine_map<(d0, d1) -> (d0, d1)>, linalg.inplaceable = false},