diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -93,10 +93,10 @@ InterfaceMethod< /*desc=*/[{ Return `true` if the operation bufferizes to IR that performs only - element-wise accesses on all tensor operands. (All operands must have - the same shape.) The `bufferize` method must be implemented in such a - way that it is free of loop-carried dependences. I.e., all loads at a - position appear before all stores at the same position. + element-wise accesses on the specified tensor operands. (The operands + must have the same shape.) The `bufferize` method must be implemented + in such a way that it is free of loop-carried dependences. I.e., all + loads at a position appear before all stores at the same position. Example: Consider a hypothetical op element-wise op, where the "ins" bufferize to a memory read and the "outs" bufferize to a memory write. @@ -130,10 +130,15 @@ equivalent tensors. (It is not possible, if %0 and %1 are merely aliasing. It is not necessary if %0 and %1 are not aliasing at all, because there would be no conflict anyway.) + + Note: Tensor operands that are not included in `opOperands` can be + ignored. A conservative implementation of this interface method may + always return "false". }], /*retType=*/"bool", /*methodName=*/"bufferizesToElementwiseAccess", - /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state), + /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state, + "ArrayRef":$opOperands), /*methodBody=*/"", /*defaultImplementation=*/[{ // It is always safe to assume that the op is not element-wise. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -446,6 +446,21 @@ } } +/// Return 'true' if a tensor that is equivalent to `other` can be found in the +/// reverse use-def chain of `start`. Note: If an OpOperand bufferizes out of +/// place along that use-def chain, the two tensors may not materialize as +/// equivalent buffers (but separate allocations). +static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, + Value start, Value other) { + TraversalConfig config; + config.followEquivalentOnly = true; + config.alwaysIncludeLeaves = false; + return !state + .findValueInReverseUseDefChain( + start, [&](Value v) { return v == other; }, config) + .empty(); +} + /// Given sets of uses and writes, return true if there is a RaW conflict under /// the assumption that all given reads/writes alias the same buffer and that /// all given writes bufferize inplace. @@ -545,15 +560,19 @@ // Two equivalent operands of the same op are not conflicting if the op // bufferizes to element-wise access. I.e., all loads at a position happen // before all stores to the same position. - if (conflictingWritingOp == readingOp && - state.areEquivalentBufferizedValues(uRead->get(), - uConflictingWrite->get())) { + if (conflictingWritingOp == readingOp) { if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) { - if (bufferizableOp.bufferizesToElementwiseAccess(state)) { - LLVM_DEBUG( - llvm::dbgs() - << " no conflict: op bufferizes to element-wise access\n"); - continue; + if (bufferizableOp.bufferizesToElementwiseAccess( + state, {uRead, uConflictingWrite})) { + if (hasEquivalentValueInReverseUseDefChain( + state, uRead->get(), uConflictingWrite->get()) || + hasEquivalentValueInReverseUseDefChain( + state, uConflictingWrite->get(), uRead->get())) { + LLVM_DEBUG( + llvm::dbgs() + << " no conflict: op bufferizes to element-wise access\n"); + continue; + } } } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -106,8 +106,8 @@ return dpsOp.isDpsInit(&opOperand); } - bool bufferizesToElementwiseAccess(Operation *op, - const AnalysisState &state) const { + bool bufferizesToElementwiseAccess(Operation *op, const AnalysisState &state, + ArrayRef opOperands) const { auto linalgOp = cast(op); // All loops must be parallel. @@ -119,10 +119,13 @@ assert(linalgOp->getNumOperands() == indexingMaps.size() && "unexpected number of indexing maps"); for (auto [operand, map] : - llvm::zip(linalgOp->getOperands(), indexingMaps)) { + llvm::zip(linalgOp->getOpOperands(), indexingMaps)) { // Non-tensors do not participate in bufferization, so they can be // ignored. - if (!isa(operand.getType())) + if (!isa(operand.get().getType())) + continue; + // Only consider operands in `opOperands`. + if (llvm::find(opOperands, &operand) == opOperands.end()) continue; // TODO: This could be generalized to other indexing maps. (All indexing // must be the same.) diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir @@ -57,3 +57,53 @@ } -> tensor<5x6xf32> return %0 : tensor<5x6xf32> } + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1)> + +// CHECK-LABEL: @elementwise_no_conflict_4 +func.func @elementwise_no_conflict_4(%arg0: tensor<8x32x32x32xf32>, %arg1: tensor<32x32x32xf32>) -> tensor<8x32x32x32xf32> { + %cst = arith.constant dense<3.000000e-02> : tensor<32x32x32xf32> + %cst_0 = arith.constant dense<6.000000e-01> : tensor<32xf32> + %cst_1 = arith.constant 0.000000e+00 : f32 + %r = scf.forall (%arg2, %arg3) in (8, 32) shared_outs(%arg4 = %arg0) -> (tensor<8x32x32x32xf32>) { + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "none"]} + %extracted_slice = tensor.extract_slice %arg4[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<8x32x32x32xf32> to tensor<32x32xf32> + + // CHECK: linalg.fill + // CHECK-SAME: {__inplace_operands_attr__ = ["none", "true"]} + %4 = linalg.fill ins(%cst_1 : f32) outs(%extracted_slice : tensor<32x32xf32>) -> tensor<32x32xf32> + + // CHECK: linalg.batch_reduce_matmul + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"]} + %5 = linalg.batch_reduce_matmul ins(%arg1, %cst : tensor<32x32x32xf32>, tensor<32x32x32xf32>) outs(%4 : tensor<32x32xf32>) -> tensor<32x32xf32> + + // CHECK: linalg.generic + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"]} + // %cst_0 has a non-identity layout may, but %5 and %extracted_slice still + // bufferize to element-wise access. + %6 = linalg.generic {indexing_maps = [#map, #map1, #map], iterator_types = ["parallel", "parallel"]} ins(%5, %cst_0 : tensor<32x32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %in_4: f32, %out: f32): + %8 = arith.addf %in, %in_4 : f32 + linalg.yield %8 : f32 + } -> tensor<32x32xf32> + + // CHECK: linalg.generic + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]} + // They are different SSA values, but %6 and %extract_slice are equivalent. + %7 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%6 : tensor<32x32xf32>) outs(%extracted_slice : tensor<32x32xf32>) { + ^bb0(%in: f32, %out: f32): + %8 = arith.maxf %in, %cst_1 : f32 + linalg.yield %8 : f32 + } -> tensor<32x32xf32> + scf.forall.in_parallel { + // CHECK: tensor.parallel_insert_slice + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "none", "none"]} + tensor.parallel_insert_slice %7 into %arg4[%arg2, %arg3, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : tensor<32x32xf32> into tensor<8x32x32x32xf32> + } + } + return %r : tensor<8x32x32x32xf32> +}