diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -91,6 +91,56 @@ llvm_unreachable("bufferizesToMemoryWrite not implemented"); }] >, + InterfaceMethod< + /*desc=*/[{ + Return `true` if the operation bufferizes to IR that performs only + element-wise accesses on all tensor operands. (All operands must have + the same shape.) The `bufferize` method must be implemented in such a + way that it is free of loop-carried dependences. I.e., all loads at a + position appear before all stores at the same position. + + Example: Consider a hypothetical op element-wise op, where the "ins" + bufferize to a memory read and the "outs" bufferize to a memory write. + ``` + test.element_wise ins(%0), outs(%1) : tensor<3xf32> + ``` + + The following is a valid access pattern: + ``` + load(%0[1]) + store(%1[1]) + load(%0[2]) + store(%1[2]) + load(%0[0]) + store(%1[0]) + ``` + + The following would be an invalid (not element-wise) access pattern: + ``` + load(%0[1]) + store(%0[1]) + load(%0[1]) + ... + ``` + + Element-wise ops can sometimes bufferize more efficiently: a RaW + conflict between two operands of the same op can be avoided if it is + guaranteed that an original element value is no longer needed after + writing a computed element value at the same location. E.g., such an + optimization is possible in the above example if %0 and %1 are + equivalent tensors. (It is not possible, if %0 and %1 are merely + aliasing. It is not necessary if %0 and %1 are not aliasing at all, + because there would be no conflict anyway.) + }], + /*retType=*/"bool", + /*methodName=*/"bufferizesToElementwiseAccess", + /*args=*/(ins "const ::mlir::bufferization::AnalysisState &":$state), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // It is always safe to assume that the op is not element-wise. + return false; + }] + >, InterfaceMethod< /*desc=*/[{ Return `true` if the given OpResult bufferizes to a memory write. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -542,6 +542,22 @@ } } + // Two equivalent operands of the same op are not conflicting if the op + // bufferizes to element-wise access. I.e., all loads at a position happen + // before all stores to the same position. + if (conflictingWritingOp == readingOp && + state.areEquivalentBufferizedValues(uRead->get(), + uConflictingWrite->get())) { + if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) { + if (bufferizableOp.bufferizesToElementwiseAccess(state)) { + LLVM_DEBUG( + llvm::dbgs() + << " no conflict: op bufferizes to element-wise access\n"); + continue; + } + } + } + // No conflict if the op interface says so. if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) { if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -95,8 +95,8 @@ bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { // Operand is read if it is used in the computation. - auto genericOp = cast(op); - return genericOp.payloadUsesValueFromOperand(&opOperand); + auto linalgOp = cast(op); + return linalgOp.payloadUsesValueFromOperand(&opOperand); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, @@ -106,6 +106,33 @@ return dpsOp.isDpsInit(&opOperand); } + bool bufferizesToElementwiseAccess(Operation *op, + const AnalysisState &state) const { + auto linalgOp = cast(op); + + // All loops must be parallel. + if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops()) + return false; + + // All index maps of tensors must be identity maps. + SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); + assert(linalgOp->getNumOperands() == indexingMaps.size() && + "unexpected number of indexing maps"); + for (auto [operand, map] : + llvm::zip(linalgOp->getOperands(), indexingMaps)) { + // Non-tensors do not participate in bufferization, so they can be + // ignored. + if (!isa(operand.getType())) + continue; + // TODO: This could be generalized to other indexing maps. (All indexing + // must be the same.) + if (!map.isIdentity()) + return false; + } + + return true; + } + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { return bufferizeDestinationStyleOpInterface( diff --git a/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/one-shot-bufferize-analysis.mlir @@ -0,0 +1,59 @@ +// RUN: mlir-opt %s -one-shot-bufferize="allow-return-allocs bufferize-function-boundaries test-analysis-only" -split-input-file | FileCheck %s + +// CHECK-LABEL: @elementwise_no_conflict +func.func @elementwise_no_conflict(%a: tensor<5xf32>, + %b: tensor<5xf32>) -> tensor<5xf32> { + // CHECK: linalg.elemwise_binary + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn} + %0 = linalg.elemwise_binary {fun = #linalg.binary_fn} + ins(%a, %b : tensor<5xf32>, tensor<5xf32>) + outs(%a : tensor<5xf32>) -> tensor<5xf32> + return %0 : tensor<5xf32> +} + +// ----- + +// CHECK-LABEL: @elementwise_no_conflict_2 +func.func @elementwise_no_conflict_2(%a: tensor<5xf32>) -> tensor<5xf32> { + // CHECK: linalg.elemwise_binary + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true", "true"], fun = #linalg.binary_fn} + %0 = linalg.elemwise_binary {fun = #linalg.binary_fn} + ins(%a, %a : tensor<5xf32>, tensor<5xf32>) + outs(%a : tensor<5xf32>) -> tensor<5xf32> + return %0 : tensor<5xf32> +} + +// ----- + +// CHECK-LABEL: @elementwise_no_conflict_3 +func.func @elementwise_no_conflict_3(%a: tensor<5xf32>) -> tensor<5xf32> { + %c0f = arith.constant 1.0 : f32 + // CHECK: linalg.elemwise_binary + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "none", "true"], fun = #linalg.binary_fn} + %0 = linalg.elemwise_binary {fun = #linalg.binary_fn} + ins(%a, %c0f : tensor<5xf32>, f32) + outs(%a : tensor<5xf32>) -> tensor<5xf32> + return %0 : tensor<5xf32> +} + +// ----- + +func.func @not_elementwise(%a: tensor<5x6xf32>) -> tensor<5x6xf32> { + %cst = arith.constant 5.0 : f32 + // CHECK: tensor.extract_slice + // CHECK-SAME: {__inplace_operands_attr__ = ["false"]} + %b = tensor.extract_slice %a[0, 0] [1, 6] [1, 1] + : tensor<5x6xf32> to tensor<6xf32> + // CHECK: linalg.generic + // CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]} + %0 = linalg.generic + { iterator_types = ["parallel", "parallel"], + indexing_maps = [ affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>] } + ins(%b: tensor<6xf32>) outs(%a: tensor<5x6xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + %r = arith.addf %arg0, %arg1 : f32 + linalg.yield %r : f32 + } -> tensor<5x6xf32> + return %0 : tensor<5x6xf32> +}