diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -412,14 +412,16 @@ /// in the operands) because their defining ops do not define the contents of /// the tensor. /// + /// Example: + /// %a = tensor.empty() : tensor<10xf32> + /// %b = arith.constant ... : tensor<10xf32> + /// %r = arith.select %cond, %a, %b : tensor<10xf32> + /// findDefinitions(%r) = {%b}. %a is excluded because it does not define the + /// contents of the tensor. + /// /// Note: OpResults of unknown ops are handled conservatively and assumed to /// be definitions. - /// - /// Note: When reaching an end of the reverse SSA use-def chain, that value - /// is included regardless of whether it is a definition or not unless - /// `alwaysIncludeLeaves` is unset. - SetVector findDefinitions(Value value, - bool alwaysIncludeLeaves = true) const; + SetVector findDefinitions(Value value) const; /// Return `true` if the given OpResult has been decided to bufferize inplace. virtual bool isInPlace(OpOperand &opOperand) const; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -494,11 +494,10 @@ } // Find the values that define the contents of the given value. -llvm::SetVector -AnalysisState::findDefinitions(Value value, bool alwaysIncludeLeaves) const { +llvm::SetVector AnalysisState::findDefinitions(Value value) const { return findValueInReverseUseDefChain( value, [&](Value v) { return this->bufferizesToMemoryWrite(v); }, - /*followEquivalentOnly=*/false, alwaysIncludeLeaves); + /*followEquivalentOnly=*/false, /*alwaysIncludeLeaves=*/false); } AnalysisState::AnalysisState(const BufferizationOptions &options) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -272,7 +272,7 @@ // If there is no preceding definition, the tensor contents are // undefined. - if (findDefinitions(opResult, /*alwaysIncludeLeaves=*/false).empty()) + if (findDefinitions(opResult).empty()) for (OpOperand &use : opResult.getUses()) undefinedTensorUses.insert(&use); } @@ -513,8 +513,11 @@ for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); + LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n"); + LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber() + << " of " << *readingOp << "\n"); - // Find most recent writes of uRead by following the SSA use-def chain. + // Find the definition of uRead by following the SSA use-def chain. // E.g.: // // %0 = "writing_op"(%t) : tensor -> tensor @@ -525,14 +528,16 @@ // definition is %0. Note that operations that create an alias but do not // bufferize to a memory write (such as ExtractSliceOp) are skipped. SetVector definitions = state.findDefinitions(uRead->get()); + if (definitions.empty()) { + // Fast path: No conflict if there are no definitions. + LLVM_DEBUG(llvm::dbgs() + << " no conflict: read value has no definitions\n"); + continue; + } // Look for conflicting memory writes. Potential conflicts are writes to an // alias that have been decided to bufferize inplace. for (OpOperand *uConflictingWrite : usesWrite) { - LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n"); - LLVM_DEBUG(llvm::dbgs() - << " uRead = operand " << uRead->getOperandNumber() << " of " - << *uRead->getOwner() << "\n"); LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand " << uConflictingWrite->getOperandNumber() << " of " << *uConflictingWrite->getOwner() << "\n"); @@ -608,15 +613,15 @@ LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n"); // No conflict if the conflicting write happens before the definition. - if (Operation *writingOp = definition.getDefiningOp()) { - if (happensBefore(conflictingWritingOp, writingOp, domInfo)) { - // conflictingWritingOp happens before writingOp. No conflict. + if (Operation *defOp = definition.getDefiningOp()) { + if (happensBefore(conflictingWritingOp, defOp, domInfo)) { + // conflictingWritingOp happens before defOp. No conflict. LLVM_DEBUG(llvm::dbgs() << " no conflict: write happens before definition\n"); continue; } - // No conflict if conflictingWritingOp is contained in writingOp. - if (writingOp->isProperAncestor(conflictingWritingOp)) { + // No conflict if conflictingWritingOp is contained in defOp. + if (defOp->isProperAncestor(conflictingWritingOp)) { LLVM_DEBUG( llvm::dbgs() << " no conflict: write is contained in definition\n"); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-analysis.mlir @@ -32,3 +32,15 @@ %3 = tensor.extract %1[%pos] : tensor<10xf32> return %3 : f32 } + +// ----- + +// CHECK-LABEL: func @read_of_undef_is_not_a_conflict( +func.func @read_of_undef_is_not_a_conflict(%f: f32, %idx: index) -> f32 { + %0 = tensor.empty() : tensor<10xf32> + // This can be in-place because the read below does reads undefined data. + // CHECK: tensor.insert {{.*}} {__inplace_operands_attr__ = ["none", "true", "none"]} + %1 = tensor.insert %f into %0[%idx] : tensor<10xf32> + %2 = tensor.extract %0[%idx] : tensor<10xf32> + return %2 : f32 +} diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -713,26 +713,28 @@ // ----- // CHECK-LABEL: func @scf_if_memory_space -func.func @scf_if_memory_space(%c: i1, %f: f32) -> (f32, f32) +func.func @scf_if_memory_space(%c: i1, %f: f32, %cst: f32) -> (f32, f32) { %c0 = arith.constant 0 : index // CHECK: %[[alloc:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1> - %0 = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<5xf32> + %alloc = bufferization.alloc_tensor() {memory_space = 1 : i64} : tensor<5xf32> + // CHECK: linalg.fill {{.*}} outs(%[[alloc]] : memref<5xf32, 1>) + %filled = linalg.fill ins(%cst : f32) outs(%alloc : tensor<5xf32>) -> tensor<5xf32> // CHECK: scf.if %{{.*}} -> (memref<5xf32, 1>) { %1 = scf.if %c -> tensor<5xf32> { // CHECK: %[[cloned:.*]] = bufferization.clone %[[alloc]] // CHECK: scf.yield %[[cloned]] - scf.yield %0 : tensor<5xf32> + scf.yield %filled : tensor<5xf32> } else { // CHECK: %[[alloc2:.*]] = memref.alloc() {{.*}} : memref<5xf32, 1> // CHECK: memref.store %{{.*}}, %[[alloc2]] // CHECK: %[[cloned2:.*]] = bufferization.clone %[[alloc2]] // CHECK: memref.dealloc %[[alloc2]] // CHECK: scf.yield %[[cloned2]] - %2 = tensor.insert %f into %0[%c0] : tensor<5xf32> + %2 = tensor.insert %f into %filled[%c0] : tensor<5xf32> scf.yield %2 : tensor<5xf32> } - %r0 = tensor.extract %0[%c0] : tensor<5xf32> + %r0 = tensor.extract %filled[%c0] : tensor<5xf32> %r1 = tensor.extract %1[%c0] : tensor<5xf32> return %r0, %r1 : f32, f32 }