diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -237,7 +237,7 @@ } /// Return the func::FuncOp called by `callOp`. -static func::FuncOp getCalledFunction(CallOpInterface callOp) { +static func::FuncOp getCalledFunction(func::CallOp callOp) { SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); if (!sym) return nullptr; @@ -278,15 +278,15 @@ /// callee-caller order (i.e. callees without callers first). /// Store the map of FuncOp to all its callers in `callerMap`. /// Return `failure()` if a cycle of calls is detected or if we are unable to -/// retrieve the called FuncOp from any CallOpInterface. +/// retrieve the called FuncOp from any func::CallOp. static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, SmallVectorImpl &orderedFuncOps, FuncCallerMap &callerMap) { // For each FuncOp, the set of functions called by it (i.e. the union of - // symbols of all nested CallOpInterfaceOp). + // symbols of all nested func::CallOp). DenseMap> calledBy; - // For each FuncOp, the number of CallOpInterface it contains. + // For each FuncOp, the number of func::CallOp it contains. DenseMap numberCallOpsContainedInFuncOp; WalkResult res = moduleOp.walk([&](func::FuncOp funcOp) -> WalkResult { if (!funcOp.getBody().empty()) { @@ -298,10 +298,7 @@ } numberCallOpsContainedInFuncOp[funcOp] = 0; - return funcOp.walk([&](CallOpInterface callOp) -> WalkResult { - // Only support CallOp for now. - if (!isa(callOp.getOperation())) - return callOp->emitError() << "expected a CallOp"; + return funcOp.walk([&](func::CallOp callOp) -> WalkResult { func::FuncOp calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called func::FuncOp"); callerMap[calledFunction].insert(callOp); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -1,16 +1,5 @@ // RUN: mlir-opt %s -allow-unregistered-dialect -one-shot-bufferize="bufferize-function-boundaries=1" -split-input-file -verify-diagnostics -func.func private @foo() -> tensor - -func.func @bar() -> tensor { - %foo = constant @foo : () -> (tensor) -// expected-error @+1 {{expected a CallOp}} - %res = call_indirect %foo() : () -> (tensor) - return %res : tensor -} - -// ----- - // expected-error @+2 {{cannot bufferize bodiless function that returns a tensor}} // expected-error @+1 {{failed to bufferize op}} func.func private @foo() -> tensor diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -625,3 +625,14 @@ // This function may write to buffer(%ptr). func.func private @maybe_writing_func(%ptr : tensor<*xf32>) + +// ----- + +// Test if other callables are left intact and don't cause trouble. + +llvm.func @llvm_func() + +func.func @call_llvm_func() { + llvm.call @llvm_func() : () -> () + return +}