diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h --- a/mlir/include/mlir/Target/Cpp/CppEmitter.h +++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h @@ -22,12 +22,17 @@ namespace mlir { namespace emitc { +enum class CppFileKind { Header, Cpp }; + /// Translates the given operation to C++ code. The operation or operations in /// the region of 'op' need almost all be in EmitC dialect. The parameter /// 'declareVariablesAtTop' enforces that all variables for op results and block /// arguments are declared at the beginning of the function. LogicalResult translateToCpp(Operation *op, raw_ostream &os, - bool declareVariablesAtTop = false); + bool declareVariablesAtTop = false, + bool emitStatefunFns = false, + std::string argNameAttr = "", + CppFileKind cppFileKind = CppFileKind::Header); } // namespace emitc } // namespace mlir diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -32,12 +32,50 @@ llvm::cl::desc("Declare variables at top when emitting C/C++"), llvm::cl::init(false)); + static llvm::cl::opt emitCppKind( + "emit-cpp-kind", + llvm::cl::desc("Emit stateful versions of the MLIR functions"), + llvm::cl::init(emitc::EmitCppKind::Stateless), + llvm::cl::values( + clEnumValN(emitc::EmitCppKind::Stateless, "stateless", + "Emit a stateless function."), + clEnumValN(emitc::EmitCppKind::Stateful, "stateful", + "Emit a 'stateful function' in the form of a class") + + )); + + static llvm::cl::opt argNameAttr( + "emit-cpp-arg-name-attr", + llvm::cl::desc("(Stateful only) Attribute which holds the argument names " + "in the MLIR block")); + + static llvm::cl::opt modelName( + "emit-cpp-model-name", + llvm::cl::desc("(Stateful only) Name of the model. Will be exposed in a " + "name() method of the class.")); + + static llvm::cl::opt onlyOneFnName( + "emit-cpp-only-one-fn", + llvm::cl::desc( + "Only translate one function in the MLIR module to C++. This " + "argument is the name of that function. If empty, translate all.")); + + static llvm::cl::opt cppFileKind( + "emit-cpp-file-kind", + llvm::cl::desc("If emitting stateful functions, should this call emit " + "the .h or .cpp file."), + llvm::cl::values(clEnumValN(emitc::CppFileKind::Header, "header", + "Emit the .h file in a .h/.cpp pair"), + clEnumValN(emitc::CppFileKind::Cpp, "cpp", + "Emit the .cpp file in a .h/.cpp pair"))); + TranslateFromMLIRRegistration reg( "mlir-to-cpp", "translate from mlir to cpp", [](Operation *op, raw_ostream &output) { return emitc::translateToCpp( op, output, - /*declareVariablesAtTop=*/declareVariablesAtTop); + /*declareVariablesAtTop=*/declareVariablesAtTop, emitCppKind, + cppFileKind, argNameAttr, modelName, onlyOneFnName); }, [](DialectRegistry ®istry) { // clang-format off diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -69,7 +69,10 @@ namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { - explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); + explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop, + emitc::EmitCppKind emitCppKind, CppFileKind cppFileKind, + std::string argNameAttr, std::string modelName, + std::string onlyOneFnName); /// Emits attribute or returns failure. LogicalResult emitAttribute(Location loc, Attribute attr); @@ -77,6 +80,9 @@ /// Emits operation 'op' with/without training semicolon or returns failure. LogicalResult emitOperation(Operation &op, bool trailingSemicolon); + /// Emits the type of the "underlying buffer" pointer, or returns failure + LogicalResult emitBufferPointerType(Location loc, Type type); + /// Emits type 'type' or returns failure. LogicalResult emitType(Location loc, Type type); @@ -84,11 +90,13 @@ /// - emits void for an empty array; /// - emits the type of the only element for arrays of size one; /// - emits a std::tuple otherwise; - LogicalResult emitTypes(Location loc, ArrayRef types); + LogicalResult emitTypes(Location loc, ArrayRef types, + bool useBufferPointerType = false); /// Emits array of types as a std::tuple of the emitted types independently of /// the array size. - LogicalResult emitTupleType(Location loc, ArrayRef types); + LogicalResult emitTupleType(Location loc, ArrayRef types, + bool useBufferPointerType = false); /// Emits an assignment for a variable which has been declared previously. LogicalResult emitVariableAssignment(OpResult result); @@ -97,6 +105,10 @@ LogicalResult emitVariableDeclaration(OpResult result, bool trailingSemicolon); + /// Emits a setter for a member variable corresponding to the variable in + /// result. + LogicalResult emitVariableSetter(OpResult result, bool trailingSemicolon); + /// Emits the variable declaration and assignment prefix for 'op'. /// - emits separate variable followed by std::tie for multi-valued operation; /// - emits single type followed by variable for single result; @@ -157,6 +169,16 @@ /// be declared at the beginning of a function. bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; + EmitCppKind getEmitCppKind() { return emitCppKind; }; + + CppFileKind getCppFileKind() const { return cppFileKind; } + + std::string getArgNameAttr() const { return argNameAttr; } + + std::string getModelName() const { return modelName; } + + std::string shouldOnlyPrintOneFn() const { return onlyOneFnName; } + private: using ValueMapper = llvm::ScopedHashTable; using BlockMapper = llvm::ScopedHashTable; @@ -169,6 +191,16 @@ /// includes results from ops located in nested regions. bool declareVariablesAtTop; + EmitCppKind emitCppKind; + + CppFileKind cppFileKind; + + std::string argNameAttr; + + std::string modelName; + + std::string onlyOneFnName; + /// Map from value to name of C++ variable that contain the name. ValueMapper valueMapper; @@ -592,33 +624,9 @@ return success(); } -static LogicalResult printOperation(CppEmitter &emitter, - func::FuncOp functionOp) { - // We need to declare variables at top if the function has multiple blocks. - if (!emitter.shouldDeclareVariablesAtTop() && - functionOp.getBlocks().size() > 1) { - return functionOp.emitOpError( - "with multiple blocks needs variables declared at top"); - } - - CppEmitter::Scope scope(emitter); +static LogicalResult printFuncOpBody(CppEmitter &emitter, + func::FuncOp functionOp) { raw_indented_ostream &os = emitter.ostream(); - if (failed(emitter.emitTypes(functionOp.getLoc(), - functionOp.getFunctionType().getResults()))) - return failure(); - os << " " << functionOp.getName(); - - os << "("; - if (failed(interleaveCommaWithError( - functionOp.getArguments(), os, - [&](BlockArgument arg) -> LogicalResult { - if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) - return failure(); - os << " " << emitter.getOrCreateName(arg); - return success(); - }))) - return failure(); - os << ") {\n"; os.indent(); if (emitter.shouldDeclareVariablesAtTop()) { // Declare all variables that hold op results including those from nested @@ -677,12 +685,273 @@ return failure(); } } - os.unindent() << "}\n"; return success(); } -CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) - : os(os), declareVariablesAtTop(declareVariablesAtTop) { +/// The goal of this method is to print something that looks like the following +/// psuedocode: +/// +/// class _MyFnImpl; +/// class MyFn { +/// private: +/// std::unique_ptr<_MyFnImpl> impl; +/// public: +/// void* get_input_buffer(std::string_view name); +/// float* run(); +/// static std::string_view name() { return "MyFn"; } +/// }; +/// +/// This declaration follows the pimpl design pattern. We use pimpl because +/// internally these methods communicate via a Tensor type, but we +/// don't want this to be a part of the public interface for various +/// project-specific reasons. +/// +/// Users of the generated code are intended to set inputs by acquiring a +/// non-owning, mutable pointer to the various input buffers, and setting the +/// raw tensor data directly. The non-owning pointer returned by run() refers +/// to the result of the computation using the aforementioned input buffers, +/// and the result pointer is valid until the next time run() is called. +/// +/// Note: by design, this class is intended to be used in a single-threaded +/// context. +static LogicalResult printStatefulFnDecl(CppEmitter &emitter, + func::FuncOp functionOp) { + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + + // Name of the impl class + std::string pimpl = "_" + emitter.getModelName() + "Impl"; + + // Declare the impl class + os << "class " << pimpl << ";\n"; + + // Declare main class + os << "class " << emitter.getModelName(); + + os << " {\nprivate:\n std::unique_ptr<" << pimpl << "> impl;\n"; + + os << "public:\n"; + // Declare the constructor and destructor of the model. Note that we need to + // explicitly declare the constructor because we must allocate the memory + // for/construct the impl member, and we need the destructor for formal build + // reasons around the pimpl design pattern. + os << " " << emitter.getModelName() << "();\n"; + os << " ~" << emitter.getModelName() << "();\n"; + + // Declare the get_input_buffer method + os << " void* get_input_buffer(std::string_view name);"; + + // Declare the static name() method + os << " static std::string_view name()"; + os << " { return \"" << emitter.getModelName() << "\"; }\n "; + + // Declare run() method + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults(), + true))) + return failure(); + + os << "run();\n};\n"; + return success(); +} + +/// The goal of this method is to print something that looks like the following +/// psuedocode: +/// +/// class _MyFnImpl { +/// private: +/// Tensor Result; +/// Tensor v0; +/// Tensor v1; +/// ... +/// public: +/// void* get_input_buffer(std::string_name name) { +/// if(name == "input_0") { +/// return static_cast(v0.get()); +/// } +/// if(name == "input_1") { +/// return static_cast(v1.get()); +/// } +/// assert(false && "Invalid input name!); +/// return nullptr; +/// } +/// float* run() { Result = runImpl(); return Result.get(); } +/// Tensor runImpl() { +/// // insert the primary definition of the function +/// } +/// }; +/// void* get_input_buffer(std::string_name name) { +/// return impl->get_input_buffer(name); +/// } +/// float* MyFn::run() { return impl->run(); } +/// +/// This is an implementation of the pimpl design pattern. In the _MyFnImpl +/// class, we need both run and runImpl because we need objects of type +/// _MyFnImpl to own a persistent buffer containing the result of the function, +/// so that the external api can return a non-owning pointer to the result. As +/// with the declaration code, these choices are motivated by the needs of the +/// github.com/google/ml-compiler-opt project. +static LogicalResult printStatefulFnDefn(CppEmitter &emitter, + func::FuncOp functionOp) { + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + + auto args = functionOp.getArguments(); + auto argAttrs = functionOp.getArgAttrsAttr(); + + // Name of the impl member + std::string pimpl = "_" + emitter.getModelName() + "Impl"; + + // Define the impl + os << "class " << pimpl << " {\n"; + os << "private:\n"; + + // Declare the result tensor member + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " result;\n"; + + // Declare the tensor members for each input + for (int i = 0; i < functionOp.getNumArguments(); ++i) { + if (failed(emitter.emitType(functionOp.getLoc(), args[i].getType()))) { + return failure(); + } + os << " " << emitter.getOrCreateName(args[i]) << ";\n"; + } + + os << "public:\n"; + + // Define the get_input_buffer method + os << "void* get_input_buffer(std::string_view name) {\n"; + for (int i = 0; i < functionOp.getNumArguments(); ++i) { + auto attr = argAttrs[i].cast(); + auto name = attr.get(emitter.getArgNameAttr()).cast()[0]; + auto varname = emitter.getOrCreateName(args[i]); + os << " if (name == " << name << ") { return static_cast(" + << varname << ".get()); }\n"; + } + os << " assert(false && \"Unknown input name!\");\n"; + os << " return nullptr;\n"; + os << "}\n"; + + // Define the run() method + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults(), + true))) + return failure(); + + os << "run() {\n"; + os << " result = runImpl();\n"; + os << " return result.get();\n"; + os << " }\n"; + + // Define the runImpl() method + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults(), + false))) + return failure(); + os << " runImpl() {\n"; + if (failed(printFuncOpBody(emitter, functionOp))) { + return failure(); + } + os.unindent(); + os << "}\n"; + + os << "};\n"; + + // Define the model constructor + os << emitter.getModelName() << "::" << emitter.getModelName() + << "() : impl{std::make_unique<" << pimpl << ">()} {}\n"; + // Define the model destructor + os << emitter.getModelName() << "::~" << emitter.getModelName() << "() {}\n"; + + // Define the get_input_buffer method for the model + os << "void* " << emitter.getModelName() + << "::get_input_buffer(std::string_view name) { return " + "impl->get_input_buffer(name); }\n"; + + // Define the run() method for the model + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults(), + true))) + return failure(); + + os << " "; + os << emitter.getModelName() << "::run() { return impl->run(); }\n"; + + return success(); +} + +static LogicalResult printStatefulFn(CppEmitter &emitter, + func::FuncOp functionOp) { + // Early-exit if we are not supposed to print this function + auto printFnName = emitter.shouldOnlyPrintOneFn(); + if (!printFnName.empty()) { + if (functionOp.getName() != printFnName) { + return success(); + } + } + + auto kind = emitter.getCppFileKind(); + if (kind == CppFileKind::Header) { + return printStatefulFnDecl(emitter, functionOp); + } + if (kind == CppFileKind::Cpp) { + return printStatefulFnDefn(emitter, functionOp); + } + // Should be unreachable + return failure(); +} + +static LogicalResult printOperation(CppEmitter &emitter, + func::FuncOp functionOp) { + + if (emitter.getEmitCppKind() == EmitCppKind::Stateful) { + return printStatefulFn(emitter, functionOp); + } + + // We need to declare variables at top if the function has multiple blocks. + if (!emitter.shouldDeclareVariablesAtTop() && + functionOp.getBlocks().size() > 1) { + return functionOp.emitOpError( + "with multiple blocks needs variables declared at top"); + } + + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " " << functionOp.getName(); + + os << "("; + if (failed(interleaveCommaWithError( + functionOp.getArguments(), os, + [&](BlockArgument arg) -> LogicalResult { + if (failed(emitter.emitType(functionOp.getLoc(), arg.getType()))) + return failure(); + os << " " << emitter.getOrCreateName(arg); + return success(); + }))) + return failure(); + os << ") {\n"; + + if (failed(printFuncOpBody(emitter, functionOp))) { + return failure(); + } + os << "}\n"; + return success(); +} + +CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, + emitc::EmitCppKind emitCppKind, CppFileKind cppFileKind, + std::string argNameAttr, std::string modelName, + std::string onlyOneFnName) + : os(os), declareVariablesAtTop(declareVariablesAtTop), + emitCppKind(emitCppKind), cppFileKind(cppFileKind), + argNameAttr(argNameAttr), modelName(modelName), + onlyOneFnName(onlyOneFnName) { valueInScopeCount.push(0); labelInScopeCount.push(0); } @@ -885,6 +1154,25 @@ return success(); } +LogicalResult CppEmitter::emitVariableSetter(OpResult result, + bool trailingSemicolon) { + std::string name = "EMPTY_NAME"; + if (result.getOwner()->getAttrDictionary().contains("iree.identifier")) { + name = result.getOwner() + ->getAttr(argNameAttr) + .cast()[0] + .cast() + .str(); + } + + os << "void set_" << name << "("; + if (failed(emitType(result.getOwner()->getLoc(), result.getType()))) + return failure(); + os << " x) {\n" << name << " = x;\n}\n"; + + return success(); +} + LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { switch (op.getNumResults()) { case 0: @@ -956,6 +1244,17 @@ return success(); } +LogicalResult CppEmitter::emitBufferPointerType(Location loc, Type type) { + auto tType = type.dyn_cast(); + if (!tType) { + return emitError(loc, "Can only emit buffer pointer type for Tensors"); + } + if (failed(emitType(loc, tType.getElementType()))) + return failure(); + os << "* "; + return success(); +} + LogicalResult CppEmitter::emitType(Location loc, Type type) { if (auto iType = type.dyn_cast()) { switch (iType.getWidth()) { @@ -1016,29 +1315,43 @@ return emitError(loc, "cannot emit type ") << type; } -LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef types) { +LogicalResult CppEmitter::emitTypes(Location loc, ArrayRef types, + bool useBufferPointerType) { switch (types.size()) { case 0: os << "void"; return success(); case 1: - return emitType(loc, types.front()); + if (useBufferPointerType) { + return emitBufferPointerType(loc, types.front()); + } else { + return emitType(loc, types.front()); + } default: - return emitTupleType(loc, types); + return emitTupleType(loc, types, useBufferPointerType); } } -LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef types) { +LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef types, + bool useBufferPointerType) { os << "std::tuple<"; - if (failed(interleaveCommaWithError( - types, os, [&](Type type) { return emitType(loc, type); }))) + if (failed(interleaveCommaWithError(types, os, [&](Type type) { + if (useBufferPointerType) { + return emitBufferPointerType(loc, type); + } else { + return emitType(loc, type); + } + }))) return failure(); os << ">"; return success(); } -LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, - bool declareVariablesAtTop) { - CppEmitter emitter(os, declareVariablesAtTop); +LogicalResult emitc::translateToCpp( + Operation *op, raw_ostream &os, bool declareVariablesAtTop, + emitc::EmitCppKind emitCppKind, CppFileKind cppFileKind, + std::string argNameAttr, std::string modelName, std::string onlyOneFnName) { + CppEmitter emitter(os, declareVariablesAtTop, emitCppKind, cppFileKind, + argNameAttr, modelName, onlyOneFnName); return emitter.emitOperation(*op, /*trailingSemicolon=*/false); } diff --git a/mlir/test/Target/Cpp/emit_stateful_fn.mlir b/mlir/test/Target/Cpp/emit_stateful_fn.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/Cpp/emit_stateful_fn.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-translate -mlir-to-cpp --emit-cpp-kind=stateful --emit-cpp-arg-name-attr=tf_saved_model.index_path --emit-cpp-model-name=test --emit-cpp-file-kind=header %s | FileCheck %s -check-prefix=HEADER +// RUN: mlir-translate -mlir-to-cpp --emit-cpp-kind=stateful --emit-cpp-arg-name-attr=tf_saved_model.index_path --emit-cpp-model-name=test --emit-cpp-file-kind=cpp %s | FileCheck %s -check-prefix=CPP + +func.func @test(%arg0: tensor<1xf32> {tf_saved_model.index_path=["a"]}, + %arg1: tensor<1xf32> {tf_saved_model.index_path=["b"]}) + -> (tensor<1xf32>) { + %0 = emitc.call "tosa::add"(%arg0, %arg1) {args = [0 : index, 1 : index]} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + +// HEADER: class _testImpl; +// HEADER: class test { +// HEADER: private: +// HEADER: std::unique_ptr<_testImpl> impl; +// HEADER: public: +// HEADER: void* get_input_buffer(std::string_view name); +// HEADER: float* run(); +// HEADER: }; + +// CPP: class _testImpl { +// CPP: private: +// CPP: Tensor result; +// CPP: Tensor v1; +// CPP: Tensor v2; +// CPP: public: +// CPP: void* get_input_buffer(std::string_view name) { +// CPP: if (name == "a") { return static_cast(v1.get()); } +// CPP: if (name == "b") { return static_cast(v2.get()); } +// CPP: assert(false && "Unknown input name!"); +// CPP: return nullptr; +// CPP: } +// CPP: float* run() { +// CPP: result = runImpl(); +// CPP: return result.get(); +// CPP: } +// CPP: Tensor runImpl() { +// CPP: Tensor v3 = tosa::add(v1, v2); +// CPP: return v3; +// CPP: } +// CPP: }; +// CPP: test::test() : impl{std::make_unique<_testImpl>()} {} +// CPP: test::~test() {} +// CPP: void* test::get_input_buffer(std::string_view name) { return impl->get_input_buffer(name); } +// CPP: float* test::run() { return impl->run(); }