diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h @@ -19,9 +19,11 @@ namespace comprehensive_bufferize { namespace scf_ext { -/// Equivalence analysis for scf.for. Raise an error if iter_args are not -/// equivalent to their corresponding loop yield values. -struct AssertDestinationPassingStyle : public PostAnalysisStep { +/// Assert that yielded values of an scf.for op are aliasing their corresponding +/// bbArgs. This is required because the i-th OpResult of an scf.for op is +/// currently assumed to alias with the i-th iter_arg (in the absence of +/// conflicts). +struct AssertScfForAliasingProperties : public PostAnalysisStep { LogicalResult run(Operation *op, BufferizationState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) override; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt @@ -96,6 +96,7 @@ LINK_LIBS PUBLIC MLIRBufferizableOpInterface + MLIRControlFlowInterfaces MLIRInferTypeOpInterface MLIRIR MLIRMemRef diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -54,6 +54,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/Operation.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" @@ -559,6 +560,76 @@ }); } +/// Assert that IR is in destination-passing style. I.e., every value that is +/// returned or yielded from a block is: +/// * aliasing a bbArg of that block or a parent block, or +/// * aliasing an OpResult of a op in a parent block. +/// +/// Example: +/// ``` +/// %0 = "some_op" : tensor +/// %1 = scf.if %c -> (tensor) { +/// scf.yield %0 : tensor +/// } else { +/// %t = linalg.init_tensor : tensor +/// scf.yield %t : tensor +/// } +/// ``` +/// In the above example, the first scf.yield op satifies destination-passing +/// style because the yielded value %0 is defined in the parent block. The +/// second scf.yield op does not satisfy destination-passing style because the +/// yielded value %t is defined in the same block as the scf.yield op. +// TODO: The current implementation checks for equivalent values instead of +// aliasing values, which is stricter than needed. We can currently not check +// for aliasing values because the analysis is a maybe-alias analysis and we +// need a must-alias analysis here. +struct AssertDestinationPassingStyle : public PostAnalysisStep { + LogicalResult run(Operation *op, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, + SmallVector &newOps) override { + LogicalResult status = success(); + DominanceInfo domInfo(op); + op->walk([&](Operation *returnOp) { + if (!isRegionReturnLike(returnOp)) + return WalkResult::advance(); + + for (OpOperand &returnValOperand : returnOp->getOpOperands()) { + Value returnVal = returnValOperand.get(); + // Skip non-tensor values. + if (!returnVal.getType().isa()) + continue; + + bool foundEquivValue = false; + aliasInfo.applyOnEquivalenceClass(returnVal, [&](Value equivVal) { + if (auto bbArg = equivVal.dyn_cast()) { + Operation *definingOp = bbArg.getOwner()->getParentOp(); + if (definingOp->isProperAncestor(returnOp)) + foundEquivValue = true; + return; + } + + Operation *definingOp = equivVal.getDefiningOp(); + if (definingOp->getBlock()->findAncestorOpInBlock( + *returnOp->getParentOp())) + // Skip ops that happen after `returnOp` and parent ops. + if (happensBefore(definingOp, returnOp, domInfo)) + foundEquivValue = true; + }); + + if (!foundEquivValue) + status = + returnOp->emitError() + << "operand #" << returnValOperand.getOperandNumber() + << " of ReturnLike op does not satisfy destination passing style"; + } + + return WalkResult::advance(); + }); + + return status; + } +}; + /// Rewrite pattern that bufferizes bufferizable ops. struct BufferizationPattern : public OpInterfaceRewritePattern { @@ -643,6 +714,13 @@ equivalenceAnalysis(newOps, aliasInfo, state); } + if (!options.allowReturnMemref) { + SmallVector newOps; + if (failed( + AssertDestinationPassingStyle().run(op, state, aliasInfo, newOps))) + return failure(); + } + // Annotate operations if we only want to report the analysis. if (options.testAnalysisOnly) annotateOpsWithBufferizationMarkers(op, aliasInfo, state); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp @@ -391,70 +391,37 @@ } }; -// TODO: Evolve toward matching ReturnLike ops. Check for aliasing values that -// do not bufferize inplace. (Requires a few more changes for ConstantOp, -// InitTensorOp, CallOp.) -LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext:: - AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state, - BufferizationAliasInfo &aliasInfo, - SmallVector &newOps) { +LogicalResult +mlir::linalg::comprehensive_bufferize::scf_ext::AssertScfForAliasingProperties:: + run(Operation *op, BufferizationState &state, + BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { LogicalResult status = success(); - op->walk([&](scf::YieldOp yieldOp) { - if (auto forOp = dyn_cast(yieldOp->getParentOp())) { - for (OpOperand &operand : yieldOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; - - OpOperand &forOperand = forOp.getOpOperandForResult( - forOp->getResult(operand.getOperandNumber())); - auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); - if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) { - // TODO: this could get resolved with copies but it can also turn into - // swaps so we need to be careful about order of copies. - status = - yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to an equivalent buffer to the matching" - << " enclosing scf::for operand"; - return WalkResult::interrupt(); - } - } - } - - if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { - // IfOps are in destination passing style if all yielded tensors are - // a value or equivalent to a value that is defined outside of the IfOp. - for (OpOperand &operand : yieldOp->getOpOperands()) { - auto tensorType = operand.get().getType().dyn_cast(); - if (!tensorType) - continue; - - bool foundOutsideEquivalent = false; - aliasInfo.applyOnEquivalenceClass(operand.get(), [&](Value value) { - Operation *valueOp = value.getDefiningOp(); - if (value.isa()) - valueOp = value.cast().getOwner()->getParentOp(); - - bool inThenBlock = ifOp.thenBlock()->findAncestorOpInBlock(*valueOp); - bool inElseBlock = ifOp.elseBlock()->findAncestorOpInBlock(*valueOp); - - if (!inThenBlock && !inElseBlock) - foundOutsideEquivalent = true; - }); - if (!foundOutsideEquivalent) { - status = yieldOp->emitError() - << "Yield operand #" << operand.getOperandNumber() - << " does not bufferize to a buffer that is equivalent to a" - << " buffer defined outside of the scf::if op"; - return WalkResult::interrupt(); - } + op->walk([&](scf::ForOp forOp) { + auto yieldOp = + cast(forOp.getLoopBody().front().getTerminator()); + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = operand.get().getType().dyn_cast(); + if (!tensorType) + continue; + + OpOperand &forOperand = forOp.getOpOperandForResult( + forOp->getResult(operand.getOperandNumber())); + auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand); + if (!aliasInfo.areAliasingBufferizedValues(operand.get(), bbArg)) { + // TODO: this could get resolved with copies but it can also turn into + // swaps so we need to be careful about order of copies. + status = + yieldOp->emitError() + << "Yield operand #" << operand.getOperandNumber() + << " does not bufferize to a buffer that is aliasing the matching" + << " enclosing scf::for operand"; + return WalkResult::interrupt(); } } - return WalkResult::advance(); }); + return status; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -95,8 +95,8 @@ linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); } - if (!allowReturnMemref) - options->addPostAnalysisStep(); + // Only certain scf.for ops are supported by the analysis. + options->addPostAnalysisStep(); ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp); diff --git a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-function-bufferize.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s // Run fuzzer with different seeds. -// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null -// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null -// RUN: mlir-opt %s -test-comprehensive-function-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null +// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null +// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null +// RUN: mlir-opt %s -test-comprehensive-function-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null // CHECK-LABEL: func @use_tensor_func_arg( // CHECK-SAME: %[[A:.*]]: tensor diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -41,16 +41,34 @@ %r = scf.if %cond -> (tensor) { scf.yield %t1 : tensor } else { - // This buffer aliases, but is not equivalent. + // This buffer aliases, but it is not equivalent. %t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor to tensor - // expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is equivalent to a buffer defined outside of the scf::if op}} + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} scf.yield %t2 : tensor } + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} return %r : tensor } // ----- +func @scf_if_not_aliasing( + %cond: i1, %t1: tensor {linalg.inplaceable = true}, + %idx: index) -> f32 { + %r = scf.if %cond -> (tensor) { + scf.yield %t1 : tensor + } else { + // This buffer aliases. + %t2 = linalg.init_tensor [%idx] : tensor + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} + scf.yield %t2 : tensor + } + %f = tensor.extract %r[%idx] : tensor + return %f : f32 +} + +// ----- + // expected-error @-3 {{expected callgraph to be free of circular dependencies}} func @foo() { @@ -80,7 +98,7 @@ // Throw a wrench in the system by swapping yielded values: this result in a // ping-pong of values at each iteration on which we currently want to fail. - // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}} + // expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is aliasing}} scf.yield %ttB, %ttA : tensor, tensor } @@ -101,7 +119,7 @@ %c1 = arith.constant 1 : index %res = scf.for %arg0 = %c0 to %iters step %c1 iter_args(%bbarg = %A) -> (tensor) { %r = call @foo(%A) : (tensor) -> (tensor) - // expected-error @+1 {{Yield operand #0 does not bufferize to an equivalent buffer}} + // expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is aliasing}} scf.yield %r : tensor } call @fun_with_side_effects(%res) : (tensor) -> () @@ -110,7 +128,6 @@ // ----- -// expected-error @+1 {{memref return type is unsupported}} func @extract_slice_fun(%A : tensor {linalg.inplaceable = true}) -> tensor<4xf32> { @@ -122,12 +139,12 @@ // argument aliasing). %r0 = tensor.extract_slice %A[0][4][1] : tensor to tensor<4xf32> + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} return %r0: tensor<4xf32> } // ----- -// expected-error @+1 {{memref return type is unsupported}} func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32> { %r = scf.if %b -> (tensor<4xf32>) { @@ -135,6 +152,7 @@ } else { scf.yield %B : tensor<4xf32> } + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} return %r: tensor<4xf32> } @@ -142,29 +160,31 @@ func @unknown_op(%A : tensor<4xf32>) -> tensor<4xf32> { - // expected-error @+1 {{op was not bufferized}} + // expected-error: @+1 {{op was not bufferized}} %r = "marklar"(%A) : (tensor<4xf32>) -> (tensor<4xf32>) + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} return %r: tensor<4xf32> } // ----- -// expected-error @+1 {{memref return type is unsupported}} func @mini_test_case1() -> tensor<10x20xf32> { %f0 = arith.constant 0.0 : f32 %t = linalg.init_tensor [10, 20] : tensor<10x20xf32> %r = linalg.fill(%f0, %t) : f32, tensor<10x20xf32> -> tensor<10x20xf32> + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} return %r : tensor<10x20xf32> } // ----- -// expected-error @+1 {{memref return type is unsupported}} func @main() -> tensor<4xi32> { %r = scf.execute_region -> tensor<4xi32> { %A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> scf.yield %A: tensor<4xi32> } + + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} return %r: tensor<4xi32> } @@ -203,12 +223,42 @@ func @foo(%t : tensor<5xf32>) -> (tensor<5xf32>) { %0 = linalg.init_tensor [5] : tensor<5xf32> + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} return %0 : tensor<5xf32> } +// Note: This function is not analyzed because there was an error in the +// previous one. func @call_to_func_returning_non_equiv_tensor(%t : tensor<5xf32>) { - // expected-error @+2 {{call to FuncOp that returns non-equivalent tensors not supported}} - // expected-error @+1 {{op was not bufferized}} call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>) return } + +// ----- + +func @destination_passing_style_dominance_test_1(%cst : f32, %idx : index, + %idx2 : index) -> f32 { + %0 = scf.execute_region -> tensor { + %1 = linalg.init_tensor [%idx] : tensor + // expected-error @+1 {{operand #0 of ReturnLike op does not satisfy destination passing style}} + scf.yield %1 : tensor + } + %2 = tensor.insert %cst into %0[%idx] : tensor + %r = tensor.extract %2[%idx2] : tensor + return %r : f32 +} + +// ----- + +func @destination_passing_style_dominance_test_2(%cst : f32, %idx : index, + %idx2 : index) -> f32 { + %1 = linalg.init_tensor [%idx] : tensor + + %0 = scf.execute_region -> tensor { + // This YieldOp is in destination-passing style, thus no error. + scf.yield %1 : tensor + } + %2 = tensor.insert %cst into %0[%idx] : tensor + %r = tensor.extract %2[%idx2] : tensor + return %r : f32 +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-partial.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref allow-unknown-ops" -split-input-file | FileCheck %s // Run fuzzer with different seeds. -// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null -// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null -// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null +// RUN: mlir-opt %s -allow-unregistered-dialect -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null // RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=tensor allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-TENSOR // RUN: mlir-opt %s -allow-unregistered-dialect -test-comprehensive-function-bufferize="dialect-filter=scf allow-unknown-ops allow-return-memref" -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-SCF diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1,9 +1,9 @@ // RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=allow-return-memref -split-input-file | FileCheck %s // Run fuzzer with different seeds. -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null -// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="allow-return-memref test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null // CHECK-LABEL: func @transfer_read(%{{.*}}: memref) -> vector<4xf32> { func @transfer_read( diff --git a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp --- a/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp @@ -104,7 +104,7 @@ auto options = std::make_unique(); if (!allowReturnMemref) - options->addPostAnalysisStep(); + options->addPostAnalysisStep(); options->allowReturnMemref = allowReturnMemref; options->allowUnknownOps = allowUnknownOps; diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -6898,6 +6898,7 @@ deps = [ ":BufferizableOpInterface", ":BufferizationDialect", + ":ControlFlowInterfaces", ":DialectUtils", ":IR", ":InferTypeOpInterface",