diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -402,6 +402,10 @@ /// bufferizable. AliasingOpResultList getAliasingOpResults(OpOperand &opOperand) const; + /// Return `true` if `opResult` bufferizes to a memory allocation. Return + /// `true` if the op is not bufferizable. + bool bufferizesToAllocation(OpResult opResult) const; + /// Return true if `opOperand` bufferizes to a memory read. Return `true` if /// the op is not bufferizable. bool bufferizesToMemoryRead(OpOperand &opOperand) const; @@ -492,6 +496,11 @@ /// be definitions. SetVector findDefinitions(Value value) const; + /// Find the OpResults in the reverse use-def chain that are a buffer + /// allocation. Either because they bufferize to an allocation or because + /// at least one of their aliasing OpOperands is out-of-place. + SetVector findAllocations(Value value) const; + /// Return `true` if the given OpResult has been decided to bufferize inplace. virtual bool isInPlace(OpOperand &opOperand) const; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -398,6 +398,15 @@ return detail::unknownGetAliasingOpResults(opOperand); } +bool AnalysisState::bufferizesToAllocation(OpResult opResult) const { + if (Operation *op = opResult.getDefiningOp()) + if (auto bufferizableOp = getOptions().dynCastBufferizableOp(op)) + return bufferizableOp.bufferizesToAllocation(opResult); + + // The op is not bufferizable. Conservatively return true. + return true; +} + /// Return true if `opOperand` bufferizes to a memory read. Return `true` if the /// op is not bufferizable. bool AnalysisState::bufferizesToMemoryRead(OpOperand &opOperand) const { @@ -538,6 +547,34 @@ value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, config); } +llvm::SetVector AnalysisState::findAllocations(Value value) const { + TraversalConfig config; + config.alwaysIncludeLeaves = true; + config.followInPlaceOnly = true; + auto bufferizesToAlloc = [&](Value v) { + auto opResult = v.dyn_cast(); + if (!opResult) + return true; + return bufferizesToAllocation(opResult); + }; + SetVector potentialAllocs = + findValueInReverseUseDefChain(value, bufferizesToAlloc, config); + SetVector result; + for (Value v : potentialAllocs) { + auto opResult = v.dyn_cast(); + if (!opResult) + continue; + if (bufferizesToAllocation(opResult)) + result.insert(opResult); + if (llvm::any_of(getAliasingOpOperands(opResult), + [&](AliasingOpOperand alias) { + return !isInPlace(*alias.opOperand); + })) + result.insert(opResult); + } + return result; +} + AnalysisState::AnalysisState(const BufferizationOptions &options) : AnalysisState(options, TypeID::get()) {} diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -1014,35 +1014,22 @@ !state.getOptions().isOpAllowed(returnOp)) return WalkResult::advance(); + // The block from which the op is yielding. + Block *block = returnOp->getBlock(); + for (OpOperand &returnValOperand : returnOp->getOpOperands()) { Value returnVal = returnValOperand.get(); + // Skip non-tensor values. if (!returnVal.getType().isa()) continue; - bool foundEquivValue = false; - state.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { - if (auto bbArg = equivVal.dyn_cast()) { - Operation *definingOp = bbArg.getOwner()->getParentOp(); - if (definingOp->isProperAncestor(returnOp)) - foundEquivValue = true; - return; - } - - Operation *definingOp = equivVal.getDefiningOp(); - if (definingOp->getBlock()->findAncestorOpInBlock( - *returnOp->getParentOp())) - // Skip ops that happen after `returnOp` and parent ops. - if (happensBefore(definingOp, returnOp, domInfo)) - foundEquivValue = true; - }); - - // Note: Returning/yielding buffer allocations is allowed only if - // `allowReturnAllocs` is set. - if (!foundEquivValue) - status = returnOp->emitError() - << "operand #" << returnValOperand.getOperandNumber() - << " may return/yield a new buffer allocation"; + for (OpResult alloc : state.findAllocations(returnVal)) { + if (block->findAncestorOpInBlock(*alloc.getDefiningOp())) + status = returnOp->emitError() + << "operand #" << returnValOperand.getOperandNumber() + << " may return/yield a new buffer allocation"; + } } return WalkResult::advance(); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir @@ -56,3 +56,56 @@ %2 = tensor.extract %0[%idx] : tensor<10xf32> return %2 : f32 } + +// ----- + +// CHECK-LABEL: func @scf_if_definitely_aliasing( +func.func @scf_if_definitely_aliasing( + %cond: i1, %t1: tensor {bufferization.writable = true}, + %idx: index) -> tensor { + %r = scf.if %cond -> (tensor) { + scf.yield %t1 : tensor + } else { + // This buffer definitely aliases, so it can be yielded from the block. + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "none"]} + %t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor to tensor + scf.yield %t2 : tensor + } + return %r : tensor +} + +// ----- + +// CHECK-LABEL: func @scf_yield_no_allocs( +func.func @scf_yield_no_allocs( + %c1: i1, %c2: i1, %A: tensor<4xf32>, %B: tensor<4xf32>, %C: tensor<4xf32>) + -> tensor<4xf32> +{ + %r = scf.if %c1 -> (tensor<4xf32>) { + // CHECK: arith.select + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true", "true"]} + %0 = arith.select %c2, %A, %B : tensor<4xf32> + scf.yield %0 : tensor<4xf32> + } else { + scf.yield %C : tensor<4xf32> + } + return %r: tensor<4xf32> +} + +// ----- + +// Just make sure that bufferization succeeds. The arith.constant ops do not +// allocate. + +// CHECK-LABEL: func @block_yield_constant( +func.func @block_yield_constant(%c: i1) -> (tensor<3x4xf32>) { + %r = scf.if %c -> (tensor<3x4xf32>) { + %0 = arith.constant dense<7.0> : tensor<3x4xf32> + scf.yield %0 : tensor<3x4xf32> + } else { + %0 = arith.constant dense<8.0> : tensor<3x4xf32> + scf.yield %0 : tensor<3x4xf32> + } + return %r : tensor<3x4xf32> +} diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -25,23 +25,6 @@ // ----- -func.func @scf_if_not_equivalent( - %cond: i1, %t1: tensor {bufferization.writable = true}, - %idx: index) -> tensor { - %r = scf.if %cond -> (tensor) { - scf.yield %t1 : tensor - } else { - // This buffer aliases, but it is not equivalent. - %t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor to tensor - // expected-error @+1 {{operand #0 may return/yield a new buffer allocation}} - scf.yield %t2 : tensor - } - // expected-error @+1 {{operand #0 may return/yield a new buffer allocation}} - return %r : tensor -} - -// ----- - func.func @scf_if_not_aliasing( %cond: i1, %t1: tensor {bufferization.writable = true}, %idx: index) -> f32 { @@ -168,36 +151,6 @@ // ----- -func.func @extract_slice_fun(%A : tensor {bufferization.writable = true}) - -> tensor<4xf32> -{ - // This bufferizes to a pattern that the cross-function boundary pass needs to - // convert into a new memref argument at all call site; this may be either: - // - an externally created aliasing subview (if we want to allow aliasing - // function arguments). - // - a new alloc + copy (more expensive but does not create new function - // argument aliasing). - %r0 = tensor.extract_slice %A[0][4][1] : tensor to tensor<4xf32> - - // expected-error @+1 {{operand #0 may return/yield a new buffer allocation}} - return %r0: tensor<4xf32> -} - -// ----- - -func.func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32> -{ - %r = scf.if %b -> (tensor<4xf32>) { - scf.yield %A : tensor<4xf32> - } else { - scf.yield %B : tensor<4xf32> - } - // expected-error @+1 {{operand #0 may return/yield a new buffer allocation}} - return %r: tensor<4xf32> -} - -// ----- - func.func @unknown_op(%A : tensor<4xf32>) -> tensor<4xf32> { // expected-error: @+1 {{op was not bufferized}}