diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -684,6 +684,14 @@ Region *getNextEnclosingRepetitiveRegion(Region *region, const BufferizationOptions &options); +/// If `region` is a parallel region, return `region`. Otherwise, find the first +/// enclosing parallel region of `region`. If there is no such region, return +/// "nullptr". +/// +/// Note: Whether a region is parallel or sequential is queried from the +/// `BufferizableOpInterface`. +Region *getParallelRegion(Region *region, const BufferizationOptions &options); + namespace detail { /// This is the default implementation of /// BufferizableOpInterface::getAliasingOpOperands. Should not be called from diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -556,6 +556,25 @@ ::llvm::cast($_op.getOperation()), index); }] >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the given region of this op is parallel, i.e., + multiple instances of the region may be executing at the same time. + If a region is parallel, it must also be marked as "repetitive". + + The RaW conflict detection of One-Shot Analysis is more strict inside + parallel regions: Buffer may have to be privatized. + + By default, regions are assumed to be sequential. + }], + /*retType=*/"bool", + /*methodName=*/"isParallelRegion", + /*args=*/(ins "unsigned":$index), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }] + >, StaticInterfaceMethod< /*desc=*/[{ Return `true` if the op and this interface implementation supports diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -119,6 +119,21 @@ return region; } +Region *bufferization::getParallelRegion(Region *region, + const BufferizationOptions &options) { + while (region) { + auto bufferizableOp = options.dynCastBufferizableOp(region->getParentOp()); + if (bufferizableOp && + bufferizableOp.isParallelRegion(region->getRegionNumber())) { + assert(isRepetitiveRegion(region, options) && + "expected that all parallel regions are also repetitive regions"); + return region; + } + region = region->getParentRegion(); + } + return nullptr; +} + Operation *bufferization::getOwnerOfValue(Value value) { if (auto opResult = llvm::dyn_cast(value)) return opResult.getDefiningOp(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -545,6 +545,43 @@ OneShotAnalysisState &state) { const BufferizationOptions &options = state.getOptions(); + // Before going through the main RaW analysis, find cases where a buffer must + // be privatized due to parallelism. If the result of a write is never read, + // privatization is not necessary (and large parts of the IR are likely dead). + if (!usesRead.empty()) { + for (OpOperand *uConflictingWrite : usesWrite) { + // Find the allocation point or last write (definition) of the buffer. + // Note: In contrast to `findDefinitions`, this also returns results of + // ops that do not bufferize to memory write when no other definition + // could be found. E.g., "bufferization.alloc_tensor" would be included, + // even though that op just bufferizes to an allocation but does define + // the contents of the buffer. + SetVector definitionsOrLeaves = + state.findValueInReverseUseDefChain( + uConflictingWrite->get(), + [&](Value v) { return state.bufferizesToMemoryWrite(v); }); + assert(!definitionsOrLeaves.empty() && + "expected at least one definition or leaf"); + + // The writing op must bufferize out-of-place if the definition is in a + // different parallel region than this write. + for (Value def : definitionsOrLeaves) { + if (getParallelRegion(def.getParentRegion(), options) != + getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(), + options)) { + LLVM_DEBUG( + llvm::dbgs() + << "\n- bufferizes out-of-place due to parallel region:\n"); + LLVM_DEBUG(llvm::dbgs() + << " unConflictingWrite = operand " + << uConflictingWrite->getOperandNumber() << " of " + << *uConflictingWrite->getOwner() << "\n"); + return true; + } + } + } + } + for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n"); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1202,6 +1202,10 @@ } return false; } + + bool isParallelRegion(Operation *op, unsigned index) const { + return isRepetitiveRegion(op, index); + } }; /// Nothing to do for InParallelOp. diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir @@ -798,3 +798,106 @@ } return } + +// ----- + +// CHECK-LABEL: func @parallel_region() +func.func @parallel_region() -> tensor<320xf32> +{ + %alloc0 = bufferization.alloc_tensor() : tensor<320xf32> + %alloc1 = bufferization.alloc_tensor() : tensor<1xf32> + %c320 = arith.constant 320 : index + // CHECK: scf.forall + %0 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %alloc0) -> (tensor<320xf32>) { + %val = "test.foo"() : () -> (f32) + // linalg.fill must bufferize out-of-place because every thread needs a + // private copy of %alloc1. + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "false"]} + %fill = linalg.fill ins(%val : f32) outs(%alloc1 : tensor<1xf32>) -> tensor<1xf32> + scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} + tensor.parallel_insert_slice %fill into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> + } + } + // CHECK: } {__inplace_operands_attr__ = ["none", "true"]} + return %0 : tensor<320xf32> +} + +// ----- + +// CHECK-LABEL: func @parallel_region_mixed_def( +func.func @parallel_region_mixed_def(%c: i1) -> tensor<320xf32> +{ + %alloc0 = bufferization.alloc_tensor() : tensor<320xf32> + %alloc1 = bufferization.alloc_tensor() : tensor<1xf32> + %c320 = arith.constant 320 : index + // CHECK: scf.forall + %0 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %alloc0) -> (tensor<320xf32>) { + %alloc2 = bufferization.alloc_tensor() : tensor<1xf32> + %selected = scf.if %c -> tensor<1xf32> { + scf.yield %alloc1 : tensor<1xf32> + } else { + scf.yield %alloc2 : tensor<1xf32> + } + %val = "test.foo"() : () -> (f32) + // linalg.fill must bufferize out-of-place because every thread needs a + // private copy of %alloc1. + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "false"]} + %fill = linalg.fill ins(%val : f32) outs(%selected : tensor<1xf32>) -> tensor<1xf32> + scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} + tensor.parallel_insert_slice %fill into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> + } + } + // CHECK: } {__inplace_operands_attr__ = ["none", "true"]} + return %0 : tensor<320xf32> +} + +// ----- + +// CHECK-LABEL: func @parallel_region_two_writes( +func.func @parallel_region_two_writes(%f: f32) -> tensor<320xf32> +{ + %alloc0 = bufferization.alloc_tensor() : tensor<320xf32> + %alloc1 = bufferization.alloc_tensor() : tensor<1xf32> + %c320 = arith.constant 320 : index + %c0 = arith.constant 0 : index + // CHECK: scf.forall + %0 = scf.forall (%arg0) in (%c320) shared_outs(%arg1 = %alloc0) -> (tensor<320xf32>) { + %val = "test.foo"() : () -> (f32) + // linalg.fill must bufferize out-of-place because every thread needs a + // private copy of %alloc1. + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "false"]} + %fill = linalg.fill ins(%val : f32) outs(%alloc1 : tensor<1xf32>) -> tensor<1xf32> + // CHECK: tensor.insert + // CHECK-SAME: __inplace_operands_attr__ = ["none", "true", "none"] + %inserted = tensor.insert %f into %fill[%c0] : tensor<1xf32> + + scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice {{.*}} {__inplace_operands_attr__ = ["true", "true", "none"]} + tensor.parallel_insert_slice %inserted into %arg1[%arg0] [1] [1] : tensor<1xf32> into tensor<320xf32> + } + } + // CHECK: } {__inplace_operands_attr__ = ["none", "true"]} + return %0 : tensor<320xf32> +} + +// ----- + +// CHECK-LABEL: func @parallel_region_no_read() +func.func @parallel_region_no_read() +{ + %alloc0 = bufferization.alloc_tensor() : tensor<320xf32> + %alloc1 = bufferization.alloc_tensor() : tensor<1xf32> + %c320 = arith.constant 320 : index + // CHECK: scf.forall + scf.forall (%arg0) in (%c320) { + %val = "test.foo"() : () -> (f32) + // linalg.fill can bufferize in-place because no alias of %alloc1 is read. + // CHECK: linalg.fill {__inplace_operands_attr__ = ["none", "true"]} + %fill = linalg.fill ins(%val : f32) outs(%alloc1 : tensor<1xf32>) -> tensor<1xf32> + scf.forall.in_parallel { + } + } + return +}