diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -279,6 +279,10 @@ Option<"indexBitwidth", "index-bitwidth", "unsigned", /*default=kDeriveIndexBitwidthFromDataLayout*/"0", "Bitwidth of the index type, 0 to use size of machine word">, + Option<"dataLayout", "data-layout", "std::string", + /*default=*/"\"\"", + "String description (LLVM format) of the data layout that is " + "expected on the produced module"> ]; } diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -105,6 +105,9 @@ /// pointers to memref descriptors for arguments. LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type); + /// Returns the data layout to use during and after conversion. + const llvm::DataLayout &getDataLayout() { return options.dataLayout; } + /// Gets the LLVM representation of the index type. The returned type is an /// integer type with the size configured for this type converter. LLVM::LLVMType getIndexType(); diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -9,6 +9,8 @@ #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ +#include "llvm/IR/DataLayout.h" + #include namespace mlir { @@ -31,6 +33,11 @@ /// Use aligned_alloc for heap allocations. bool useAlignedAlloc = false; + /// The data layout of the module to produce. This must be consistent with the + /// data layout used in the upper levels of the lowering pipeline. + // TODO: this should be replaced by MLIR data layout when one exists. + llvm::DataLayout dataLayout = llvm::DataLayout(""); + /// Get a statically allocated copy of the default LowerToLLVMOptions. static const LowerToLLVMOptions &getDefaultOptions() { static LowerToLLVMOptions options; diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -20,16 +20,15 @@ let name = "llvm"; let cppNamespace = "LLVM"; let hasRegionArgAttrVerify = 1; + let hasOperationAttrVerify = 1; let extraClassDeclaration = [{ - ~LLVMDialect(); - const llvm::DataLayout &getDataLayout(); + /// Name of the data layout attributes. + constexpr static StringRef kDataLayoutAttrName = "llvm.data_layout"; - private: - friend LLVMType; - - // This can't be a unique_ptr because the ctor is generated inline - // in the class definition at the moment. - detail::LLVMDialectImpl *impl; + /// Verifies if the given string is a well-formed data layout descriptor. + /// Uses `reportError` to report errors. + static LogicalResult verifyDataLayoutString( + StringRef descr, llvm::function_ref reportError); }]; } diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -129,8 +129,7 @@ options(options) { assert(llvmDialect && "LLVM IR dialect is not registered"); if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout) - this->options.indexBitwidth = - llvmDialect->getDataLayout().getPointerSizeInBits(); + this->options.indexBitwidth = options.dataLayout.getPointerSizeInBits(); // Register conversions for the standard types. addConversion([&](ComplexType type) { return convertComplexType(type); }); @@ -198,7 +197,7 @@ } unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { - return llvmDialect->getDataLayout().getPointerSizeInBits(addressSpace); + return options.dataLayout.getPointerSizeInBits(addressSpace); } Type LLVMTypeConverter::convertIndexType(IndexType type) { @@ -3429,11 +3428,13 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase { LLVMLoweringPass() = default; LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, - unsigned indexBitwidth, bool useAlignedAlloc) { + unsigned indexBitwidth, bool useAlignedAlloc, + const llvm::DataLayout &dataLayout) { this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; this->indexBitwidth = indexBitwidth; this->useAlignedAlloc = useAlignedAlloc; + this->dataLayout = dataLayout.getStringRepresentation(); } /// Run the dialect converter on the module. @@ -3445,11 +3446,19 @@ signalPassFailure(); return; } + if (failed(LLVM::LLVMDialect::verifyDataLayoutString( + this->dataLayout, [this](Twine message) { + getOperation().emitError() << message.str(); + }))) { + signalPassFailure(); + return; + } ModuleOp m = getOperation(); LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers, - indexBitwidth, useAlignedAlloc}; + indexBitwidth, useAlignedAlloc, + llvm::DataLayout(this->dataLayout)}; LLVMTypeConverter typeConverter(&getContext(), options); OwningRewritePatternList patterns; @@ -3458,6 +3467,8 @@ LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, patterns))) signalPassFailure(); + m.setAttr(LLVM::LLVMDialect::kDataLayoutAttrName, + StringAttr::get(this->dataLayout, m.getContext())); } }; } // end namespace @@ -3473,5 +3484,5 @@ mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) { return std::make_unique( options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth, - options.useAlignedAlloc); + options.useAlignedAlloc, options.dataLayout); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -128,11 +128,10 @@ // TODO: this should use the MLIR data layout when it becomes available and // stop depending on translation. - LLVM::LLVMDialect *dialect = typeConverter.getDialect(); llvm::LLVMContext llvmContext; align = LLVM::TypeToLLVMIRTranslator(llvmContext) .getPreferredAlignment(elementTy.cast(), - dialect->getDataLayout()); + typeConverter.getDataLayout()); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -1668,23 +1668,7 @@ // LLVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// -namespace mlir { -namespace LLVM { -namespace detail { -struct LLVMDialectImpl { - LLVMDialectImpl() : layout("") {} - - /// Default data layout to use. - // TODO: this should be moved to some Op equivalent to LLVM module and - // eventually replaced with a proper MLIR data layout. - llvm::DataLayout layout; -}; -} // end namespace detail -} // end namespace LLVM -} // end namespace mlir - void LLVMDialect::initialize() { - impl = new detail::LLVMDialectImpl(); // clang-format off addTypeslayout; } - /// Parse a type registered to this dialect. Type LLVMDialect::parseType(DialectAsmParser &parser) const { return detail::parseType(parser); @@ -1732,6 +1712,39 @@ return detail::printType(type.cast(), os); } +LogicalResult LLVMDialect::verifyDataLayoutString( + StringRef descr, llvm::function_ref reportError) { + llvm::Expected maybeDataLayout = + llvm::DataLayout::parse(descr); + if (maybeDataLayout) + return success(); + + std::string message; + llvm::raw_string_ostream messageStream(message); + llvm::logAllUnhandledErrors(maybeDataLayout.takeError(), messageStream); + reportError("invalid data layout descriptor: " + messageStream.str()); + return failure(); +} + +/// Verify LLVM dialect attributes. +LogicalResult LLVMDialect::verifyOperationAttribute(Operation *op, + NamedAttribute attr) { + // If the data layout attribute is present, it must use the LLVM data layout + // syntax. Try parsing it and report errors in case of failure. Users of this + // attribute may assume it is well-formed and can pass it to the (asserting) + // llvm::DataLayout constructor. + if (attr.first.strref() != LLVM::LLVMDialect::kDataLayoutAttrName) + return success(); + if (auto stringAttr = attr.second.dyn_cast()) + return verifyDataLayoutString(stringAttr.getValue(), [op](Twine message) { + op->emitOpError() << message.str(); + }); + + return op->emitOpError() << "expected '" + << LLVM::LLVMDialect::kDataLayoutAttrName + << "' to be a string attribute"; +} + /// Verify LLVMIR function argument attributes. LogicalResult LLVMDialect::verifyRegionArgAttribute(Operation *op, unsigned regionIdx, diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -944,11 +944,10 @@ std::unique_ptr ModuleTranslation::prepareLLVMModule( Operation *m, llvm::LLVMContext &llvmContext, StringRef name) { - auto *dialect = m->getContext()->getRegisteredDialect(); - assert(dialect && "LLVM dialect must be registered"); - auto llvmModule = std::make_unique(name, llvmContext); - llvmModule->setDataLayout(dialect->getDataLayout()); + + if (auto dataLayoutAttr = m->getAttr(LLVM::LLVMDialect::kDataLayoutAttrName)) + llvmModule->setDataLayout(dataLayoutAttr.cast().getValue()); // Inject declarations for `malloc` and `free` functions that can be used in // memref allocation/deallocation coming from standard ops lowering. diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -603,3 +603,10 @@ // expected-error @+1 {{can be given only acquire, release, acq_rel, and seq_cst orderings}} llvm.fence syncscope("agent") monotonic } + +// ----- + +// expected-error @+1 {{invalid data layout descriptor}} +module attributes {llvm.data_layout = "#vjkr32"} { + func @invalid_data_layout() +} diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1295,3 +1295,19 @@ } // CHECK: ![[NODE]] = !{i32 1} + +// ----- + +// Check that the translation does not crash in absence of a data layout. +module { + // CHECK: declare void @module_default_layout + llvm.func @module_default_layout() +} + +// ----- + +// CHECK: target datalayout = "E" +module attributes {llvm.data_layout = "E"} { + llvm.func @module_big_endian() +} +