diff --git a/mlir/include/mlir/Target/Cpp/CppEmitter.h b/mlir/include/mlir/Target/Cpp/CppEmitter.h --- a/mlir/include/mlir/Target/Cpp/CppEmitter.h +++ b/mlir/include/mlir/Target/Cpp/CppEmitter.h @@ -22,12 +22,17 @@ namespace mlir { namespace emitc { +enum class CppFileKind { Header, Cpp }; + /// Translates the given operation to C++ code. The operation or operations in /// the region of 'op' need almost all be in EmitC dialect. The parameter /// 'declareVariablesAtTop' enforces that all variables for op results and block /// arguments are declared at the beginning of the function. LogicalResult translateToCpp(Operation *op, raw_ostream &os, - bool declareVariablesAtTop = false); + bool declareVariablesAtTop = false, + bool emitStatefunFns = false, + std::string argNameAttr = "", + CppFileKind cppFileKind = CppFileKind::Header); } // namespace emitc } // namespace mlir diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -32,12 +32,31 @@ llvm::cl::desc("Declare variables at top when emitting C/C++"), llvm::cl::init(false)); + static llvm::cl::opt emitStatefulFn( + "emit-stateful-fn", + llvm::cl::desc("Emit stateful versions of the MLIR functions"), + llvm::cl::init(false)); + + static llvm::cl::opt argNameAttr( + "arg-name-attr", + llvm::cl::desc( + "Attribute which holds the argument names in the MLIR block")); + + static llvm::cl::opt cppFileKind( + "cpp-file-kind", + llvm::cl::desc("If emitting stateful functions, should this call emit " + "the .h or .cpp file."), + llvm::cl::values( + clEnumValN(emitc::CppFileKind::Header, "header", "Emit the .h file"), + clEnumValN(emitc::CppFileKind::Cpp, "cpp", "Emit the .cpp file"))); + TranslateFromMLIRRegistration reg( "mlir-to-cpp", "translate from mlir to cpp", [](Operation *op, raw_ostream &output) { return emitc::translateToCpp( op, output, - /*declareVariablesAtTop=*/declareVariablesAtTop); + /*declareVariablesAtTop=*/declareVariablesAtTop, emitStatefulFn, + argNameAttr, cppFileKind); }, [](DialectRegistry ®istry) { // clang-format off diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -69,7 +69,9 @@ namespace { /// Emitter that uses dialect specific emitters to emit C++ code. struct CppEmitter { - explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop); + explicit CppEmitter(raw_ostream &os, bool declareVariablesAtTop, + bool emitStatefulFns, std::string argNameAttr, + CppFileKind cppFileKind); /// Emits attribute or returns failure. LogicalResult emitAttribute(Location loc, Attribute attr); @@ -97,6 +99,10 @@ LogicalResult emitVariableDeclaration(OpResult result, bool trailingSemicolon); + /// Emits a setter for a member variable corresponding to the variable in + /// result. + LogicalResult emitVariableSetter(OpResult result, bool trailingSemicolon); + /// Emits the variable declaration and assignment prefix for 'op'. /// - emits separate variable followed by std::tie for multi-valued operation; /// - emits single type followed by variable for single result; @@ -157,6 +163,12 @@ /// be declared at the beginning of a function. bool shouldDeclareVariablesAtTop() { return declareVariablesAtTop; }; + bool shouldEmitStatefulFns() { return emitStatefulFns; }; + + std::string getArgNameAttr() const { return argNameAttr; } + + CppFileKind getCppFileKind() const { return cppFileKind; } + private: using ValueMapper = llvm::ScopedHashTable; using BlockMapper = llvm::ScopedHashTable; @@ -169,6 +181,12 @@ /// includes results from ops located in nested regions. bool declareVariablesAtTop; + bool emitStatefulFns; + + std::string argNameAttr; + + CppFileKind cppFileKind; + /// Map from value to name of C++ variable that contain the name. ValueMapper valueMapper; @@ -592,8 +610,114 @@ return success(); } +static LogicalResult printStatefulFn(CppEmitter &emitter, + func::FuncOp functionOp) { + CppEmitter::Scope scope(emitter); + raw_indented_ostream &os = emitter.ostream(); + + auto args = functionOp.getArguments(); + auto argAttrs = functionOp.getArgAttrsAttr(); + + if (emitter.getCppFileKind() == CppFileKind::Header) { + os << "class " << functionOp.getName() << " {\n private:\n"; + for (int i = 0; i < functionOp.getNumArguments(); ++i) { + if (failed(emitter.emitType(functionOp.getLoc(), args[i].getType()))) { + return failure(); + } + os << " " << emitter.getOrCreateName(args[i]) << ";\n"; + } + } + + if (emitter.getCppFileKind() == CppFileKind::Header) { + os << "public:\n"; + for (int i = 0; i < functionOp.getNumArguments(); ++i) { + auto attr = argAttrs[i].cast(); + auto name = attr.get(emitter.getArgNameAttr()).cast()[0]; + os << "void set_" << name.cast().str() << "("; + if (failed(emitter.emitType(functionOp.getLoc(), args[i].getType()))) { + return failure(); + } + os << " x);\n"; + } + } else { + for (int i = 0; i < functionOp.getNumArguments(); ++i) { + auto attr = argAttrs[i].cast(); + auto name = attr.get(emitter.getArgNameAttr()).cast()[0]; + os << "void " << functionOp.getName() << "::set_" + << name.cast().str() << "("; + if (failed(emitter.emitType(functionOp.getLoc(), args[i].getType()))) { + return failure(); + } + os << " x) { " << emitter.getOrCreateName(args[i]) + << " = std::move(x); }\n"; + } + } + + if (failed(emitter.emitTypes(functionOp.getLoc(), + functionOp.getFunctionType().getResults()))) + return failure(); + os << " " << functionOp.getName() << "()"; + + if (emitter.getCppFileKind() == CppFileKind::Header) { + os << ";\n"; + } else { + os << "{\n"; + os.indent(); + Region::BlockListType &blocks = functionOp.getBlocks(); + // Create label names for basic blocks. + for (Block &block : blocks) { + emitter.getOrCreateName(block); + } + + // Declare variables for basic block arguments. + for (Block &block : llvm::drop_begin(blocks)) { + for (BlockArgument &arg : block.getArguments()) { + if (emitter.hasValueInScope(arg)) + return functionOp.emitOpError(" block argument #") + << arg.getArgNumber() << " is out of scope"; + if (failed(emitter.emitType(block.getParentOp()->getLoc(), + arg.getType()))) { + return failure(); + } + os << " " << emitter.getOrCreateName(arg) << ";\n"; + } + } + + for (Block &block : blocks) { + // Only print a label if the block has predecessors. + if (!block.hasNoPredecessors()) { + if (failed(emitter.emitLabel(block))) + return failure(); + } + for (Operation &op : block.getOperations()) { + // When generating code for an scf.if or cf.cond_br op no semicolon + // needs to be printed after the closing brace. When generating code for + // an scf.for op, printing a trailing semicolon is handled within the + // printOperation function. + bool trailingSemicolon = + !isa(op); + + if (failed(emitter.emitOperation( + op, /*trailingSemicolon=*/trailingSemicolon))) + return failure(); + } + } + os.unindent() << "}\n"; + } + + if (emitter.getCppFileKind() == CppFileKind::Header) { + os << "};\n"; + } + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, func::FuncOp functionOp) { + + if (emitter.shouldEmitStatefulFns()) { + return printStatefulFn(emitter, functionOp); + } + // We need to declare variables at top if the function has multiple blocks. if (!emitter.shouldDeclareVariablesAtTop() && functionOp.getBlocks().size() > 1) { @@ -681,8 +805,12 @@ return success(); } -CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop) - : os(os), declareVariablesAtTop(declareVariablesAtTop) { +CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop, + bool emitStatefulFns, std::string argNameAttr, + CppFileKind cppFileKind) + : os(os), declareVariablesAtTop(declareVariablesAtTop), + emitStatefulFns(emitStatefulFns), argNameAttr(argNameAttr), + cppFileKind(cppFileKind) { valueInScopeCount.push(0); labelInScopeCount.push(0); } @@ -885,6 +1013,25 @@ return success(); } +LogicalResult CppEmitter::emitVariableSetter(OpResult result, + bool trailingSemicolon) { + std::string name = "EMPTY_NAME"; + if (result.getOwner()->getAttrDictionary().contains("iree.identifier")) { + name = result.getOwner() + ->getAttr(argNameAttr) + .cast()[0] + .cast() + .str(); + } + + os << "void set_" << name << "("; + if (failed(emitType(result.getOwner()->getLoc(), result.getType()))) + return failure(); + os << " x) {\n" << name << " = x;\n}\n"; + + return success(); +} + LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { switch (op.getNumResults()) { case 0: @@ -1038,7 +1185,11 @@ } LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os, - bool declareVariablesAtTop) { - CppEmitter emitter(os, declareVariablesAtTop); + bool declareVariablesAtTop, + bool emitStatefulFns, + std::string argNameAttr, + CppFileKind cppFileKind) { + CppEmitter emitter(os, declareVariablesAtTop, emitStatefulFns, argNameAttr, + cppFileKind); return emitter.emitOperation(*op, /*trailingSemicolon=*/false); }