diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Affine) add_subdirectory(AVX512) +add_subdirectory(EmitC) add_subdirectory(FxpMathOps) add_subdirectory(GPU) add_subdirectory(Linalg) diff --git a/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt b/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(EmitC emitc EmitC) diff --git a/mlir/include/mlir/Dialect/EmitC/EmitC.td b/mlir/include/mlir/Dialect/EmitC/EmitC.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/EmitC.td @@ -0,0 +1,46 @@ +//===- EmitC.td - EmitC operation definitions --------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines some operations with trivial mapping to C/C++ constructs. +// +//===----------------------------------------------------------------------===// + +#ifndef EMITC_OPS +#define EMITC_OPS + +include "mlir/Interfaces/SideEffects.td" + +def EmitC_Dialect : Dialect { + let name = "emitc"; + let cppNamespace = "emitc"; +} + +// Base class for EmitC dialect ops. +class EmitC_Op traits = []> : + Op; + +def EmitC_CallOp : EmitC_Op<"call", []> { + let summary = "call operation"; + let description = [{ + The "call" operation represents a C++ function call. The call allows + specifying order of operands and attributes in the call as follows: + + - integer value of index type refers to an operand; + - attribute which will get lowered to constant value in call; + }]; + let arguments = (ins + Arg:$callee, + Arg, "the order of operands and attributes">:$args, + Variadic:$operands); + let results = (outs Variadic); + let assemblyFormat = [{ + $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) + }]; +} + +#endif // EMITC_OPS diff --git a/mlir/include/mlir/Dialect/EmitC/EmitCDialect.h b/mlir/include/mlir/Dialect/EmitC/EmitCDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/EmitCDialect.h @@ -0,0 +1,31 @@ +//===- AVX512Dialect.h - MLIR Dialect for AVX512 ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for AVX512 in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_EMITC_EMITCDIALECT_H_ +#define MLIR_DIALECT_EMITC_EMITCDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffects.h" + +namespace mlir { +namespace emitc { + +#define GET_OP_CLASSES +#include "mlir/Dialect/EmitC/EmitC.h.inc" + +#include "mlir/Dialect/EmitC/EmitCDialect.h.inc" + +} // namespace emitc +} // namespace mlir + +#endif // MLIR_DIALECT_EMITC_EMITCDIALECT_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -16,6 +16,7 @@ #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/EmitC/EmitCDialect.h" #include "mlir/Dialect/FxpMathOps/FxpMathOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMAVX512Dialect.h" @@ -41,6 +42,7 @@ static bool init_once = []() { registerDialect(); registerDialect(); + registerDialect(); registerDialect(); registerDialect(); registerDialect(); diff --git a/mlir/include/mlir/Target/Cpp.h b/mlir/include/mlir/Target/Cpp.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/Cpp.h @@ -0,0 +1,30 @@ +//===- Cpp.h - Helpers to create C++ emitter --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file define a helpers to emit C++ code using the EmitC dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_CPP_H +#define MLIR_TARGET_CPP_H + +#include "mlir/IR/Value.h" +#include "llvm/Support/raw_ostream.h" + +namespace mlir { +namespace emitc { + +/// Translates the given operation to C++ code. The operation or operations in +/// the region of 'op' need almost all be in EmitC dialect. +LogicalResult TranslateToCpp(Operation &op, raw_ostream &os, + bool trailingSemicolon = false); + +} // namespace emitc +} // namespace mlir + +#endif // MLIR_TARGET_CPP_H diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Affine) add_subdirectory(AVX512) +add_subdirectory(EmitC) add_subdirectory(FxpMathOps) add_subdirectory(GPU) add_subdirectory(Linalg) diff --git a/mlir/lib/Dialect/EmitC/CMakeLists.txt b/mlir/lib/Dialect/EmitC/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/EmitC/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_dialect_library(MLIREmitC + IR/EmitCDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC + + DEPENDS + MLIREmitCIncGen + ) +target_link_libraries(MLIREmitC + PUBLIC + MLIRIR + MLIRSideEffects + LLVMSupport + ) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitCDialect.cpp b/mlir/lib/Dialect/EmitC/IR/EmitCDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/EmitC/IR/EmitCDialect.cpp @@ -0,0 +1,35 @@ +//===- EmitCOps.cpp - MLIR EmitC ops implementation -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the EmitC dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/EmitCDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +emitc::EmitCDialect::EmitCDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/EmitC/EmitC.cpp.inc" + >(); + allowUnknownTypes(); + allowUnknownOperations(); +} + +namespace mlir { +namespace emitc { +#define GET_OP_CLASSES +#include "mlir/Dialect/EmitC/EmitC.cpp.inc" +} // namespace emitc +} // namespace mlir diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -1,3 +1,26 @@ +add_mlir_library(MLIRTargetCpp + Cpp/TranslateToCpp.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/Cpp + ) +target_link_libraries(MLIRTargetCpp + PUBLIC + MLIREmitC + MLIRIR + MLIRStandardOps + LLVMSupport + ) +add_mlir_library(MLIRTargetCppRegistration + Cpp/TranslateToCppRegistration.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/Cpp + ) +target_link_libraries(MLIRTargetCppRegistration + PUBLIC + MLIRTargetCpp + ) add_mlir_library(MLIRTargetLLVMIRModuleTranslation LLVMIR/DebugTranslation.cpp LLVMIR/ModuleTranslation.cpp diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -0,0 +1,422 @@ +//===- TranslateToCpp.cpp - Translating to C++ calls ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/EmitCDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/Target/Cpp.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include + +#define DEBUG_TYPE "translate-to-cpp" + +using namespace mlir; +using llvm::formatv; + +/// Convenience functions to produce interleaved output with functions returning +/// a LogicalResult. This is different than those in STL as functions used on +/// each element doesn't return a string. +template +inline LogicalResult +interleaveWithError(ForwardIterator begin, ForwardIterator end, + UnaryFunctor each_fn, NullaryFunctor between_fn) { + if (begin == end) + return success(); + if (failed(each_fn(*begin))) + return failure(); + ++begin; + for (; begin != end; ++begin) { + between_fn(); + if (failed(each_fn(*begin))) + return failure(); + } + return success(); +} + +template +inline LogicalResult interleaveWithError(const Container &c, + UnaryFunctor each_fn, + NullaryFunctor between_fn) { + return interleaveWithError(c.begin(), c.end(), each_fn, between_fn); +} + +template +inline LogicalResult interleaveCommaWithError(const Container &c, + raw_ostream &os, + UnaryFunctor each_fn) { + return interleaveWithError(c.begin(), c.end(), each_fn, + [&]() { os << ", "; }); +} + +namespace { +/// Emitter that uses dialect specific emitters to emit C++ code. +struct CppEmitter { + explicit CppEmitter(raw_ostream &os); + + /// Emits attribute or returns failure. + LogicalResult emitAttribute(Attribute attr); + + /// Emits operation 'op' with/without training semicolon or returns failure. + LogicalResult emitOperation(Operation &op, bool trailingSemicolon = true); + + /// Emits type 'type' or returns failure. + LogicalResult emitType(Type type); + + /// Emits array of types as a std::tuple of the emitted types. + LogicalResult emitTypes(ArrayRef types); + + /// Emits the variable declaration and assignment prefix for 'op'. + /// - emits separate variable followed by std::tie for multi-valued operation; + /// - emits single type followed by variable for single result; + /// - emits nothing if no value produced by op; + /// Emits final '=' operator where a type is produced. Returns failure if + /// any result type could not be converted. + LogicalResult emitAssignPrefix(Operation &op); + + /// Emits the operands and atttributes of the operation. All operands are + /// emitted first and then all attributes in alphabetical order. + LogicalResult emitOperandsAndAttributes(Operation &op, + ArrayRef exclude = {}); + + /// Emits the operands of the operation. All operands are emitted in order. + LogicalResult emitOperands(Operation &op); + + /// Return the existing or a new name for a Value. + StringRef getOrCreateName(Value val); + + /// RAII helper function to manage entering/exiting C++ scopes. + struct Scope { + Scope(CppEmitter &emitter) : mapperScope(emitter.mapper), emitter(emitter) { + emitter.valueInScopeCount.push(emitter.valueInScopeCount.top()); + } + ~Scope() { emitter.valueInScopeCount.pop(); } + + private: + llvm::ScopedHashTableScope mapperScope; + CppEmitter &emitter; + }; + + /// Returns wether the Value is assigned to a C++ variable in the scope. + bool hasValueInScope(Value val); + + /// Returns the output stream. + raw_ostream &ostream() { return os; }; + +private: + using ValMapper = llvm::ScopedHashTable; + + /// Output stream to emit to. + raw_ostream &os; + + /// Map from value to name of C++ variable that contain the name. + ValMapper mapper; + + /// The number of values in the current scope. This is used to declare the + /// names of values in a scope. + std::stack valueInScopeCount; +}; +} // namespace + +static LogicalResult printConstantOp(CppEmitter &emitter, + ConstantOp constantOp) { + auto &os = emitter.ostream(); + emitter.emitType(constantOp.getType()); + os << " " << emitter.getOrCreateName(constantOp.getResult()) << '{'; + if (failed(emitter.emitAttribute(constantOp.getValue()))) + return constantOp.emitError("unable to emit constant value"); + os << '}'; + return success(); +} + +static LogicalResult printCallOp(CppEmitter &emitter, CallOp callOp) { + if (failed(emitter.emitAssignPrefix(*callOp.getOperation()))) + return failure(); + + auto &os = emitter.ostream(); + os << callOp.getCallee() << "("; + if (failed(emitter.emitOperands(*callOp.getOperation()))) + return failure(); + os << ")"; + return success(); +} + +static LogicalResult printCallOp(CppEmitter &emitter, emitc::CallOp callOp) { + auto &os = emitter.ostream(); + auto &op = *callOp.getOperation(); + if (failed(emitter.emitAssignPrefix(op))) + return failure(); + os << callOp.callee() << "("; + + auto emitArgs = [&](Attribute attr) -> LogicalResult { + if (auto t = attr.dyn_cast()) { + // Index attributes are treated specially as operand index. + if (t.getType().isIndex()) { + auto idx = t.getInt(); + if ((idx < 0) || (idx >= op.getNumOperands())) + return op.emitOpError() << "invalid operand index"; + if (!emitter.hasValueInScope(op.getOperand(idx))) + return op.emitOpError() + << "operand " << idx << "'s value not defined in scope"; + os << emitter.getOrCreateName(op.getOperand(idx)); + return success(); + } + } + return emitter.emitAttribute(attr); + }; + + if (callOp.argsAttr()) { + callOp.dump(); + } + auto emittedArgs = + callOp.args() ? interleaveCommaWithError(*callOp.args(), os, emitArgs) + : emitter.emitOperands(op); + if (failed(emittedArgs)) + return failure(); + os << ")"; + return success(); +} + +static LogicalResult printReturnOp(CppEmitter &emitter, ReturnOp returnOp) { + auto &os = emitter.ostream(); + os << "return "; + switch (returnOp.getNumOperands()) { + case 0: + return success(); + case 1: + os << emitter.getOrCreateName(returnOp.getOperand(0)); + return success(emitter.hasValueInScope(returnOp.getOperand(0))); + default: + os << "std::make_tuple("; + if (failed(emitter.emitOperandsAndAttributes(*returnOp.getOperation()))) + return failure(); + os << ")"; + return success(); + } +} + +static LogicalResult printModule(CppEmitter &emitter, ModuleOp moduleOp) { + CppEmitter::Scope scope(emitter); + auto &os = emitter.ostream(); + os << "// Forward declare functions.\n"; + for (FuncOp funcOp : moduleOp.getOps()) { + if (failed(emitter.emitTypes(funcOp.getType().getResults()))) + return funcOp.emitError() << "failed to convert operand type"; + os << " " << funcOp.getName() << "("; + if (failed(interleaveCommaWithError( + funcOp.getArguments(), os, [&](BlockArgument arg) { + return emitter.emitType(arg.getType()); + }))) + return failure(); + os << ");\n"; + } + os << "\n"; + + for (Operation &op : moduleOp) { + if (failed(emitter.emitOperation(op, /*trailingSemiColon=*/false))) + return failure(); + } + return success(); +} + +static LogicalResult printFunction(CppEmitter &emitter, FuncOp functionOp) { + auto &blocks = functionOp.getBlocks(); + if (blocks.size() != 1) + return functionOp.emitOpError() << "only single block functions supported"; + + CppEmitter::Scope scope(emitter); + auto &os = emitter.ostream(); + if (failed(emitter.emitTypes(functionOp.getType().getResults()))) + return functionOp.emitError() << "unable to emit all types"; + os << " " << functionOp.getName(); + + os << "("; + if (failed(interleaveCommaWithError( + functionOp.getArguments(), os, + [&](BlockArgument arg) -> LogicalResult { + if (failed(emitter.emitType(arg.getType()))) + return functionOp.emitError() << "unable to emit arg " + << arg.getArgNumber() << "'s type"; + os << " " << emitter.getOrCreateName(arg); + return success(); + }))) + return failure(); + os << ") {\n"; + + for (Operation &op : functionOp.front()) { + if (failed(emitter.emitOperation(op))) + return failure(); + } + os << "}\n"; + return success(); +} + +CppEmitter::CppEmitter(raw_ostream &os) : os(os) { valueInScopeCount.push(0); } + +/// Return the existing or a new name for a Value*. +StringRef CppEmitter::getOrCreateName(Value val) { + if (!mapper.count(val)) + mapper.insert(val, formatv("v{0}", ++valueInScopeCount.top())); + return *mapper.begin(val); +} + +bool CppEmitter::hasValueInScope(Value val) { return mapper.count(val); } + +LogicalResult CppEmitter::emitAttribute(Attribute attr) { + if (auto iAttr = attr.dyn_cast()) { + os << iAttr.getValue(); + return success(); + } + if (auto dense = attr.dyn_cast()) { + os << '{'; + interleaveComma(dense.getIntValues(), os); + os << '}'; + return success(); + } + return failure(); +} + +LogicalResult CppEmitter::emitOperands(Operation &op) { + auto emitOperandName = [&](Value result) -> LogicalResult { + if (!hasValueInScope(result)) + return op.emitError() << "operand value not in scope"; + os << getOrCreateName(result); + return success(); + }; + return interleaveCommaWithError(op.getOperands(), os, emitOperandName); +} + +LogicalResult +CppEmitter::emitOperandsAndAttributes(Operation &op, + ArrayRef exclude) { + if (failed(emitOperands(op))) + return failure(); + // Insert comma in between operands and non-filtered attributes if needed. + if (op.getNumOperands() > 0) { + for (auto attr : op.getAttrs()) { + if (!llvm::is_contained(exclude, attr.first.strref())) { + os << ", "; + break; + } + } + } + // Emit attributes. + auto emitNamedAttribute = [&](NamedAttribute attr) -> LogicalResult { + if (llvm::is_contained(exclude, attr.first.strref())) + return success(); + os << "/* " << attr.first << " */"; + if (failed(emitAttribute(attr.second))) + return op.emitError() << "unable to emit attribute " << attr.second; + return success(); + }; + return interleaveCommaWithError(op.getAttrs(), os, emitNamedAttribute); +} + +LogicalResult CppEmitter::emitAssignPrefix(Operation &op) { + switch (op.getNumResults()) { + case 0: + break; + case 1: { + auto result = op.getResult(0); + if (failed(emitType(result.getType()))) + return failure(); + os << " " << getOrCreateName(result) << " = "; + break; + } + default: + for (auto result : op.getResults()) { + if (failed(emitType(result.getType()))) + return failure(); + os << " " << getOrCreateName(result) << ";\n"; + } + os << "std::tie("; + interleaveComma(op.getResults(), os, + [&](Value result) { os << getOrCreateName(result); }); + os << ") = "; + } + return success(); +} + +static LogicalResult printOperation(CppEmitter &emitter, Operation &op) { + if (auto callOp = dyn_cast(op)) + return printCallOp(emitter, callOp); + if (auto callOp = dyn_cast(op)) + return printCallOp(emitter, callOp); + if (auto constantOp = dyn_cast(op)) + return printConstantOp(emitter, constantOp); + if (auto returnOp = dyn_cast(op)) + return printReturnOp(emitter, returnOp); + if (auto moduleOp = dyn_cast(op)) + return printModule(emitter, moduleOp); + if (auto funcOp = dyn_cast(op)) + return printFunction(emitter, funcOp); + if (isa(op)) + return success(); + + return op.emitOpError() << "unable to find printer for op"; +} + +LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { + if (failed(printOperation(*this, op))) + return failure(); + os << (trailingSemicolon ? ";\n" : "\n"); + return success(); +} + +LogicalResult CppEmitter::emitType(Type type) { + if (auto itype = type.dyn_cast()) { + switch (itype.getWidth()) { + case 1: + return (os << "bool"), success(); + case 32: + return (os << "int32_t"), success(); + case 64: + return (os << "int64_t"), success(); + default: + return failure(); + } + } + // TODO: Change to be EmitC specific. + if (auto ot = type.dyn_cast()) { + os << ot.getTypeData(); + return success(); + } + return failure(); +} + +LogicalResult CppEmitter::emitTypes(ArrayRef types) { + switch (types.size()) { + case 0: + os << "void"; + return success(); + case 1: + return emitType(types.front()); + default: + os << "std::tuple<"; + if (failed(interleaveCommaWithError( + types, os, [&](Type type) { return emitType(type); }))) + return failure(); + os << ">"; + return success(); + } +} + +LogicalResult emitc::TranslateToCpp(Operation &op, raw_ostream &os, + bool trailingSemicolon) { + CppEmitter emitter(os); + return emitter.emitOperation(op, trailingSemicolon); +} diff --git a/mlir/lib/Target/Cpp/TranslateToCppRegistration.cpp b/mlir/lib/Target/Cpp/TranslateToCppRegistration.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/Cpp/TranslateToCppRegistration.cpp @@ -0,0 +1,22 @@ +//===- TranslateToCppRegistration.cpp - Register for mlir-translate ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Module.h" +#include "mlir/Target/Cpp.h" +#include "mlir/Translation.h" + +using namespace mlir; + +static LogicalResult MlirToCppTranslateFunction(ModuleOp module, + llvm::raw_ostream &output) { + return emitc::TranslateToCpp(*module.getOperation(), output, + /*trailingSemiColon=*/false); +} + +static TranslateFromMLIRRegistration reg("mlir-to-cpp", + MlirToCppTranslateFunction); diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -0,0 +1,11 @@ +// RUN: mlir-opt -verify-diagnostics %s | FileCheck %s + +// CHECK-LABEL: func @f(%{{.*}}: i32, %{{.*}}: !emitc.int32_t) -> i1 { +func @f(%arg0: i32, %f: !emitc<"int32_t">) -> i1 { + %1 = "emitc.call"() {callee = "blah"} : () -> i64 + emitc.call "foo" (%1) {args = [ + 0 : index, dense<[0, 1]> : tensor<2xi32>, 0 : index + ]} : (i64) -> () + %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1) + return %2#1 : i1 +} diff --git a/mlir/test/Target/cpp-calls.mlir b/mlir/test/Target/cpp-calls.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/cpp-calls.mlir @@ -0,0 +1,35 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s --dump-input-on-failure + +// CHECK: // Forward declare functions. +// CHECK: void test_foo_print(); +// CHECK: int32_t test_single_return(int32_t); +// CHECK: std::tuple test_multiple_return(); + +// CHECK: void test_foo_print() +func @test_foo_print() { + // CHECK: [[V1:[^ ]*]] = foo::constant({0, 1}); + %0 = emitc.call "foo::constant"() {args = [dense<[0, 1]> : tensor<2xi32>]} : () -> (i32) + // CHECK: [[V2:[^ ]*]] = foo::op_and_attr({0, 1}, [[V1]]); + %1 = emitc.call "foo::op_and_attr"(%0) {args = [dense<[0, 1]> : tensor<2xi32>, 0 : index]} : (i32) -> (i32) + // CHECK: [[V3:[^ ]*]] = foo::op_and_attr([[V2]], {0, 1}); + %2 = emitc.call "foo::op_and_attr"(%1) {args = [0 : index, dense<[0, 1]> : tensor<2xi32>]} : (i32) -> (i32) + // CHECK: foo::print([[V3]]); + emitc.call "foo::print"(%2): (i32) -> () + return +} + +// CHECK: int32_t test_single_return(int32_t [[V2:.*]]) +func @test_single_return(%arg0 : i32) -> i32 { + // CHECK: return [[V2]] + return %arg0 : i32 +} + +// CHECK: std::tuple test_multiple_return() +func @test_multiple_return() -> (i32, i32) { + // CHECK: std::tie([[V3:.*]], [[V4:.*]]) = foo::blah(); + %0:2 = emitc.call "foo::blah"() : () -> (i32, i32) + // CHECK: [[V5:[^ ]*]] = test_single_return([[V3]]); + %1 = call @test_single_return(%0#0) : (i32) -> i32 + // CHECK: return std::make_tuple([[V5]], [[V4]]); + return %1, %0#1 : i32, i32 +} diff --git a/mlir/tools/mlir-translate/CMakeLists.txt b/mlir/tools/mlir-translate/CMakeLists.txt --- a/mlir/tools/mlir-translate/CMakeLists.txt +++ b/mlir/tools/mlir-translate/CMakeLists.txt @@ -15,6 +15,8 @@ set(FULL_LIBS MLIRSPIRVSerialization MLIRTargetAVX512 + MLIRTargetCppRegistration + MLIRTargetCpp MLIRTargetLLVMIR MLIRTargetNVVMIR MLIRTargetROCDLIR