diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -376,21 +376,6 @@ return nullptr; } -/// For each given value, find the closest enclosing repetitive region. If this -/// is the same region for each value, return it. Otherwise return None. -/// Note: If there is no enclosing repetitive region, return nullptr. -static Optional -getCommonEnclosingRepetitiveRegion(ArrayRef values, - const BufferizationOptions &options) { - if (values.empty()) - return None; - Region *r = getEnclosingRepetitiveRegion(values.front(), options); - for (Value value : values.drop_front()) - if (getEnclosingRepetitiveRegion(value, options) != r) - return None; - return r; -} - /// Return `true` if the given tensor value is a memory write. Most values are /// tensor writes, but ops that define a tensor SSA value without specifying its /// contents (e.g., alloc_tensor) are not. @@ -404,6 +389,118 @@ return bufferizableOp.isMemoryWrite(opResult, state); } +/// Return `true` if op dominance can be used to rule out read-after-write +/// conflicts wrt. the given reads and writes. +/// +/// Op dominance can often be used to rule out potential conflicts such as +/// "read" happens before "write". E.g., the following IR is not a RaW conflict +/// because the the read happens *before* the write. +/// +/// %0 = ... : tensor +/// "reading_op"(%0) : tensor +/// %1 = "writing_op"(%0) : tensor -> tensor +/// +/// This is no longer true inside loops (or repetitive regions). In such cases, +/// there may not be a meaningful `happensBefore` relationship because ops +/// could be executed multiple times. E.g.: +/// +/// %0 = ... : tensor +/// scf.for ... { +/// "reading_op"(%0) : tensor +/// %1 = "writing_op"(%0) : tensor -> tensor +/// ... +/// } +/// +/// In the above example, reading_op happens before writing_op according to +/// op dominance. However, both ops may happen multiple times; in +/// particular, the second execution of reading_op happens after the first +/// execution of writing_op. This is problematic because the tensor %0 they +/// operate on (i.e., the "definition") is defined outside of the loop. +/// +/// Counter example: +/// +/// scf.for ... { +/// %0 = ... : tensor +/// "reading_op"(%0) : tensor +/// %1 = "writing_op"(%0) : tensor -> tensor +/// ... +/// } +/// +/// In this example, the definition %0 is in the same repetitive region as +/// "writing_op", so op dominance can be used to compute the `happensBefore` +/// relationship. +/// +/// This functions finds the closest enclosing repetitive region of all buffer +/// writes wrt. the given given tensor reads and writes. If this is the same +/// region (nullptr in case of "no repetitive region" found at all), op +/// dominance can be used. Otherwise, it cannot be used. +/// +/// Example: The common enclosing repetitive region is the scf.for loop. +/// Op dominance can be used. +/// scf.for ... { +/// %0 = tensor.generate +/// "read"(%0) +/// } +/// +/// Example: The common enclosing repetitive region is nullptr: There is no +/// repetitive region around the tensor.generate. Op dominance can be +/// used. +/// %0 = tensor.generate +/// scf.for ... { "read"(%0) } +/// +/// Example: The common enclosing repetitive regions of tensor.generate and +/// "write" differ. Op dominance cannot be used. +/// %0 = tensor.generate +/// scf.for ... { +/// "read"(%0) +/// "write"(%0) +/// } +/// +/// Example: The common enclosing repetitive regions of tensor.generate and +/// "write" differ, but there is no read of %0, so op dominance can be +/// used. +/// %0 = tensor.generate +/// scf.for ... { +/// "write"(%0) +/// } +/// +/// Note: iter_args of loops are not aliases of their respective block +/// arguments, so op domanice can be used when analyzing ops that operate +/// on them. +bool canUseOpDominance(const DenseSet &usesRead, + const DenseSet &usesWrite, + const AnalysisState &state) { + const BufferizationOptions &options = state.getOptions(); + Optional commonEnclosingRegion = None; + + // In case of a write, take the region in which the write takes place. + for (OpOperand *uWrite : usesWrite) { + Region *r = getEnclosingRepetitiveRegion(uWrite->getOwner(), options); + if (!commonEnclosingRegion.has_value()) { + commonEnclosingRegion = r; + continue; + } + if (*commonEnclosingRegion != r) + return false; + } + + // In case of a read, take the region which the read value is defined. + for (OpOperand *uRead : usesRead) { + // Optimization: Skip reads of values that have no defined contents. + if (!isMemoryWrite(uRead->get(), state)) + continue; + Region *r = getEnclosingRepetitiveRegion(uRead->get(), options); + if (!commonEnclosingRegion.has_value()) { + commonEnclosingRegion = r; + continue; + } + if (*commonEnclosingRegion != r) + return false; + } + + return commonEnclosingRegion.has_value(); +} + /// Annotate IR with details about the detected RaW conflict. static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, Value lastWrite) { @@ -450,15 +547,8 @@ AnalysisState &state, const BufferizationAliasInfo &aliasInfo) { const BufferizationOptions &options = state.getOptions(); - // Gather all written aliases. Skip over aliases that are not actual writes. - SmallVector writtenAliases; - for (OpOperand *uWrite : usesWrite) - if (isMemoryWrite(uWrite->get(), state)) - writtenAliases.push_back(uWrite->get()); - // Find the inner-most enclosing repetitive region of each alias. If this is - // the same region for every alias, save it in `repetitiveRegionOfWrites`. - Optional repetitiveRegionOfWrites = - getCommonEnclosingRepetitiveRegion(writtenAliases, options); + // Check if op dominance can be used to rule out read-after-write conflicts. + bool useDominance = canUseOpDominance(usesRead, usesWrite, state); for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); @@ -482,55 +572,12 @@ // met for uConflictingWrite to be an actual conflict. Operation *conflictingWritingOp = uConflictingWrite->getOwner(); - // Check if conflictingWritingOp is in the same repetitive region as all - // written aliases. If this is not the case, there is no meaningful - // `happensBefore` relationship because conflictingWritingOp may be - // executed multiple times. E.g.: - // - // %0 = ... : tensor - // scf.for ... { - // "reading_op"(%0) : tensor - // %1 = "writing_op"(%0) : tensor -> tensor - // ... - // } - // - // In the above example, reading_op happens before writing_op according to - // op dominance. However, both ops may happen multiple times; in - // particular, the second execution of reading_op happens after the first - // execution of writing_op. This is problematic if the tensor they operate - // on (%0) is defined outside of the loop. - // - // Counter example: - // - // scf.for ... { - // %0 = ... : tensor - // "reading_op"(%0) : tensor - // %1 = "writing_op"(%0) : tensor -> tensor - // ... - // } - // - // In this example, %0 is in the same repetitive region as - // conflictingWritingOp, so op dominance can be used to compute the - // `happensBefore` relationship. - // - // Note: iter_args of loops are not aliases of their respective block - // arguments, so op domanice can be used when analyzing ops that operate - // on them. - // - // Note: If `writtenAliases` is empty, there are no memory writes outside - // of the repetitive region of conflictingWritingOp, which means that all - // relevant aliases are inside the same repetitive region. - bool canUseOpDominance = - writtenAliases.empty() || - repetitiveRegionOfWrites == - getEnclosingRepetitiveRegion(conflictingWritingOp, options); - // No conflict if the readingOp dominates conflictingWritingOp, i.e., the // write is not visible when reading. // // Note: If ops are executed multiple times (e.g., because they are inside // a loop), there may be no meaningful `happensBefore` relationship. - if (canUseOpDominance && + if (useDominance && happensBefore(readingOp, conflictingWritingOp, domInfo)) continue; @@ -540,7 +587,7 @@ // Note: Just being the same op is not enough. It has to be the same use. // Note: If the op is executed multiple times (e.g., because it is inside // a loop), it may be conflicting with itself. - if (canUseOpDominance && uConflictingWrite == uRead) + if (useDominance && uConflictingWrite == uRead) continue; // No conflict if the op interface says so. @@ -559,7 +606,7 @@ // Note: If ops are executed multiple times (e.g., because they are inside // a loop), mutually exclusive regions may be executed multiple // times. - if (canUseOpDominance && + if (useDominance && insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) continue; diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir @@ -630,3 +630,69 @@ } {thread_dim_mapping = []} return %4 : tensor<320xf32> } + +// ----- + +// CHECK-LABEL: different_repetitive_region_via_alias +func.func @different_repetitive_region_via_alias(%arg0: tensor<4xf32>, + %arg1: tensor<4xf32>, + %arg2: index, + %arg3: index, + %arg4: index) + -> (tensor<4xf32>) +{ + %cst = arith.constant 0.000000e+00 : f32 + %cst2 = arith.constant 1.000000e+00 : f32 + %0 = bufferization.alloc_tensor() : tensor<4xf32> + + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "false"]} + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32> + + %2 = scf.for %arg5 = %arg2 to %arg3 step %arg4 iter_args(%arg6 = %arg1) -> (tensor<4xf32>) { + // CHECK: tensor.extract {{.*}} {__inplace_operands_attr__ = ["true", "none"]} + %4 = tensor.extract %1[%arg4] : tensor<4xf32> + vector.print %4 : f32 + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]} + %5 = linalg.fill ins(%cst2 : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32> + scf.yield %5 : tensor<4xf32> + } + + return %2 : tensor<4xf32> +} + +// ----- + +// CHECK-LABEL: no_raw_conflict_after_repetitive_use +func.func @no_raw_conflict_after_repetitive_use(%arg0: tensor<4xf32>, + %arg1: tensor<4xf32>, + %arg2: index, + %arg3: index, + %arg4: index) + -> (tensor<4xf32>, tensor<4xf32>) +{ + %cst = arith.constant 0.000000e+00 : f32 + %cst2 = arith.constant 1.000000e+00 : f32 + %cst3 = arith.constant 2.000000e+00 : f32 + %0 = bufferization.alloc_tensor() : tensor<4xf32> + + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]} + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4xf32>) -> tensor<4xf32> + + %2 = scf.for %arg5 = %arg2 to %arg3 step %arg4 iter_args(%arg6 = %arg1) -> (tensor<4xf32>) { + // CHECK: tensor.extract {{.*}} {__inplace_operands_attr__ = ["true", "none"]} + %4 = tensor.extract %1[%arg4] : tensor<4xf32> + vector.print %4 : f32 + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "false"]} + %5 = linalg.fill ins(%cst2 : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> + scf.yield %5 : tensor<4xf32> + } + + // The following is *not* a RaW conflict. + // CHECK: tensor.extract {{.*}} {__inplace_operands_attr__ = ["true", "none"]} + %6 = tensor.extract %1[%arg4] : tensor<4xf32> + vector.print %6 : f32 + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]} + %7 = linalg.fill ins(%cst3 : f32) outs(%1 : tensor<4xf32>) -> tensor<4xf32> + + return %2, %7 : tensor<4xf32>, tensor<4xf32> +}