diff --git a/mlir/include/mlir/Target/Cpp.h b/mlir/include/mlir/Target/Cpp.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/Cpp.h @@ -0,0 +1,220 @@ +//===- Cpp.h - Helpers to create C++ emitter --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file define a emitter class to translate between an MLIR function with +// ops and C++. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_CPP_H +#define MLIR_TARGET_CPP_H + +#include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace mlir { + +class DialectCppEmitter; + +/// Emitter that uses dialect specific emitters to emit C++ code. +/// +/// This emitter only does some trivial checks for valid C++ code emitted and +/// relies on the dialect emitters to produce valid C++ code. It does provide +/// some common utility functions that the DialectCppEmitters could use. +/// +/// All dialects are required to be registered for dialect specific C++ +/// emission. There is very little hardcoded and dialect emitters have freedom +/// to do arbitrary lowerings (this is also a weakness). +/// +/// Known limitations: +/// - Doesn't understand C++ semantics and so may emit invalid code; +/// - Doesn't attempt to ensure the function names are valid C++ names; +/// - Doesn't insert includes etc. +struct CppEmitter { + /// Map from dialect to DialectCppEmitter. + using DialectCppEmitters = + llvm::StringMap>; + /// A registry function that adds emitter to the given emitters map. + using EmitterRegistryFunction = + std::function; + + CppEmitter() = delete; + + /// Emits the given operation with provided dialect emitters on 'os'. + /// Optionally emit a trailing semicolon. + static LogicalResult emit(DialectCppEmitters &emitters, Operation &op, + raw_ostream &os, bool trailingSemicolon = false); + static LogicalResult + emit(const llvm::StringMap &emitters, Operation &op, + raw_ostream &os, bool trailingSemicolon = false); + + /// Emits attribute or returns failure. + LogicalResult emitAttribute(Attribute attribute); + + /// Emits operation 'op' with/without training semicolon or returns failure. + LogicalResult emitOperation(Operation &op, bool trailingSemicolon = true); + + /// Emits type 'type' or returns failure. + LogicalResult emitType(Type type); + + /// Emits array of types as a std::tuple of the emitted types. + LogicalResult emitTypes(ArrayRef types); + + /// Emits a simple operation (one without regions) by + /// - emitting assign prefix (see below) + /// - splitting op name into dialect :: remainder (where it is an error if + /// remainder has a period); + /// - followed by all operands and attributes (see below). + LogicalResult emitOperationFallback(Operation &op, + bool trailingSemicolon = true); + + /// Emits the variable declaration and assignment prefix for 'op'. + /// - emits separate variable followed by std::tie for multi-valued operation; + /// - emits single type followed by variable for single result; + /// - emits nothing if no value produced by op; + /// Emits final '=' operator where a type is produced. Returns failure if + /// any result type could not be converted. + LogicalResult emitAssignPrefix(Operation &op); + + /// Emits the operands and atttributes of the operation. All operands are + /// emitted first and then all attributes in alphabetical order. + LogicalResult emitOperandsAndAttributes(Operation &op, + ArrayRef exclude = {}); + + /// Return the existing or a new name for a Value. + StringRef getOrCreateName(Value val); + + /// RAII helper function to manage entering/exiting C++ scopes. + struct Scope { + Scope(CppEmitter &emitter) : mapperScope(emitter.mapper), emitter(emitter) { + emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); + } + ~Scope() { emitter.valueInScopeCount.pop(); } + + private: + llvm::ScopedHashTableScope mapperScope; + CppEmitter &emitter; + }; + + /// Returns wether the Value is assigned to a C++ variable in the scope. + bool hasValueInScope(Value val); + + /// Returns the output stream. + raw_ostream &ostream() { return os; }; + +private: + friend struct Scope; + using ValMapper = llvm::ScopedHashTable; + + /// Private constructor, the static emit function is entry point for use. + explicit CppEmitter(raw_ostream &os); + + /// Returns dialect emitter for the given dialect. + DialectCppEmitter *getDialectEmitter(Dialect *dialect); + + /// Map from dialect to DialectCppEmitter. + llvm::DenseMap emitters; + + /// Output stream to emit to. + raw_ostream &os; + + /// Map from value to name of C++ variable that contain the name. + ValMapper mapper; + + /// The number of values in the current scope. This is used to declare the + /// names of values in a scope. + std::stack valueInScopeCount; +}; + +/// Interface for C++ emitter for a dialect. This interface can be implemented +/// by a dialect to define behavior specific to a dialect (e.g., how a dialect +/// attribute should be emitted to C++/if it can). +class DialectCppEmitter { +public: + DialectCppEmitter() = default; + virtual ~DialectCppEmitter() = default; + + virtual LogicalResult printAttribute(CppEmitter &emitter, Attribute val) = 0; + virtual LogicalResult printOperation(CppEmitter &emitter, Operation &op) = 0; + virtual LogicalResult printType(CppEmitter &emitter, Type type) = 0; +}; + +/// Register a EmitterRegistryFunction function for a dialect. Typically used +/// via DialectCppRegistration. +/// Note: the registry is used for testing along with mlir-translate and while +/// this registry's dialect emitters are registered at link time, the +/// CppEmitter need not use the registry. +void registerCppEmitter(StringRef dialect, + const CppEmitter::EmitterRegistryFunction &function); + +/// DialectCppRegistration provides a global initializer that registers a +/// DialectCppEmitter allocation routine for a concrete emitter instance. +/// +/// Usage: +/// +/// // At namespace scope. +/// static DialectCppEmitterRegistration reg("mydialect"); +template struct DialectCppEmitterRegistration { + DialectCppEmitterRegistration(StringRef dialect) { + registerCppEmitter(dialect, [=](CppEmitter::DialectCppEmitters ®istry) { + if (registry.find(dialect) != registry.end()) + return failure(); + registry[dialect] = std::make_unique(); + return success(); + }); + } +}; + +/// Returns all the EmitterRegistryFunctions registered via +/// DialectCppEmitterRegistration. +const llvm::StringMap & +getCppEmitterRegistry(); + +/// Convenience functions to produce interleaved output with functions returning +/// a LogicalResult. This is different than those in STL as functions used on +/// each element doesn't return a string. +template +inline LogicalResult +interleaveWithError(ForwardIterator begin, ForwardIterator end, + UnaryFunctor each_fn, NullaryFunctor between_fn) { + if (begin == end) + return success(); + if (failed(each_fn(*begin))) + return failure(); + ++begin; + for (; begin != end; ++begin) { + between_fn(); + if (failed(each_fn(*begin))) + return failure(); + } + return success(); +} + +template +inline LogicalResult interleaveWithError(const Container &c, + UnaryFunctor each_fn, + NullaryFunctor between_fn) { + return interleaveWithError(c.begin(), c.end(), each_fn, between_fn); +} + +template +inline LogicalResult interleaveCommaWithError(const Container &c, + raw_ostream &os, + UnaryFunctor each_fn) { + return interleaveWithError(c.begin(), c.end(), each_fn, + [&]() { os << ", "; }); +} + +} // end namespace mlir + +#endif // MLIR_TARGET_CPP_H diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -1,3 +1,25 @@ +add_mlir_library(MLIRTargetCpp + Cpp/TranslateToCpp.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/Cpp + ) +target_link_libraries(MLIRTargetCpp + PUBLIC + MLIRIR + LLVMSupport + ) +add_mlir_library(MLIRTargetCppRegistration + Cpp/TranslateToCppRegistration.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/Cpp + ) +target_link_libraries(MLIRTargetCppRegistration + PUBLIC + MLIRStandardOps + MLIRTargetCpp + ) add_mlir_library(MLIRTargetLLVMIRModuleTranslation LLVMIR/DebugTranslation.cpp LLVMIR/ModuleTranslation.cpp diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -0,0 +1,212 @@ +//===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/Cpp.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" + +#define DEBUG_TYPE "translate-to-cpp" + +using namespace mlir; +using llvm::formatv; + +LogicalResult CppEmitter::emit(DialectCppEmitters &emitters, Operation &op, + raw_ostream &os, bool trailingSemicolon) { + CppEmitter emitter(os); + auto *context = op.getContext(); + for (auto &it : emitters) { + auto *regDialect = context->getRegisteredDialect(it.first()); + if (regDialect == nullptr) + return op.emitError() + << "dialect '" << it.first() << "' is not registered"; + emitter.emitters[regDialect] = it.second.get(); + } + return emitter.emitOperation(op, trailingSemicolon); +} + +LogicalResult +CppEmitter::emit(const llvm::StringMap &emitters, + Operation &op, raw_ostream &os, bool trailingSemicolon) { + DialectCppEmitters allocatedEmitters; + for (auto &it : emitters) { + if (failed(it.second(allocatedEmitters))) + return op.emitError() << "dialect '" << it.first() + << "'s dialect emitter failed to construct\n"; + } + return emit(allocatedEmitters, op, os, trailingSemicolon); +} + +CppEmitter::CppEmitter(raw_ostream &os) : os(os) { valueInScopeCount.push(0); } + +/// Return the existing or a new name for a Value*. +StringRef CppEmitter::getOrCreateName(Value val) { + if (!mapper.count(val)) + mapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); + return *mapper.begin(val); +} + +bool CppEmitter::hasValueInScope(Value val) { return mapper.count(val); } + +DialectCppEmitter *CppEmitter::getDialectEmitter(Dialect *dialect) { + return emitters.lookup(dialect); +} + +LogicalResult CppEmitter::emitAttribute(Attribute attribute) { + if (auto *dialectEmitter = getDialectEmitter(&attribute.getDialect())) + return dialectEmitter->printAttribute(*this, attribute); + LLVM_DEBUG(llvm::dbgs() << "no dialect emitter found for " << attribute + << " of dialect '" + << attribute.getDialect().getNamespace() + << "' trying fallback\n"); + return failure(); +} + +LogicalResult +CppEmitter::emitOperandsAndAttributes(Operation &op, + ArrayRef exclude) { + // Emit operands. + auto emitOperandName = [&](Value result) -> LogicalResult { + if (!hasValueInScope(result)) + return op.emitError() << "operand value not in scope"; + os << getOrCreateName(result); + return success(); + }; + if (failed(interleaveCommaWithError(op.getOperands(), os, emitOperandName))) + return failure(); + // Insert comma in between operands and non-filtered attributes if needed. + if (op.getNumOperands() > 0) { + for (auto attr : op.getAttrs()) { + if (!llvm::is_contained(exclude, attr.first.strref())) { + os << ", "; + break; + } + } + } + // Emit attributes. + auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { + if (llvm::is_contained(exclude, attr.first.strref())) + return success(); + os << "/* " << attr.first << " */"; + if (failed(emitAttribute(attr.second))) + return op.emitError() << "unable to emit attribute " << attr.second; + return success(); + }; + return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); +} + +LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { + switch (op.getNumResults()) { + case 0: + break; + case 1: { + auto result = op.getResult(0); + if (failed(emitType(result.getType()))) + return failure(); + os << " " << getOrCreateName(result) << " = "; + break; + } + default: + for (auto result : op.getResults()) { + if (failed(emitType(result.getType()))) + return failure(); + os << " " << getOrCreateName(result) << ";\n"; + } + os << "std::tie("; + interleaveComma(op.getResults(), os, + [&](Value result) { os << getOrCreateName(result); }); + os << ") = "; + } + return success(); +} + +LogicalResult CppEmitter::emitOperationFallback(Operation &op, + bool trailingSemicolon) { + if (op.getNumRegions() != 0) + return op.emitOpError() + << "only simple ops supported by fallback op emitter"; + if (failed(emitAssignPrefix(op))) + return failure(); + + auto dialectAndName = op.getName().getStringRef().split('.'); + if (dialectAndName.second.count('.') != 0) + return op.emitOpError() << "unable to convert to C++ call"; + os << dialectAndName.first << "::" << dialectAndName.second << "("; + + if (failed(emitOperandsAndAttributes(op))) + return failure(); + if (trailingSemicolon) + os << ");\n"; + else + os << ")\n"; + return success(); +} + +LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { + if (auto *dialectEmitter = getDialectEmitter(op.getDialect())) { + if (succeeded(dialectEmitter->printOperation(*this, op))) { + os << (trailingSemicolon ? ";\n" : "\n"); + return success(); + } + return failure(); + } + LLVM_DEBUG({ + llvm::dbgs() << "no dialect emitter found for " << op << " of dialect "; + if (auto *dialect = op.getDialect()) + llvm::dbgs() << "'" << dialect->getNamespace() << "'"; + else + llvm::dbgs() << ""; + llvm::dbgs() << " trying fallback\n"; + }); + return emitOperationFallback(op, trailingSemicolon); +} + +LogicalResult CppEmitter::emitType(Type type) { + if (auto *dialectEmitter = getDialectEmitter(&type.getDialect())) + return dialectEmitter->printType(*this, type); + LLVM_DEBUG(llvm::dbgs() << "no dialect emitter found for " << type + << " of dialect '" << type.getDialect().getNamespace() + << "' trying fallback\n"); + return failure(); +} + +LogicalResult CppEmitter::emitTypes(ArrayRef types) { + switch (types.size()) { + case 0: + os << "void"; + return success(); + case 1: + return emitType(types.front()); + default: + os << "std::tuple<"; + if (failed(interleaveCommaWithError( + types, os, [&](Type type) { return emitType(type); }))) + return failure(); + os << ">"; + return success(); + } +} + +llvm::StringMap & +getMutableCppEmitterRegistry() { + static llvm::StringMap registry; + return registry; +} + +void mlir::registerCppEmitter( + StringRef dialect, const CppEmitter::EmitterRegistryFunction &function) { + assert(function && "Attempting to register empty emitter function"); + getMutableCppEmitterRegistry()[dialect] = function; +} + +const llvm::StringMap & +mlir::getCppEmitterRegistry() { + return getMutableCppEmitterRegistry(); +} diff --git a/mlir/lib/Target/Cpp/TranslateToCppRegistration.cpp b/mlir/lib/Target/Cpp/TranslateToCppRegistration.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/Cpp/TranslateToCppRegistration.cpp @@ -0,0 +1,218 @@ +//===- TranslateToCppRegistration.cpp - Register for mlir-translate ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Module.h" +#include "mlir/Target/Cpp.h" +#include "mlir/Translation.h" +// Headers only here due to example emitter registration. +// TODO: Remove once the registrations are moved closer to their dialects. +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/StandardTypes.h" + +using namespace mlir; + +// TODO: Move out of registration here and closer to dialect. This is just here +// locally as an example for the initial development. +namespace { +class StdEmitter : public DialectCppEmitter { +public: + LogicalResult printAttribute(CppEmitter &emitter, Attribute attr) override; + LogicalResult printOperation(CppEmitter &emitter, Operation &op) override; + LogicalResult printType(CppEmitter &emitter, Type type) override; +}; +} // namespace + +LogicalResult StdEmitter::printAttribute(CppEmitter &emitter, Attribute attr) { + return failure(); +} + +static LogicalResult printConstantOp(CppEmitter &emitter, + ConstantOp constantOp) { + auto &os = emitter.ostream(); + emitter.emitType(constantOp.getType()); + os << " " << emitter.getOrCreateName(constantOp.getResult()) << '{'; + if (failed(emitter.emitAttribute(constantOp.getValue()))) + return constantOp.emitError("unable to emit constant value"); + os << '}'; + return success(); +} + +static LogicalResult printCallOp(CppEmitter &emitter, CallOp callOp) { + if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) + return failure(); + + auto &os = emitter.ostream(); + os << callOp.getCallee() << "("; + if (failed(emitter.emitOperandsAndAttributes(*callOp.getOperation(), + /*exclude=*/{"callee"}))) + return failure(); + os << ")"; + return success(); +} + +static LogicalResult printReturnOp(CppEmitter &emitter, ReturnOp returnOp) { + auto &os = emitter.ostream(); + os << "return "; + switch (returnOp.getNumOperands()) { + case 0: + return success(); + case 1: + os << emitter.getOrCreateName(returnOp.getOperand(0)); + return success(emitter.hasValueInScope(returnOp.getOperand(0))); + default: + os << "std::make_tuple("; + if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) + return failure(); + os << ")"; + return success(); + } +} + +LogicalResult StdEmitter::printOperation(CppEmitter &emitter, Operation &op) { + if (auto callOp = dyn_cast(op)) + return printCallOp(emitter, callOp); + if (auto constantOp = dyn_cast(op)) + return printConstantOp(emitter, constantOp); + if (auto returnOp = dyn_cast(op)) + return printReturnOp(emitter, returnOp); + + return op.emitOpError() << "unable to find printer for op"; +} + +LogicalResult StdEmitter::printType(CppEmitter &emitter, Type type) { + emitter.ostream() << type; + return success(); +} +static DialectCppEmitterRegistration stdReg("std"); + +namespace { +class BuiltinEmitter : public DialectCppEmitter { +public: + LogicalResult printAttribute(CppEmitter &emitter, Attribute attr) override; + LogicalResult printOperation(CppEmitter &emitter, Operation &op) override; + LogicalResult printType(CppEmitter &emitter, Type type) override; +}; +} // namespace + +LogicalResult BuiltinEmitter::printAttribute(CppEmitter &emitter, + Attribute attr) { + auto &os = emitter.ostream(); + if (auto iAttr = attr.dyn_cast()) { + os << iAttr.getValue(); + return success(); + } + if (auto dense = attr.dyn_cast()) { + os << '{'; + interleaveComma(dense.getIntValues(), os); + os << '}'; + return success(); + } + return failure(); +} + +static void printFunctionVisibility(FuncOp funcOp, raw_ostream &os) { + if (funcOp.getVisibility() == SymbolTable::Visibility::Private) + os << "static "; +} + +static LogicalResult printModule(CppEmitter &emitter, ModuleOp moduleOp) { + CppEmitter::Scope scope(emitter); + auto &os = emitter.ostream(); + os << "// Forward declare functions.\n"; + for (FuncOp funcOp : moduleOp.getOps()) { + printFunctionVisibility(funcOp, os); + if (failed(emitter.emitTypes(funcOp.getType().getResults()))) + return failure(); + os << " " << funcOp.getName() << "("; + if (failed(interleaveCommaWithError( + funcOp.getArguments(), os, [&](BlockArgument arg) { + return emitter.emitType(arg.getType()); + }))) + return failure(); + os << ");\n"; + } + os << "\n"; + + for (Operation &op : moduleOp) { + if (failed(emitter.emitOperation(op, /*trailingSemiColon=*/false))) + return failure(); + } + return success(); +} + +static LogicalResult printFunction(CppEmitter &emitter, FuncOp functionOp) { + auto &blocks = functionOp.getBlocks(); + if (blocks.size() != 1) + return functionOp.emitOpError() << "only single block functions supported"; + + CppEmitter::Scope scope(emitter); + auto &os = emitter.ostream(); + printFunctionVisibility(functionOp, os); + if (failed(emitter.emitTypes(functionOp.getType().getResults()))) + return functionOp.emitError() << "unable to emit all types"; + os << " " << functionOp.getName(); + + os << "("; + if (failed(interleaveCommaWithError( + functionOp.getArguments(), os, + [&](BlockArgument arg) -> LogicalResult { + if (failed(emitter.emitType(arg.getType()))) + return functionOp.emitError() << "unable to emit arg " + << arg.getArgNumber() << "'s type"; + os << " " << emitter.getOrCreateName(arg); + return success(); + }))) + return failure(); + os << ") {\n"; + + for (Operation &op : functionOp.front()) { + if (failed(emitter.emitOperation(op))) + return failure(); + } + os << "}\n"; + return success(); +} + +LogicalResult BuiltinEmitter::printOperation(CppEmitter &emitter, + Operation &op) { + if (auto moduleOp = dyn_cast(op)) + return printModule(emitter, moduleOp); + if (auto funcOp = dyn_cast(op)) + return printFunction(emitter, funcOp); + if (isa(op)) + return success(); + + return op.emitOpError() << "unable to find printer for op"; +} + +LogicalResult BuiltinEmitter::printType(CppEmitter &emitter, Type type) { + auto &os = emitter.ostream(); + if (auto itype = type.dyn_cast()) { + switch (itype.getWidth()) { + case 32: + return (os << "int32_t"), success(); + case 64: + return (os << "int64_t"), success(); + default: + return failure(); + } + } + return failure(); +} +static DialectCppEmitterRegistration builtinReg(""); + +static LogicalResult MlirToCppTranslateFunction(ModuleOp module, + llvm::raw_ostream &output) { + return CppEmitter::emit(getCppEmitterRegistry(), *module.getOperation(), + output, + /*trailingSemiColon=*/false); +} + +static TranslateFromMLIRRegistration reg("mlir-to-cpp", + MlirToCppTranslateFunction); diff --git a/mlir/test/Target/cpp-calls.mlir b/mlir/test/Target/cpp-calls.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/cpp-calls.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s --dump-input-on-failure + +// CHECK: // Forward declare functions. +// CHECK: void test_foo_print(); +// CHECK: int32_t test_single_return(int32_t); +// CHECK: std::tuple test_multiple_return(); + +// CHECK: void test_foo_print() +func @test_foo_print() { + // CHECK: [[V1:[^ ]*]] = foo::constant(/* value */{0, 1}); + %0 = "foo.constant"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> (i32) + // CHECK: foo::print([[V1]]); + "foo.print"(%0) : (i32) -> () + return +} + +// CHECK: int32_t test_single_return(int32_t [[V2:.*]]) +func @test_single_return(%arg0 : i32) -> i32 { + // CHECK: return [[V2]] + return %arg0 : i32 +} + +// CHECK: std::tuple test_multiple_return() +func @test_multiple_return() -> (i32, i32) { + // CHECK: std::tie([[V3:.*]], [[V4:.*]]) = foo::blah(); + %0:2 = "foo.blah"() : () -> (i32, i32) + // CHECK: [[V5:[^ ]*]] = test_single_return([[V3]]); + %1 = call @test_single_return(%0#0) : (i32) -> i32 + // CHECK: return std::make_tuple([[V5]], [[V4]]); + return %1, %0#1 : i32, i32 +} diff --git a/mlir/tools/mlir-translate/CMakeLists.txt b/mlir/tools/mlir-translate/CMakeLists.txt --- a/mlir/tools/mlir-translate/CMakeLists.txt +++ b/mlir/tools/mlir-translate/CMakeLists.txt @@ -13,6 +13,8 @@ ) set(FULL_LIBS MLIRSPIRVSerialization + MLIRTargetCppRegistration + MLIRTargetCpp MLIRTargetLLVMIR MLIRTargetNVVMIR MLIRTargetROCDLIR