diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -71,6 +71,8 @@ SmallVector &newOps) = 0; }; +using PostAnalysisStepList = std::vector>; + /// Options for ComprehensiveBufferize. struct BufferizationOptions { BufferizationOptions(); @@ -107,7 +109,7 @@ bool testAnalysisOnly = false; /// Registered post analysis steps. - std::vector> postAnalysisSteps; + PostAnalysisStepList postAnalysisSteps; }; /// Specify fine-grain relationship between buffers to enable more analysis. diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -18,13 +18,16 @@ struct BufferizationOptions; struct BufferizationState; +struct PostAnalysisStep; /// Bufferize the given function. Does not bufferize the function boundary. +/// Reuses an existing BufferizationState object. // TODO: This function is meant to be called from ModuleBufferize and not can // not yet be called standalone. -LogicalResult runComprehensiveBufferize(FuncOp funcOp, - const BufferizationOptions &options, - BufferizationState &state); +LogicalResult runComprehensiveBufferize( + FuncOp funcOp, const BufferizationOptions &options, + BufferizationState &state, + const std::vector> &extraSteps); } // namespace comprehensive_bufferize } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -726,7 +726,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( FuncOp funcOp, const BufferizationOptions &options, - BufferizationState &state) { + BufferizationState &state, const PostAnalysisStepList &extraSteps) { DominanceInfo domInfo(funcOp); BufferizationAliasInfo &aliasInfo = state.aliasInfo; @@ -744,16 +744,23 @@ return failure(); equivalenceAnalysis(op, aliasInfo); - for (const std::unique_ptr &step : - options.postAnalysisSteps) { - SmallVector newOps; - if (failed(step->run(funcOp, state, newOps))) - return failure(); - // Analyze ops that were created by the PostAnalysisStep. - if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo))) - return failure(); - equivalenceAnalysis(newOps, aliasInfo); - } + auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) { + for (const std::unique_ptr &step : steps) { + SmallVector newOps; + if (failed(step->run(funcOp, state, newOps))) + return failure(); + // Analyze ops that were created by the PostAnalysisStep. + if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo))) + return failure(); + equivalenceAnalysis(newOps, aliasInfo); + } + return success(); + }; + + if (failed(runPostAnalysisSteps(extraSteps))) + return failure(); + if (failed(runPostAnalysisSteps(options.postAnalysisSteps))) + return failure(); // Annotate operations if we only want to report the analysis. if (options.testAnalysisOnly) { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -33,8 +33,9 @@ /// A map for looking up bufferized function types. DenseMap bufferizedFunctionTypes; - /// A mapping of return values to equivalent BlockArguments. - DenseMap equivalentReturnValToBBArg; + /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg + /// indices. + DenseMap> equivalentFuncArgs; }; } // namespace @@ -44,6 +45,70 @@ StandardOpsDialect::getDialectNamespace()); } +/// Return the unique ReturnOp that terminates `funcOp`. +/// Return nullptr if there is no such unique ReturnOp. +static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { + ReturnOp returnOp; + for (Block &b : funcOp.body()) { + if (auto candidateOp = dyn_cast(b.getTerminator())) { + if (returnOp) + return nullptr; + returnOp = candidateOp; + } + } + return returnOp; +} + +namespace { +/// Store function BlockArguments that are equivalent to a returned value in +/// ModuleBufferizationState. +struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep { + /// Annotate IR with the results of the analysis. For testing purposes only. + static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) { + const char *kEquivalentArgsAttr = "__equivalent_func_args__"; + Operation *op = returnVal.getOwner(); + + SmallVector equivBbArgs; + if (op->hasAttr(kEquivalentArgsAttr)) { + auto attr = op->getAttr(kEquivalentArgsAttr).cast(); + equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) { + return a.cast().getValue().getSExtValue(); + })); + } else { + equivBbArgs.append(op->getNumOperands(), -1); + } + equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber(); + + OpBuilder b(op->getContext()); + op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs)); + } + + LogicalResult run(FuncOp funcOp, BufferizationState &state, + SmallVector &newOps) override { + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + + // Support only single return-terminated block in the function. + ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); + + for (OpOperand &returnVal : returnOp->getOpOperands()) + if (returnVal.get().getType().isa()) + for (BlockArgument bbArg : funcOp.getArguments()) + if (bbArg.getType().isa()) + if (state.aliasInfo.areEquivalentBufferizedValues(returnVal.get(), + bbArg)) { + moduleState + .equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] = + bbArg.getArgNumber(); + if (state.options.testAnalysisOnly) + annotateReturnOp(returnVal, bbArg); + } + + return success(); + } +}; +} // namespace + static bool isaTensor(Type t) { return t.isa(); } /// If `value` is a memref::CastOp, return its source. Otherwise, return @@ -73,20 +138,6 @@ SymbolTable::lookupNearestSymbolFrom(callOp, sym)); } -/// Return the unique ReturnOp that terminates `funcOp`. -/// Return nullptr if there is no such unique ReturnOp. -static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { - ReturnOp returnOp; - for (Block &b : funcOp.body()) { - if (auto candidateOp = dyn_cast(b.getTerminator())) { - if (returnOp) - return nullptr; - returnOp = candidateOp; - } - } - return returnOp; -} - /// Return the FunctionType with `argumentTypes` and `resultTypes` where each /// tensor is replaced by the corresponding buffer type. /// In order for all the callers to agree, this *must* bufferize to the most @@ -128,22 +179,30 @@ return it2.first->second; } -/// Store function BlockArguments that are equivalent to a returned value in -/// the given ModuleBufferizationState. -static void populateEquivalentFuncOpBBArgs(FuncOp funcOp, - BufferizationState &state) { - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); - - // Support only single return-terminated block in the function. - ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); +/// Gather equivalence info of CallOps. +/// Note: This only adds new equivalence info if `funcOp` was already analyzed. +// TODO: This does not handle cyclic function call graphs etc. +static void equivalenceAnalysis(FuncOp funcOp, + BufferizationAliasInfo &aliasInfo, + ModuleBufferizationState &moduleState) { + funcOp->walk([&](CallOp callOp) { + FuncOp calledFunction = getCalledFunction(callOp); + assert(calledFunction && "could not retrieved called FuncOp"); + + // No equivalence info available for the called function. + if (!moduleState.equivalentFuncArgs.count(calledFunction)) + return WalkResult::skip(); + + for (auto it : moduleState.equivalentFuncArgs[calledFunction]) { + int64_t returnIdx = it.first; + int64_t bbargIdx = it.second; + Value returnVal = callOp.getResult(returnIdx); + Value argVal = callOp->getOperand(bbargIdx); + aliasInfo.unionEquivalenceClasses(returnVal, argVal); + } - for (Value returnVal : returnOp.operands()) - if (returnVal.getType().isa()) - for (BlockArgument bbArg : funcOp.getArguments()) - if (bbArg.getType().isa()) - if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg)) - moduleState.equivalentReturnValToBBArg[returnVal] = bbArg; + return WalkResult::advance(); + }); } /// Rewrite the `funcOp` arguments analysis return values and terminator into @@ -217,7 +276,8 @@ } // If return operand is equivalent to some bbArg, no need to return it. - if (moduleState.equivalentReturnValToBBArg.count(returnVal)) + if (moduleState.equivalentFuncArgs[funcOp].count( + returnOperand.getOperandNumber())) continue; // Cast values at the call site if necessary. @@ -499,12 +559,12 @@ } // If return operand is equivalent to some bbArg, no need to return it. - Value returnVal = returnOperand.get(); - if (moduleState.equivalentReturnValToBBArg.count(returnVal)) { - BlockArgument bbArg = - moduleState.equivalentReturnValToBBArg[returnVal]; + if (moduleState.equivalentFuncArgs[funcOp].count( + returnOperand.getOperandNumber())) { + int64_t idx = + moduleState + .equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()]; Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); - int64_t idx = bbArg.getArgNumber(); Value buffer = state.lookupBuffer(callOp->getOperand(idx)); // Add CallOp operand/result equivalence: this is interprocedural // info. @@ -667,6 +727,7 @@ return failure(); BufferizationState state(moduleOp, options); + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); BufferizationAliasInfo &aliasInfo = state.aliasInfo; // Interestingly, all function args that are not visible outside of a module @@ -698,11 +759,17 @@ aliasInfo.setBufferizesToWritableMemory(bbArg); } + // Register extra post analysis steps. These cannot be stored in `options` + // because `options` is immutable. + PostAnalysisStepList extraSteps; + extraSteps.emplace_back(std::make_unique()); + + // Gather equivalence info for CallOps. + equivalenceAnalysis(funcOp, aliasInfo, moduleState); + // Analyze and bufferize funcOp. - if (failed(runComprehensiveBufferize(funcOp, options, state))) + if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps))) return failure(); - - populateEquivalentFuncOpBBArgs(funcOp, state); } if (options.testAnalysisOnly) diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-analysis.mlir @@ -40,15 +40,17 @@ -> (tensor, tensor) { // must bufferize out of place. - // CHECK: tensor.insert_slice + // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_results_attr__ = ["false"]} %r0 = tensor.insert_slice %C into %A[0][4][1] : tensor<4xf32> into tensor // bufferizes inplace. - // CHECK: tensor.insert_slice + // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"]} %r1 = tensor.insert_slice %C into %B[0][4][1] : tensor<4xf32> into tensor + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [-1, 1]} return %r0, %r1: tensor, tensor } @@ -81,6 +83,8 @@ outs(%B: tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [-1, -1, 1]} return %C, %D, %E: tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32> } @@ -136,6 +140,8 @@ // CHECK: {__inplace_results_attr__ = ["false"]} %r3 = tensor.insert_slice %r2 into %B[0][4][1] : tensor<4xf32> into tensor + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0, -1]} return %r1, %r3: tensor, tensor } @@ -172,6 +178,8 @@ // CHECK-SAME: {__inplace_results_attr__ = ["false"]} %r3 = tensor.insert_slice %r2 into %B[%idx][4][1] : tensor<4xf32> into tensor + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0, -1]} return %r1, %r3: tensor, tensor } @@ -208,6 +216,8 @@ // CHECK-SAME: {__inplace_results_attr__ = ["false"]} %r3 = tensor.insert_slice %r2 into %B[0][4][1] : tensor<4xf32> into tensor + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0, -1]} return %r1, %r3: tensor, tensor } @@ -234,6 +244,9 @@ %2 = tensor.insert_slice %1 into %A[%idx][%idx][1] : tensor into tensor %3 = vector.transfer_read %1[%idx2], %cst2 : tensor, vector<5xf32> + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0, -1]} return %2, %3 : tensor, vector<5xf32> } @@ -274,6 +287,8 @@ // CHECK-SAME: {__inplace_results_attr__ = ["true"]} %6 = tensor.insert_slice %5 into %2[%idx3][%idx3][1] : tensor into tensor + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0, -1]} return %6, %3 : tensor, vector<5xf32> } @@ -306,6 +321,8 @@ outs(%C: tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [-1, 2]} return %D, %E: tensor<4x4xf32>, tensor<4x4xf32> } @@ -372,6 +389,8 @@ // CHECK-SAME: {__inplace_results_attr__ = ["true"]} %20 = tensor.insert_slice %19 into %C[%s3, %s4] [%s1, %s2] [1, 1] : tensor into tensor<30x20xf32> + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [6]} return %20 : tensor<30x20xf32> } @@ -502,6 +521,8 @@ %rsC = tensor.insert_slice %FC into %sC[0, 0][12345, 67890][1, 1] : tensor<4x4xf32> into tensor %rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor into tensor + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [-1, 1, 2]} return %rA, %rB, %rC: tensor, tensor, tensor } @@ -531,6 +552,8 @@ scf.yield %t : tensor } + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [-1, 1]} return %r0, %r1: tensor, tensor } @@ -562,6 +585,8 @@ scf.yield %ttA, %ttB : tensor, tensor } + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [-1, 1]} return %r0#0, %r0#1: tensor, tensor } @@ -621,6 +646,8 @@ linalg.yield %t : tensor } + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0, 1]} return %r1, %r3: tensor, tensor } @@ -766,6 +793,8 @@ ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [2]} return %r : tensor<256x256xf32> } @@ -811,6 +840,8 @@ ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>) outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32> + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [2]} return %r : tensor<256x256xf32> } @@ -856,6 +887,8 @@ // CHECK-SAME: {__inplace_results_attr__ = ["true"] %15 = tensor.insert_slice %14 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32> + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [4]} return %15 : tensor<62x90xf32> } @@ -881,6 +914,9 @@ %t3 = tensor.insert_slice %t2 into %arg1[%x, 0] [5, %y] [1, 1] : tensor<5x?xf32> into tensor<10x20xf32> scf.yield %t3 : tensor<10x20xf32> } + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} return %r : tensor<10x20xf32> } @@ -908,6 +944,9 @@ ^bb(%0: f32, %1: f32, %2 : f32) : linalg.yield %0, %0 : f32, f32 } -> (tensor, tensor) + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [1, -1]} return %o#0, %o#1 : tensor, tensor } @@ -949,6 +988,8 @@ // CHECK-SAME: {__inplace_results_attr__ = ["true"] %15 = tensor.insert_slice %14 into %e[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [2, -1]} return %8, %15 : tensor<62x90xf32>, tensor } @@ -978,6 +1019,8 @@ // CHECK-SAME: {__inplace_results_attr__ = ["true"] %15 = tensor.insert_slice %10 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32> + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} return %15 : tensor<62x90xf32> } @@ -1007,6 +1050,8 @@ // CHECK-SAME: {__inplace_results_attr__ = ["true"] %15 = tensor.insert_slice %10 into %8[31, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32> + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} return %15 : tensor<62x90xf32> } @@ -1029,6 +1074,8 @@ // CHECK-SAME: {__inplace_results_attr__ = ["true"] %15 = tensor.insert_slice %2 into %8[15, 0] [32, 90] [1, 1] : tensor<32x90xf32> into tensor<62x90xf32> + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} return %15 : tensor<62x90xf32> } @@ -1130,6 +1177,8 @@ linalg.yield %cst : f32 } -> (tensor) + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0, -1]} return %o, %v3 : tensor, vector<5xf32> } @@ -1158,6 +1207,9 @@ // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"] %3 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor into tensor + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [-1, 0]} return %2, %3 : tensor, tensor } @@ -1178,6 +1230,9 @@ // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"] %2 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor into tensor + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0, 0]} return %2, %2 : tensor, tensor } @@ -1212,6 +1267,8 @@ %t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor scf.yield %t2 : tensor } + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} return %r : tensor } @@ -1261,6 +1318,9 @@ scf.yield %r : tensor } %v2 = vector.transfer_read %r_alias[%idx], %cst : tensor, vector<10xf32> + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0, -1]} return %r_alias, %v2 : tensor, vector<10xf32> } @@ -1286,6 +1346,9 @@ // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"] %r2 = tensor.insert_slice %r into %t1[%idx][%idx][1] : tensor into tensor + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} return %r2 : tensor } @@ -1316,6 +1379,9 @@ %t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor scf.yield %t3 : tensor } + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} return %r : tensor } @@ -1394,6 +1460,9 @@ // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"] %r2 = tensor.insert_slice %r into %t1[%idx3][%idx3][1] : tensor into tensor + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} return %r2 : tensor } @@ -1418,6 +1487,9 @@ // CHECK: tensor.insert_slice // CHECK-SAME: {__inplace_results_attr__ = ["true"] %r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor into tensor + + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} return %r2 : tensor } @@ -1531,3 +1603,44 @@ return %r1, %r2 : vector<5xf32>, vector<5xf32> } + +// ----- + +// CHECK-LABEL: func @inner_func +func @inner_func(%t: tensor) -> tensor { + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} + return %t : tensor +} + +func @equivalent_func_arg(%c0: index, %c10: index, %c1: index, %t0: tensor) -> tensor { + // This test does not check IR. It just asserts there is no failure due to + // non-equivalent scf.for yield values. + %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor) { + %3 = call @inner_func(%t1) : (tensor) -> tensor + scf.yield %3 : tensor + } + return %1: tensor +} + +// ----- + +// CHECK-LABEL: func @inner_func_2 +func @inner_func_2(%t: tensor) -> tensor { + %f = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + %0 = tensor.insert %f into %t[%c0] : tensor + // CHECK: return + // CHECK-SAME: {__equivalent_func_args__ = [0]} + return %0 : tensor +} + +func @equivalent_func_arg_2(%c0: index, %c10: index, %c1: index, %t0: tensor) -> tensor { + // This test does not check IR. It just asserts there is no failure due to + // non-equivalent scf.for yield values. + %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor) { + %3 = call @inner_func_2(%t1) : (tensor) -> tensor + scf.yield %3 : tensor + } + return %1: tensor +} diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -928,3 +928,54 @@ // CHECK: return return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @inner_func( +// CHECK-SAME: %[[arg0:.*]]: memref) -> tensor { + %f = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + // CHECK: memref.store %{{.*}}, %[[arg0]] + %0 = tensor.insert %f into %t[%c0] : tensor + return %0 : tensor +} + +// CHECK-LABEL: func @equivalent_func_arg( +// CHECK-SAME: %[[arg0:.*]]: memref {linalg.inplaceable = true}, + %c0: index, %c10: index, %c1: index) -> tensor { + // CHECK-NOT: copy + %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor) { + // CHECK: call @inner_func(%[[arg0]]) + %3 = call @inner_func(%t1) : (tensor) -> tensor + scf.yield %3 : tensor + } + return %1: tensor +} + +// ----- + +// CHECK-LABEL: func @inner_func_2( +// CHECK-SAME: %[[arg0:.*]]: memref) -> tensor { + %f = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + // CHECK: memref.store %{{.*}}, %[[arg0]] + %0 = tensor.insert %f into %t[%c0] : tensor + return %0 : tensor +} + +// CHECK-LABEL: func @equivalent_func_arg_2( +// CHECK-SAME: %[[arg0:.*]]: memref {linalg.inplaceable = true}, + %c0: index, %c10: index, %c1: index) -> tensor { + %1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor) { + // TODO: There should be a memory copy here. This is a bug in CallOp + // bufferization. + // CHECK: call @inner_func_2(%[[arg0]]) + %3 = call @inner_func_2(%t1) : (tensor) -> tensor + scf.yield %t1 : tensor + } + return %1: tensor +}