diff --git a/202209301111.patch b/202209301111.patch new file mode 100644 --- /dev/null +++ b/202209301111.patch @@ -0,0 +1,2050 @@ +diff --git a/mlir/examples/standalone/standalone-translate/standalone-translate.cpp b/mlir/examples/standalone/standalone-translate/standalone-translate.cpp +index 2c2f275..31ddef4 100644 +--- a/mlir/examples/standalone/standalone-translate/standalone-translate.cpp ++++ b/mlir/examples/standalone/standalone-translate/standalone-translate.cpp +@@ -1,27 +1,34 @@ + //===- standalone-translate.cpp ---------------------------------*- C++ -*-===// + // + // This file is licensed under the Apache License v2.0 with LLVM Exceptions. + // See https://llvm.org/LICENSE.txt for license information. + // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // + //===----------------------------------------------------------------------===// + // + // This is a command line utility that translates a file from/to MLIR using one + // of the registered translations. + // + //===----------------------------------------------------------------------===// + ++#include "Standalone/StandaloneDialect.h" ++#include "mlir/IR/BuiltinOps.h" + #include "mlir/InitAllTranslations.h" + #include "mlir/Support/LogicalResult.h" + #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" +- +-#include "Standalone/StandaloneDialect.h" ++#include "mlir/Tools/mlir-translate/Translation.h" + + int main(int argc, char **argv) { + mlir::registerAllTranslations(); + + // TODO: Register standalone translations here. ++ mlir::TranslateFromMLIRRegistration withdescription( ++ "option", "different from option", ++ [](mlir::ModuleOp op, llvm::raw_ostream &output) { ++ return mlir::LogicalResult::success(); ++ }, ++ [](mlir::DialectRegistry &a) {}); + + return failed( + mlir::mlirTranslateMain(argc, argv, "MLIR Translation Testing Tool")); + } +diff --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h +index d91e479..c8b5b70 100644 +--- a/mlir/include/mlir/Tools/mlir-translate/Translation.h ++++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h +@@ -1,102 +1,103 @@ + //===- Translation.h - Translation registry ---------------------*- C++ -*-===// + // + // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + // See https://llvm.org/LICENSE.txt for license information. + // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // + //===----------------------------------------------------------------------===// + // + // Registry for user-provided translations. + // + //===----------------------------------------------------------------------===// + + #ifndef MLIR_TOOLS_MLIRTRANSLATE_TRANSLATION_H + #define MLIR_TOOLS_MLIRTRANSLATE_TRANSLATION_H + + #include "llvm/Support/CommandLine.h" + + namespace llvm { + class MemoryBuffer; + class SourceMgr; + class StringRef; + } // namespace llvm + + namespace mlir { + class DialectRegistry; + struct LogicalResult; + class MLIRContext; + class ModuleOp; + template + class OwningOpRef; + + /// Interface of the function that translates the sources managed by `sourceMgr` + /// to MLIR. The source manager has at least one buffer. The implementation + /// should create a new MLIR ModuleOp in the given context and return a pointer + /// to it, or a nullptr in case of any error. + using TranslateSourceMgrToMLIRFunction = std::function( + llvm::SourceMgr &sourceMgr, MLIRContext *)>; + + /// Interface of the function that translates the given string to MLIR. The + /// implementation should create a new MLIR ModuleOp in the given context. If + /// source-related error reporting is required from within the function, use + /// TranslateSourceMgrToMLIRFunction instead. + using TranslateStringRefToMLIRFunction = + std::function(llvm::StringRef, MLIRContext *)>; + + /// Interface of the function that translates MLIR to a different format and + /// outputs the result to a stream. It is allowed to modify the module. + using TranslateFromMLIRFunction = + std::function; + + /// Interface of the function that performs file-to-file translation involving + /// MLIR. The input file is held in the given MemoryBuffer; the output file + /// should be written to the given raw_ostream. The implementation should create + /// all MLIR constructs needed during the process inside the given context. This + /// can be used for round-tripping external formats through the MLIR system. + using TranslateFunction = std::function; + + /// Use Translate[ToMLIR|FromMLIR]Registration as an initializer that + /// registers a function and associates it with name. This requires that a + /// translation has not been registered to a given name. + /// + /// Usage: + /// + /// // At file scope. + /// namespace mlir { + /// void registerTRexToMLIRRegistration() { + /// TranslateToMLIRRegistration Unused(&MySubCommand, [] { ... }); + /// } + /// } // namespace mlir + /// + /// \{ + struct TranslateToMLIRRegistration { +- TranslateToMLIRRegistration(llvm::StringRef name, ++ TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description, + const TranslateSourceMgrToMLIRFunction &function); +- TranslateToMLIRRegistration(llvm::StringRef name, ++ TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description, + const TranslateStringRefToMLIRFunction &function); + }; + + struct TranslateFromMLIRRegistration { + TranslateFromMLIRRegistration( +- llvm::StringRef name, const TranslateFromMLIRFunction &function, ++ llvm::StringRef name, llvm::StringRef description, ++ const TranslateFromMLIRFunction &function, + const std::function &dialectRegistration = + [](DialectRegistry &) {}); + }; + struct TranslateRegistration { +- TranslateRegistration(llvm::StringRef name, ++ TranslateRegistration(llvm::StringRef name, llvm::StringRef description, + const TranslateFunction &function); + }; + /// \} + + /// A command line parser for translation functions. + struct TranslationParser : public llvm::cl::parser { + TranslationParser(llvm::cl::Option &opt); + + void printOptionInfo(const llvm::cl::Option &o, + size_t globalWidth) const override; + }; + + } // namespace mlir + + #endif // MLIR_TOOLS_MLIRTRANSLATE_TRANSLATION_H +# diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp +# index c4d26ea..c85dd0d 100644 +# --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +# +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp +# @@ -1,54 +1,54 @@ +# //===- TranslateRegistration.cpp - Register translation -------------------===// +# // +# // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# // See https://llvm.org/LICENSE.txt for license information. +# // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# // +# //===----------------------------------------------------------------------===// + +# #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +# #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +# #include "mlir/Dialect/EmitC/IR/EmitC.h" +# #include "mlir/Dialect/Func/IR/FuncOps.h" +# #include "mlir/Dialect/Math/IR/Math.h" +# #include "mlir/Dialect/SCF/IR/SCF.h" +# #include "mlir/IR/BuiltinOps.h" +# #include "mlir/IR/Dialect.h" +# #include "mlir/Target/Cpp/CppEmitter.h" +# #include "mlir/Tools/mlir-translate/Translation.h" +# #include "llvm/Support/CommandLine.h" + +# using namespace mlir; + +# namespace mlir { + +# //===----------------------------------------------------------------------===// +# // Cpp registration +# //===----------------------------------------------------------------------===// + +# void registerToCppTranslation() { +# static llvm::cl::opt declareVariablesAtTop( +# "declare-variables-at-top", +# llvm::cl::desc("Declare variables at top when emitting C/C++"), +# llvm::cl::init(false)); + +# TranslateFromMLIRRegistration reg( +# - "mlir-to-cpp", +# + "mlir-to-cpp", "translate mlir to cpp", +# [](ModuleOp module, raw_ostream &output) { +# return emitc::translateToCpp( +# module, output, +# /*declareVariablesAtTop=*/declareVariablesAtTop); +# }, +# [](DialectRegistry ®istry) { +# // clang-format off +# registry.insert(); +# // clang-format on +# }); +# } + +# } // namespace mlir +# diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +# index 9c47d36..b839bba 100644 +# --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +# +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +# @@ -1,1415 +1,1416 @@ +# //===- ConvertFromLLVMIR.cpp - MLIR to LLVM IR conversion -----------------===// +# // +# // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# // See https://llvm.org/LICENSE.txt for license information. +# // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# // +# //===----------------------------------------------------------------------===// +# // +# // This file implements a translation between LLVM IR and the MLIR LLVM dialect. +# // +# //===----------------------------------------------------------------------===// + +# #include "mlir/Target/LLVMIR/Import.h" + +# #include "mlir/Dialect/DLTI/DLTI.h" +# #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +# #include "mlir/IR/Builders.h" +# #include "mlir/IR/BuiltinOps.h" +# #include "mlir/IR/BuiltinTypes.h" +# #include "mlir/IR/MLIRContext.h" +# #include "mlir/Interfaces/DataLayoutInterfaces.h" +# #include "mlir/Target/LLVMIR/TypeFromLLVM.h" +# #include "mlir/Tools/mlir-translate/Translation.h" + +# #include "llvm/ADT/StringSet.h" +# #include "llvm/ADT/TypeSwitch.h" +# #include "llvm/IR/Attributes.h" +# #include "llvm/IR/Constants.h" +# #include "llvm/IR/DerivedTypes.h" +# #include "llvm/IR/Function.h" +# #include "llvm/IR/InlineAsm.h" +# #include "llvm/IR/Instructions.h" +# #include "llvm/IR/Intrinsics.h" +# #include "llvm/IR/Type.h" +# #include "llvm/IRReader/IRReader.h" +# #include "llvm/Support/Error.h" +# #include "llvm/Support/SourceMgr.h" + +# using namespace mlir; +# using namespace mlir::LLVM; + +# #include "mlir/Dialect/LLVMIR/LLVMConversionEnumsFromLLVM.inc" + +# // Utility to print an LLVM value as a string for passing to emitError(). +# // FIXME: Diagnostic should be able to natively handle types that have +# // operator << (raw_ostream&) defined. +# static std::string diag(llvm::Value &v) { +# std::string s; +# llvm::raw_string_ostream os(s); +# os << v; +# return os.str(); +# } + +# /// Creates an attribute containing ABI and preferred alignment numbers parsed +# /// a string. The string may be either "abi:preferred" or just "abi". In the +# /// latter case, the prefrred alignment is considered equal to ABI alignment. +# static DenseIntElementsAttr parseDataLayoutAlignment(MLIRContext &ctx, +# StringRef spec) { +# auto i32 = IntegerType::get(&ctx, 32); + +# StringRef abiString, preferredString; +# std::tie(abiString, preferredString) = spec.split(':'); +# int abi, preferred; +# if (abiString.getAsInteger(/*Radix=*/10, abi)) +# return nullptr; + +# if (preferredString.empty()) +# preferred = abi; +# else if (preferredString.getAsInteger(/*Radix=*/10, preferred)) +# return nullptr; + +# return DenseIntElementsAttr::get(VectorType::get({2}, i32), {abi, preferred}); +# } + +# /// Returns a supported MLIR floating point type of the given bit width or null +# /// if the bit width is not supported. +# static FloatType getDLFloatType(MLIRContext &ctx, int32_t bitwidth) { +# switch (bitwidth) { +# case 16: +# return FloatType::getF16(&ctx); +# case 32: +# return FloatType::getF32(&ctx); +# case 64: +# return FloatType::getF64(&ctx); +# case 80: +# return FloatType::getF80(&ctx); +# case 128: +# return FloatType::getF128(&ctx); +# default: +# return nullptr; +# } +# } + +# DataLayoutSpecInterface +# mlir::translateDataLayout(const llvm::DataLayout &dataLayout, +# MLIRContext *context) { +# assert(context && "expected MLIR context"); +# std::string layoutstr = dataLayout.getStringRepresentation(); + +# // Remaining unhandled default layout defaults +# // e (little endian if not set) +# // p[n]:64:64:64 (non zero address spaces have 64-bit properties) +# std::string append = +# "p:64:64:64-S0-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:32:64-f16:16:16-f64:" +# "64:64-f128:128:128-v64:64:64-v128:128:128-a:0:64"; +# if (layoutstr.empty()) +# layoutstr = append; +# else +# layoutstr = layoutstr + "-" + append; + +# StringRef layout(layoutstr); + +# SmallVector entries; +# StringSet<> seen; +# while (!layout.empty()) { +# // Split at '-'. +# std::pair split = layout.split('-'); +# StringRef current; +# std::tie(current, layout) = split; + +# // Split at ':'. +# StringRef kind, spec; +# std::tie(kind, spec) = current.split(':'); +# if (seen.contains(kind)) +# continue; +# seen.insert(kind); + +# char symbol = kind.front(); +# StringRef parameter = kind.substr(1); + +# if (symbol == 'i' || symbol == 'f') { +# unsigned bitwidth; +# if (parameter.getAsInteger(/*Radix=*/10, bitwidth)) +# return nullptr; +# DenseIntElementsAttr params = parseDataLayoutAlignment(*context, spec); +# if (!params) +# return nullptr; +# auto entry = DataLayoutEntryAttr::get( +# symbol == 'i' ? static_cast(IntegerType::get(context, bitwidth)) +# : getDLFloatType(*context, bitwidth), +# params); +# entries.emplace_back(entry); +# } else if (symbol == 'e' || symbol == 'E') { +# auto value = StringAttr::get( +# context, symbol == 'e' ? DLTIDialect::kDataLayoutEndiannessLittle +# : DLTIDialect::kDataLayoutEndiannessBig); +# auto entry = DataLayoutEntryAttr::get( +# StringAttr::get(context, DLTIDialect::kDataLayoutEndiannessKey), +# value); +# entries.emplace_back(entry); +# } +# } + +# return DataLayoutSpecAttr::get(context, entries); +# } + +# // Handles importing globals and functions from an LLVM module. +# namespace { +# class Importer { +# public: +# Importer(MLIRContext *context, ModuleOp module) +# : b(context), context(context), module(module), +# unknownLoc(FileLineColLoc::get(context, "imported-bitcode", 0, 0)), +# typeTranslator(*context) { +# b.setInsertionPointToStart(module.getBody()); +# } + +# /// Imports `f` into the current module. +# LogicalResult processFunction(llvm::Function *f); + +# /// Converts function attributes of LLVM Function \p f +# /// into LLVM dialect attributes of LLVMFuncOp \p funcOp. +# void processFunctionAttributes(llvm::Function *f, LLVMFuncOp funcOp); + +# /// Imports GV as a GlobalOp, creating it if it doesn't exist. +# GlobalOp processGlobal(llvm::GlobalVariable *gv); + +# private: +# /// Returns personality of `f` as a FlatSymbolRefAttr. +# FlatSymbolRefAttr getPersonalityAsAttr(llvm::Function *f); +# /// Imports `bb` into `block`, which must be initially empty. +# LogicalResult processBasicBlock(llvm::BasicBlock *bb, Block *block); +# /// Imports `inst` and populates instMap[inst] with the imported Value. +# LogicalResult processInstruction(llvm::Instruction *inst); +# /// Creates an LLVM-compatible MLIR type for `type`. +# Type processType(llvm::Type *type); +# /// `value` is an SSA-use. Return the remapped version of `value` or a +# /// placeholder that will be remapped later if this is an instruction that +# /// has not yet been visited. +# Value processValue(llvm::Value *value); +# /// Create the most accurate Location possible using a llvm::DebugLoc and +# /// possibly an llvm::Instruction to narrow the Location if debug information +# /// is unavailable. +# Location processDebugLoc(const llvm::DebugLoc &loc, +# llvm::Instruction *inst = nullptr); +# /// `br` branches to `target`. Append the block arguments to attach to the +# /// generated branch op to `blockArguments`. These should be in the same order +# /// as the PHIs in `target`. +# LogicalResult processBranchArgs(llvm::Instruction *br, +# llvm::BasicBlock *target, +# SmallVectorImpl &blockArguments); +# /// Returns the builtin type equivalent to be used in attributes for the given +# /// LLVM IR dialect type. +# Type getStdTypeForAttr(Type type); +# /// Return `value` as an attribute to attach to a GlobalOp. +# Attribute getConstantAsAttr(llvm::Constant *value); +# /// Return `c` as an MLIR Value. This could either be a ConstantOp, or +# /// an expanded sequence of ops in the current function's entry block (for +# /// ConstantExprs or ConstantGEPs). +# Value processConstant(llvm::Constant *c); + +# /// The current builder, pointing at where the next Instruction should be +# /// generated. +# OpBuilder b; +# /// The current context. +# MLIRContext *context; +# /// The current module being created. +# ModuleOp module; +# /// The entry block of the current function being processed. +# Block *currentEntryBlock = nullptr; + +# /// Globals are inserted before the first function, if any. +# Block::iterator getGlobalInsertPt() { +# auto it = module.getBody()->begin(); +# auto endIt = module.getBody()->end(); +# while (it != endIt && !isa(it)) +# ++it; +# return it; +# } + +# /// Functions are always inserted before the module terminator. +# Block::iterator getFuncInsertPt() { +# return std::prev(module.getBody()->end()); +# } + +# /// Remapped blocks, for the current function. +# DenseMap blocks; +# /// Remapped values. These are function-local. +# DenseMap instMap; +# /// Instructions that had not been defined when first encountered as a use. +# /// Maps to the dummy Operation that was created in processValue(). +# DenseMap unknownInstMap; +# /// Uniquing map of GlobalVariables. +# DenseMap globals; +# /// Cached FileLineColLoc::get("imported-bitcode", 0, 0). +# Location unknownLoc; +# /// The stateful type translator (contains named structs). +# LLVM::TypeFromLLVMIRTranslator typeTranslator; +# }; +# } // namespace + +# Location Importer::processDebugLoc(const llvm::DebugLoc &loc, +# llvm::Instruction *inst) { +# if (!loc) +# return unknownLoc; + +# // FIXME: Obtain the filename from DILocationInfo. +# return FileLineColLoc::get(context, "imported-bitcode", loc.getLine(), +# loc.getCol()); +# } + +# Type Importer::processType(llvm::Type *type) { +# if (Type result = typeTranslator.translateType(type)) +# return result; + +# // FIXME: Diagnostic should be able to natively handle types that have +# // operator<<(raw_ostream&) defined. +# std::string s; +# llvm::raw_string_ostream os(s); +# os << *type; +# emitError(unknownLoc) << "unhandled type: " << os.str(); +# return nullptr; +# } + +# // We only need integers, floats, doubles, and vectors and tensors thereof for +# // attributes. Scalar and vector types are converted to the standard +# // equivalents. Array types are converted to ranked tensors; nested array types +# // are converted to multi-dimensional tensors or vectors, depending on the +# // innermost type being a scalar or a vector. +# Type Importer::getStdTypeForAttr(Type type) { +# if (!type) +# return nullptr; + +# if (type.isa()) +# return type; + +# // LLVM vectors can only contain scalars. +# if (LLVM::isCompatibleVectorType(type)) { +# auto numElements = LLVM::getVectorNumElements(type); +# if (numElements.isScalable()) { +# emitError(unknownLoc) << "scalable vectors not supported"; +# return nullptr; +# } +# Type elementType = getStdTypeForAttr(LLVM::getVectorElementType(type)); +# if (!elementType) +# return nullptr; +# return VectorType::get(numElements.getKnownMinValue(), elementType); +# } + +# // LLVM arrays can contain other arrays or vectors. +# if (auto arrayType = type.dyn_cast()) { +# // Recover the nested array shape. +# SmallVector shape; +# shape.push_back(arrayType.getNumElements()); +# while (arrayType.getElementType().isa()) { +# arrayType = arrayType.getElementType().cast(); +# shape.push_back(arrayType.getNumElements()); +# } + +# // If the innermost type is a vector, use the multi-dimensional vector as +# // attribute type. +# if (LLVM::isCompatibleVectorType(arrayType.getElementType())) { +# auto numElements = LLVM::getVectorNumElements(arrayType.getElementType()); +# if (numElements.isScalable()) { +# emitError(unknownLoc) << "scalable vectors not supported"; +# return nullptr; +# } +# shape.push_back(numElements.getKnownMinValue()); + +# Type elementType = getStdTypeForAttr( +# LLVM::getVectorElementType(arrayType.getElementType())); +# if (!elementType) +# return nullptr; +# return VectorType::get(shape, elementType); +# } + +# // Otherwise use a tensor. +# Type elementType = getStdTypeForAttr(arrayType.getElementType()); +# if (!elementType) +# return nullptr; +# return RankedTensorType::get(shape, elementType); +# } + +# return nullptr; +# } + +# // Get the given constant as an attribute. Not all constants can be represented +# // as attributes. +# Attribute Importer::getConstantAsAttr(llvm::Constant *value) { +# if (auto *ci = dyn_cast(value)) +# return b.getIntegerAttr( +# IntegerType::get(context, ci->getType()->getBitWidth()), +# ci->getValue()); +# if (auto *c = dyn_cast(value)) +# if (c->isString()) +# return b.getStringAttr(c->getAsString()); +# if (auto *c = dyn_cast(value)) { +# auto *type = c->getType(); +# FloatType floatTy; +# if (type->isBFloatTy()) +# floatTy = FloatType::getBF16(context); +# else +# floatTy = getDLFloatType(*context, type->getScalarSizeInBits()); +# assert(floatTy && "unsupported floating point type"); +# return b.getFloatAttr(floatTy, c->getValueAPF()); +# } +# if (auto *f = dyn_cast(value)) +# return SymbolRefAttr::get(b.getContext(), f->getName()); + +# // Convert constant data to a dense elements attribute. +# if (auto *cd = dyn_cast(value)) { +# Type type = processType(cd->getElementType()); +# if (!type) +# return nullptr; + +# auto attrType = getStdTypeForAttr(processType(cd->getType())) +# .dyn_cast_or_null(); +# if (!attrType) +# return nullptr; + +# if (type.isa()) { +# SmallVector values; +# values.reserve(cd->getNumElements()); +# for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) +# values.push_back(cd->getElementAsAPInt(i)); +# return DenseElementsAttr::get(attrType, values); +# } + +# if (type.isa()) { +# SmallVector values; +# values.reserve(cd->getNumElements()); +# for (unsigned i = 0, e = cd->getNumElements(); i < e; ++i) +# values.push_back(cd->getElementAsAPFloat(i)); +# return DenseElementsAttr::get(attrType, values); +# } + +# return nullptr; +# } + +# // Unpack constant aggregates to create dense elements attribute whenever +# // possible. Return nullptr (failure) otherwise. +# if (isa(value)) { +# auto outerType = getStdTypeForAttr(processType(value->getType())) +# .dyn_cast_or_null(); +# if (!outerType) +# return nullptr; + +# SmallVector values; +# SmallVector shape; + +# for (unsigned i = 0, e = value->getNumOperands(); i < e; ++i) { +# auto nested = getConstantAsAttr(value->getAggregateElement(i)) +# .dyn_cast_or_null(); +# if (!nested) +# return nullptr; + +# values.append(nested.value_begin(), +# nested.value_end()); +# } + +# return DenseElementsAttr::get(outerType, values); +# } + +# return nullptr; +# } + +# GlobalOp Importer::processGlobal(llvm::GlobalVariable *gv) { +# auto it = globals.find(gv); +# if (it != globals.end()) +# return it->second; + +# OpBuilder b(module.getBody(), getGlobalInsertPt()); +# Attribute valueAttr; +# if (gv->hasInitializer()) +# valueAttr = getConstantAsAttr(gv->getInitializer()); +# Type type = processType(gv->getValueType()); +# if (!type) +# return nullptr; + +# uint64_t alignment = 0; +# llvm::MaybeAlign maybeAlign = gv->getAlign(); +# if (maybeAlign.has_value()) { +# llvm::Align align = maybeAlign.value(); +# alignment = align.value(); +# } + +# GlobalOp op = b.create( +# UnknownLoc::get(context), type, gv->isConstant(), +# convertLinkageFromLLVM(gv->getLinkage()), gv->getName(), valueAttr, +# alignment, /*addr_space=*/gv->getAddressSpace(), +# /*dso_local=*/gv->isDSOLocal(), /*thread_local=*/gv->isThreadLocal()); + +# if (gv->hasInitializer() && !valueAttr) { +# Region &r = op.getInitializerRegion(); +# currentEntryBlock = b.createBlock(&r); +# b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); +# Value v = processConstant(gv->getInitializer()); +# if (!v) +# return nullptr; +# b.create(op.getLoc(), ArrayRef({v})); +# } +# if (gv->hasAtLeastLocalUnnamedAddr()) +# op.setUnnamedAddrAttr(UnnamedAddrAttr::get( +# context, convertUnnamedAddrFromLLVM(gv->getUnnamedAddr()))); +# if (gv->hasSection()) +# op.setSectionAttr(b.getStringAttr(gv->getSection())); + +# return globals[gv] = op; +# } + +# Value Importer::processConstant(llvm::Constant *c) { +# OpBuilder bEntry(currentEntryBlock, currentEntryBlock->begin()); +# if (Attribute attr = getConstantAsAttr(c)) { +# // These constants can be represented as attributes. +# OpBuilder b(currentEntryBlock, currentEntryBlock->begin()); +# Type type = processType(c->getType()); +# if (!type) +# return nullptr; +# if (auto symbolRef = attr.dyn_cast()) +# return bEntry.create(unknownLoc, type, symbolRef.getValue()); +# return bEntry.create(unknownLoc, type, attr); +# } +# if (auto *cn = dyn_cast(c)) { +# Type type = processType(cn->getType()); +# if (!type) +# return nullptr; +# return bEntry.create(unknownLoc, type); +# } +# if (auto *gv = dyn_cast(c)) +# return bEntry.create(UnknownLoc::get(context), +# processGlobal(gv)); + +# if (auto *ce = dyn_cast(c)) { +# llvm::Instruction *i = ce->getAsInstruction(); +# OpBuilder::InsertionGuard guard(b); +# b.setInsertionPoint(currentEntryBlock, currentEntryBlock->begin()); +# if (failed(processInstruction(i))) +# return nullptr; +# assert(instMap.count(i)); + +# // If we don't remove entry of `i` here, it's totally possible that the +# // next time llvm::ConstantExpr::getAsInstruction is called again, which +# // always allocates a new Instruction, memory address of the newly +# // created Instruction might be the same as `i`. Making processInstruction +# // falsely believe that the new Instruction has been processed before +# // and raised an assertion error. +# Value value = instMap[i]; +# instMap.erase(i); +# // Remove this zombie LLVM instruction now, leaving us only with the MLIR +# // op. +# i->deleteValue(); +# return value; +# } +# if (auto *ue = dyn_cast(c)) { +# Type type = processType(ue->getType()); +# if (!type) +# return nullptr; +# return bEntry.create(UnknownLoc::get(context), type); +# } + +# if (isa(c) || isa(c)) { +# unsigned numElements = c->getNumOperands(); +# std::function getElement = +# [&](unsigned index) -> llvm::Constant * { +# return c->getAggregateElement(index); +# }; +# // llvm::ConstantAggregateZero doesn't take any operand +# // so its getNumOperands is always zero. +# if (auto *caz = dyn_cast(c)) { +# numElements = caz->getElementCount().getFixedValue(); +# // We want to capture the pointer rather than reference +# // to the pointer since the latter will become dangling upon +# // exiting the scope. +# getElement = [=](unsigned index) -> llvm::Constant * { +# return caz->getElementValue(index); +# }; +# } + +# // Generate a llvm.undef as the root value first. +# Type rootType = processType(c->getType()); +# if (!rootType) +# return nullptr; +# bool useInsertValue = rootType.isa(); +# assert((useInsertValue || LLVM::isCompatibleVectorType(rootType)) && +# "unrecognized aggregate type"); +# Value root = bEntry.create(unknownLoc, rootType); +# for (unsigned i = 0; i < numElements; ++i) { +# llvm::Constant *element = getElement(i); +# Value elementValue = processConstant(element); +# if (!elementValue) +# return nullptr; +# if (useInsertValue) { +# root = bEntry.create(UnknownLoc::get(context), root, +# elementValue, i); +# } else { +# Attribute indexAttr = bEntry.getI32IntegerAttr(static_cast(i)); +# Value indexValue = bEntry.create( +# unknownLoc, bEntry.getI32Type(), indexAttr); +# if (!indexValue) +# return nullptr; +# root = bEntry.create( +# UnknownLoc::get(context), rootType, root, elementValue, indexValue); +# } +# } +# return root; +# } + +# emitError(unknownLoc) << "unhandled constant: " << diag(*c); +# return nullptr; +# } + +# Value Importer::processValue(llvm::Value *value) { +# auto it = instMap.find(value); +# if (it != instMap.end()) +# return it->second; + +# // We don't expect to see instructions in dominator order. If we haven't seen +# // this instruction yet, create an unknown op and remap it later. +# if (isa(value)) { +# Type type = processType(value->getType()); +# if (!type) +# return nullptr; +# unknownInstMap[value] = +# b.create(UnknownLoc::get(context), b.getStringAttr("llvm.unknown"), +# /*operands=*/{}, type); +# return unknownInstMap[value]->getResult(0); +# } + +# if (auto *c = dyn_cast(value)) +# return processConstant(c); + +# emitError(unknownLoc) << "unhandled value: " << diag(*value); +# return nullptr; +# } + +# /// Return the MLIR OperationName for the given LLVM opcode. +# static StringRef lookupOperationNameFromOpcode(unsigned opcode) { +# // Maps from LLVM opcode to MLIR OperationName. This is deliberately ordered +# // as in llvm/IR/Instructions.def to aid comprehension and spot missing +# // instructions. +# #define INST(llvm_n, mlir_n) \ +# { llvm::Instruction::llvm_n, LLVM::mlir_n##Op::getOperationName() } +# static const DenseMap opcMap = { +# // clang-format off +# INST(Ret, Return), +# // Br is handled specially. +# // Switch is handled specially. +# // FIXME: indirectbr +# // Invoke is handled specially. +# INST(Resume, Resume), +# INST(Unreachable, Unreachable), +# // FIXME: cleanupret +# // FIXME: catchret +# // FIXME: catchswitch +# // FIXME: callbr +# INST(FNeg, FNeg), +# INST(Add, Add), +# INST(FAdd, FAdd), +# INST(Sub, Sub), +# INST(FSub, FSub), +# INST(Mul, Mul), +# INST(FMul, FMul), +# INST(UDiv, UDiv), +# INST(SDiv, SDiv), +# INST(FDiv, FDiv), +# INST(URem, URem), +# INST(SRem, SRem), +# INST(FRem, FRem), +# INST(Shl, Shl), +# INST(LShr, LShr), +# INST(AShr, AShr), +# INST(And, And), +# INST(Or, Or), +# INST(Xor, XOr), +# INST(ExtractElement, ExtractElement), +# INST(InsertElement, InsertElement), +# // ShuffleVector is handled specially. +# // ExtractValue is handled specially. +# // InsertValue is handled specially. +# INST(Alloca, Alloca), +# INST(Load, Load), +# INST(Store, Store), +# INST(Fence, Fence), +# // AtomicCmpXchg is handled specially. +# // AtomicRMW is handled specially. +# // Getelementptr is handled specially. +# INST(Trunc, Trunc), +# INST(ZExt, ZExt), +# INST(SExt, SExt), +# INST(FPToUI, FPToUI), +# INST(FPToSI, FPToSI), +# INST(UIToFP, UIToFP), +# INST(SIToFP, SIToFP), +# INST(FPTrunc, FPTrunc), +# INST(FPExt, FPExt), +# INST(PtrToInt, PtrToInt), +# INST(IntToPtr, IntToPtr), +# INST(BitCast, Bitcast), +# INST(AddrSpaceCast, AddrSpaceCast), +# // ICmp is handled specially. +# // FCmp is handled specially. +# // PHI is handled specially. +# INST(Select, Select), +# INST(Freeze, Freeze), +# INST(Call, Call), +# // FIXME: vaarg +# // FIXME: landingpad +# // FIXME: catchpad +# // FIXME: cleanuppad +# // clang-format on +# }; +# #undef INST + +# return opcMap.lookup(opcode); +# } + +# /// Return the MLIR OperationName for the given LLVM intrinsic ID. +# static StringRef lookupOperationNameFromIntrinsicID(unsigned id) { +# // Maps from LLVM intrinsic ID to MLIR OperationName. +# static const DenseMap intrMap = { +# #include "mlir/Dialect/LLVMIR/LLVMIntrinsicToLLVMIROpPairs.inc" +# }; +# return intrMap.lookup(id); +# } + +# static ICmpPredicate getICmpPredicate(llvm::CmpInst::Predicate p) { +# switch (p) { +# default: +# llvm_unreachable("incorrect comparison predicate"); +# case llvm::CmpInst::Predicate::ICMP_EQ: +# return LLVM::ICmpPredicate::eq; +# case llvm::CmpInst::Predicate::ICMP_NE: +# return LLVM::ICmpPredicate::ne; +# case llvm::CmpInst::Predicate::ICMP_SLT: +# return LLVM::ICmpPredicate::slt; +# case llvm::CmpInst::Predicate::ICMP_SLE: +# return LLVM::ICmpPredicate::sle; +# case llvm::CmpInst::Predicate::ICMP_SGT: +# return LLVM::ICmpPredicate::sgt; +# case llvm::CmpInst::Predicate::ICMP_SGE: +# return LLVM::ICmpPredicate::sge; +# case llvm::CmpInst::Predicate::ICMP_ULT: +# return LLVM::ICmpPredicate::ult; +# case llvm::CmpInst::Predicate::ICMP_ULE: +# return LLVM::ICmpPredicate::ule; +# case llvm::CmpInst::Predicate::ICMP_UGT: +# return LLVM::ICmpPredicate::ugt; +# case llvm::CmpInst::Predicate::ICMP_UGE: +# return LLVM::ICmpPredicate::uge; +# } +# llvm_unreachable("incorrect integer comparison predicate"); +# } + +# static FCmpPredicate getFCmpPredicate(llvm::CmpInst::Predicate p) { +# switch (p) { +# default: +# llvm_unreachable("incorrect comparison predicate"); +# case llvm::CmpInst::Predicate::FCMP_FALSE: +# return LLVM::FCmpPredicate::_false; +# case llvm::CmpInst::Predicate::FCMP_TRUE: +# return LLVM::FCmpPredicate::_true; +# case llvm::CmpInst::Predicate::FCMP_OEQ: +# return LLVM::FCmpPredicate::oeq; +# case llvm::CmpInst::Predicate::FCMP_ONE: +# return LLVM::FCmpPredicate::one; +# case llvm::CmpInst::Predicate::FCMP_OLT: +# return LLVM::FCmpPredicate::olt; +# case llvm::CmpInst::Predicate::FCMP_OLE: +# return LLVM::FCmpPredicate::ole; +# case llvm::CmpInst::Predicate::FCMP_OGT: +# return LLVM::FCmpPredicate::ogt; +# case llvm::CmpInst::Predicate::FCMP_OGE: +# return LLVM::FCmpPredicate::oge; +# case llvm::CmpInst::Predicate::FCMP_ORD: +# return LLVM::FCmpPredicate::ord; +# case llvm::CmpInst::Predicate::FCMP_ULT: +# return LLVM::FCmpPredicate::ult; +# case llvm::CmpInst::Predicate::FCMP_ULE: +# return LLVM::FCmpPredicate::ule; +# case llvm::CmpInst::Predicate::FCMP_UGT: +# return LLVM::FCmpPredicate::ugt; +# case llvm::CmpInst::Predicate::FCMP_UGE: +# return LLVM::FCmpPredicate::uge; +# case llvm::CmpInst::Predicate::FCMP_UNO: +# return LLVM::FCmpPredicate::uno; +# case llvm::CmpInst::Predicate::FCMP_UEQ: +# return LLVM::FCmpPredicate::ueq; +# case llvm::CmpInst::Predicate::FCMP_UNE: +# return LLVM::FCmpPredicate::une; +# } +# llvm_unreachable("incorrect floating point comparison predicate"); +# } + +# static AtomicOrdering getLLVMAtomicOrdering(llvm::AtomicOrdering ordering) { +# switch (ordering) { +# case llvm::AtomicOrdering::NotAtomic: +# return LLVM::AtomicOrdering::not_atomic; +# case llvm::AtomicOrdering::Unordered: +# return LLVM::AtomicOrdering::unordered; +# case llvm::AtomicOrdering::Monotonic: +# return LLVM::AtomicOrdering::monotonic; +# case llvm::AtomicOrdering::Acquire: +# return LLVM::AtomicOrdering::acquire; +# case llvm::AtomicOrdering::Release: +# return LLVM::AtomicOrdering::release; +# case llvm::AtomicOrdering::AcquireRelease: +# return LLVM::AtomicOrdering::acq_rel; +# case llvm::AtomicOrdering::SequentiallyConsistent: +# return LLVM::AtomicOrdering::seq_cst; +# } +# llvm_unreachable("incorrect atomic ordering"); +# } + +# static AtomicBinOp getLLVMAtomicBinOp(llvm::AtomicRMWInst::BinOp binOp) { +# switch (binOp) { +# case llvm::AtomicRMWInst::Xchg: +# return LLVM::AtomicBinOp::xchg; +# case llvm::AtomicRMWInst::Add: +# return LLVM::AtomicBinOp::add; +# case llvm::AtomicRMWInst::Sub: +# return LLVM::AtomicBinOp::sub; +# case llvm::AtomicRMWInst::And: +# return LLVM::AtomicBinOp::_and; +# case llvm::AtomicRMWInst::Nand: +# return LLVM::AtomicBinOp::nand; +# case llvm::AtomicRMWInst::Or: +# return LLVM::AtomicBinOp::_or; +# case llvm::AtomicRMWInst::Xor: +# return LLVM::AtomicBinOp::_xor; +# case llvm::AtomicRMWInst::Max: +# return LLVM::AtomicBinOp::max; +# case llvm::AtomicRMWInst::Min: +# return LLVM::AtomicBinOp::min; +# case llvm::AtomicRMWInst::UMax: +# return LLVM::AtomicBinOp::umax; +# case llvm::AtomicRMWInst::UMin: +# return LLVM::AtomicBinOp::umin; +# case llvm::AtomicRMWInst::FAdd: +# return LLVM::AtomicBinOp::fadd; +# case llvm::AtomicRMWInst::FSub: +# return LLVM::AtomicBinOp::fsub; +# default: +# llvm_unreachable("unsupported atomic binary operation"); +# } +# } + +# // `br` branches to `target`. Return the branch arguments to `br`, in the +# // same order of the PHIs in `target`. +# LogicalResult +# Importer::processBranchArgs(llvm::Instruction *br, llvm::BasicBlock *target, +# SmallVectorImpl &blockArguments) { +# for (auto inst = target->begin(); isa(inst); ++inst) { +# auto *pn = cast(&*inst); +# Value value = processValue(pn->getIncomingValueForBlock(br->getParent())); +# if (!value) +# return failure(); +# blockArguments.push_back(value); +# } +# return success(); +# } + +# LogicalResult Importer::processInstruction(llvm::Instruction *inst) { +# // FIXME: Support uses of SubtargetData. Currently inbounds GEPs, fast-math +# // flags and call / operand attributes are not supported. +# Location loc = processDebugLoc(inst->getDebugLoc(), inst); +# assert(!instMap.count(inst) && +# "processInstruction must be called only once per instruction!"); +# switch (inst->getOpcode()) { +# default: +# return emitError(loc) << "unknown instruction: " << diag(*inst); +# case llvm::Instruction::Add: +# case llvm::Instruction::FAdd: +# case llvm::Instruction::Sub: +# case llvm::Instruction::FSub: +# case llvm::Instruction::Mul: +# case llvm::Instruction::FMul: +# case llvm::Instruction::UDiv: +# case llvm::Instruction::SDiv: +# case llvm::Instruction::FDiv: +# case llvm::Instruction::URem: +# case llvm::Instruction::SRem: +# case llvm::Instruction::FRem: +# case llvm::Instruction::Shl: +# case llvm::Instruction::LShr: +# case llvm::Instruction::AShr: +# case llvm::Instruction::And: +# case llvm::Instruction::Or: +# case llvm::Instruction::Xor: +# case llvm::Instruction::Load: +# case llvm::Instruction::Store: +# case llvm::Instruction::Ret: +# case llvm::Instruction::Resume: +# case llvm::Instruction::Trunc: +# case llvm::Instruction::ZExt: +# case llvm::Instruction::SExt: +# case llvm::Instruction::FPToUI: +# case llvm::Instruction::FPToSI: +# case llvm::Instruction::UIToFP: +# case llvm::Instruction::SIToFP: +# case llvm::Instruction::FPTrunc: +# case llvm::Instruction::FPExt: +# case llvm::Instruction::PtrToInt: +# case llvm::Instruction::IntToPtr: +# case llvm::Instruction::AddrSpaceCast: +# case llvm::Instruction::Freeze: +# case llvm::Instruction::BitCast: +# case llvm::Instruction::ExtractElement: +# case llvm::Instruction::InsertElement: +# case llvm::Instruction::Select: +# case llvm::Instruction::FNeg: +# case llvm::Instruction::Unreachable: { +# OperationState state(loc, lookupOperationNameFromOpcode(inst->getOpcode())); +# SmallVector ops; +# ops.reserve(inst->getNumOperands()); +# for (auto *op : inst->operand_values()) { +# Value value = processValue(op); +# if (!value) +# return failure(); +# ops.push_back(value); +# } +# state.addOperands(ops); +# if (!inst->getType()->isVoidTy()) { +# Type type = processType(inst->getType()); +# if (!type) +# return failure(); +# state.addTypes(type); +# } +# Operation *op = b.create(state); +# if (!inst->getType()->isVoidTy()) +# instMap[inst] = op->getResult(0); +# return success(); +# } +# case llvm::Instruction::Alloca: { +# Value size = processValue(inst->getOperand(0)); +# if (!size) +# return failure(); + +# auto *allocaInst = cast(inst); +# instMap[inst] = +# b.create(loc, processType(inst->getType()), +# processType(allocaInst->getAllocatedType()), size, +# allocaInst->getAlign().value()); +# return success(); +# } +# case llvm::Instruction::ICmp: { +# Value lhs = processValue(inst->getOperand(0)); +# Value rhs = processValue(inst->getOperand(1)); +# if (!lhs || !rhs) +# return failure(); +# instMap[inst] = b.create( +# loc, getICmpPredicate(cast(inst)->getPredicate()), lhs, +# rhs); +# return success(); +# } +# case llvm::Instruction::FCmp: { +# Value lhs = processValue(inst->getOperand(0)); +# Value rhs = processValue(inst->getOperand(1)); +# if (!lhs || !rhs) +# return failure(); + +# if (lhs.getType() != rhs.getType()) +# return failure(); + +# Type boolType = b.getI1Type(); +# Type resType = boolType; +# if (LLVM::isCompatibleVectorType(lhs.getType())) { +# unsigned numElements = +# LLVM::getVectorNumElements(lhs.getType()).getFixedValue(); +# resType = VectorType::get({numElements}, boolType); +# } + +# instMap[inst] = b.create( +# loc, resType, +# getFCmpPredicate(cast(inst)->getPredicate()), lhs, rhs); +# return success(); +# } +# case llvm::Instruction::Br: { +# auto *brInst = cast(inst); +# OperationState state(loc, +# brInst->isConditional() ? "llvm.cond_br" : "llvm.br"); +# if (brInst->isConditional()) { +# Value condition = processValue(brInst->getCondition()); +# if (!condition) +# return failure(); +# state.addOperands(condition); +# } + +# std::array operandSegmentSizes = {1, 0, 0}; +# for (int i : llvm::seq(0, brInst->getNumSuccessors())) { +# auto *succ = brInst->getSuccessor(i); +# SmallVector blockArguments; +# if (failed(processBranchArgs(brInst, succ, blockArguments))) +# return failure(); +# state.addSuccessors(blocks[succ]); +# state.addOperands(blockArguments); +# operandSegmentSizes[i + 1] = blockArguments.size(); +# } + +# if (brInst->isConditional()) { +# state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(), +# b.getDenseI32ArrayAttr(operandSegmentSizes)); +# } + +# b.create(state); +# return success(); +# } +# case llvm::Instruction::Switch: { +# auto *swInst = cast(inst); +# // Process the condition value. +# Value condition = processValue(swInst->getCondition()); +# if (!condition) +# return failure(); + +# SmallVector defaultBlockArgs; +# // Process the default case. +# llvm::BasicBlock *defaultBB = swInst->getDefaultDest(); +# if (failed(processBranchArgs(swInst, defaultBB, defaultBlockArgs))) +# return failure(); + +# // Process the cases. +# unsigned numCases = swInst->getNumCases(); +# SmallVector> caseOperands(numCases); +# SmallVector caseOperandRefs(numCases); +# SmallVector caseValues(numCases); +# SmallVector caseBlocks(numCases); +# for (const auto &en : llvm::enumerate(swInst->cases())) { +# const llvm::SwitchInst::CaseHandle &caseHandle = en.value(); +# unsigned i = en.index(); +# llvm::BasicBlock *succBB = caseHandle.getCaseSuccessor(); +# if (failed(processBranchArgs(swInst, succBB, caseOperands[i]))) +# return failure(); +# caseOperandRefs[i] = caseOperands[i]; +# caseValues[i] = caseHandle.getCaseValue()->getSExtValue(); +# caseBlocks[i] = blocks[succBB]; +# } + +# b.create(loc, condition, blocks[defaultBB], defaultBlockArgs, +# caseValues, caseBlocks, caseOperandRefs); +# return success(); +# } +# case llvm::Instruction::PHI: { +# Type type = processType(inst->getType()); +# if (!type) +# return failure(); +# instMap[inst] = b.getInsertionBlock()->addArgument( +# type, processDebugLoc(inst->getDebugLoc(), inst)); +# return success(); +# } +# case llvm::Instruction::Call: { +# llvm::CallInst *ci = cast(inst); +# SmallVector ops; +# ops.reserve(inst->getNumOperands()); +# for (auto &op : ci->args()) { +# Value arg = processValue(op.get()); +# if (!arg) +# return failure(); +# ops.push_back(arg); +# } + +# SmallVector tys; +# if (!ci->getType()->isVoidTy()) { +# Type type = processType(inst->getType()); +# if (!type) +# return failure(); +# tys.push_back(type); +# } +# Operation *op; +# if (llvm::Function *callee = ci->getCalledFunction()) { +# // For all intrinsics, try to generate to the corresponding op. +# if (callee->isIntrinsic()) { +# auto id = callee->getIntrinsicID(); +# StringRef opName = lookupOperationNameFromIntrinsicID(id); +# if (!opName.empty()) { +# OperationState state(loc, opName); +# state.addOperands(ops); +# state.addTypes(tys); +# Operation *op = b.create(state); +# if (!inst->getType()->isVoidTy()) +# instMap[inst] = op->getResult(0); +# return success(); +# } +# } +# op = b.create( +# loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops); +# } else { +# Value calledValue = processValue(ci->getCalledOperand()); +# if (!calledValue) +# return failure(); +# ops.insert(ops.begin(), calledValue); +# op = b.create(loc, tys, ops); +# } +# if (!ci->getType()->isVoidTy()) +# instMap[inst] = op->getResult(0); +# return success(); +# } +# case llvm::Instruction::LandingPad: { +# llvm::LandingPadInst *lpi = cast(inst); +# SmallVector ops; + +# for (unsigned i = 0, ie = lpi->getNumClauses(); i < ie; i++) +# ops.push_back(processConstant(lpi->getClause(i))); + +# Type ty = processType(lpi->getType()); +# if (!ty) +# return failure(); + +# instMap[inst] = b.create(loc, ty, lpi->isCleanup(), ops); +# return success(); +# } +# case llvm::Instruction::Invoke: { +# llvm::InvokeInst *ii = cast(inst); + +# SmallVector tys; +# if (!ii->getType()->isVoidTy()) +# tys.push_back(processType(inst->getType())); + +# SmallVector ops; +# ops.reserve(inst->getNumOperands() + 1); +# for (auto &op : ii->args()) +# ops.push_back(processValue(op.get())); + +# SmallVector normalArgs, unwindArgs; +# (void)processBranchArgs(ii, ii->getNormalDest(), normalArgs); +# (void)processBranchArgs(ii, ii->getUnwindDest(), unwindArgs); + +# Operation *op; +# if (llvm::Function *callee = ii->getCalledFunction()) { +# op = b.create( +# loc, tys, SymbolRefAttr::get(b.getContext(), callee->getName()), ops, +# blocks[ii->getNormalDest()], normalArgs, blocks[ii->getUnwindDest()], +# unwindArgs); +# } else { +# ops.insert(ops.begin(), processValue(ii->getCalledOperand())); +# op = b.create(loc, tys, ops, blocks[ii->getNormalDest()], +# normalArgs, blocks[ii->getUnwindDest()], +# unwindArgs); +# } + +# if (!ii->getType()->isVoidTy()) +# instMap[inst] = op->getResult(0); +# return success(); +# } +# case llvm::Instruction::Fence: { +# StringRef syncscope; +# SmallVector ssNs; +# llvm::LLVMContext &llvmContext = inst->getContext(); +# llvm::FenceInst *fence = cast(inst); +# llvmContext.getSyncScopeNames(ssNs); +# int fenceSyncScopeID = fence->getSyncScopeID(); +# for (unsigned i = 0, e = ssNs.size(); i != e; i++) { +# if (fenceSyncScopeID == llvmContext.getOrInsertSyncScopeID(ssNs[i])) { +# syncscope = ssNs[i]; +# break; +# } +# } +# b.create(loc, getLLVMAtomicOrdering(fence->getOrdering()), +# syncscope); +# return success(); +# } +# case llvm::Instruction::AtomicRMW: { +# auto *atomicInst = cast(inst); +# Value ptr = processValue(atomicInst->getPointerOperand()); +# Value val = processValue(atomicInst->getValOperand()); +# if (!ptr || !val) +# return failure(); + +# LLVM::AtomicBinOp binOp = getLLVMAtomicBinOp(atomicInst->getOperation()); +# LLVM::AtomicOrdering ordering = +# getLLVMAtomicOrdering(atomicInst->getOrdering()); + +# Type type = processType(inst->getType()); +# if (!type) +# return failure(); + +# instMap[inst] = b.create(loc, type, binOp, ptr, val, ordering); +# return success(); +# } +# case llvm::Instruction::AtomicCmpXchg: { +# auto *cmpXchgInst = cast(inst); +# Value ptr = processValue(cmpXchgInst->getPointerOperand()); +# Value cmpVal = processValue(cmpXchgInst->getCompareOperand()); +# Value newVal = processValue(cmpXchgInst->getNewValOperand()); +# if (!ptr || !cmpVal || !newVal) +# return failure(); + +# LLVM::AtomicOrdering ordering = +# getLLVMAtomicOrdering(cmpXchgInst->getSuccessOrdering()); +# LLVM::AtomicOrdering failOrdering = +# getLLVMAtomicOrdering(cmpXchgInst->getFailureOrdering()); + +# Type type = processType(inst->getType()); +# if (!type) +# return failure(); + +# instMap[inst] = b.create(loc, type, ptr, cmpVal, newVal, +# ordering, failOrdering); +# return success(); +# } +# case llvm::Instruction::GetElementPtr: { +# // FIXME: Support inbounds GEPs. +# llvm::GetElementPtrInst *gep = cast(inst); +# Value basePtr = processValue(gep->getOperand(0)); +# Type sourceElementType = processType(gep->getSourceElementType()); + +# // Treat every indices as dynamic since GEPOp::build will refine those +# // indices into static attributes later. One small downside of this +# // approach is that many unused `llvm.mlir.constant` would be emitted +# // at first place. +# SmallVector indices; +# for (llvm::Value *operand : llvm::drop_begin(gep->operand_values())) { +# Value val = processValue(operand); +# if (!val) +# return failure(); +# indices.push_back(val); +# } + +# Type type = processType(inst->getType()); +# if (!type) +# return failure(); +# instMap[inst] = +# b.create(loc, type, sourceElementType, basePtr, indices); +# return success(); +# } +# case llvm::Instruction::InsertValue: { +# auto *ivInst = cast(inst); +# Value inserted = processValue(ivInst->getInsertedValueOperand()); +# if (!inserted) +# return failure(); +# Value aggOperand = processValue(ivInst->getAggregateOperand()); +# if (!aggOperand) +# return failure(); + +# SmallVector indices; +# llvm::append_range(indices, ivInst->getIndices()); +# instMap[inst] = b.create(loc, aggOperand, inserted, indices); +# return success(); +# } +# case llvm::Instruction::ExtractValue: { +# auto *evInst = cast(inst); +# Value aggOperand = processValue(evInst->getAggregateOperand()); +# if (!aggOperand) +# return failure(); + +# Type type = processType(inst->getType()); +# if (!type) +# return failure(); + +# SmallVector indices; +# llvm::append_range(indices, evInst->getIndices()); +# instMap[inst] = b.create(loc, aggOperand, indices); +# return success(); +# } +# case llvm::Instruction::ShuffleVector: { +# auto *svInst = cast(inst); +# Value vec1 = processValue(svInst->getOperand(0)); +# if (!vec1) +# return failure(); +# Value vec2 = processValue(svInst->getOperand(1)); +# if (!vec2) +# return failure(); + +# SmallVector mask(svInst->getShuffleMask()); +# instMap[inst] = b.create(loc, vec1, vec2, mask); +# return success(); +# } +# } +# } + +# FlatSymbolRefAttr Importer::getPersonalityAsAttr(llvm::Function *f) { +# if (!f->hasPersonalityFn()) +# return nullptr; + +# llvm::Constant *pf = f->getPersonalityFn(); + +# // If it directly has a name, we can use it. +# if (pf->hasName()) +# return SymbolRefAttr::get(b.getContext(), pf->getName()); + +# // If it doesn't have a name, currently, only function pointers that are +# // bitcast to i8* are parsed. +# if (auto *ce = dyn_cast(pf)) { +# if (ce->getOpcode() == llvm::Instruction::BitCast && +# ce->getType() == llvm::Type::getInt8PtrTy(f->getContext())) { +# if (auto *func = dyn_cast(ce->getOperand(0))) +# return SymbolRefAttr::get(b.getContext(), func->getName()); +# } +# } +# return FlatSymbolRefAttr(); +# } + +# void Importer::processFunctionAttributes(llvm::Function *func, +# LLVMFuncOp funcOp) { +# auto addNamedUnitAttr = [&](StringRef name) { +# return funcOp->setAttr(name, UnitAttr::get(context)); +# }; +# if (func->hasFnAttribute(llvm::Attribute::ReadNone)) +# addNamedUnitAttr(LLVMDialect::getReadnoneAttrName()); +# } + +# LogicalResult Importer::processFunction(llvm::Function *f) { +# blocks.clear(); +# instMap.clear(); +# unknownInstMap.clear(); + +# auto functionType = +# processType(f->getFunctionType()).dyn_cast(); +# if (!functionType) +# return failure(); + +# if (f->isIntrinsic()) { +# StringRef opName = lookupOperationNameFromIntrinsicID(f->getIntrinsicID()); +# // Skip the intrinsic decleration if we could found a corresponding op. +# if (!opName.empty()) +# return success(); +# } + +# bool dsoLocal = f->hasLocalLinkage(); +# CConv cconv = convertCConvFromLLVM(f->getCallingConv()); + +# b.setInsertionPoint(module.getBody(), getFuncInsertPt()); +# LLVMFuncOp fop = b.create( +# UnknownLoc::get(context), f->getName(), functionType, +# convertLinkageFromLLVM(f->getLinkage()), dsoLocal, cconv); + +# for (const auto &arg : llvm::enumerate(functionType.getParams())) { +# llvm::SmallVector argAttrs; +# if (auto *type = f->getParamByValType(arg.index())) { +# auto mlirType = processType(type); +# argAttrs.push_back( +# NamedAttribute(b.getStringAttr(LLVMDialect::getByValAttrName()), +# TypeAttr::get(mlirType))); +# } +# if (auto *type = f->getParamByRefType(arg.index())) { +# auto mlirType = processType(type); +# argAttrs.push_back( +# NamedAttribute(b.getStringAttr(LLVMDialect::getByRefAttrName()), +# TypeAttr::get(mlirType))); +# } +# if (auto *type = f->getParamStructRetType(arg.index())) { +# auto mlirType = processType(type); +# argAttrs.push_back( +# NamedAttribute(b.getStringAttr(LLVMDialect::getStructRetAttrName()), +# TypeAttr::get(mlirType))); +# } +# if (auto *type = f->getParamInAllocaType(arg.index())) { +# auto mlirType = processType(type); +# argAttrs.push_back( +# NamedAttribute(b.getStringAttr(LLVMDialect::getInAllocaAttrName()), +# TypeAttr::get(mlirType))); +# } + +# fop.setArgAttrs(arg.index(), argAttrs); +# } + +# if (FlatSymbolRefAttr personality = getPersonalityAsAttr(f)) +# fop->setAttr(b.getStringAttr("personality"), personality); +# else if (f->hasPersonalityFn()) +# emitWarning(UnknownLoc::get(context), +# "could not deduce personality, skipping it"); + +# if (f->hasGC()) +# fop.setGarbageCollectorAttr(b.getStringAttr(f->getGC())); + +# // Handle Function attributes. +# processFunctionAttributes(f, fop); + +# if (f->isDeclaration()) +# return success(); + +# // Eagerly create all blocks. +# SmallVector blockList; +# for (llvm::BasicBlock &bb : *f) { +# blockList.push_back(b.createBlock(&fop.getBody(), fop.getBody().end())); +# blocks[&bb] = blockList.back(); +# } +# currentEntryBlock = blockList[0]; + +# // Add function arguments to the entry block. +# for (const auto &kv : llvm::enumerate(f->args())) { +# instMap[&kv.value()] = blockList[0]->addArgument( +# functionType.getParamType(kv.index()), fop.getLoc()); +# } + +# for (auto bbs : llvm::zip(*f, blockList)) { +# if (failed(processBasicBlock(&std::get<0>(bbs), std::get<1>(bbs)))) +# return failure(); +# } + +# // Now that all instructions are guaranteed to have been visited, ensure +# // any unknown uses we encountered are remapped. +# for (auto &llvmAndUnknown : unknownInstMap) { +# assert(instMap.count(llvmAndUnknown.first)); +# Value newValue = instMap[llvmAndUnknown.first]; +# Value oldValue = llvmAndUnknown.second->getResult(0); +# oldValue.replaceAllUsesWith(newValue); +# llvmAndUnknown.second->erase(); +# } +# return success(); +# } + +# LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) { +# b.setInsertionPointToStart(block); +# for (llvm::Instruction &inst : *bb) { +# if (failed(processInstruction(&inst))) +# return failure(); +# } +# return success(); +# } + +# OwningOpRef +# mlir::translateLLVMIRToModule(std::unique_ptr llvmModule, +# MLIRContext *context) { +# context->loadDialect(); +# context->loadDialect(); +# OwningOpRef module(ModuleOp::create( +# FileLineColLoc::get(context, "", /*line=*/0, /*column=*/0))); + +# DataLayoutSpecInterface dlSpec = +# translateDataLayout(llvmModule->getDataLayout(), context); +# if (!dlSpec) { +# emitError(UnknownLoc::get(context), "can't translate data layout"); +# return {}; +# } + +# module.get()->setAttr(DLTIDialect::kDataLayoutAttrName, dlSpec); + +# Importer deserializer(context, module.get()); +# for (llvm::GlobalVariable &gv : llvmModule->globals()) { +# if (!deserializer.processGlobal(&gv)) +# return {}; +# } +# for (llvm::Function &f : llvmModule->functions()) { +# if (failed(deserializer.processFunction(&f))) +# return {}; +# } + +# return module; +# } + + // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the + // LLVM dialect. + OwningOpRef translateLLVMIRToModule(llvm::SourceMgr &sourceMgr, + MLIRContext *context) { + llvm::SMDiagnostic err; + llvm::LLVMContext llvmContext; + std::unique_ptr llvmModule = llvm::parseIR( + *sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), err, llvmContext); + if (!llvmModule) { + std::string errStr; + llvm::raw_string_ostream errStream(errStr); + err.print(/*ProgName=*/"", errStream); + emitError(UnknownLoc::get(context)) << errStream.str(); + return {}; + } + return translateLLVMIRToModule(std::move(llvmModule), context); + } + + namespace mlir { + void registerFromLLVMIRTranslation() { + TranslateToMLIRRegistration fromLLVM( +- "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { ++ "import-llvm", "from llvm to mlir", ++ [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { + return ::translateLLVMIRToModule(sourceMgr, context); + }); + } + } // namespace mlir +diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +index 70f86d2..1ae7b83 100644 +--- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp ++++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +@@ -1,42 +1,42 @@ + //===- ConvertToLLVMIR.cpp - MLIR to LLVM IR conversion -------------------===// + // + // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + // See https://llvm.org/LICENSE.txt for license information. + // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // + //===----------------------------------------------------------------------===// + // + // This file implements a translation between the MLIR LLVM dialect and LLVM IR. + // + //===----------------------------------------------------------------------===// + + #include "mlir/Dialect/DLTI/DLTI.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" + #include "mlir/IR/BuiltinOps.h" + #include "mlir/Target/LLVMIR/Dialect/All.h" + #include "mlir/Target/LLVMIR/Export.h" + #include "mlir/Tools/mlir-translate/Translation.h" + #include "llvm/IR/LLVMContext.h" + #include "llvm/IR/Module.h" + + using namespace mlir; + + namespace mlir { + void registerToLLVMIRTranslation() { + TranslateFromMLIRRegistration registration( +- "mlir-to-llvmir", ++ "mlir-to-llvmir", "translate mlir to llvmir", + [](ModuleOp module, raw_ostream &output) { + llvm::LLVMContext llvmContext; + auto llvmModule = translateModuleToLLVMIR(module, llvmContext); + if (!llvmModule) + return failure(); + + llvmModule->print(output, nullptr); + return success(); + }, + [](DialectRegistry ®istry) { + registry.insert(); + registerAllToLLVMIRTranslations(registry); + }); + } + } // namespace mlir +diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +index 664b796..d24e578 100644 +--- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp ++++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +@@ -1,184 +1,184 @@ + //===- TranslateRegistration.cpp - hooks to mlir-translate ----------------===// + // + // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + // See https://llvm.org/LICENSE.txt for license information. + // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // + //===----------------------------------------------------------------------===// + // + // This file implements a translation from SPIR-V binary module to MLIR SPIR-V + // ModuleOp. + // + //===----------------------------------------------------------------------===// + + #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" + #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + #include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinOps.h" + #include "mlir/IR/Dialect.h" + #include "mlir/IR/Verifier.h" + #include "mlir/Parser/Parser.h" + #include "mlir/Support/FileUtilities.h" + #include "mlir/Target/SPIRV/Deserialization.h" + #include "mlir/Target/SPIRV/Serialization.h" + #include "mlir/Tools/mlir-translate/Translation.h" + #include "llvm/ADT/StringRef.h" + #include "llvm/Support/MemoryBuffer.h" + #include "llvm/Support/SMLoc.h" + #include "llvm/Support/SourceMgr.h" + #include "llvm/Support/ToolOutputFile.h" + + using namespace mlir; + + //===----------------------------------------------------------------------===// + // Deserialization registration + //===----------------------------------------------------------------------===// + + // Deserializes the SPIR-V binary module stored in the file named as + // `inputFilename` and returns a module containing the SPIR-V module. + static OwningOpRef deserializeModule(const llvm::MemoryBuffer *input, + MLIRContext *context) { + context->loadDialect(); + + // Make sure the input stream can be treated as a stream of SPIR-V words + auto *start = input->getBufferStart(); + auto size = input->getBufferSize(); + if (size % sizeof(uint32_t) != 0) { + emitError(UnknownLoc::get(context)) + << "SPIR-V binary module must contain integral number of 32-bit words"; + return {}; + } + + auto binary = llvm::makeArrayRef(reinterpret_cast(start), + size / sizeof(uint32_t)); + + OwningOpRef spirvModule = + spirv::deserialize(binary, context); + if (!spirvModule) + return {}; + + OwningOpRef module(ModuleOp::create(FileLineColLoc::get( + context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0))); + module->getBody()->push_front(spirvModule.release()); + + return module; + } + + namespace mlir { + void registerFromSPIRVTranslation() { + TranslateToMLIRRegistration fromBinary( +- "deserialize-spirv", ++ "deserialize-spirv", "deserializes the SPIR-V module", + [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { + assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer"); + return deserializeModule( + sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()), context); + }); + } + } // namespace mlir + + //===----------------------------------------------------------------------===// + // Serialization registration + //===----------------------------------------------------------------------===// + + static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) { + if (!module) + return failure(); + + SmallVector binary; + + SmallVector spirvModules; + module.walk([&](spirv::ModuleOp op) { spirvModules.push_back(op); }); + + if (spirvModules.empty()) + return module.emitError("found no 'spirv.module' op"); + + if (spirvModules.size() != 1) + return module.emitError("found more than one 'spirv.module' op"); + + if (failed(spirv::serialize(spirvModules[0], binary))) + return failure(); + + output.write(reinterpret_cast(binary.data()), + binary.size() * sizeof(uint32_t)); + + return mlir::success(); + } + + namespace mlir { + void registerToSPIRVTranslation() { + TranslateFromMLIRRegistration toBinary( +- "serialize-spirv", ++ "serialize-spirv", "serialize SPIR-V dialect", + [](ModuleOp module, raw_ostream &output) { + return serializeModule(module, output); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); + } + } // namespace mlir + + //===----------------------------------------------------------------------===// + // Round-trip registration + //===----------------------------------------------------------------------===// + + static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo, + raw_ostream &output) { + SmallVector binary; + MLIRContext *context = srcModule.getContext(); + auto spirvModules = srcModule.getOps(); + + if (spirvModules.begin() == spirvModules.end()) + return srcModule.emitError("found no 'spirv.module' op"); + + if (std::next(spirvModules.begin()) != spirvModules.end()) + return srcModule.emitError("found more than one 'spirv.module' op"); + + spirv::SerializationOptions options; + options.emitDebugInfo = emitDebugInfo; + if (failed(spirv::serialize(*spirvModules.begin(), binary, options))) + return failure(); + + MLIRContext deserializationContext(context->getDialectRegistry()); + // TODO: we should only load the required dialects instead of all dialects. + deserializationContext.loadAllAvailableDialects(); + // Then deserialize to get back a SPIR-V module. + OwningOpRef spirvModule = + spirv::deserialize(binary, &deserializationContext); + if (!spirvModule) + return failure(); + + // Wrap around in a new MLIR module. + OwningOpRef dstModule(ModuleOp::create( + FileLineColLoc::get(&deserializationContext, + /*filename=*/"", /*line=*/0, /*column=*/0))); + dstModule->getBody()->push_front(spirvModule.release()); + if (failed(verify(*dstModule))) + return failure(); + dstModule->print(output); + + return mlir::success(); + } + + namespace mlir { + void registerTestRoundtripSPIRV() { + TranslateFromMLIRRegistration roundtrip( +- "test-spirv-roundtrip", ++ "test-spirv-roundtrip", "test roundtrip in SPIR-V dialect", + [](ModuleOp module, raw_ostream &output) { + return roundTripModule(module, /*emitDebugInfo=*/false, output); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); + } + + void registerTestRoundtripDebugSPIRV() { + TranslateFromMLIRRegistration roundtrip( +- "test-spirv-roundtrip-debug", ++ "test-spirv-roundtrip-debug", "test roundtrip debug in SPIR-V", + [](ModuleOp module, raw_ostream &output) { + return roundTripModule(module, /*emitDebugInfo=*/true, output); + }, + [](DialectRegistry ®istry) { + registry.insert(); + }); + } + } // namespace mlir +diff --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp +index 50b8547..7517ae0 100644 +--- a/mlir/lib/Tools/mlir-translate/Translation.cpp ++++ b/mlir/lib/Tools/mlir-translate/Translation.cpp +@@ -1,126 +1,140 @@ + //===- Translation.cpp - Translation registry -----------------------------===// + // + // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. + // See https://llvm.org/LICENSE.txt for license information. + // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + // + //===----------------------------------------------------------------------===// + // + // Definitions of the translation registry. + // + //===----------------------------------------------------------------------===// + + #include "mlir/Tools/mlir-translate/Translation.h" + #include "mlir/IR/AsmState.h" + #include "mlir/IR/BuiltinOps.h" + #include "mlir/IR/Dialect.h" + #include "mlir/IR/Verifier.h" + #include "mlir/Parser/Parser.h" + #include "llvm/Support/SourceMgr.h" + + using namespace mlir; + + //===----------------------------------------------------------------------===// + // Translation Registry + //===----------------------------------------------------------------------===// + ++struct TranslationBundle { ++ TranslateFunction translateFunction; ++ StringRef translateDescription; ++}; ++ + /// Get the mutable static map between registered file-to-file MLIR translations +-/// and the TranslateFunctions that perform those translations. +-static llvm::StringMap &getTranslationRegistry() { +- static llvm::StringMap translationRegistry; +- return translationRegistry; ++/// and TranslateFunctions with its description that perform those translations. ++static llvm::StringMap &getTranslationRegistry() { ++ static llvm::StringMap translationBundle; ++ return translationBundle; + } + + /// Register the given translation. +-static void registerTranslation(StringRef name, ++static void registerTranslation(StringRef name, StringRef description, + const TranslateFunction &function) { + auto &translationRegistry = getTranslationRegistry(); + if (translationRegistry.find(name) != translationRegistry.end()) + llvm::report_fatal_error( + "Attempting to overwrite an existing function"); + assert(function && + "Attempting to register an empty translate function"); +- translationRegistry[name] = function; ++ translationRegistry[name].translateFunction = function; ++ translationRegistry[name].translateDescription = description; + } + + TranslateRegistration::TranslateRegistration( +- StringRef name, const TranslateFunction &function) { +- registerTranslation(name, function); ++ StringRef name, StringRef description, const TranslateFunction &function) { ++ registerTranslation(name, description, function); + } + + //===----------------------------------------------------------------------===// + // Translation to MLIR + //===----------------------------------------------------------------------===// + + // Puts `function` into the to-MLIR translation registry unless there is already + // a function registered for the same name. + static void registerTranslateToMLIRFunction( +- StringRef name, const TranslateSourceMgrToMLIRFunction &function) { ++ StringRef name, StringRef description, ++ const TranslateSourceMgrToMLIRFunction &function) { + auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output, + MLIRContext *context) { + OwningOpRef module = function(sourceMgr, context); + if (!module || failed(verify(*module))) + return failure(); + module->print(output); + return success(); + }; +- registerTranslation(name, wrappedFn); ++ registerTranslation(name, description, wrappedFn); + } + + TranslateToMLIRRegistration::TranslateToMLIRRegistration( +- StringRef name, const TranslateSourceMgrToMLIRFunction &function) { +- registerTranslateToMLIRFunction(name, function); ++ StringRef name, StringRef description, ++ const TranslateSourceMgrToMLIRFunction &function) { ++ registerTranslateToMLIRFunction(name, description, function); + } +- + /// Wraps `function` with a lambda that extracts a StringRef from a source + /// manager and registers the wrapper lambda as a to-MLIR conversion. + TranslateToMLIRRegistration::TranslateToMLIRRegistration( +- StringRef name, const TranslateStringRefToMLIRFunction &function) { ++ StringRef name, StringRef description, ++ const TranslateStringRefToMLIRFunction &function) { + registerTranslateToMLIRFunction( +- name, [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) { ++ name, description, ++ [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) { + const llvm::MemoryBuffer *buffer = + sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); + return function(buffer->getBuffer(), ctx); + }); + } + + //===----------------------------------------------------------------------===// + // Translation from MLIR + //===----------------------------------------------------------------------===// + + TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( +- StringRef name, const TranslateFromMLIRFunction &function, ++ StringRef name, StringRef description, ++ const TranslateFromMLIRFunction &function, + const std::function &dialectRegistration) { +- registerTranslation(name, [function, dialectRegistration]( +- llvm::SourceMgr &sourceMgr, raw_ostream &output, +- MLIRContext *context) { +- DialectRegistry registry; +- dialectRegistration(registry); +- context->appendDialectRegistry(registry); +- auto module = parseSourceFile(sourceMgr, context); +- if (!module || failed(verify(*module))) +- return failure(); +- return function(module.get(), output); +- }); ++ registerTranslation(name, description, ++ [function, dialectRegistration]( ++ llvm::SourceMgr &sourceMgr, raw_ostream &output, ++ MLIRContext *context) { ++ DialectRegistry registry; ++ dialectRegistration(registry); ++ context->appendDialectRegistry(registry); ++ auto module = ++ parseSourceFile(sourceMgr, context); ++ if (!module || failed(verify(*module))) ++ return failure(); ++ return function(module.get(), output); ++ }); + } + + //===----------------------------------------------------------------------===// + // Translation Parser + //===----------------------------------------------------------------------===// + + TranslationParser::TranslationParser(llvm::cl::Option &opt) + : llvm::cl::parser(opt) { +- for (const auto &kv : getTranslationRegistry()) +- addLiteralOption(kv.first(), &kv.second, kv.first()); ++ for (const auto &kv : getTranslationRegistry()) { ++ addLiteralOption(kv.first(), &kv.second.translateFunction, ++ kv.second.translateDescription); ++ } + } + + void TranslationParser::printOptionInfo(const llvm::cl::Option &o, + size_t globalWidth) const { + TranslationParser *tp = const_cast(this); + llvm::array_pod_sort(tp->Values.begin(), tp->Values.end(), + [](const TranslationParser::OptionInfo *lhs, + const TranslationParser::OptionInfo *rhs) { + return lhs->Name.compare(rhs->Name); + }); + llvm::cl::parser::printOptionInfo(o, globalWidth); + } diff --git a/mlir/examples/standalone/standalone-translate/standalone-translate.cpp b/mlir/examples/standalone/standalone-translate/standalone-translate.cpp --- a/mlir/examples/standalone/standalone-translate/standalone-translate.cpp +++ b/mlir/examples/standalone/standalone-translate/standalone-translate.cpp @@ -11,16 +11,23 @@ // //===----------------------------------------------------------------------===// +#include "Standalone/StandaloneDialect.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/InitAllTranslations.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" - -#include "Standalone/StandaloneDialect.h" +#include "mlir/Tools/mlir-translate/Translation.h" int main(int argc, char **argv) { mlir::registerAllTranslations(); // TODO: Register standalone translations here. + mlir::TranslateFromMLIRRegistration withdescription( + "option", "different from option", + [](mlir::ModuleOp op, llvm::raw_ostream &output) { + return mlir::LogicalResult::success(); + }, + [](mlir::DialectRegistry &a) {}); return failed( mlir::mlirTranslateMain(argc, argv, "MLIR Translation Testing Tool")); diff --git a/mlir/include/mlir/Tools/mlir-translate/Translation.h b/mlir/include/mlir/Tools/mlir-translate/Translation.h --- a/mlir/include/mlir/Tools/mlir-translate/Translation.h +++ b/mlir/include/mlir/Tools/mlir-translate/Translation.h @@ -71,20 +71,21 @@ /// /// \{ struct TranslateToMLIRRegistration { - TranslateToMLIRRegistration(llvm::StringRef name, + TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description, const TranslateSourceMgrToMLIRFunction &function); - TranslateToMLIRRegistration(llvm::StringRef name, + TranslateToMLIRRegistration(llvm::StringRef name, llvm::StringRef description, const TranslateStringRefToMLIRFunction &function); }; struct TranslateFromMLIRRegistration { TranslateFromMLIRRegistration( - llvm::StringRef name, const TranslateFromMLIRFunction &function, + llvm::StringRef name, llvm::StringRef description, + const TranslateFromMLIRFunction &function, const std::function &dialectRegistration = [](DialectRegistry &) {}); }; struct TranslateRegistration { - TranslateRegistration(llvm::StringRef name, + TranslateRegistration(llvm::StringRef name, llvm::StringRef description, const TranslateFunction &function); }; /// \} diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -33,7 +33,7 @@ llvm::cl::init(false)); TranslateFromMLIRRegistration reg( - "mlir-to-cpp", + "mlir-to-cpp", "translate from mlir to cpp", [](ModuleOp module, raw_ostream &output) { return emitc::translateToCpp( module, output, diff --git a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp @@ -1402,7 +1402,8 @@ namespace mlir { void registerFromLLVMIRTranslation() { TranslateToMLIRRegistration fromLLVM( - "import-llvm", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { + "import-llvm", "from llvm to mlir", + [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { return ::translateLLVMIRToModule(sourceMgr, context); }); } diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -24,7 +24,7 @@ namespace mlir { void registerToLLVMIRTranslation() { TranslateFromMLIRRegistration registration( - "mlir-to-llvmir", + "mlir-to-llvmir", "translate mlir to llvmir", [](ModuleOp module, raw_ostream &output) { llvm::LLVMContext llvmContext; auto llvmModule = translateModuleToLLVMIR(module, llvmContext); diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -67,7 +67,7 @@ namespace mlir { void registerFromSPIRVTranslation() { TranslateToMLIRRegistration fromBinary( - "deserialize-spirv", + "deserialize-spirv", "deserializes the SPIR-V module", [](llvm::SourceMgr &sourceMgr, MLIRContext *context) { assert(sourceMgr.getNumBuffers() == 1 && "expected one buffer"); return deserializeModule( @@ -107,7 +107,7 @@ namespace mlir { void registerToSPIRVTranslation() { TranslateFromMLIRRegistration toBinary( - "serialize-spirv", + "serialize-spirv", "serialize SPIR-V dialect", [](ModuleOp module, raw_ostream &output) { return serializeModule(module, output); }, @@ -162,7 +162,7 @@ namespace mlir { void registerTestRoundtripSPIRV() { TranslateFromMLIRRegistration roundtrip( - "test-spirv-roundtrip", + "test-spirv-roundtrip", "test roundtrip in SPIR-V dialect", [](ModuleOp module, raw_ostream &output) { return roundTripModule(module, /*emitDebugInfo=*/false, output); }, @@ -173,7 +173,7 @@ void registerTestRoundtripDebugSPIRV() { TranslateFromMLIRRegistration roundtrip( - "test-spirv-roundtrip-debug", + "test-spirv-roundtrip-debug", "test roundtrip debug in SPIR-V", [](ModuleOp module, raw_ostream &output) { return roundTripModule(module, /*emitDebugInfo=*/true, output); }, diff --git a/mlir/lib/Tools/mlir-translate/Translation.cpp b/mlir/lib/Tools/mlir-translate/Translation.cpp --- a/mlir/lib/Tools/mlir-translate/Translation.cpp +++ b/mlir/lib/Tools/mlir-translate/Translation.cpp @@ -24,15 +24,20 @@ // Translation Registry //===----------------------------------------------------------------------===// +struct TranslationBundle { + TranslateFunction translateFunction; + StringRef translateDescription; +}; + /// Get the mutable static map between registered file-to-file MLIR translations -/// and the TranslateFunctions that perform those translations. -static llvm::StringMap &getTranslationRegistry() { - static llvm::StringMap translationRegistry; - return translationRegistry; +/// and TranslateFunctions with its description that perform those translations. +static llvm::StringMap &getTranslationRegistry() { + static llvm::StringMap translationBundle; + return translationBundle; } /// Register the given translation. -static void registerTranslation(StringRef name, +static void registerTranslation(StringRef name, StringRef description, const TranslateFunction &function) { auto &translationRegistry = getTranslationRegistry(); if (translationRegistry.find(name) != translationRegistry.end()) @@ -40,12 +45,13 @@ "Attempting to overwrite an existing function"); assert(function && "Attempting to register an empty translate function"); - translationRegistry[name] = function; + translationRegistry[name].translateFunction = function; + translationRegistry[name].translateDescription = description; } TranslateRegistration::TranslateRegistration( - StringRef name, const TranslateFunction &function) { - registerTranslation(name, function); + StringRef name, StringRef description, const TranslateFunction &function) { + registerTranslation(name, description, function); } //===----------------------------------------------------------------------===// @@ -55,7 +61,8 @@ // Puts `function` into the to-MLIR translation registry unless there is already // a function registered for the same name. static void registerTranslateToMLIRFunction( - StringRef name, const TranslateSourceMgrToMLIRFunction &function) { + StringRef name, StringRef description, + const TranslateSourceMgrToMLIRFunction &function) { auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output, MLIRContext *context) { OwningOpRef module = function(sourceMgr, context); @@ -64,20 +71,22 @@ module->print(output); return success(); }; - registerTranslation(name, wrappedFn); + registerTranslation(name, description, wrappedFn); } TranslateToMLIRRegistration::TranslateToMLIRRegistration( - StringRef name, const TranslateSourceMgrToMLIRFunction &function) { - registerTranslateToMLIRFunction(name, function); + StringRef name, StringRef description, + const TranslateSourceMgrToMLIRFunction &function) { + registerTranslateToMLIRFunction(name, description, function); } - /// Wraps `function` with a lambda that extracts a StringRef from a source /// manager and registers the wrapper lambda as a to-MLIR conversion. TranslateToMLIRRegistration::TranslateToMLIRRegistration( - StringRef name, const TranslateStringRefToMLIRFunction &function) { + StringRef name, StringRef description, + const TranslateStringRefToMLIRFunction &function) { registerTranslateToMLIRFunction( - name, [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) { + name, description, + [function](llvm::SourceMgr &sourceMgr, MLIRContext *ctx) { const llvm::MemoryBuffer *buffer = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); return function(buffer->getBuffer(), ctx); @@ -89,19 +98,22 @@ //===----------------------------------------------------------------------===// TranslateFromMLIRRegistration::TranslateFromMLIRRegistration( - StringRef name, const TranslateFromMLIRFunction &function, + StringRef name, StringRef description, + const TranslateFromMLIRFunction &function, const std::function &dialectRegistration) { - registerTranslation(name, [function, dialectRegistration]( - llvm::SourceMgr &sourceMgr, raw_ostream &output, - MLIRContext *context) { - DialectRegistry registry; - dialectRegistration(registry); - context->appendDialectRegistry(registry); - auto module = parseSourceFile(sourceMgr, context); - if (!module || failed(verify(*module))) - return failure(); - return function(module.get(), output); - }); + registerTranslation(name, description, + [function, dialectRegistration]( + llvm::SourceMgr &sourceMgr, raw_ostream &output, + MLIRContext *context) { + DialectRegistry registry; + dialectRegistration(registry); + context->appendDialectRegistry(registry); + auto module = + parseSourceFile(sourceMgr, context); + if (!module || failed(verify(*module))) + return failure(); + return function(module.get(), output); + }); } //===----------------------------------------------------------------------===// @@ -110,8 +122,10 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt) : llvm::cl::parser(opt) { - for (const auto &kv : getTranslationRegistry()) - addLiteralOption(kv.first(), &kv.second, kv.first()); + for (const auto &kv : getTranslationRegistry()) { + addLiteralOption(kv.first(), &kv.second.translateFunction, + kv.second.translateDescription); + } } void TranslationParser::printOptionInfo(const llvm::cl::Option &o,