diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -237,6 +237,12 @@ /// Return true if the given op has a tensor result or a tensor operand. static bool hasTensorSemantics(Operation *op) { + if (auto funcOp = dyn_cast(op)) { + bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); + bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); + return hasTensorArg || hasTensorResult; + } + bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); return hasTensorResult || hasTensorOperand; @@ -341,6 +347,12 @@ // Bufferize the op and its nested ops. RewritePatternSet patterns(op->getContext()); patterns.add(patterns.getContext(), bufferizationState); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + // Bufferize op without nested ops. + bool erased = false; + if (failed(applyOpPatternsAndFold(op, frozenPatterns, &erased))) + return failure(); // Bufferize ops top-to-bottom. When creating a new op, we should ideally // know the exact memref type of all operands. Otherwise, we have to use a @@ -354,13 +366,15 @@ GreedyRewriteConfig config; config.useTopDownTraversal = true; + // Bufferize all nested ops (unless the op was erased). // TODO: Perform a preorder walk instead of the greedy pattern rewriter. This // would be more efficient because every bufferization pattern is guaranteed // to apply only a single time (otherwise, an assertion would be triggered). // However, there are restrictions wrt. erasing ops during a preorder walk, // which would likely require a larger refactoring. - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) - return failure(); + if (!erased) + if (failed(applyPatternsAndFoldGreedily(op, frozenPatterns, config))) + return failure(); if (failed(checkBufferizationResult(op, bufferizationState.getOptions()))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// Module Bufferization is an extension of Comprehensive Bufferize that +// Module Bufferization is an extension of One-Shot Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` // implementations for FuncOp, CallOp and ReturnOp. // @@ -337,8 +337,6 @@ } } // namespace -static bool isaTensor(Type t) { return t.isa(); } - /// If `value` is a memref::CastOp, return its source. Otherwise, return /// `value` directly. static Value getNonCastedValue(Value value) { @@ -366,26 +364,28 @@ SymbolTable::lookupNearestSymbolFrom(callOp, sym)); } -/// Return the FunctionType with `argumentTypes` and `resultTypes` where each -/// tensor is replaced by the corresponding buffer type. -/// In order for all the callers to agree, this *must* bufferize to the most -/// dynamic buffer type supported. -/// A later pass across all CallOps in the module can decide whether to simplify -/// the types of to version according to some cost model. -static FunctionType -getBufferizedFunctionType(MLIRContext *ctx, TypeRange argumentTypes, - TypeRange resultTypes, - const BufferizationOptions &options) { - auto rewrite = [&](Type t) -> Type { - // TODO: non-zero address space. - // TODO: layout information if relevant. - if (auto tensorType = t.dyn_cast()) - return getMemRefType(tensorType, options); - return t; - }; - auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite)); - auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite)); - return FunctionType::get(ctx, argTypes, retTypes); +/// Return the index-th bufferized function argument type. This assumes that the +/// specified argument is a tensor. If the tensor is ranked, a layout map may be +/// specified by the user. If no layout map is specified, a fully dynamic map is +/// used. +static BaseMemRefType +getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, + const BufferizationOptions &options) { + auto tensorType = + funcOp.getFunctionType().getInput(index).dyn_cast(); + assert(tensorType && "expected TensorType"); + BaseMemRefType memrefType = getMemRefType(tensorType, options); + + auto layoutAttr = funcOp.getArgAttrOfType( + index, BufferizableOpInterface::kBufferLayoutAttrName); + if (!layoutAttr) + return memrefType; + + auto rankedMemrefType = memrefType.dyn_cast(); + assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); + return MemRefType::get( + rankedMemrefType.getShape(), rankedMemrefType.getElementType(), + layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); } /// Gather equivalence info of CallOps. @@ -415,150 +415,6 @@ }); } -/// Rewrite the `funcOp` arguments analysis return values and terminator into -/// buffer form (using the canonical memref layout for now), according to the -/// inPlace-bufferizable information of the function arguments. -/// -/// This relies on a buffer equivalence analysis of each return operand. When a -/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be -/// dropped from the return values and becomes inplaceable at all callers. This -/// assumes all CallOp perform the necessary work to clone operands so as to -/// make them inplaceable. Reliance on this logic will need to be relaxed in the -/// future. -/// -/// Note: Returning a memref currently fails bufferization. If such memrefs -/// originate from an op with an Alloc effect, they could be hoisted in the -/// future. -static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, - RewriterBase &rewriter, - BufferizationState &state) { - const FuncAnalysisState &funcState = - getFuncAnalysisState(state.getAnalysisState()); - - // If nothing to do then we are done. - if (!llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) && - !llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor)) - return success(); - - // Get the bufferized FunctionType for funcOp or construct it if not yet - // available. - // TODO: Atm we have 3 cases: - // 1. if a function is called from within the Module, it must have bufferized - // to inplaceable tensor results. - // 2. if it is bodiless, it must have bufferized and is not allowed to have - // result tensors. - // 3. if it is not called internally, it still must bufferize to inplaceable - // tensor results and we construct it now (e.g. top-level function called - // externally). - // -> Figure out a better layering. - TypeRange resultTypes; - - // Corner case: Bodiless FuncOp - // ============================ - // The body of such functions is assumed opaque and we can't know the - // bufferization contract they want to enforce atm. - // As a consequence, only support functions that don't return any tensor atm. - if (funcOp.getBody().empty()) { - if (llvm::any_of(funcOp.getFunctionType().getResults(), isaTensor)) - return funcOp->emitError() << "cannot bufferize bodiless function that " - << "returns a tensor"; - FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), funcOp.getFunctionType().getInputs(), - funcOp.getFunctionType().getResults(), state.getOptions()); - funcOp.setType(bufferizedFuncType); - return success(); - } - - // Support only single return-terminated block in the function. - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - // 1. For each FuncOp result, keep track of which inplace argument it reuses. - SmallVector returnValues; - for (OpOperand &returnOperand : returnOp->getOpOperands()) { - Value returnVal = returnOperand.get(); - - // If not a renturn tensor type just forward it. - if (!returnVal.getType().isa()) { - returnValues.push_back(returnVal); - continue; - } - - // If return operand is equivalent to some bbArg, no need to return it. - auto funcOpIt = funcState.equivalentFuncArgs.find(funcOp); - if (funcOpIt != funcState.equivalentFuncArgs.end() && - funcOpIt->second.count(returnOperand.getOperandNumber())) - continue; - - // Cast values at the call site if necessary. - returnValues.push_back( - getNonCastedValue(*state.getBuffer(rewriter, returnOperand))); - } - - // 2. Rewrite the terminator without the inPlace bufferizable values. - ValueRange retValues{returnValues}; - FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), funcOp.getFunctionType().getInputs(), - retValues.getTypes(), state.getOptions()); - OpBuilder b(returnOp); - b.create(returnOp.getLoc(), returnValues); - returnOp->erase(); - - // 3. Rewrite the bbArgs. - // Iterate on the original `numArgs` and replace them in order. - // This guarantees the argument order still matches after the rewrite. - Block &frontBlock = funcOp.getBody().front(); - unsigned numArgs = frontBlock.getNumArguments(); - for (unsigned idx = 0; idx < numArgs; ++idx) { - auto bbArg = frontBlock.getArgument(0); - auto tensorType = bbArg.getType().dyn_cast(); - // Non-tensor types are just forwarded. - if (!tensorType) { - frontBlock.addArgument(bbArg.getType(), bbArg.getLoc()); - bbArg.replaceAllUsesWith(frontBlock.getArguments().back()); - frontBlock.eraseArgument(0); - continue; - } - - // Get the buffer type from the bufferized function type. - Type memrefType = bufferizedFuncType.getInput(idx); - Value memref = frontBlock.addArgument(memrefType, bbArg.getLoc()); - OpBuilder b(funcOp->getContext()); - b.setInsertionPointToStart(&frontBlock); - // Replace all uses of bbArg through a ToMemRefOp. - for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) { - if (auto toMemrefOp = - dyn_cast(use.getOwner())) { - if (memref.getType() != toMemrefOp.memref().getType()) { - // Type has changed, insert a cast. - assert(memref::CastOp::areCastCompatible( - memref.getType(), toMemrefOp.memref().getType()) && - "bufferizeFuncOpBoundary: cast incompatible"); - auto castOp = b.create( - funcOp.getLoc(), toMemrefOp.memref().getType(), memref); - toMemrefOp.memref().replaceAllUsesWith(castOp); - } else { - // Type did not change, replace directly. - toMemrefOp.memref().replaceAllUsesWith(memref); - } - } - } - // Replace all remaining uses by a to_tensor. - if (!bbArg.use_empty()) { - auto toTensorOp = - b.create(funcOp.getLoc(), memref); - bbArg.replaceAllUsesWith(toTensorOp); - } - frontBlock.eraseArgument(0); - // TODO: add support to erase aliasInfo entries if deemed necessary. - } - - // 4. Rewrite the FuncOp type to buffer form. - funcOp.setType(bufferizedFuncType); - - return success(); -} - /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by /// callee-caller order (i.e. callees without callers first). /// Store the map of FuncOp to all its callers in `callerMap`. @@ -616,103 +472,6 @@ return success(); } -static void foreachCaller(const FuncCallerMap &callerMap, FuncOp callee, - llvm::function_ref doit) { - auto itCallers = callerMap.find(callee); - if (itCallers == callerMap.end()) - return; - for (Operation *caller : itCallers->second) - doit(caller); -} - -/// Postprocess the linalg.buffer_layout annotation across function boundaries. -/// This is a purely mechanical process that may later become part of a -/// separate pass with its own layout assignment heuristic. -static void layoutPostProcessing(ModuleOp moduleOp) { - SmallVector orderedFuncOps; - DenseMap> callerMap; - auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap); - (void)res; - assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure"); - - for (FuncOp funcOp : orderedFuncOps) { - DenseMap> operandsPerCaller; - foreachCaller(callerMap, funcOp, [&](Operation *caller) { - operandsPerCaller.try_emplace(caller, SmallVector()); - }); - - SmallVector argumentTypes; - // Iterate on each function argument and check it it was marked with a - // desired layout. - for (const auto &it : - llvm::enumerate(funcOp.getFunctionType().getInputs())) { - int argNumber = it.index(); - Type inputType = it.value(); - auto memrefType = inputType.dyn_cast(); - auto layoutAttr = funcOp.getArgAttrOfType( - argNumber, BufferizableOpInterface::kBufferLayoutAttrName); - AffineMap desiredLayoutMap = - layoutAttr ? layoutAttr.getValue() : AffineMap(); - AffineMap currentLayoutMap = - memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap(); - if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) { - argumentTypes.push_back(inputType); - foreachCaller(callerMap, funcOp, [&](Operation *caller) { - operandsPerCaller.find(caller)->getSecond().push_back( - caller->getOperand(argNumber)); - }); - continue; - } - - // Compute the buffer type with desired layout and add to input argument - // types. - MemRefType desiredMemrefType = MemRefType::get( - memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap); - argumentTypes.push_back(desiredMemrefType); - - // If funcOp's body is not empty, change the bbArg type and propagate. - if (!funcOp.getBody().empty()) { - BlockArgument bbArg = funcOp.getArgument(argNumber); - bbArg.setType(desiredMemrefType); - OpBuilder b(bbArg.getContext()); - b.setInsertionPointToStart(bbArg.getOwner()); - assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) && - "layoutPostProcessing: cast incompatible"); - // Cast back to the original memrefType and let it canonicalize. - Value cast = - b.create(funcOp.getLoc(), memrefType, bbArg); - bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp()); - } - - // Cast to desired buffer type on all callers to `funcOp`. - // TODO: on the callee side, this may even have to trigger a copy to - // change the layout. For now let the memref::CastOp fail to verify in - // such cases. - auto castArg = [&](Operation *caller) { - OpBuilder b(caller); - assert( - memref::CastOp::areCastCompatible( - caller->getOperand(argNumber).getType(), desiredMemrefType) && - "layoutPostProcessing.2: cast incompatible"); - Value newOperand = b.create( - funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber)); - operandsPerCaller.find(caller)->getSecond().push_back(newOperand); - }; - foreachCaller(callerMap, funcOp, castArg); - } - - // Set operands with cast buffer on all callers to `funcOp`. - foreachCaller(callerMap, funcOp, [&](Operation *caller) { - caller->setOperands(operandsPerCaller.lookup(caller)); - }); - - // Finally set the funcOp type to update the arguments. - auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes, - funcOp.getFunctionType().getResults()); - funcOp.setType(newFuncType); - } -} - namespace mlir { namespace linalg { namespace comprehensive_bufferize { @@ -826,9 +585,8 @@ return BufferRelation::Equivalent; } - /// In a first approximation, all the function arguments of a FuncOp are - /// marked inplaceable. For now, it is the responsibility of the `callOp` - /// bufferization to allow FuncOp that are inplaceable to write inPlace. + /// All function arguments are writable. It is the responsibility of the + /// CallOp to insert buffer copies where necessary. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { func::CallOp callOp = cast(op); @@ -871,7 +629,7 @@ for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { unsigned returnValIdx = it.index(); Type returnType = it.value(); - if (!isaTensor(returnType)) { + if (!returnType.isa()) { // Non-tensor values are returned. retValMapping[returnValIdx] = resultTypes.size(); resultTypes.push_back(returnType); @@ -904,11 +662,7 @@ } // 2. Compute bufferized FunctionType. - SmallVector argumentTypes{callOp->getOperandTypes()}; - // Get the bufferized FunctionType for funcOp or construct it if not yet - // available. - FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), argumentTypes, resultTypes, options); + FunctionType bufferizedFuncType = funcOp.getFunctionType(); // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. for (OpOperand &opOperand : callOp->getOpOperands()) { @@ -993,15 +747,138 @@ assert(isa(returnOp->getParentOp()) && "only support FuncOp parent for ReturnOp"); #endif // NDEBUG + + // ReturnOps are bufferized as part of FuncOps. return failure(); } }; struct FuncOpInterface : public BufferizableOpInterface::ExternalModel { + /// Rewrite function bbArgs and return values into buffer form (using the + /// canonical memref layout for now). This function bufferizes the function + /// signature and the ReturnOp. When the entire function body has been + /// bufferized, function return types can be switched to more concise memref + /// types as part of `foldMemRefCasts`. + /// + /// When a tensor function argument is known to be equivalent to a tensor + /// result, it is dropped from the return values. + /// + /// All function bbArgs are writable unless they are explicitly marked as + /// read-only. Callers must insert copies when needed. + /// + /// Note: Returning a memref is possible, but corresponding CallOp + /// bufferizations fail unless `allowReturnAllocs`. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { - return failure(); + auto funcOp = cast(op); + FunctionType funcType = funcOp.getFunctionType(); + const FuncAnalysisState &moduleState = + getFuncAnalysisState(state.getAnalysisState()); + const BufferizationOptions &options = state.getOptions(); + + // Construct the bufferized function type. + SmallVector argTypes; + for (const auto &it : llvm::enumerate(funcType.getInputs())) { + Type argType = it.value(); + if (auto tensorType = argType.dyn_cast()) { + argTypes.push_back( + getBufferizedFunctionArgType(funcOp, it.index(), options)); + continue; + } + argTypes.push_back(argType); + } + + // Bodiless functions are assumed opaque and we cannot know the + // bufferization contract they want to enforce. As a consequence, only + // support functions that don't return any tensors atm. + if (funcOp.getBody().empty()) { + SmallVector retTypes; + for (Type resultType : funcType.getResults()) { + if (resultType.isa()) + return funcOp->emitError() << "cannot bufferize bodiless function " + << "that returns a tensor"; + retTypes.push_back(resultType); + } + funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); + return success(); + } + + // TODO: Support functions with multiple returns. + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); + + // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. + Block &frontBlock = funcOp.getBody().front(); + for (BlockArgument &bbArg : frontBlock.getArguments()) { + auto tensorType = bbArg.getType().dyn_cast(); + // Non-tensor types stay the same. + if (!tensorType) + continue; + + // Collect all uses of the bbArg. + SmallVector bbArgUses; + for (OpOperand &use : bbArg.getUses()) + bbArgUses.push_back(&use); + + // Change the bbArg type to memref. + Type memrefType = + getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); + bbArg.setType(memrefType); + + // Replace all uses of the original tensor bbArg. + rewriter.setInsertionPointToStart(&frontBlock); + if (!bbArgUses.empty()) { + // Insert to_tensor because the remaining function body has not been + // bufferized yet. + Value toTensorOp = + rewriter.create(funcOp.getLoc(), bbArg); + for (OpOperand *use : bbArgUses) + use->set(toTensorOp); + } + } + + // 2. For each result, keep track of which inplace argument it reuses. + SmallVector returnValues; + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + Value returnVal = returnOperand.get(); + + // If not a tensor type just forward it. + if (!returnVal.getType().isa()) { + returnValues.push_back(returnVal); + continue; + } + + // If return operand is equivalent to some bbArg, no need to return it. + if (Optional equivBbArgIdx = getEquivalentFuncArgIdx( + funcOp, moduleState, returnOperand.getOperandNumber())) { + rewriter.setInsertionPoint(returnOp); + Location loc = returnOp.getLoc(); + Value toMemrefOp = rewriter.create( + loc, getMemRefType(returnVal.getType().cast(), options), + returnVal); + BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); + // Note: This copy will fold away. It must be inserted here to ensure + // that `returnVal` still has at least one use and does not fold away. + if (failed( + createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) + return funcOp->emitError("could not generate copy for bbArg"); + continue; + } + + // Cast values at the call site if necessary. + returnValues.push_back( + getNonCastedValue(*state.getBuffer(rewriter, returnOperand))); + } + + // 3. Rewrite the terminator without the in-place bufferizable values. + returnOp.operandsMutable().assign(returnValues); + + // 4. Rewrite the FuncOp type to buffer form. + funcOp.setType(FunctionType::get(op->getContext(), argTypes, + ValueRange(returnValues).getTypes())); + + return success(); } /// Return `true` if the given function argument is writable. @@ -1057,6 +934,34 @@ setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); } +/// Fold return values that are memref casts and update function return types. +/// +/// During FuncOp bufferization, the exact type of the returned memrefs (if any) +/// is not known yet. Therefore, the bufferization uses memref types with the +/// most generic layout map as function return types. After bufferizing the +/// entire function body, a more concise memref type can potentially be used for +/// the return type of the function. +static void foldMemRefCasts(FuncOp funcOp) { + if (funcOp.getBody().empty()) + return; + + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + SmallVector resultTypes; + + for (OpOperand &operand : returnOp->getOpOperands()) { + if (auto castOp = operand.get().getDefiningOp()) { + operand.set(castOp.source()); + resultTypes.push_back(castOp.source().getType()); + } else { + resultTypes.push_back(operand.get().getType()); + } + } + + auto newFuncType = FunctionType::get( + funcOp.getContext(), funcOp.getFunctionType().getInputs(), resultTypes); + funcOp.setType(newFuncType); +} + LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize( ModuleOp moduleOp, OneShotBufferizationOptions options) { IRRewriter rewriter(moduleOp.getContext()); @@ -1107,15 +1012,11 @@ // Bufferize functions. for (FuncOp funcOp : orderedFuncOps) { - // No body => no analysis. - if (!funcOp.getBody().empty()) - if (failed(bufferizeOp(funcOp, bufferizationState))) - return failure(); - // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, bufferizationState))) + if (failed(bufferizeOp(funcOp, bufferizationState))) return failure(); + foldMemRefCasts(funcOp); } // Check result. @@ -1133,10 +1034,6 @@ if (failed(finalizeBuffers(moduleOp, options))) return failure(); - // Perform a post-processing pass of layout modification at function boundary - // according to the kBufferLayoutAttrName. - layoutPostProcessing(moduleOp); - // Post-pass cleanup of inplaceable and buffer_layout attributes. moduleOp.walk([&](FuncOp op) { for (BlockArgument bbArg : op.getArguments()) diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -11,6 +11,7 @@ // ----- +// expected-error @+2 {{op was not bufferized}} // expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} func private @foo() -> tensor @@ -212,6 +213,7 @@ // ----- +// expected-error @+2 {{op was not bufferized}} // expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} func private @foo(%t : tensor) -> (f32, tensor, f32) diff --git a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir @@ -468,15 +468,15 @@ // conflict. However, inside `entry`, the writes do cause a conflict because // %A, %B and %C are not inplaceable. This test case shows that this kind of // conflict detection has a "transitive" nature. -// CHECK: %[[ALLOC_C:.*]] = memref.alloc -// CHECK: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]] -// CHECK: %[[ALLOC_B:.*]] = memref.alloc -// CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] -// CHECK: %[[ALLOC_A:.*]] = memref.alloc -// CHECK: memref.copy %[[A]], %[[ALLOC_A]] -// CHECK: memref.copy %[[B]], %[[ALLOC_B]] -// CHECK: memref.copy %[[C]], %[[ALLOC_C]] -// CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]] +// CHECK-DAG: %[[ALLOC_C:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]] +// CHECK-DAG: %[[ALLOC_B:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] +// CHECK-DAG: %[[ALLOC_A:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]] +// CHECK-DAG: memref.copy %[[A]], %[[ALLOC_A]] +// CHECK-DAG: memref.copy %[[B]], %[[ALLOC_B]] +// CHECK-DAG: memref.copy %[[C]], %[[ALLOC_C]] // CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]]) call @callee(%A, %B, %C) : (tensor, tensor, tensor) -> () return