diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/StandardTypes.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/AsmParser/Parser.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/Bitcode/BitcodeWriter.h" @@ -37,6 +38,154 @@ #include "mlir/Dialect/LLVMIR/LLVMOpsEnums.cpp.inc" +namespace { +/// Support for translating LLVM IR types to MLIR LLVM dialect types. +class TypeFromLLVMIRTranslator { +public: + /// Constructs a class creating types in the given MLIR context. + TypeFromLLVMIRTranslator(MLIRContext &context) : context(context) {} + + /// Translates the given type. + LLVM::LLVMType translateType(llvm::Type *type) { + if (knownTranslations.count(type)) + return knownTranslations.lookup(type); + + LLVM::LLVMType translated = + llvm::TypeSwitch(type) + .Case([this](llvm::ArrayType *type) { + return translateArrayType(type); + }) + .Case([this](llvm::FunctionType *type) { + return translateFunctionType(type); + }) + .Case([this](llvm::IntegerType *type) { + return translateIntegerType(type); + }) + .Case([this](llvm::PointerType *type) { + return translatePointerType(type); + }) + .Case([this](llvm::StructType *type) { + return translateStructType(type); + }) + .Case([this](llvm::FixedVectorType *type) { + return translateFixedVectorType(type); + }) + .Case([this](llvm::ScalableVectorType *type) { + return translateScalableVectorType(type); + }) + .Default([this](llvm::Type *type) { + return translatePrimitiveType(type); + }); + knownTranslations.try_emplace(type, translated); + return translated; + } + +private: + /// Translates the given primitive, i.e. non-parametric in MLIR nomenclature, + /// type. + LLVM::LLVMType translatePrimitiveType(llvm::Type *type) { + if (type->isVoidTy()) + return LLVM::LLVMVoidType::get(&context); + if (type->isHalfTy()) + return LLVM::LLVMHalfType::get(&context); + if (type->isBFloatTy()) + return LLVM::LLVMBFloatType::get(&context); + if (type->isFloatTy()) + return LLVM::LLVMFloatType::get(&context); + if (type->isDoubleTy()) + return LLVM::LLVMDoubleType::get(&context); + if (type->isFP128Ty()) + return LLVM::LLVMFP128Type::get(&context); + if (type->isX86_FP80Ty()) + return LLVM::LLVMX86FP80Type::get(&context); + if (type->isPPC_FP128Ty()) + return LLVM::LLVMPPCFP128Type::get(&context); + if (type->isX86_MMXTy()) + return LLVM::LLVMX86MMXType::get(&context); + if (type->isLabelTy()) + return LLVM::LLVMLabelType::get(&context); + if (type->isMetadataTy()) + return LLVM::LLVMMetadataType::get(&context); + llvm_unreachable("not a primitive type"); + } + + /// Translates the given array type. + LLVM::LLVMType translateArrayType(llvm::ArrayType *type) { + return LLVM::LLVMArrayType::get(translateType(type->getElementType()), + type->getNumElements()); + } + + /// Translates the given function type. + LLVM::LLVMType translateFunctionType(llvm::FunctionType *type) { + SmallVector paramTypes; + translateTypes(type->params(), paramTypes); + return LLVM::LLVMFunctionType::get(translateType(type->getReturnType()), + paramTypes, type->isVarArg()); + } + + /// Translates the given integer type. + LLVM::LLVMType translateIntegerType(llvm::IntegerType *type) { + return LLVM::LLVMIntegerType::get(&context, type->getBitWidth()); + } + + /// Translates the given pointer type. + LLVM::LLVMType translatePointerType(llvm::PointerType *type) { + return LLVM::LLVMPointerType::get(translateType(type->getElementType()), + type->getAddressSpace()); + } + + /// Translates the given structure type. + LLVM::LLVMType translateStructType(llvm::StructType *type) { + SmallVector subtypes; + if (type->isLiteral()) { + translateTypes(type->subtypes(), subtypes); + return LLVM::LLVMStructType::getLiteral(&context, subtypes, + type->isPacked()); + } + + if (type->isOpaque()) + return LLVM::LLVMStructType::getOpaque(type->getName(), &context); + + LLVM::LLVMStructType translated = + LLVM::LLVMStructType::getIdentified(&context, type->getName()); + knownTranslations.try_emplace(type, translated); + translateTypes(type->subtypes(), subtypes); + LogicalResult bodySet = translated.setBody(subtypes, type->isPacked()); + assert(succeeded(bodySet) && + "could not set the body of an identified struct"); + (void)bodySet; + return translated; + } + + /// Translates the given fixed-vector type. + LLVM::LLVMType translateFixedVectorType(llvm::FixedVectorType *type) { + return LLVM::LLVMFixedVectorType::get(translateType(type->getElementType()), + type->getNumElements()); + } + + /// Translates the given scalable-vector type. + LLVM::LLVMType translateScalableVectorType(llvm::ScalableVectorType *type) { + return LLVM::LLVMScalableVectorType::get( + translateType(type->getElementType()), type->getMinNumElements()); + } + + /// Translates a list of types. + void translateTypes(ArrayRef types, + SmallVectorImpl &result) { + result.reserve(result.size() + types.size()); + for (llvm::Type *type : types) + result.push_back(translateType(type)); + } + + /// Map of known translations. Serves as a cache and as recursion stopper for + /// translating recursive structs. + llvm::DenseMap knownTranslations; + + /// The context in which MLIR types are created. + MLIRContext &context; +}; +} // end namespace + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::CmpOp. //===----------------------------------------------------------------------===// @@ -1746,7 +1895,17 @@ /// Parse a type registered to this dialect. Type LLVMDialect::parseType(DialectAsmParser &parser) const { - return detail::parseType(parser); + std::string tyData = parser.getFullSymbolSpec().str(); + + // LLVM is not thread-safe, so lock access to it. + llvm::sys::SmartScopedLock lock(impl->mutex); + + llvm::SMDiagnostic errorMessage; + llvm::Type *type = llvm::parseType(tyData, errorMessage, impl->module); + if (type) + return TypeFromLLVMIRTranslator(*getContext()).translateType(type); + + return nullptr; } /// Print a type registered to this dialect. diff --git a/mlir/tools/CMakeLists.txt b/mlir/tools/CMakeLists.txt --- a/mlir/tools/CMakeLists.txt +++ b/mlir/tools/CMakeLists.txt @@ -1,9 +1,10 @@ add_subdirectory(mlir-cuda-runner) add_subdirectory(mlir-cpu-runner) add_subdirectory(mlir-linalg-ods-gen) +add_subdirectory(mlir-modernize) add_subdirectory(mlir-opt) add_subdirectory(mlir-reduce) add_subdirectory(mlir-rocm-runner) add_subdirectory(mlir-shlib) add_subdirectory(mlir-translate) -add_subdirectory(mlir-vulkan-runner) \ No newline at end of file +add_subdirectory(mlir-vulkan-runner) diff --git a/mlir/tools/mlir-modernize/CMakeLists.txt b/mlir/tools/mlir-modernize/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-modernize/CMakeLists.txt @@ -0,0 +1,22 @@ +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +set(LLVM_LINK_COMPONENTS + Core + Support + AsmParser + ) + +set(LIBS + ${dialect_libs} + MLIRIR + MLIRSupport + MLIRTransformUtils + ) + +add_llvm_tool(mlir-modernize + mlir-modernize.cpp + + DEPENDS + ${LIBS} + ) +target_link_libraries(mlir-modernize PRIVATE ${LIBS}) +llvm_update_compile_flags(mlir-modernize) diff --git a/mlir/tools/mlir-modernize/mlir-modernize.cpp b/mlir/tools/mlir-modernize/mlir-modernize.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-modernize/mlir-modernize.cpp @@ -0,0 +1,84 @@ +#include "mlir/InitAllDialects.h" +#include "mlir/Parser.h" +#include "mlir/Support/FileUtilities.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/raw_ostream.h" + +size_t findEndPos(llvm::StringRef str) { + if (str[0] == '.') { + return str.find_if_not([](char c) { return llvm::isAlnum(c) || c == '_'; }, + 1); + } + assert(str[0] == '<'); + int counter = 0; + size_t i = 1; + do { + char c = str[i++]; + if (c == '<') { + ++counter; + } else if (c == '>') { + if (counter == 0) + return i; + --counter; + } + } while (i < str.size()); + return str.size(); +} + +int main(int argc, char **argv) { + llvm::cl::opt input_filename(llvm::cl::Positional, + llvm::cl::desc(""), + llvm::cl::init("-")); + llvm::InitLLVM init(argc, argv); + + mlir::registerAllDialects(); + llvm::cl::ParseCommandLineOptions(argc, argv, "modernizer"); + + mlir::MLIRContext mlirContext; + + std::string errorMessage; + auto inputFile = mlir::openInputFile(input_filename, &errorMessage); + if (!inputFile) { + llvm::errs() << errorMessage << "\n"; + return EXIT_FAILURE; + } + + llvm::StringRef string = inputFile->getBuffer(); + llvm::LLVMContext llvmctx; + llvm::Module m("module", llvmctx); + size_t pos = 0; + do { + size_t start = string.find("!llvm", pos); + if (start == std::string::npos) { + llvm::outs() << string.substr(pos); + break; + } + llvm::outs() << string.substr(pos, start - pos); + + llvm::SMDiagnostic errorMessage; + size_t prefixLen = llvm::StringRef("!llvm").size(); + size_t endPos = findEndPos(string.substr(start + prefixLen)); + auto possibleType = + string.substr(start).take_front(prefixLen + endPos).str(); + + auto type = mlir::parseType(possibleType, &mlirContext); + if (!type) { + pos = start + prefixLen; + continue; + } + + assert(type.isa()); + type.print(llvm::outs()); + pos = start + possibleType.size(); + + // auto *type = llvm::parseTypeAtBeginning(string.drop_front(start), read, + // errorMessage, m); + + } while (true); + + return 0; +}