diff --git a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h --- a/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h +++ b/mlir/include/mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h @@ -8,15 +8,34 @@ #ifndef MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_ #define MLIR_CONVERSION_COMPLEXTOLLVM_COMPLEXTOLLVM_H_ -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { -class MLIRContext; +class LLVMTypeConverter; class ModuleOp; template class OperationPass; +class ComplexStructBuilder : public StructBuilder { +public: + /// Construct a helper for the given complex number value. + using StructBuilder::StructBuilder; + /// Build IR creating an `undef` value of the complex number type. + static ComplexStructBuilder undef(OpBuilder &builder, Location loc, + Type type); + + // Build IR extracting the real value from the complex number struct. + Value real(OpBuilder &builder, Location loc); + // Build IR inserting the real value into the complex number struct. + void setReal(OpBuilder &builder, Location loc, Value real); + + // Build IR extracting the imaginary value from the complex number struct. + Value imaginary(OpBuilder &builder, Location loc); + // Build IR inserting the imaginary value into the complex number struct. + void setImaginary(OpBuilder &builder, Location loc, Value imaginary); +}; + /// Populate the given list with patterns that convert from Complex to LLVM. void populateComplexToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h --- a/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ b/mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -8,7 +8,7 @@ #ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ #define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include namespace mlir { diff --git a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h --- a/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h +++ b/mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h @@ -8,7 +8,7 @@ #ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ #define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include namespace mlir { diff --git a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h @@ -0,0 +1,73 @@ +//===- LoweringOptions.h - Common config for lowering to LLVM ---*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides a configuration shared by several conversions targeting the LLVM +// dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H +#define MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H + +#include "llvm/IR/DataLayout.h" + +namespace mlir { + +class DataLayout; +class MLIRContext; + +/// Value to pass as bitwidth for the index type when the converter is expected +/// to derive the bitwidth from the LLVM data layout. +static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0; + +/// Options to control the Standard dialect to LLVM lowering. The struct is used +/// to share lowering options between passes, patterns, and type converter. +class LowerToLLVMOptions { +public: + explicit LowerToLLVMOptions(MLIRContext *ctx); + LowerToLLVMOptions(MLIRContext *ctx, const DataLayout &dl); + + bool useBarePtrCallConv = false; + bool emitCWrappers = false; + + enum class AllocLowering { + /// Use malloc for for heap allocations. + Malloc, + + /// Use aligned_alloc for heap allocations. + AlignedAlloc, + + /// Do not lower heap allocations. Users must provide their own patterns for + /// AllocOp and DeallocOp lowering. + None + }; + + AllocLowering allocLowering = AllocLowering::Malloc; + + /// The data layout of the module to produce. This must be consistent with the + /// data layout used in the upper levels of the lowering pipeline. + // TODO: this should be replaced by MLIR data layout when one exists. + llvm::DataLayout dataLayout = llvm::DataLayout(""); + + /// Set the index bitwidth to the given value. + void overrideIndexBitwidth(unsigned bitwidth) { + assert(bitwidth != kDeriveIndexBitwidthFromDataLayout && + "can only override to a concrete bitwidth"); + indexBitwidth = bitwidth; + } + + /// Get the index bitwidth. + unsigned getIndexBitwidth() const { return indexBitwidth; } + +private: + unsigned indexBitwidth; +}; + +} // namespace mlir + +#endif // MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H diff --git a/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/LLVMCommon/MemRefBuilder.h @@ -0,0 +1,245 @@ +//===- MemRefBuilder.h - Helper for LLVM MemRef equivalents -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides a convenience API for emitting IR that inspects or constructs values +// of LLVM dialect structure type that correspond to ranked or unranked memref. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H +#define MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H + +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/IR/OperationSupport.h" + +namespace mlir { + +class LLVMTypeConverter; +class MemRefType; +class UnrankedMemRefType; + +namespace LLVM { +class LLVMPointerType; +} // namespace LLVM + +/// Helper class to produce LLVM dialect operations extracting or inserting +/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. +/// The Value may be null, in which case none of the operations are valid. +class MemRefDescriptor : public StructBuilder { +public: + /// Construct a helper for the given descriptor value. + explicit MemRefDescriptor(Value descriptor); + /// Builds IR creating an `undef` value of the descriptor type. + static MemRefDescriptor undef(OpBuilder &builder, Location loc, + Type descriptorType); + /// Builds IR creating a MemRef descriptor that represents `type` and + /// populates it with static shape and stride information extracted from the + /// type. + static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + MemRefType type, Value memory); + + /// Builds IR extracting the allocated pointer from the descriptor. + Value allocatedPtr(OpBuilder &builder, Location loc); + /// Builds IR inserting the allocated pointer into the descriptor. + void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); + + /// Builds IR extracting the aligned pointer from the descriptor. + Value alignedPtr(OpBuilder &builder, Location loc); + + /// Builds IR inserting the aligned pointer into the descriptor. + void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); + + /// Builds IR extracting the offset from the descriptor. + Value offset(OpBuilder &builder, Location loc); + + /// Builds IR inserting the offset into the descriptor. + void setOffset(OpBuilder &builder, Location loc, Value offset); + void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); + + /// Builds IR extracting the pos-th size from the descriptor. + Value size(OpBuilder &builder, Location loc, unsigned pos); + Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank); + + /// Builds IR inserting the pos-th size into the descriptor + void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); + void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, + uint64_t size); + + /// Builds IR extracting the pos-th size from the descriptor. + Value stride(OpBuilder &builder, Location loc, unsigned pos); + + /// Builds IR inserting the pos-th stride into the descriptor + void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); + void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, + uint64_t stride); + + /// Returns the (LLVM) pointer type this descriptor contains. + LLVM::LLVMPointerType getElementPtrType(); + + /// Builds IR populating a MemRef descriptor structure from a list of + /// individual values composing that descriptor, in the following order: + /// - allocated pointer; + /// - aligned pointer; + /// - offset; + /// - sizes; + /// - shapes; + /// where is the MemRef rank as provided in `type`. + static Value pack(OpBuilder &builder, Location loc, + LLVMTypeConverter &converter, MemRefType type, + ValueRange values); + + /// Builds IR extracting individual elements of a MemRef descriptor structure + /// and returning them as `results` list. + static void unpack(OpBuilder &builder, Location loc, Value packed, + MemRefType type, SmallVectorImpl &results); + + /// Returns the number of non-aggregate values that would be produced by + /// `unpack`. + static unsigned getNumUnpackedValues(MemRefType type); + +private: + // Cached index type. + Type indexType; +}; + +/// Helper class allowing the user to access a range of Values that correspond +/// to an unpacked memref descriptor using named accessors. This does not own +/// the values. +class MemRefDescriptorView { +public: + /// Constructs the view from a range of values. Infers the rank from the size + /// of the range. + explicit MemRefDescriptorView(ValueRange range); + + /// Returns the allocated pointer Value. + Value allocatedPtr(); + + /// Returns the aligned pointer Value. + Value alignedPtr(); + + /// Returns the offset Value. + Value offset(); + + /// Returns the pos-th size Value. + Value size(unsigned pos); + + /// Returns the pos-th stride Value. + Value stride(unsigned pos); + +private: + /// Rank of the memref the descriptor is pointing to. + int rank; + /// Underlying range of Values. + ValueRange elements; +}; + +class UnrankedMemRefDescriptor : public StructBuilder { +public: + /// Construct a helper for the given descriptor value. + explicit UnrankedMemRefDescriptor(Value descriptor); + /// Builds IR creating an `undef` value of the descriptor type. + static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, + Type descriptorType); + + /// Builds IR extracting the rank from the descriptor + Value rank(OpBuilder &builder, Location loc); + /// Builds IR setting the rank in the descriptor + void setRank(OpBuilder &builder, Location loc, Value value); + /// Builds IR extracting ranked memref descriptor ptr + Value memRefDescPtr(OpBuilder &builder, Location loc); + /// Builds IR setting ranked memref descriptor ptr + void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); + + /// Builds IR populating an unranked MemRef descriptor structure from a list + /// of individual constituent values in the following order: + /// - rank of the memref; + /// - pointer to the memref descriptor. + static Value pack(OpBuilder &builder, Location loc, + LLVMTypeConverter &converter, UnrankedMemRefType type, + ValueRange values); + + /// Builds IR extracting individual elements that compose an unranked memref + /// descriptor and returns them as `results` list. + static void unpack(OpBuilder &builder, Location loc, Value packed, + SmallVectorImpl &results); + + /// Returns the number of non-aggregate values that would be produced by + /// `unpack`. + static unsigned getNumUnpackedValues() { return 2; } + + /// Builds IR computing the sizes in bytes (suitable for opaque allocation) + /// and appends the corresponding values into `sizes`. + static void computeSizes(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + ArrayRef values, + SmallVectorImpl &sizes); + + /// TODO: The following accessors don't take alignment rules between elements + /// of the descriptor struct into account. For some architectures, it might be + /// necessary to extend them and to use `llvm::DataLayout` contained in + /// `LLVMTypeConverter`. + + /// Builds IR extracting the allocated pointer from the descriptor. + static Value allocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, Type elemPtrPtrType); + /// Builds IR inserting the allocated pointer into the descriptor. + static void setAllocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, Type elemPtrPtrType, + Value allocatedPtr); + + /// Builds IR extracting the aligned pointer from the descriptor. + static Value alignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, Value memRefDescPtr, + Type elemPtrPtrType); + /// Builds IR inserting the aligned pointer into the descriptor. + static void setAlignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, Type elemPtrPtrType, + Value alignedPtr); + + /// Builds IR extracting the offset from the descriptor. + static Value offset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, Value memRefDescPtr, + Type elemPtrPtrType); + /// Builds IR inserting the offset into the descriptor. + static void setOffset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, Value memRefDescPtr, + Type elemPtrPtrType, Value offset); + + /// Builds IR extracting the pointer to the first element of the size array. + static Value sizeBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMPointerType elemPtrPtrType); + /// Builds IR extracting the size[index] from the descriptor. + static Value size(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value sizeBasePtr, + Value index); + /// Builds IR inserting the size[index] into the descriptor. + static void setSize(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value sizeBasePtr, + Value index, Value size); + + /// Builds IR extracting the pointer to the first element of the stride array. + static Value strideBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value sizeBasePtr, Value rank); + /// Builds IR extracting the stride[index] from the descriptor. + static Value stride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value strideBasePtr, + Value index, Value stride); + /// Builds IR inserting the stride[index] into the descriptor. + static void setStride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value strideBasePtr, + Value index, Value stride); +}; + +} // namespace mlir + +#endif // MLIR_CONVERSION_LLVMCOMMON_MEMREFBUILDER_H_ diff --git a/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h b/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/LLVMCommon/StructBuilder.h @@ -0,0 +1,51 @@ +//===- StructBuilder.h - Helper for building LLVM structs -------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides a convenience API for emitting IR that inspects or constructs values +// of LLVM dialect structure types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_LLVMCOMMON_STRUCTBUILDER_H +#define MLIR_CONVERSION_LLVMCOMMON_STRUCTBUILDER_H + +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" + +namespace mlir { + +class OpBuilder; + +/// Helper class to produce LLVM dialect operations extracting or inserting +/// values to a struct. +class StructBuilder { +public: + /// Construct a helper for the given value. + explicit StructBuilder(Value v); + /// Builds IR creating an `undef` value of the descriptor type. + static StructBuilder undef(OpBuilder &builder, Location loc, + Type descriptorType); + + /*implicit*/ operator Value() { return value; } + +protected: + // LLVM value + Value value; + // Cached struct type. + Type structType; + +protected: + /// Builds IR to extract a value from the struct at position pos + Value extractPtr(OpBuilder &builder, Location loc, unsigned pos); + /// Builds IR to set a value in the struct at position pos + void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); +}; + +} // namespace mlir + +#endif // MLIR_CONVERSION_LLVMCOMMON_STRUCTBUILDER_H diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -0,0 +1,227 @@ +//===- TypeConverter.h - Convert builtin to LLVM dialect types --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides a type converter configuration for converting most builtin types to +// LLVM dialect types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H +#define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H + +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +class DataLayoutAnalysis; +class LowerToLLVMOptions; + +namespace LLVM { +class LLVMDialect; +} // namespace LLVM + +/// Conversion from types in the Standard dialect to the LLVM IR dialect. +class LLVMTypeConverter : public TypeConverter { + /// Give structFuncArgTypeConverter access to memref-specific functions. + friend LogicalResult + structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, + SmallVectorImpl &result); + +public: + using TypeConverter::convertType; + + /// Create an LLVMTypeConverter using the default LowerToLLVMOptions. + /// Optionally takes a data layout analysis to use in conversions. + LLVMTypeConverter(MLIRContext *ctx, + const DataLayoutAnalysis *analysis = nullptr); + + /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. Optionally + /// takes a data layout analysis to use in conversions. + LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options, + const DataLayoutAnalysis *analysis = nullptr); + + /// Convert a function type. The arguments and results are converted one by + /// one and results are packed into a wrapped LLVM IR structure type. `result` + /// is populated with argument mapping. + Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, + SignatureConversion &result); + + /// Convert a non-empty list of types to be returned from a function into a + /// supported LLVM IR type. In particular, if more than one value is + /// returned, create an LLVM IR structure type with elements that correspond + /// to each of the MLIR types converted with `convertType`. + Type packFunctionResults(TypeRange types); + + /// Convert a type in the context of the default or bare pointer calling + /// convention. Calling convention sensitive types, such as MemRefType and + /// UnrankedMemRefType, are converted following the specific rules for the + /// calling convention. Calling convention independent types are converted + /// following the default LLVM type conversions. + Type convertCallingConventionType(Type type); + + /// Promote the bare pointers in 'values' that resulted from memrefs to + /// descriptors. 'stdTypes' holds the types of 'values' before the conversion + /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). + void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, + Location loc, ArrayRef stdTypes, + SmallVectorImpl &values); + + /// Returns the MLIR context. + MLIRContext &getContext(); + + /// Returns the LLVM dialect. + LLVM::LLVMDialect *getDialect() { return llvmDialect; } + + const LowerToLLVMOptions &getOptions() const { return options; } + + /// Promote the LLVM representation of all operands including promoting MemRef + /// descriptors to stack and use pointers to struct to avoid the complexity + /// of the platform-specific C/C++ ABI lowering related to struct argument + /// passing. + SmallVector promoteOperands(Location loc, ValueRange opOperands, + ValueRange operands, + OpBuilder &builder); + + /// Promote the LLVM struct representation of one MemRef descriptor to stack + /// and use pointer to struct to avoid the complexity of the platform-specific + /// C/C++ ABI lowering related to struct argument passing. + Value promoteOneMemRefDescriptor(Location loc, Value operand, + OpBuilder &builder); + + /// Converts the function type to a C-compatible format, in particular using + /// pointers to memref descriptors for arguments. Also converts the return + /// type to a pointer argument if it is a struct. Returns true if this + /// was the case. + std::pair convertFunctionTypeCWrapper(FunctionType type); + + /// Returns the data layout to use during and after conversion. + const llvm::DataLayout &getDataLayout() { return options.dataLayout; } + + /// Returns the data layout analysis to query during conversion. + const DataLayoutAnalysis *getDataLayoutAnalysis() const { + return dataLayoutAnalysis; + } + + /// Gets the LLVM representation of the index type. The returned type is an + /// integer type with the size configured for this type converter. + Type getIndexType(); + + /// Gets the bitwidth of the index type when converted to LLVM. + unsigned getIndexTypeBitwidth() { return options.getIndexBitwidth(); } + + /// Gets the pointer bitwidth. + unsigned getPointerBitwidth(unsigned addressSpace = 0); + + /// Returns the size of the memref descriptor object in bytes. + unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout); + + /// Returns the size of the unranked memref descriptor object in bytes. + unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, + const DataLayout &layout); + +protected: + /// Pointer to the LLVM dialect. + LLVM::LLVMDialect *llvmDialect; + +private: + /// Convert a function type. The arguments and results are converted one by + /// one. Additionally, if the function returns more than one value, pack the + /// results into an LLVM IR structure type so that the converted function type + /// returns at most one result. + Type convertFunctionType(FunctionType type); + + /// Convert the index type. Uses llvmModule data layout to create an integer + /// of the pointer bitwidth. + Type convertIndexType(IndexType type); + + /// Convert an integer type `i*` to `!llvm<"i*">`. + Type convertIntegerType(IntegerType type); + + /// Convert a floating point type: `f16` to `f16`, `f32` to + /// `f32` and `f64` to `f64`. `bf16` is not supported + /// by LLVM. + Type convertFloatType(FloatType type); + + /// Convert complex number type: `complex` to `!llvm<"{ half, half }">`, + /// `complex` to `!llvm<"{ float, float }">`, and `complex` to + /// `!llvm<"{ double, double }">`. `complex` is not supported. + Type convertComplexType(ComplexType type); + + /// Convert a memref type into an LLVM type that captures the relevant data. + Type convertMemRefType(MemRefType type); + + /// Convert a memref type into a list of LLVM IR types that will form the + /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides` + /// arrays in the descriptors are unpacked to individual index-typed elements, + /// else they are are kept as rank-sized arrays of index type. In particular, + /// the list will contain: + /// - two pointers to the memref element type, followed by + /// - an index-typed offset, followed by + /// - (if unpackAggregates = true) + /// - one index-typed size per dimension of the memref, followed by + /// - one index-typed stride per dimension of the memref. + /// - (if unpackArrregates = false) + /// - one rank-sized array of index-type for the size of each dimension + /// - one rank-sized array of index-type for the stride of each dimension + /// + /// For example, memref is converted to the following list: + /// - `!llvm<"float*">` (allocated pointer), + /// - `!llvm<"float*">` (aligned pointer), + /// - `i64` (offset), + /// - `i64`, `i64` (sizes), + /// - `i64`, `i64` (strides). + /// These types can be recomposed to a memref descriptor struct. + SmallVector getMemRefDescriptorFields(MemRefType type, + bool unpackAggregates); + + /// Convert an unranked memref type into a list of non-aggregate LLVM IR types + /// that will form the unranked memref descriptor. In particular, this list + /// contains: + /// - an integer rank, followed by + /// - a pointer to the memref descriptor struct. + /// For example, memref<*xf32> is converted to the following list: + /// i64 (rank) + /// !llvm<"i8*"> (type-erased pointer). + /// These types can be recomposed to a unranked memref descriptor struct. + SmallVector getUnrankedMemRefDescriptorFields(); + + // Convert an unranked memref type to an LLVM type that captures the + // runtime rank and a pointer to the static ranked memref desc + Type convertUnrankedMemRefType(UnrankedMemRefType type); + + /// Convert a memref type to a bare pointer to the memref element type. + Type convertMemRefToBarePtr(BaseMemRefType type); + + /// Convert a 1D vector type into an LLVM vector type. + Type convertVectorType(VectorType type); + + /// Options for customizing the llvm lowering. + LowerToLLVMOptions options; + + /// Data layout analysis mapping scopes to layouts active in them. + const DataLayoutAnalysis *dataLayoutAnalysis; +}; + +/// Callback to convert function argument types. It converts a MemRef function +/// argument to a list of non-aggregate types containing descriptor +/// information, and an UnrankedmemRef function argument to a list containing +/// the rank and a pointer to a descriptor struct. +LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type, + SmallVectorImpl &result); + +/// Callback to convert function argument types. It converts MemRef function +/// arguments to bare pointers to the MemRef element type. +LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type, + SmallVectorImpl &result); + +} // namespace mlir + +#endif // MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -15,6 +15,8 @@ #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Transforms/DialectConversion.h" @@ -38,458 +40,7 @@ class LLVMPointerType; } // namespace LLVM -/// Callback to convert function argument types. It converts a MemRef function -/// argument to a list of non-aggregate types containing descriptor -/// information, and an UnrankedmemRef function argument to a list containing -/// the rank and a pointer to a descriptor struct. -LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter, - Type type, - SmallVectorImpl &result); - -/// Callback to convert function argument types. It converts MemRef function -/// arguments to bare pointers to the MemRef element type. -LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, - Type type, - SmallVectorImpl &result); - -/// Conversion from types in the Standard dialect to the LLVM IR dialect. -class LLVMTypeConverter : public TypeConverter { - /// Give structFuncArgTypeConverter access to memref-specific functions. - friend LogicalResult - structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, - SmallVectorImpl &result); - -public: - using TypeConverter::convertType; - - /// Create an LLVMTypeConverter using the default LowerToLLVMOptions. - /// Optionally takes a data layout analysis to use in conversions. - LLVMTypeConverter(MLIRContext *ctx, - const DataLayoutAnalysis *analysis = nullptr); - - /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. Optionally - /// takes a data layout analysis to use in conversions. - LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options, - const DataLayoutAnalysis *analysis = nullptr); - - /// Convert a function type. The arguments and results are converted one by - /// one and results are packed into a wrapped LLVM IR structure type. `result` - /// is populated with argument mapping. - Type convertFunctionSignature(FunctionType funcTy, bool isVariadic, - SignatureConversion &result); - - /// Convert a non-empty list of types to be returned from a function into a - /// supported LLVM IR type. In particular, if more than one value is - /// returned, create an LLVM IR structure type with elements that correspond - /// to each of the MLIR types converted with `convertType`. - Type packFunctionResults(TypeRange types); - - /// Convert a type in the context of the default or bare pointer calling - /// convention. Calling convention sensitive types, such as MemRefType and - /// UnrankedMemRefType, are converted following the specific rules for the - /// calling convention. Calling convention independent types are converted - /// following the default LLVM type conversions. - Type convertCallingConventionType(Type type); - - /// Promote the bare pointers in 'values' that resulted from memrefs to - /// descriptors. 'stdTypes' holds the types of 'values' before the conversion - /// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). - void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, - Location loc, ArrayRef stdTypes, - SmallVectorImpl &values); - - /// Returns the MLIR context. - MLIRContext &getContext(); - - /// Returns the LLVM dialect. - LLVM::LLVMDialect *getDialect() { return llvmDialect; } - - const LowerToLLVMOptions &getOptions() const { return options; } - - /// Promote the LLVM representation of all operands including promoting MemRef - /// descriptors to stack and use pointers to struct to avoid the complexity - /// of the platform-specific C/C++ ABI lowering related to struct argument - /// passing. - SmallVector promoteOperands(Location loc, ValueRange opOperands, - ValueRange operands, - OpBuilder &builder); - - /// Promote the LLVM struct representation of one MemRef descriptor to stack - /// and use pointer to struct to avoid the complexity of the platform-specific - /// C/C++ ABI lowering related to struct argument passing. - Value promoteOneMemRefDescriptor(Location loc, Value operand, - OpBuilder &builder); - - /// Converts the function type to a C-compatible format, in particular using - /// pointers to memref descriptors for arguments. Also converts the return - /// type to a pointer argument if it is a struct. Returns true if this - /// was the case. - std::pair convertFunctionTypeCWrapper(FunctionType type); - - /// Returns the data layout to use during and after conversion. - const llvm::DataLayout &getDataLayout() { return options.dataLayout; } - - /// Returns the data layout analysis to query during conversion. - const DataLayoutAnalysis *getDataLayoutAnalysis() const { - return dataLayoutAnalysis; - } - - /// Gets the LLVM representation of the index type. The returned type is an - /// integer type with the size configured for this type converter. - Type getIndexType(); - - /// Gets the bitwidth of the index type when converted to LLVM. - unsigned getIndexTypeBitwidth() { return options.getIndexBitwidth(); } - - /// Gets the pointer bitwidth. - unsigned getPointerBitwidth(unsigned addressSpace = 0); - - /// Returns the size of the memref descriptor object in bytes. - unsigned getMemRefDescriptorSize(MemRefType type, const DataLayout &layout); - - /// Returns the size of the unranked memref descriptor object in bytes. - unsigned getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, - const DataLayout &layout); - -protected: - /// Pointer to the LLVM dialect. - LLVM::LLVMDialect *llvmDialect; - -private: - /// Convert a function type. The arguments and results are converted one by - /// one. Additionally, if the function returns more than one value, pack the - /// results into an LLVM IR structure type so that the converted function type - /// returns at most one result. - Type convertFunctionType(FunctionType type); - - /// Convert the index type. Uses llvmModule data layout to create an integer - /// of the pointer bitwidth. - Type convertIndexType(IndexType type); - - /// Convert an integer type `i*` to `!llvm<"i*">`. - Type convertIntegerType(IntegerType type); - - /// Convert a floating point type: `f16` to `f16`, `f32` to - /// `f32` and `f64` to `f64`. `bf16` is not supported - /// by LLVM. - Type convertFloatType(FloatType type); - - /// Convert complex number type: `complex` to `!llvm<"{ half, half }">`, - /// `complex` to `!llvm<"{ float, float }">`, and `complex` to - /// `!llvm<"{ double, double }">`. `complex` is not supported. - Type convertComplexType(ComplexType type); - - /// Convert a memref type into an LLVM type that captures the relevant data. - Type convertMemRefType(MemRefType type); - - /// Convert a memref type into a list of LLVM IR types that will form the - /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides` - /// arrays in the descriptors are unpacked to individual index-typed elements, - /// else they are are kept as rank-sized arrays of index type. In particular, - /// the list will contain: - /// - two pointers to the memref element type, followed by - /// - an index-typed offset, followed by - /// - (if unpackAggregates = true) - /// - one index-typed size per dimension of the memref, followed by - /// - one index-typed stride per dimension of the memref. - /// - (if unpackArrregates = false) - /// - one rank-sized array of index-type for the size of each dimension - /// - one rank-sized array of index-type for the stride of each dimension - /// - /// For example, memref is converted to the following list: - /// - `!llvm<"float*">` (allocated pointer), - /// - `!llvm<"float*">` (aligned pointer), - /// - `i64` (offset), - /// - `i64`, `i64` (sizes), - /// - `i64`, `i64` (strides). - /// These types can be recomposed to a memref descriptor struct. - SmallVector getMemRefDescriptorFields(MemRefType type, - bool unpackAggregates); - - /// Convert an unranked memref type into a list of non-aggregate LLVM IR types - /// that will form the unranked memref descriptor. In particular, this list - /// contains: - /// - an integer rank, followed by - /// - a pointer to the memref descriptor struct. - /// For example, memref<*xf32> is converted to the following list: - /// i64 (rank) - /// !llvm<"i8*"> (type-erased pointer). - /// These types can be recomposed to a unranked memref descriptor struct. - SmallVector getUnrankedMemRefDescriptorFields(); - - // Convert an unranked memref type to an LLVM type that captures the - // runtime rank and a pointer to the static ranked memref desc - Type convertUnrankedMemRefType(UnrankedMemRefType type); - - /// Convert a memref type to a bare pointer to the memref element type. - Type convertMemRefToBarePtr(BaseMemRefType type); - - /// Convert a 1D vector type into an LLVM vector type. - Type convertVectorType(VectorType type); - - /// Options for customizing the llvm lowering. - LowerToLLVMOptions options; - - /// Data layout analysis mapping scopes to layouts active in them. - const DataLayoutAnalysis *dataLayoutAnalysis; -}; - -/// Helper class to produce LLVM dialect operations extracting or inserting -/// values to a struct. -class StructBuilder { -public: - /// Construct a helper for the given value. - explicit StructBuilder(Value v); - /// Builds IR creating an `undef` value of the descriptor type. - static StructBuilder undef(OpBuilder &builder, Location loc, - Type descriptorType); - - /*implicit*/ operator Value() { return value; } - -protected: - // LLVM value - Value value; - // Cached struct type. - Type structType; - -protected: - /// Builds IR to extract a value from the struct at position pos - Value extractPtr(OpBuilder &builder, Location loc, unsigned pos); - /// Builds IR to set a value in the struct at position pos - void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); -}; - -class ComplexStructBuilder : public StructBuilder { -public: - /// Construct a helper for the given complex number value. - using StructBuilder::StructBuilder; - /// Build IR creating an `undef` value of the complex number type. - static ComplexStructBuilder undef(OpBuilder &builder, Location loc, - Type type); - - // Build IR extracting the real value from the complex number struct. - Value real(OpBuilder &builder, Location loc); - // Build IR inserting the real value into the complex number struct. - void setReal(OpBuilder &builder, Location loc, Value real); - - // Build IR extracting the imaginary value from the complex number struct. - Value imaginary(OpBuilder &builder, Location loc); - // Build IR inserting the imaginary value into the complex number struct. - void setImaginary(OpBuilder &builder, Location loc, Value imaginary); -}; - -/// Helper class to produce LLVM dialect operations extracting or inserting -/// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. -/// The Value may be null, in which case none of the operations are valid. -class MemRefDescriptor : public StructBuilder { -public: - /// Construct a helper for the given descriptor value. - explicit MemRefDescriptor(Value descriptor); - /// Builds IR creating an `undef` value of the descriptor type. - static MemRefDescriptor undef(OpBuilder &builder, Location loc, - Type descriptorType); - /// Builds IR creating a MemRef descriptor that represents `type` and - /// populates it with static shape and stride information extracted from the - /// type. - static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - MemRefType type, Value memory); - - /// Builds IR extracting the allocated pointer from the descriptor. - Value allocatedPtr(OpBuilder &builder, Location loc); - /// Builds IR inserting the allocated pointer into the descriptor. - void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); - - /// Builds IR extracting the aligned pointer from the descriptor. - Value alignedPtr(OpBuilder &builder, Location loc); - - /// Builds IR inserting the aligned pointer into the descriptor. - void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); - - /// Builds IR extracting the offset from the descriptor. - Value offset(OpBuilder &builder, Location loc); - - /// Builds IR inserting the offset into the descriptor. - void setOffset(OpBuilder &builder, Location loc, Value offset); - void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); - - /// Builds IR extracting the pos-th size from the descriptor. - Value size(OpBuilder &builder, Location loc, unsigned pos); - Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank); - - /// Builds IR inserting the pos-th size into the descriptor - void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); - void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, - uint64_t size); - - /// Builds IR extracting the pos-th size from the descriptor. - Value stride(OpBuilder &builder, Location loc, unsigned pos); - - /// Builds IR inserting the pos-th stride into the descriptor - void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); - void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, - uint64_t stride); - - /// Returns the (LLVM) pointer type this descriptor contains. - LLVM::LLVMPointerType getElementPtrType(); - - /// Builds IR populating a MemRef descriptor structure from a list of - /// individual values composing that descriptor, in the following order: - /// - allocated pointer; - /// - aligned pointer; - /// - offset; - /// - sizes; - /// - shapes; - /// where is the MemRef rank as provided in `type`. - static Value pack(OpBuilder &builder, Location loc, - LLVMTypeConverter &converter, MemRefType type, - ValueRange values); - - /// Builds IR extracting individual elements of a MemRef descriptor structure - /// and returning them as `results` list. - static void unpack(OpBuilder &builder, Location loc, Value packed, - MemRefType type, SmallVectorImpl &results); - - /// Returns the number of non-aggregate values that would be produced by - /// `unpack`. - static unsigned getNumUnpackedValues(MemRefType type); - -private: - // Cached index type. - Type indexType; -}; - -/// Helper class allowing the user to access a range of Values that correspond -/// to an unpacked memref descriptor using named accessors. This does not own -/// the values. -class MemRefDescriptorView { -public: - /// Constructs the view from a range of values. Infers the rank from the size - /// of the range. - explicit MemRefDescriptorView(ValueRange range); - - /// Returns the allocated pointer Value. - Value allocatedPtr(); - - /// Returns the aligned pointer Value. - Value alignedPtr(); - - /// Returns the offset Value. - Value offset(); - - /// Returns the pos-th size Value. - Value size(unsigned pos); - - /// Returns the pos-th stride Value. - Value stride(unsigned pos); - -private: - /// Rank of the memref the descriptor is pointing to. - int rank; - /// Underlying range of Values. - ValueRange elements; -}; - -class UnrankedMemRefDescriptor : public StructBuilder { -public: - /// Construct a helper for the given descriptor value. - explicit UnrankedMemRefDescriptor(Value descriptor); - /// Builds IR creating an `undef` value of the descriptor type. - static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, - Type descriptorType); - - /// Builds IR extracting the rank from the descriptor - Value rank(OpBuilder &builder, Location loc); - /// Builds IR setting the rank in the descriptor - void setRank(OpBuilder &builder, Location loc, Value value); - /// Builds IR extracting ranked memref descriptor ptr - Value memRefDescPtr(OpBuilder &builder, Location loc); - /// Builds IR setting ranked memref descriptor ptr - void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); - - /// Builds IR populating an unranked MemRef descriptor structure from a list - /// of individual constituent values in the following order: - /// - rank of the memref; - /// - pointer to the memref descriptor. - static Value pack(OpBuilder &builder, Location loc, - LLVMTypeConverter &converter, UnrankedMemRefType type, - ValueRange values); - - /// Builds IR extracting individual elements that compose an unranked memref - /// descriptor and returns them as `results` list. - static void unpack(OpBuilder &builder, Location loc, Value packed, - SmallVectorImpl &results); - - /// Returns the number of non-aggregate values that would be produced by - /// `unpack`. - static unsigned getNumUnpackedValues() { return 2; } - - /// Builds IR computing the sizes in bytes (suitable for opaque allocation) - /// and appends the corresponding values into `sizes`. - static void computeSizes(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - ArrayRef values, - SmallVectorImpl &sizes); - - /// TODO: The following accessors don't take alignment rules between elements - /// of the descriptor struct into account. For some architectures, it might be - /// necessary to extend them and to use `llvm::DataLayout` contained in - /// `LLVMTypeConverter`. - - /// Builds IR extracting the allocated pointer from the descriptor. - static Value allocatedPtr(OpBuilder &builder, Location loc, - Value memRefDescPtr, Type elemPtrPtrType); - /// Builds IR inserting the allocated pointer into the descriptor. - static void setAllocatedPtr(OpBuilder &builder, Location loc, - Value memRefDescPtr, Type elemPtrPtrType, - Value allocatedPtr); - - /// Builds IR extracting the aligned pointer from the descriptor. - static Value alignedPtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, Value memRefDescPtr, - Type elemPtrPtrType); - /// Builds IR inserting the aligned pointer into the descriptor. - static void setAlignedPtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, Type elemPtrPtrType, - Value alignedPtr); - - /// Builds IR extracting the offset from the descriptor. - static Value offset(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, Value memRefDescPtr, - Type elemPtrPtrType); - /// Builds IR inserting the offset into the descriptor. - static void setOffset(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, Value memRefDescPtr, - Type elemPtrPtrType, Value offset); - - /// Builds IR extracting the pointer to the first element of the size array. - static Value sizeBasePtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - LLVM::LLVMPointerType elemPtrPtrType); - /// Builds IR extracting the size[index] from the descriptor. - static Value size(OpBuilder &builder, Location loc, - LLVMTypeConverter typeConverter, Value sizeBasePtr, - Value index); - /// Builds IR inserting the size[index] into the descriptor. - static void setSize(OpBuilder &builder, Location loc, - LLVMTypeConverter typeConverter, Value sizeBasePtr, - Value index, Value size); - - /// Builds IR extracting the pointer to the first element of the stride array. - static Value strideBasePtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value sizeBasePtr, Value rank); - /// Builds IR extracting the stride[index] from the descriptor. - static Value stride(OpBuilder &builder, Location loc, - LLVMTypeConverter typeConverter, Value strideBasePtr, - Value index, Value stride); - /// Builds IR inserting the stride[index] into the descriptor. - static void setStride(OpBuilder &builder, Location loc, - LLVMTypeConverter typeConverter, Value strideBasePtr, - Value index, Value stride); -}; +// ------------------ /// Base class for operation conversions targeting the LLVM IR dialect. It /// provides the conversion patterns with access to the LLVMTypeConverter and diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -9,67 +9,17 @@ #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVMPASS_H_ -#include "llvm/IR/DataLayout.h" - #include namespace mlir { -class DataLayout; class LLVMTypeConverter; -class MLIRContext; +class LowerToLLVMOptions; class ModuleOp; template class OperationPass; class RewritePatternSet; using OwningRewritePatternList = RewritePatternSet; -/// Value to pass as bitwidth for the index type when the converter is expected -/// to derive the bitwidth from the LLVM data layout. -static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0; - -/// Options to control the Standard dialect to LLVM lowering. The struct is used -/// to share lowering options between passes, patterns, and type converter. -class LowerToLLVMOptions { -public: - explicit LowerToLLVMOptions(MLIRContext *ctx); - explicit LowerToLLVMOptions(MLIRContext *ctx, const DataLayout &dl); - - bool useBarePtrCallConv = false; - bool emitCWrappers = false; - - enum class AllocLowering { - /// Use malloc for for heap allocations. - Malloc, - - /// Use aligned_alloc for heap allocations. - AlignedAlloc, - - /// Do not lower heap allocations. Users must provide their own patterns for - /// AllocOp and DeallocOp lowering. - None - }; - - AllocLowering allocLowering = AllocLowering::Malloc; - - /// The data layout of the module to produce. This must be consistent with the - /// data layout used in the upper levels of the lowering pipeline. - // TODO: this should be replaced by MLIR data layout when one exists. - llvm::DataLayout dataLayout = llvm::DataLayout(""); - - /// Set the index bitwidth to the given value. - void overrideIndexBitwidth(unsigned bitwidth) { - assert(bitwidth != kDeriveIndexBitwidthFromDataLayout && - "can only override to a concrete bitwidth"); - indexBitwidth = bitwidth; - } - - /// Get the index bitwidth. - unsigned getIndexBitwidth() const { return indexBitwidth; } - -private: - unsigned indexBitwidth; -}; - /// Collect a set of patterns to convert memory-related operations from the /// Standard dialect to the LLVM dialect, excluding non-memory-related /// operations and FuncOp. diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -11,6 +11,7 @@ add_subdirectory(LinalgToLLVM) add_subdirectory(LinalgToSPIRV) add_subdirectory(LinalgToStandard) +add_subdirectory(LLVMCommon) add_subdirectory(MathToLibm) add_subdirectory(OpenACCToLLVM) add_subdirectory(OpenACCToSCF) diff --git a/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt b/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/ComplexToLLVM/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRComplex + MLIRLLVMCommonConversion MLIRLLVMIR MLIRStandardOpsTransforms MLIRStandardToLLVM diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp --- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp +++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp @@ -9,12 +9,48 @@ #include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" #include "../PassDetail.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" using namespace mlir; using namespace mlir::LLVM; +//===----------------------------------------------------------------------===// +// ComplexStructBuilder implementation. +//===----------------------------------------------------------------------===// + +static constexpr unsigned kRealPosInComplexNumberStruct = 0; +static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; + +ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, + Location loc, Type type) { + Value val = builder.create(loc, type); + return ComplexStructBuilder(val); +} + +void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, + Value real) { + setPtr(builder, loc, kRealPosInComplexNumberStruct, real); +} + +Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kRealPosInComplexNumberStruct); +} + +void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, + Value imaginary) { + setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); +} + +Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); +} + +//===----------------------------------------------------------------------===// +// Conversion patterns. +//===----------------------------------------------------------------------===// + namespace { struct AbsOpConversion : public ConvertOpToLLVMPattern { diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRGPUOps MLIRGPUToGPURuntimeTransforms + MLIRLLVMCommonConversion MLIRLLVMIR MLIRMemRef MLIRNVVMIR diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -13,7 +13,9 @@ #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt @@ -12,6 +12,7 @@ LINK_LIBS PUBLIC MLIRGPUOps MLIRGPUToGPURuntimeTransforms + MLIRLLVMCommonConversion MLIRLLVMIR MLIRROCDLIR MLIRPass diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -13,7 +13,9 @@ #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "mlir/Dialect/GPU/GPUDialect.h" diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt @@ -0,0 +1,15 @@ +add_mlir_conversion_library(MLIRLLVMCommonConversion + LoweringOptions.cpp + MemRefBuilder.cpp + StructBuilder.cpp + TypeConverter.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSupport + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/LLVMCommon/LoweringOptions.cpp b/mlir/lib/Conversion/LLVMCommon/LoweringOptions.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LLVMCommon/LoweringOptions.cpp @@ -0,0 +1,21 @@ +//===- LoweringOptions.cpp - Common config for lowering to LLVM ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" + +using namespace mlir; + +mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx) + : LowerToLLVMOptions(ctx, DataLayout()) {} + +mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx, + const DataLayout &dl) { + indexBitwidth = dl.getTypeSizeInBits(IndexType::get(ctx)); +} diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -0,0 +1,525 @@ +//===- MemRefBuilder.cpp - Helper for LLVM MemRef equivalents -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "MemRefDescriptor.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/Support/MathExtras.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// MemRefDescriptor implementation +//===----------------------------------------------------------------------===// + +/// Construct a helper for the given descriptor value. +MemRefDescriptor::MemRefDescriptor(Value descriptor) + : StructBuilder(descriptor) { + assert(value != nullptr && "value cannot be null"); + indexType = value.getType() + .cast() + .getBody()[kOffsetPosInMemRefDescriptor]; +} + +/// Builds IR creating an `undef` value of the descriptor type. +MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, + Type descriptorType) { + + Value descriptor = builder.create(loc, descriptorType); + return MemRefDescriptor(descriptor); +} + +/// Builds IR creating a MemRef descriptor that represents `type` and +/// populates it with static shape and stride information extracted from the +/// type. +MemRefDescriptor +MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + MemRefType type, Value memory) { + assert(type.hasStaticShape() && "unexpected dynamic shape"); + + // Extract all strides and offsets and verify they are static. + int64_t offset; + SmallVector strides; + auto result = getStridesAndOffset(type, strides, offset); + (void)result; + assert(succeeded(result) && "unexpected failure in stride computation"); + assert(!MemRefType::isDynamicStrideOrOffset(offset) && + "expected static offset"); + assert(!llvm::any_of(strides, [](int64_t stride) { + return MemRefType::isDynamicStrideOrOffset(stride); + }) && "expected static strides"); + + auto convertedType = typeConverter.convertType(type); + assert(convertedType && "unexpected failure in memref type conversion"); + + auto descr = MemRefDescriptor::undef(builder, loc, convertedType); + descr.setAllocatedPtr(builder, loc, memory); + descr.setAlignedPtr(builder, loc, memory); + descr.setConstantOffset(builder, loc, offset); + + // Fill in sizes and strides + for (unsigned i = 0, e = type.getRank(); i != e; ++i) { + descr.setConstantSize(builder, loc, i, type.getDimSize(i)); + descr.setConstantStride(builder, loc, i, strides[i]); + } + return descr; +} + +/// Builds IR extracting the allocated pointer from the descriptor. +Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); +} + +/// Builds IR inserting the allocated pointer into the descriptor. +void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, + Value ptr) { + setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); +} + +/// Builds IR extracting the aligned pointer from the descriptor. +Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); +} + +/// Builds IR inserting the aligned pointer into the descriptor. +void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, + Value ptr) { + setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); +} + +// Creates a constant Op producing a value of `resultType` from an index-typed +// integer attribute. +static Value createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { + return builder.create( + loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); +} + +/// Builds IR extracting the offset from the descriptor. +Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { + return builder.create( + loc, indexType, value, + builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); +} + +/// Builds IR inserting the offset into the descriptor. +void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, + Value offset) { + value = builder.create( + loc, structType, value, offset, + builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); +} + +/// Builds IR inserting the offset into the descriptor. +void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, + uint64_t offset) { + setOffset(builder, loc, + createIndexAttrConstant(builder, loc, indexType, offset)); +} + +/// Builds IR extracting the pos-th size from the descriptor. +Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { + return builder.create( + loc, indexType, value, + builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); +} + +Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, + int64_t rank) { + auto indexPtrTy = LLVM::LLVMPointerType::get(indexType); + auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank); + auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy); + + // Copy size values to stack-allocated memory. + auto zero = createIndexAttrConstant(builder, loc, indexType, 0); + auto one = createIndexAttrConstant(builder, loc, indexType, 1); + auto sizes = builder.create( + loc, arrayTy, value, + builder.getI64ArrayAttr({kSizePosInMemRefDescriptor})); + auto sizesPtr = + builder.create(loc, arrayPtrTy, one, /*alignment=*/0); + builder.create(loc, sizes, sizesPtr); + + // Load an return size value of interest. + auto resultPtr = builder.create(loc, indexPtrTy, sizesPtr, + ValueRange({zero, pos})); + return builder.create(loc, resultPtr); +} + +/// Builds IR inserting the pos-th size into the descriptor +void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, + Value size) { + value = builder.create( + loc, structType, value, size, + builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); +} + +void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, + unsigned pos, uint64_t size) { + setSize(builder, loc, pos, + createIndexAttrConstant(builder, loc, indexType, size)); +} + +/// Builds IR extracting the pos-th stride from the descriptor. +Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { + return builder.create( + loc, indexType, value, + builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); +} + +/// Builds IR inserting the pos-th stride into the descriptor +void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, + Value stride) { + value = builder.create( + loc, structType, value, stride, + builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); +} + +void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, + unsigned pos, uint64_t stride) { + setStride(builder, loc, pos, + createIndexAttrConstant(builder, loc, indexType, stride)); +} + +LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { + return value.getType() + .cast() + .getBody()[kAlignedPtrPosInMemRefDescriptor] + .cast(); +} + +/// Creates a MemRef descriptor structure from a list of individual values +/// composing that descriptor, in the following order: +/// - allocated pointer; +/// - aligned pointer; +/// - offset; +/// - sizes; +/// - shapes; +/// where is the MemRef rank as provided in `type`. +Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, + LLVMTypeConverter &converter, MemRefType type, + ValueRange values) { + Type llvmType = converter.convertType(type); + auto d = MemRefDescriptor::undef(builder, loc, llvmType); + + d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); + d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); + d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); + + int64_t rank = type.getRank(); + for (unsigned i = 0; i < rank; ++i) { + d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); + d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); + } + + return d; +} + +/// Builds IR extracting individual elements of a MemRef descriptor structure +/// and returning them as `results` list. +void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, + MemRefType type, + SmallVectorImpl &results) { + int64_t rank = type.getRank(); + results.reserve(results.size() + getNumUnpackedValues(type)); + + MemRefDescriptor d(packed); + results.push_back(d.allocatedPtr(builder, loc)); + results.push_back(d.alignedPtr(builder, loc)); + results.push_back(d.offset(builder, loc)); + for (int64_t i = 0; i < rank; ++i) + results.push_back(d.size(builder, loc, i)); + for (int64_t i = 0; i < rank; ++i) + results.push_back(d.stride(builder, loc, i)); +} + +/// Returns the number of non-aggregate values that would be produced by +/// `unpack`. +unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { + // Two pointers, offset, sizes, shapes. + return 3 + 2 * type.getRank(); +} + +//===----------------------------------------------------------------------===// +// MemRefDescriptorView implementation. +//===----------------------------------------------------------------------===// + +MemRefDescriptorView::MemRefDescriptorView(ValueRange range) + : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} + +Value MemRefDescriptorView::allocatedPtr() { + return elements[kAllocatedPtrPosInMemRefDescriptor]; +} + +Value MemRefDescriptorView::alignedPtr() { + return elements[kAlignedPtrPosInMemRefDescriptor]; +} + +Value MemRefDescriptorView::offset() { + return elements[kOffsetPosInMemRefDescriptor]; +} + +Value MemRefDescriptorView::size(unsigned pos) { + return elements[kSizePosInMemRefDescriptor + pos]; +} + +Value MemRefDescriptorView::stride(unsigned pos) { + return elements[kSizePosInMemRefDescriptor + rank + pos]; +} + +//===----------------------------------------------------------------------===// +// UnrankedMemRefDescriptor implementation +//===----------------------------------------------------------------------===// + +/// Construct a helper for the given descriptor value. +UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) + : StructBuilder(descriptor) {} + +/// Builds IR creating an `undef` value of the descriptor type. +UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, + Location loc, + Type descriptorType) { + Value descriptor = builder.create(loc, descriptorType); + return UnrankedMemRefDescriptor(descriptor); +} +Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { + return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); +} +void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, + Value v) { + setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); +} +Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, + Location loc) { + return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); +} +void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, + Location loc, Value v) { + setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); +} + +/// Builds IR populating an unranked MemRef descriptor structure from a list +/// of individual constituent values in the following order: +/// - rank of the memref; +/// - pointer to the memref descriptor. +Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, + LLVMTypeConverter &converter, + UnrankedMemRefType type, + ValueRange values) { + Type llvmType = converter.convertType(type); + auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType); + + d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); + d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); + return d; +} + +/// Builds IR extracting individual elements that compose an unranked memref +/// descriptor and returns them as `results` list. +void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, + Value packed, + SmallVectorImpl &results) { + UnrankedMemRefDescriptor d(packed); + results.reserve(results.size() + 2); + results.push_back(d.rank(builder, loc)); + results.push_back(d.memRefDescPtr(builder, loc)); +} + +void UnrankedMemRefDescriptor::computeSizes( + OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, + ArrayRef values, SmallVectorImpl &sizes) { + if (values.empty()) + return; + + // Cache the index type. + Type indexType = typeConverter.getIndexType(); + + // Initialize shared constants. + Value one = createIndexAttrConstant(builder, loc, indexType, 1); + Value two = createIndexAttrConstant(builder, loc, indexType, 2); + Value pointerSize = createIndexAttrConstant( + builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8)); + Value indexSize = + createIndexAttrConstant(builder, loc, indexType, + ceilDiv(typeConverter.getIndexTypeBitwidth(), 8)); + + sizes.reserve(sizes.size() + values.size()); + for (UnrankedMemRefDescriptor desc : values) { + // Emit IR computing the memory necessary to store the descriptor. This + // assumes the descriptor to be + // { type*, type*, index, index[rank], index[rank] } + // and densely packed, so the total size is + // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). + // TODO: consider including the actual size (including eventual padding due + // to data layout) into the unranked descriptor. + Value doublePointerSize = + builder.create(loc, indexType, two, pointerSize); + + // (1 + 2 * rank) * sizeof(index) + Value rank = desc.rank(builder, loc); + Value doubleRank = builder.create(loc, indexType, two, rank); + Value doubleRankIncremented = + builder.create(loc, indexType, doubleRank, one); + Value rankIndexSize = builder.create( + loc, indexType, doubleRankIncremented, indexSize); + + // Total allocation size. + Value allocationSize = builder.create( + loc, indexType, doublePointerSize, rankIndexSize); + sizes.push_back(allocationSize); + } +} + +Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, + Type elemPtrPtrType) { + + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + return builder.create(loc, elementPtrPtr); +} + +void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, + Type elemPtrPtrType, + Value allocatedPtr) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + builder.create(loc, allocatedPtr, elementPtrPtr); +} + +Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + Type elemPtrPtrType) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value one = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); + Value alignedGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); + return builder.create(loc, alignedGep); +} + +void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + Type elemPtrPtrType, + Value alignedPtr) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value one = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); + Value alignedGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); + builder.create(loc, alignedPtr, alignedGep); +} + +Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + Type elemPtrPtrType) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value two = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); + Value offsetGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); + offsetGep = builder.create( + loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); + return builder.create(loc, offsetGep); +} + +void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + Type elemPtrPtrType, Value offset) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value two = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); + Value offsetGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); + offsetGep = builder.create( + loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); + builder.create(loc, offset, offsetGep); +} + +Value UnrankedMemRefDescriptor::sizeBasePtr( + OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) { + Type elemPtrTy = elemPtrPtrType.getElementType(); + Type indexTy = typeConverter.getIndexType(); + Type structPtrTy = + LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral( + indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy})); + Value structPtr = + builder.create(loc, structPtrTy, memRefDescPtr); + + Type int32_type = typeConverter.convertType(builder.getI32Type()); + Value zero = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0); + Value three = builder.create(loc, int32_type, + builder.getI32IntegerAttr(3)); + return builder.create(loc, LLVM::LLVMPointerType::get(indexTy), + structPtr, ValueRange({zero, three})); +} + +Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value sizeBasePtr, Value index) { + Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); + Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, + ValueRange({index})); + return builder.create(loc, sizeStoreGep); +} + +void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value sizeBasePtr, Value index, + Value size) { + Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); + Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, + ValueRange({index})); + builder.create(loc, size, sizeStoreGep); +} + +Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value sizeBasePtr, Value rank) { + Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); + return builder.create(loc, indexPtrTy, sizeBasePtr, + ValueRange({rank})); +} + +Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value strideBasePtr, Value index, + Value stride) { + Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); + Value strideStoreGep = builder.create( + loc, indexPtrTy, strideBasePtr, ValueRange({index})); + return builder.create(loc, strideStoreGep); +} + +void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value strideBasePtr, Value index, + Value stride) { + Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); + Value strideStoreGep = builder.create( + loc, indexPtrTy, strideBasePtr, ValueRange({index})); + builder.create(loc, stride, strideStoreGep); +} diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefDescriptor.h b/mlir/lib/Conversion/LLVMCommon/MemRefDescriptor.h new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LLVMCommon/MemRefDescriptor.h @@ -0,0 +1,25 @@ +//===- MemRefDescriptor.h - MemRef descriptor constants ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Defines constants that are used in LLVM dialect equivalents of MemRef type. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_LIB_CONVERSION_LLVMCOMMON_MEMREFDESCRIPTOR_H +#define MLIR_LIB_CONVERSION_LLVMCOMMON_MEMREFDESCRIPTOR_H + +static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0; +static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1; +static constexpr unsigned kOffsetPosInMemRefDescriptor = 2; +static constexpr unsigned kSizePosInMemRefDescriptor = 3; +static constexpr unsigned kStridePosInMemRefDescriptor = 4; + +static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0; +static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1; + +#endif // MLIR_LIB_CONVERSION_LLVMCOMMON_MEMREFDESCRIPTOR_H diff --git a/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LLVMCommon/StructBuilder.cpp @@ -0,0 +1,36 @@ +//===- StructBuilder.cpp - Helper for building LLVM structs --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/StructBuilder.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// StructBuilder implementation +//===----------------------------------------------------------------------===// + +StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) { + assert(value != nullptr && "value cannot be null"); + assert(LLVM::isCompatibleType(structType) && "expected llvm type"); +} + +Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, + unsigned pos) { + Type type = structType.cast().getBody()[pos]; + return builder.create(loc, type, value, + builder.getI64ArrayAttr(pos)); +} + +void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, + Value ptr) { + value = builder.create(loc, structType, value, ptr, + builder.getI64ArrayAttr(pos)); +} diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -0,0 +1,492 @@ +//===- TypeConverter.cpp - Convert builtin to LLVM dialect types ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "MemRefDescriptor.h" +#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" + +using namespace mlir; + +/// Create an LLVMTypeConverter using default LowerToLLVMOptions. +LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, + const DataLayoutAnalysis *analysis) + : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {} + +/// Create an LLVMTypeConverter using custom LowerToLLVMOptions. +LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, + const LowerToLLVMOptions &options, + const DataLayoutAnalysis *analysis) + : llvmDialect(ctx->getOrLoadDialect()), options(options), + dataLayoutAnalysis(analysis) { + assert(llvmDialect && "LLVM IR dialect is not registered"); + + // Register conversions for the builtin types. + addConversion([&](ComplexType type) { return convertComplexType(type); }); + addConversion([&](FloatType type) { return convertFloatType(type); }); + addConversion([&](FunctionType type) { return convertFunctionType(type); }); + addConversion([&](IndexType type) { return convertIndexType(type); }); + addConversion([&](IntegerType type) { return convertIntegerType(type); }); + addConversion([&](MemRefType type) { return convertMemRefType(type); }); + addConversion( + [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); + addConversion([&](VectorType type) { return convertVectorType(type); }); + + // LLVM-compatible types are legal, so add a pass-through conversion. + addConversion([](Type type) { + return LLVM::isCompatibleType(type) ? llvm::Optional(type) + : llvm::None; + }); + + // Materialization for memrefs creates descriptor structs from individual + // values constituting them, when descriptors are used, i.e. more than one + // value represents a memref. + addArgumentMaterialization( + [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() == 1) + return llvm::None; + return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, + inputs); + }); + addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, + ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() == 1) + return llvm::None; + return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); + }); + // Add generic source and target materializations to handle cases where + // non-LLVM types persist after an LLVM conversion. + addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() != 1) + return llvm::None; + // FIXME: These should check LLVM::DialectCastOp can actually be constructed + // from the input and result. + return builder.create(loc, resultType, inputs[0]) + .getResult(); + }); + addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() != 1) + return llvm::None; + // FIXME: These should check LLVM::DialectCastOp can actually be constructed + // from the input and result. + return builder.create(loc, resultType, inputs[0]) + .getResult(); + }); +} + +/// Returns the MLIR context. +MLIRContext &LLVMTypeConverter::getContext() { + return *getDialect()->getContext(); +} + +Type LLVMTypeConverter::getIndexType() { + return IntegerType::get(&getContext(), getIndexTypeBitwidth()); +} + +unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { + return options.dataLayout.getPointerSizeInBits(addressSpace); +} + +Type LLVMTypeConverter::convertIndexType(IndexType type) { + return getIndexType(); +} + +Type LLVMTypeConverter::convertIntegerType(IntegerType type) { + return IntegerType::get(&getContext(), type.getWidth()); +} + +Type LLVMTypeConverter::convertFloatType(FloatType type) { return type; } + +// Convert a `ComplexType` to an LLVM type. The result is a complex number +// struct with entries for the +// 1. real part and for the +// 2. imaginary part. +Type LLVMTypeConverter::convertComplexType(ComplexType type) { + auto elementType = convertType(type.getElementType()); + return LLVM::LLVMStructType::getLiteral(&getContext(), + {elementType, elementType}); +} + +// Except for signatures, MLIR function types are converted into LLVM +// pointer-to-function types. +Type LLVMTypeConverter::convertFunctionType(FunctionType type) { + SignatureConversion conversion(type.getNumInputs()); + Type converted = + convertFunctionSignature(type, /*isVariadic=*/false, conversion); + return LLVM::LLVMPointerType::get(converted); +} + +// Function types are converted to LLVM Function types by recursively converting +// argument and result types. If MLIR Function has zero results, the LLVM +// Function has one VoidType result. If MLIR Function has more than one result, +// they are into an LLVM StructType in their order of appearance. +Type LLVMTypeConverter::convertFunctionSignature( + FunctionType funcTy, bool isVariadic, + LLVMTypeConverter::SignatureConversion &result) { + // Select the argument converter depending on the calling convention. + auto funcArgConverter = options.useBarePtrCallConv + ? barePtrFuncArgTypeConverter + : structFuncArgTypeConverter; + // Convert argument types one by one and check for errors. + for (auto &en : llvm::enumerate(funcTy.getInputs())) { + Type type = en.value(); + SmallVector converted; + if (failed(funcArgConverter(*this, type, converted))) + return {}; + result.addInputs(en.index(), converted); + } + + SmallVector argTypes; + argTypes.reserve(llvm::size(result.getConvertedTypes())); + for (Type type : result.getConvertedTypes()) + argTypes.push_back(type); + + // If function does not return anything, create the void result type, + // if it returns on element, convert it, otherwise pack the result types into + // a struct. + Type resultType = funcTy.getNumResults() == 0 + ? LLVM::LLVMVoidType::get(&getContext()) + : packFunctionResults(funcTy.getResults()); + if (!resultType) + return {}; + return LLVM::LLVMFunctionType::get(resultType, argTypes, isVariadic); +} + +/// Converts the function type to a C-compatible format, in particular using +/// pointers to memref descriptors for arguments. +std::pair +LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { + SmallVector inputs; + bool resultIsNowArg = false; + + Type resultType = type.getNumResults() == 0 + ? LLVM::LLVMVoidType::get(&getContext()) + : packFunctionResults(type.getResults()); + if (!resultType) + return {}; + + if (auto structType = resultType.dyn_cast()) { + // Struct types cannot be safely returned via C interface. Make this a + // pointer argument, instead. + inputs.push_back(LLVM::LLVMPointerType::get(structType)); + resultType = LLVM::LLVMVoidType::get(&getContext()); + resultIsNowArg = true; + } + + for (Type t : type.getInputs()) { + auto converted = convertType(t); + if (!converted || !LLVM::isCompatibleType(converted)) + return {}; + if (t.isa()) + converted = LLVM::LLVMPointerType::get(converted); + inputs.push_back(converted); + } + + return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg}; +} + +/// Convert a memref type into a list of LLVM IR types that will form the +/// memref descriptor. The result contains the following types: +/// 1. The pointer to the allocated data buffer, followed by +/// 2. The pointer to the aligned data buffer, followed by +/// 3. A lowered `index`-type integer containing the distance between the +/// beginning of the buffer and the first element to be accessed through the +/// view, followed by +/// 4. An array containing as many `index`-type integers as the rank of the +/// MemRef: the array represents the size, in number of elements, of the memref +/// along the given dimension. For constant MemRef dimensions, the +/// corresponding size entry is a constant whose runtime value must match the +/// static value, followed by +/// 5. A second array containing as many `index`-type integers as the rank of +/// the MemRef: the second array represents the "stride" (in tensor abstraction +/// sense), i.e. the number of consecutive elements of the underlying buffer. +/// TODO: add assertions for the static cases. +/// +/// If `unpackAggregates` is set to true, the arrays described in (4) and (5) +/// are expanded into individual index-type elements. +/// +/// template +/// struct { +/// Elem *allocatedPtr; +/// Elem *alignedPtr; +/// Index offset; +/// Index sizes[Rank]; // omitted when rank == 0 +/// Index strides[Rank]; // omitted when rank == 0 +/// }; +SmallVector +LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, + bool unpackAggregates) { + assert(isStrided(type) && + "Non-strided layout maps must have been normalized away"); + + Type elementType = convertType(type.getElementType()); + if (!elementType) + return {}; + auto ptrTy = + LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt()); + auto indexTy = getIndexType(); + + SmallVector results = {ptrTy, ptrTy, indexTy}; + auto rank = type.getRank(); + if (rank == 0) + return results; + + if (unpackAggregates) + results.insert(results.end(), 2 * rank, indexTy); + else + results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank)); + return results; +} + +unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, + const DataLayout &layout) { + // Compute the descriptor size given that of its components indicated above. + unsigned space = type.getMemorySpaceAsInt(); + return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) + + (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType()); +} + +/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that +/// packs the descriptor fields as defined by `getMemRefDescriptorFields`. +Type LLVMTypeConverter::convertMemRefType(MemRefType type) { + // When converting a MemRefType to a struct with descriptor fields, do not + // unpack the `sizes` and `strides` arrays. + SmallVector types = + getMemRefDescriptorFields(type, /*unpackAggregates=*/false); + if (types.empty()) + return {}; + return LLVM::LLVMStructType::getLiteral(&getContext(), types); +} + +/// Convert an unranked memref type into a list of non-aggregate LLVM IR types +/// that will form the unranked memref descriptor. In particular, the fields +/// for an unranked memref descriptor are: +/// 1. index-typed rank, the dynamic rank of this MemRef +/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be +/// stack allocated (alloca) copy of a MemRef descriptor that got casted to +/// be unranked. +SmallVector LLVMTypeConverter::getUnrankedMemRefDescriptorFields() { + return {getIndexType(), + LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8))}; +} + +unsigned +LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, + const DataLayout &layout) { + // Compute the descriptor size given that of its components indicated above. + unsigned space = type.getMemorySpaceAsInt(); + return layout.getTypeSize(getIndexType()) + + llvm::divideCeil(getPointerBitwidth(space), 8); +} + +Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { + if (!convertType(type.getElementType())) + return {}; + return LLVM::LLVMStructType::getLiteral(&getContext(), + getUnrankedMemRefDescriptorFields()); +} + +/// Convert a memref type to a bare pointer to the memref element type. +Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { + if (type.isa()) + // Unranked memref is not supported in the bare pointer calling convention. + return {}; + + // Check that the memref has static shape, strides and offset. Otherwise, it + // cannot be lowered to a bare pointer. + auto memrefTy = type.cast(); + if (!memrefTy.hasStaticShape()) + return {}; + + int64_t offset = 0; + SmallVector strides; + if (failed(getStridesAndOffset(memrefTy, strides, offset))) + return {}; + + for (int64_t stride : strides) + if (ShapedType::isDynamicStrideOrOffset(stride)) + return {}; + + if (ShapedType::isDynamicStrideOrOffset(offset)) + return {}; + + Type elementType = convertType(type.getElementType()); + if (!elementType) + return {}; + return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt()); +} + +/// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type +/// when n > 1. For example, `vector<4 x f32>` remains as is while, +/// `vector<4x8x16xf32>` converts to `!llvm.array<4xarray<8 x vector<16xf32>>>`. +Type LLVMTypeConverter::convertVectorType(VectorType type) { + auto elementType = convertType(type.getElementType()); + if (!elementType) + return {}; + Type vectorType = VectorType::get(type.getShape().back(), elementType); + assert(LLVM::isCompatibleVectorType(vectorType) && + "expected vector type compatible with the LLVM dialect"); + auto shape = type.getShape(); + for (int i = shape.size() - 2; i >= 0; --i) + vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); + return vectorType; +} + +/// Convert a type in the context of the default or bare pointer calling +/// convention. Calling convention sensitive types, such as MemRefType and +/// UnrankedMemRefType, are converted following the specific rules for the +/// calling convention. Calling convention independent types are converted +/// following the default LLVM type conversions. +Type LLVMTypeConverter::convertCallingConventionType(Type type) { + if (options.useBarePtrCallConv) + if (auto memrefTy = type.dyn_cast()) + return convertMemRefToBarePtr(memrefTy); + + return convertType(type); +} + +/// Promote the bare pointers in 'values' that resulted from memrefs to +/// descriptors. 'stdTypes' holds they types of 'values' before the conversion +/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). +void LLVMTypeConverter::promoteBarePtrsToDescriptors( + ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, + SmallVectorImpl &values) { + assert(stdTypes.size() == values.size() && + "The number of types and values doesn't match"); + for (unsigned i = 0, end = values.size(); i < end; ++i) + if (auto memrefTy = stdTypes[i].dyn_cast()) + values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, + memrefTy, values[i]); +} + +/// Convert a non-empty list of types to be returned from a function into a +/// supported LLVM IR type. In particular, if more than one value is returned, +/// create an LLVM IR structure type with elements that correspond to each of +/// the MLIR types converted with `convertType`. +Type LLVMTypeConverter::packFunctionResults(TypeRange types) { + assert(!types.empty() && "expected non-empty list of type"); + + if (types.size() == 1) + return convertCallingConventionType(types.front()); + + SmallVector resultTypes; + resultTypes.reserve(types.size()); + for (auto t : types) { + auto converted = convertCallingConventionType(t); + if (!converted || !LLVM::isCompatibleType(converted)) + return {}; + resultTypes.push_back(converted); + } + + return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); +} + +Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, + OpBuilder &builder) { + auto *context = builder.getContext(); + auto int64Ty = IntegerType::get(builder.getContext(), 64); + auto indexType = IndexType::get(context); + // Alloca with proper alignment. We do not expect optimizations of this + // alloca op and so we omit allocating at the entry block. + auto ptrType = LLVM::LLVMPointerType::get(operand.getType()); + Value one = builder.create(loc, int64Ty, + IntegerAttr::get(indexType, 1)); + Value allocated = + builder.create(loc, ptrType, one, /*alignment=*/0); + // Store into the alloca'ed descriptor. + builder.create(loc, operand, allocated); + return allocated; +} + +SmallVector LLVMTypeConverter::promoteOperands(Location loc, + ValueRange opOperands, + ValueRange operands, + OpBuilder &builder) { + SmallVector promotedOperands; + promotedOperands.reserve(operands.size()); + for (auto it : llvm::zip(opOperands, operands)) { + auto operand = std::get<0>(it); + auto llvmOperand = std::get<1>(it); + + if (options.useBarePtrCallConv) { + // For the bare-ptr calling convention, we only have to extract the + // aligned pointer of a memref. + if (auto memrefType = operand.getType().dyn_cast()) { + MemRefDescriptor desc(llvmOperand); + llvmOperand = desc.alignedPtr(builder, loc); + } else if (operand.getType().isa()) { + llvm_unreachable("Unranked memrefs are not supported"); + } + } else { + if (operand.getType().isa()) { + UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, + promotedOperands); + continue; + } + if (auto memrefType = operand.getType().dyn_cast()) { + MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, + promotedOperands); + continue; + } + } + + promotedOperands.push_back(llvmOperand); + } + return promotedOperands; +} + +/// Callback to convert function argument types. It converts a MemRef function +/// argument to a list of non-aggregate types containing descriptor +/// information, and an UnrankedmemRef function argument to a list containing +/// the rank and a pointer to a descriptor struct. +LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type, + SmallVectorImpl &result) { + if (auto memref = type.dyn_cast()) { + // In signatures, Memref descriptors are expanded into lists of + // non-aggregate values. + auto converted = + converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true); + if (converted.empty()) + return failure(); + result.append(converted.begin(), converted.end()); + return success(); + } + if (type.isa()) { + auto converted = converter.getUnrankedMemRefDescriptorFields(); + if (converted.empty()) + return failure(); + result.append(converted.begin(), converted.end()); + return success(); + } + auto converted = converter.convertType(type); + if (!converted) + return failure(); + result.push_back(converted); + return success(); +} + +/// Callback to convert function argument types. It converts MemRef function +/// arguments to bare pointers to the MemRef element type. +LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, + Type type, + SmallVectorImpl &result) { + auto llvmTy = converter.convertCallingConventionType(type); + if (!llvmTy) + return failure(); + + result.push_back(llvmTy); + return success(); +} diff --git a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/SPIRVToLLVM/CMakeLists.txt @@ -14,6 +14,7 @@ MLIRGPUOps MLIRSPIRV MLIRSPIRVUtils + MLIRLLVMCommonConversion MLIRLLVMIR MLIRStandardToLLVM MLIRIR diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -12,6 +12,8 @@ //===----------------------------------------------------------------------===// #include "../PassDetail.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h" #include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" diff --git a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/StandardToLLVM/CMakeLists.txt @@ -14,6 +14,7 @@ LINK_LIBS PUBLIC MLIRAnalysis MLIRDataLayoutInterfaces + MLIRLLVMCommonConversion MLIRLLVMIR MLIRMath MLIRMemRef diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -46,994 +46,12 @@ #define PASS_NAME "convert-std-to-llvm" -// Extract an LLVM IR type from the LLVM IR dialect type. -static Type unwrap(Type type) { - if (!type) - return nullptr; - auto *mlirContext = type.getContext(); - if (!LLVM::isCompatibleType(type)) - emitError(UnknownLoc::get(mlirContext), - "conversion resulted in a non-LLVM type ") - << type; - return type; -} - -/// Callback to convert function argument types. It converts a MemRef function -/// argument to a list of non-aggregate types containing descriptor -/// information, and an UnrankedmemRef function argument to a list containing -/// the rank and a pointer to a descriptor struct. -LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, - Type type, - SmallVectorImpl &result) { - if (auto memref = type.dyn_cast()) { - // In signatures, Memref descriptors are expanded into lists of - // non-aggregate values. - auto converted = - converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true); - if (converted.empty()) - return failure(); - result.append(converted.begin(), converted.end()); - return success(); - } - if (type.isa()) { - auto converted = converter.getUnrankedMemRefDescriptorFields(); - if (converted.empty()) - return failure(); - result.append(converted.begin(), converted.end()); - return success(); - } - auto converted = converter.convertType(type); - if (!converted) - return failure(); - result.push_back(converted); - return success(); -} - -/// Callback to convert function argument types. It converts MemRef function -/// arguments to bare pointers to the MemRef element type. -LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, - Type type, - SmallVectorImpl &result) { - auto llvmTy = converter.convertCallingConventionType(type); - if (!llvmTy) - return failure(); - - result.push_back(llvmTy); - return success(); -} - -/// Create an LLVMTypeConverter using default LowerToLLVMOptions. -LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, - const DataLayoutAnalysis *analysis) - : LLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), analysis) {} - -/// Create an LLVMTypeConverter using custom LowerToLLVMOptions. -LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, - const LowerToLLVMOptions &options, - const DataLayoutAnalysis *analysis) - : llvmDialect(ctx->getOrLoadDialect()), options(options), - dataLayoutAnalysis(analysis) { - assert(llvmDialect && "LLVM IR dialect is not registered"); - - // Register conversions for the builtin types. - addConversion([&](ComplexType type) { return convertComplexType(type); }); - addConversion([&](FloatType type) { return convertFloatType(type); }); - addConversion([&](FunctionType type) { return convertFunctionType(type); }); - addConversion([&](IndexType type) { return convertIndexType(type); }); - addConversion([&](IntegerType type) { return convertIntegerType(type); }); - addConversion([&](MemRefType type) { return convertMemRefType(type); }); - addConversion( - [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); - addConversion([&](VectorType type) { return convertVectorType(type); }); - - // LLVM-compatible types are legal, so add a pass-through conversion. - addConversion([](Type type) { - return LLVM::isCompatibleType(type) ? llvm::Optional(type) - : llvm::None; - }); - - // Materialization for memrefs creates descriptor structs from individual - // values constituting them, when descriptors are used, i.e. more than one - // value represents a memref. - addArgumentMaterialization( - [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, - Location loc) -> Optional { - if (inputs.size() == 1) - return llvm::None; - return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, - inputs); - }); - addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, - ValueRange inputs, - Location loc) -> Optional { - if (inputs.size() == 1) - return llvm::None; - return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); - }); - // Add generic source and target materializations to handle cases where - // non-LLVM types persist after an LLVM conversion. - addSourceMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> Optional { - if (inputs.size() != 1) - return llvm::None; - // FIXME: These should check LLVM::DialectCastOp can actually be constructed - // from the input and result. - return builder.create(loc, resultType, inputs[0]) - .getResult(); - }); - addTargetMaterialization([&](OpBuilder &builder, Type resultType, - ValueRange inputs, - Location loc) -> Optional { - if (inputs.size() != 1) - return llvm::None; - // FIXME: These should check LLVM::DialectCastOp can actually be constructed - // from the input and result. - return builder.create(loc, resultType, inputs[0]) - .getResult(); - }); -} - -/// Returns the MLIR context. -MLIRContext &LLVMTypeConverter::getContext() { - return *getDialect()->getContext(); -} - -Type LLVMTypeConverter::getIndexType() { - return IntegerType::get(&getContext(), getIndexTypeBitwidth()); -} - -unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { - return options.dataLayout.getPointerSizeInBits(addressSpace); -} - -Type LLVMTypeConverter::convertIndexType(IndexType type) { - return getIndexType(); -} - -Type LLVMTypeConverter::convertIntegerType(IntegerType type) { - return IntegerType::get(&getContext(), type.getWidth()); -} - -Type LLVMTypeConverter::convertFloatType(FloatType type) { return type; } - -// Convert a `ComplexType` to an LLVM type. The result is a complex number -// struct with entries for the -// 1. real part and for the -// 2. imaginary part. -static constexpr unsigned kRealPosInComplexNumberStruct = 0; -static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; -Type LLVMTypeConverter::convertComplexType(ComplexType type) { - auto elementType = convertType(type.getElementType()); - return LLVM::LLVMStructType::getLiteral(&getContext(), - {elementType, elementType}); -} - -// Except for signatures, MLIR function types are converted into LLVM -// pointer-to-function types. -Type LLVMTypeConverter::convertFunctionType(FunctionType type) { - SignatureConversion conversion(type.getNumInputs()); - Type converted = - convertFunctionSignature(type, /*isVariadic=*/false, conversion); - return LLVM::LLVMPointerType::get(converted); -} - -// Function types are converted to LLVM Function types by recursively converting -// argument and result types. If MLIR Function has zero results, the LLVM -// Function has one VoidType result. If MLIR Function has more than one result, -// they are into an LLVM StructType in their order of appearance. -Type LLVMTypeConverter::convertFunctionSignature( - FunctionType funcTy, bool isVariadic, - LLVMTypeConverter::SignatureConversion &result) { - // Select the argument converter depending on the calling convention. - auto funcArgConverter = options.useBarePtrCallConv - ? barePtrFuncArgTypeConverter - : structFuncArgTypeConverter; - // Convert argument types one by one and check for errors. - for (auto &en : llvm::enumerate(funcTy.getInputs())) { - Type type = en.value(); - SmallVector converted; - if (failed(funcArgConverter(*this, type, converted))) - return {}; - result.addInputs(en.index(), converted); - } - - SmallVector argTypes; - argTypes.reserve(llvm::size(result.getConvertedTypes())); - for (Type type : result.getConvertedTypes()) - argTypes.push_back(unwrap(type)); - - // If function does not return anything, create the void result type, - // if it returns on element, convert it, otherwise pack the result types into - // a struct. - Type resultType = funcTy.getNumResults() == 0 - ? LLVM::LLVMVoidType::get(&getContext()) - : unwrap(packFunctionResults(funcTy.getResults())); - if (!resultType) - return {}; - return LLVM::LLVMFunctionType::get(resultType, argTypes, isVariadic); -} - -/// Converts the function type to a C-compatible format, in particular using -/// pointers to memref descriptors for arguments. -std::pair -LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { - SmallVector inputs; - bool resultIsNowArg = false; - - Type resultType = type.getNumResults() == 0 - ? LLVM::LLVMVoidType::get(&getContext()) - : unwrap(packFunctionResults(type.getResults())); - if (!resultType) - return {}; - - if (auto structType = resultType.dyn_cast()) { - // Struct types cannot be safely returned via C interface. Make this a - // pointer argument, instead. - inputs.push_back(LLVM::LLVMPointerType::get(structType)); - resultType = LLVM::LLVMVoidType::get(&getContext()); - resultIsNowArg = true; - } - - for (Type t : type.getInputs()) { - auto converted = convertType(t); - if (!converted || !LLVM::isCompatibleType(converted)) - return {}; - if (t.isa()) - converted = LLVM::LLVMPointerType::get(converted); - inputs.push_back(converted); - } - - return {LLVM::LLVMFunctionType::get(resultType, inputs), resultIsNowArg}; -} - -static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0; -static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1; -static constexpr unsigned kOffsetPosInMemRefDescriptor = 2; -static constexpr unsigned kSizePosInMemRefDescriptor = 3; -static constexpr unsigned kStridePosInMemRefDescriptor = 4; - -/// Convert a memref type into a list of LLVM IR types that will form the -/// memref descriptor. The result contains the following types: -/// 1. The pointer to the allocated data buffer, followed by -/// 2. The pointer to the aligned data buffer, followed by -/// 3. A lowered `index`-type integer containing the distance between the -/// beginning of the buffer and the first element to be accessed through the -/// view, followed by -/// 4. An array containing as many `index`-type integers as the rank of the -/// MemRef: the array represents the size, in number of elements, of the memref -/// along the given dimension. For constant MemRef dimensions, the -/// corresponding size entry is a constant whose runtime value must match the -/// static value, followed by -/// 5. A second array containing as many `index`-type integers as the rank of -/// the MemRef: the second array represents the "stride" (in tensor abstraction -/// sense), i.e. the number of consecutive elements of the underlying buffer. -/// TODO: add assertions for the static cases. -/// -/// If `unpackAggregates` is set to true, the arrays described in (4) and (5) -/// are expanded into individual index-type elements. -/// -/// template -/// struct { -/// Elem *allocatedPtr; -/// Elem *alignedPtr; -/// Index offset; -/// Index sizes[Rank]; // omitted when rank == 0 -/// Index strides[Rank]; // omitted when rank == 0 -/// }; -SmallVector -LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, - bool unpackAggregates) { - assert(isStrided(type) && - "Non-strided layout maps must have been normalized away"); - - Type elementType = unwrap(convertType(type.getElementType())); - if (!elementType) - return {}; - auto ptrTy = - LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt()); - auto indexTy = getIndexType(); - - SmallVector results = {ptrTy, ptrTy, indexTy}; - auto rank = type.getRank(); - if (rank == 0) - return results; - - if (unpackAggregates) - results.insert(results.end(), 2 * rank, indexTy); - else - results.insert(results.end(), 2, LLVM::LLVMArrayType::get(indexTy, rank)); - return results; -} - -unsigned LLVMTypeConverter::getMemRefDescriptorSize(MemRefType type, - const DataLayout &layout) { - // Compute the descriptor size given that of its components indicated above. - unsigned space = type.getMemorySpaceAsInt(); - return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) + - (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType()); -} - -/// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that -/// packs the descriptor fields as defined by `getMemRefDescriptorFields`. -Type LLVMTypeConverter::convertMemRefType(MemRefType type) { - // When converting a MemRefType to a struct with descriptor fields, do not - // unpack the `sizes` and `strides` arrays. - SmallVector types = - getMemRefDescriptorFields(type, /*unpackAggregates=*/false); - if (types.empty()) - return {}; - return LLVM::LLVMStructType::getLiteral(&getContext(), types); -} - -static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0; -static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1; - -/// Convert an unranked memref type into a list of non-aggregate LLVM IR types -/// that will form the unranked memref descriptor. In particular, the fields -/// for an unranked memref descriptor are: -/// 1. index-typed rank, the dynamic rank of this MemRef -/// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be -/// stack allocated (alloca) copy of a MemRef descriptor that got casted to -/// be unranked. -SmallVector LLVMTypeConverter::getUnrankedMemRefDescriptorFields() { - return {getIndexType(), - LLVM::LLVMPointerType::get(IntegerType::get(&getContext(), 8))}; -} - -unsigned -LLVMTypeConverter::getUnrankedMemRefDescriptorSize(UnrankedMemRefType type, - const DataLayout &layout) { - // Compute the descriptor size given that of its components indicated above. - unsigned space = type.getMemorySpaceAsInt(); - return layout.getTypeSize(getIndexType()) + - llvm::divideCeil(getPointerBitwidth(space), 8); -} - -Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { - if (!convertType(type.getElementType())) - return {}; - return LLVM::LLVMStructType::getLiteral(&getContext(), - getUnrankedMemRefDescriptorFields()); -} - -/// Convert a memref type to a bare pointer to the memref element type. -Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { - if (type.isa()) - // Unranked memref is not supported in the bare pointer calling convention. - return {}; - - // Check that the memref has static shape, strides and offset. Otherwise, it - // cannot be lowered to a bare pointer. - auto memrefTy = type.cast(); - if (!memrefTy.hasStaticShape()) - return {}; - - int64_t offset = 0; - SmallVector strides; - if (failed(getStridesAndOffset(memrefTy, strides, offset))) - return {}; - - for (int64_t stride : strides) - if (ShapedType::isDynamicStrideOrOffset(stride)) - return {}; - - if (ShapedType::isDynamicStrideOrOffset(offset)) - return {}; - - Type elementType = unwrap(convertType(type.getElementType())); - if (!elementType) - return {}; - return LLVM::LLVMPointerType::get(elementType, type.getMemorySpaceAsInt()); -} - -/// Convert an n-D vector type to an LLVM vector type via (n-1)-D array type -/// when n > 1. For example, `vector<4 x f32>` remains as is while, -/// `vector<4x8x16xf32>` converts to `!llvm.array<4xarray<8 x vector<16xf32>>>`. -Type LLVMTypeConverter::convertVectorType(VectorType type) { - auto elementType = unwrap(convertType(type.getElementType())); - if (!elementType) - return {}; - Type vectorType = VectorType::get(type.getShape().back(), elementType); - assert(LLVM::isCompatibleVectorType(vectorType) && - "expected vector type compatible with the LLVM dialect"); - auto shape = type.getShape(); - for (int i = shape.size() - 2; i >= 0; --i) - vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); - return vectorType; -} - -/// Convert a type in the context of the default or bare pointer calling -/// convention. Calling convention sensitive types, such as MemRefType and -/// UnrankedMemRefType, are converted following the specific rules for the -/// calling convention. Calling convention independent types are converted -/// following the default LLVM type conversions. -Type LLVMTypeConverter::convertCallingConventionType(Type type) { - if (options.useBarePtrCallConv) - if (auto memrefTy = type.dyn_cast()) - return convertMemRefToBarePtr(memrefTy); - - return convertType(type); -} - -/// Promote the bare pointers in 'values' that resulted from memrefs to -/// descriptors. 'stdTypes' holds they types of 'values' before the conversion -/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). -void LLVMTypeConverter::promoteBarePtrsToDescriptors( - ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, - SmallVectorImpl &values) { - assert(stdTypes.size() == values.size() && - "The number of types and values doesn't match"); - for (unsigned i = 0, end = values.size(); i < end; ++i) - if (auto memrefTy = stdTypes[i].dyn_cast()) - values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, - memrefTy, values[i]); -} - ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit) : ConversionPattern(typeConverter, rootOpName, benefit, context) {} -//===----------------------------------------------------------------------===// -// StructBuilder implementation -//===----------------------------------------------------------------------===// - -StructBuilder::StructBuilder(Value v) : value(v), structType(v.getType()) { - assert(value != nullptr && "value cannot be null"); - assert(LLVM::isCompatibleType(structType) && "expected llvm type"); -} - -Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, - unsigned pos) { - Type type = structType.cast().getBody()[pos]; - return builder.create(loc, type, value, - builder.getI64ArrayAttr(pos)); -} - -void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, - Value ptr) { - value = builder.create(loc, structType, value, ptr, - builder.getI64ArrayAttr(pos)); -} - -//===----------------------------------------------------------------------===// -// ComplexStructBuilder implementation -//===----------------------------------------------------------------------===// - -ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, - Location loc, Type type) { - Value val = builder.create(loc, type); - return ComplexStructBuilder(val); -} - -void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, - Value real) { - setPtr(builder, loc, kRealPosInComplexNumberStruct, real); -} - -Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, kRealPosInComplexNumberStruct); -} - -void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, - Value imaginary) { - setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); -} - -Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); -} - -//===----------------------------------------------------------------------===// -// MemRefDescriptor implementation -//===----------------------------------------------------------------------===// - -/// Construct a helper for the given descriptor value. -MemRefDescriptor::MemRefDescriptor(Value descriptor) - : StructBuilder(descriptor) { - assert(value != nullptr && "value cannot be null"); - indexType = value.getType() - .cast() - .getBody()[kOffsetPosInMemRefDescriptor]; -} - -/// Builds IR creating an `undef` value of the descriptor type. -MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, - Type descriptorType) { - - Value descriptor = builder.create(loc, descriptorType); - return MemRefDescriptor(descriptor); -} - -/// Builds IR creating a MemRef descriptor that represents `type` and -/// populates it with static shape and stride information extracted from the -/// type. -MemRefDescriptor -MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - MemRefType type, Value memory) { - assert(type.hasStaticShape() && "unexpected dynamic shape"); - - // Extract all strides and offsets and verify they are static. - int64_t offset; - SmallVector strides; - auto result = getStridesAndOffset(type, strides, offset); - (void)result; - assert(succeeded(result) && "unexpected failure in stride computation"); - assert(!MemRefType::isDynamicStrideOrOffset(offset) && - "expected static offset"); - assert(!llvm::any_of(strides, [](int64_t stride) { - return MemRefType::isDynamicStrideOrOffset(stride); - }) && "expected static strides"); - - auto convertedType = typeConverter.convertType(type); - assert(convertedType && "unexpected failure in memref type conversion"); - - auto descr = MemRefDescriptor::undef(builder, loc, convertedType); - descr.setAllocatedPtr(builder, loc, memory); - descr.setAlignedPtr(builder, loc, memory); - descr.setConstantOffset(builder, loc, offset); - - // Fill in sizes and strides - for (unsigned i = 0, e = type.getRank(); i != e; ++i) { - descr.setConstantSize(builder, loc, i, type.getDimSize(i)); - descr.setConstantStride(builder, loc, i, strides[i]); - } - return descr; -} - -/// Builds IR extracting the allocated pointer from the descriptor. -Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); -} - -/// Builds IR inserting the allocated pointer into the descriptor. -void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, - Value ptr) { - setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); -} - -/// Builds IR extracting the aligned pointer from the descriptor. -Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); -} - -/// Builds IR inserting the aligned pointer into the descriptor. -void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, - Value ptr) { - setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); -} - -// Creates a constant Op producing a value of `resultType` from an index-typed -// integer attribute. -static Value createIndexAttrConstant(OpBuilder &builder, Location loc, - Type resultType, int64_t value) { - return builder.create( - loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); -} - -/// Builds IR extracting the offset from the descriptor. -Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { - return builder.create( - loc, indexType, value, - builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); -} - -/// Builds IR inserting the offset into the descriptor. -void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, - Value offset) { - value = builder.create( - loc, structType, value, offset, - builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); -} - -/// Builds IR inserting the offset into the descriptor. -void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, - uint64_t offset) { - setOffset(builder, loc, - createIndexAttrConstant(builder, loc, indexType, offset)); -} - -/// Builds IR extracting the pos-th size from the descriptor. -Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create( - loc, indexType, value, - builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); -} - -Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, - int64_t rank) { - auto indexPtrTy = LLVM::LLVMPointerType::get(indexType); - auto arrayTy = LLVM::LLVMArrayType::get(indexType, rank); - auto arrayPtrTy = LLVM::LLVMPointerType::get(arrayTy); - - // Copy size values to stack-allocated memory. - auto zero = createIndexAttrConstant(builder, loc, indexType, 0); - auto one = createIndexAttrConstant(builder, loc, indexType, 1); - auto sizes = builder.create( - loc, arrayTy, value, - builder.getI64ArrayAttr({kSizePosInMemRefDescriptor})); - auto sizesPtr = - builder.create(loc, arrayPtrTy, one, /*alignment=*/0); - builder.create(loc, sizes, sizesPtr); - - // Load an return size value of interest. - auto resultPtr = builder.create(loc, indexPtrTy, sizesPtr, - ValueRange({zero, pos})); - return builder.create(loc, resultPtr); -} - -/// Builds IR inserting the pos-th size into the descriptor -void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, - Value size) { - value = builder.create( - loc, structType, value, size, - builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); -} - -void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, - unsigned pos, uint64_t size) { - setSize(builder, loc, pos, - createIndexAttrConstant(builder, loc, indexType, size)); -} - -/// Builds IR extracting the pos-th stride from the descriptor. -Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { - return builder.create( - loc, indexType, value, - builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); -} - -/// Builds IR inserting the pos-th stride into the descriptor -void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, - Value stride) { - value = builder.create( - loc, structType, value, stride, - builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); -} - -void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, - unsigned pos, uint64_t stride) { - setStride(builder, loc, pos, - createIndexAttrConstant(builder, loc, indexType, stride)); -} - -LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { - return value.getType() - .cast() - .getBody()[kAlignedPtrPosInMemRefDescriptor] - .cast(); -} - -/// Creates a MemRef descriptor structure from a list of individual values -/// composing that descriptor, in the following order: -/// - allocated pointer; -/// - aligned pointer; -/// - offset; -/// - sizes; -/// - shapes; -/// where is the MemRef rank as provided in `type`. -Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, - LLVMTypeConverter &converter, MemRefType type, - ValueRange values) { - Type llvmType = converter.convertType(type); - auto d = MemRefDescriptor::undef(builder, loc, llvmType); - - d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); - d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); - d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); - - int64_t rank = type.getRank(); - for (unsigned i = 0; i < rank; ++i) { - d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); - d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); - } - - return d; -} - -/// Builds IR extracting individual elements of a MemRef descriptor structure -/// and returning them as `results` list. -void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, - MemRefType type, - SmallVectorImpl &results) { - int64_t rank = type.getRank(); - results.reserve(results.size() + getNumUnpackedValues(type)); - - MemRefDescriptor d(packed); - results.push_back(d.allocatedPtr(builder, loc)); - results.push_back(d.alignedPtr(builder, loc)); - results.push_back(d.offset(builder, loc)); - for (int64_t i = 0; i < rank; ++i) - results.push_back(d.size(builder, loc, i)); - for (int64_t i = 0; i < rank; ++i) - results.push_back(d.stride(builder, loc, i)); -} - -/// Returns the number of non-aggregate values that would be produced by -/// `unpack`. -unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { - // Two pointers, offset, sizes, shapes. - return 3 + 2 * type.getRank(); -} - -//===----------------------------------------------------------------------===// -// MemRefDescriptorView implementation. -//===----------------------------------------------------------------------===// - -MemRefDescriptorView::MemRefDescriptorView(ValueRange range) - : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} - -Value MemRefDescriptorView::allocatedPtr() { - return elements[kAllocatedPtrPosInMemRefDescriptor]; -} - -Value MemRefDescriptorView::alignedPtr() { - return elements[kAlignedPtrPosInMemRefDescriptor]; -} - -Value MemRefDescriptorView::offset() { - return elements[kOffsetPosInMemRefDescriptor]; -} - -Value MemRefDescriptorView::size(unsigned pos) { - return elements[kSizePosInMemRefDescriptor + pos]; -} - -Value MemRefDescriptorView::stride(unsigned pos) { - return elements[kSizePosInMemRefDescriptor + rank + pos]; -} - -//===----------------------------------------------------------------------===// -// UnrankedMemRefDescriptor implementation -//===----------------------------------------------------------------------===// - -/// Construct a helper for the given descriptor value. -UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) - : StructBuilder(descriptor) {} - -/// Builds IR creating an `undef` value of the descriptor type. -UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, - Location loc, - Type descriptorType) { - Value descriptor = builder.create(loc, descriptorType); - return UnrankedMemRefDescriptor(descriptor); -} -Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { - return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); -} -void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, - Value v) { - setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); -} -Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, - Location loc) { - return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); -} -void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, - Location loc, Value v) { - setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); -} - -/// Builds IR populating an unranked MemRef descriptor structure from a list -/// of individual constituent values in the following order: -/// - rank of the memref; -/// - pointer to the memref descriptor. -Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, - LLVMTypeConverter &converter, - UnrankedMemRefType type, - ValueRange values) { - Type llvmType = converter.convertType(type); - auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType); - - d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); - d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); - return d; -} - -/// Builds IR extracting individual elements that compose an unranked memref -/// descriptor and returns them as `results` list. -void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, - Value packed, - SmallVectorImpl &results) { - UnrankedMemRefDescriptor d(packed); - results.reserve(results.size() + 2); - results.push_back(d.rank(builder, loc)); - results.push_back(d.memRefDescPtr(builder, loc)); -} - -void UnrankedMemRefDescriptor::computeSizes( - OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - ArrayRef values, SmallVectorImpl &sizes) { - if (values.empty()) - return; - - // Cache the index type. - Type indexType = typeConverter.getIndexType(); - - // Initialize shared constants. - Value one = createIndexAttrConstant(builder, loc, indexType, 1); - Value two = createIndexAttrConstant(builder, loc, indexType, 2); - Value pointerSize = createIndexAttrConstant( - builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8)); - Value indexSize = - createIndexAttrConstant(builder, loc, indexType, - ceilDiv(typeConverter.getIndexTypeBitwidth(), 8)); - - sizes.reserve(sizes.size() + values.size()); - for (UnrankedMemRefDescriptor desc : values) { - // Emit IR computing the memory necessary to store the descriptor. This - // assumes the descriptor to be - // { type*, type*, index, index[rank], index[rank] } - // and densely packed, so the total size is - // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). - // TODO: consider including the actual size (including eventual padding due - // to data layout) into the unranked descriptor. - Value doublePointerSize = - builder.create(loc, indexType, two, pointerSize); - - // (1 + 2 * rank) * sizeof(index) - Value rank = desc.rank(builder, loc); - Value doubleRank = builder.create(loc, indexType, two, rank); - Value doubleRankIncremented = - builder.create(loc, indexType, doubleRank, one); - Value rankIndexSize = builder.create( - loc, indexType, doubleRankIncremented, indexSize); - - // Total allocation size. - Value allocationSize = builder.create( - loc, indexType, doublePointerSize, rankIndexSize); - sizes.push_back(allocationSize); - } -} - -Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc, - Value memRefDescPtr, - Type elemPtrPtrType) { - - Value elementPtrPtr = - builder.create(loc, elemPtrPtrType, memRefDescPtr); - return builder.create(loc, elementPtrPtr); -} - -void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, - Value memRefDescPtr, - Type elemPtrPtrType, - Value allocatedPtr) { - Value elementPtrPtr = - builder.create(loc, elemPtrPtrType, memRefDescPtr); - builder.create(loc, allocatedPtr, elementPtrPtr); -} - -Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - Type elemPtrPtrType) { - Value elementPtrPtr = - builder.create(loc, elemPtrPtrType, memRefDescPtr); - - Value one = - createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); - Value alignedGep = builder.create( - loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); - return builder.create(loc, alignedGep); -} - -void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - Type elemPtrPtrType, - Value alignedPtr) { - Value elementPtrPtr = - builder.create(loc, elemPtrPtrType, memRefDescPtr); - - Value one = - createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); - Value alignedGep = builder.create( - loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); - builder.create(loc, alignedPtr, alignedGep); -} - -Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - Type elemPtrPtrType) { - Value elementPtrPtr = - builder.create(loc, elemPtrPtrType, memRefDescPtr); - - Value two = - createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); - Value offsetGep = builder.create( - loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); - offsetGep = builder.create( - loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); - return builder.create(loc, offsetGep); -} - -void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value memRefDescPtr, - Type elemPtrPtrType, Value offset) { - Value elementPtrPtr = - builder.create(loc, elemPtrPtrType, memRefDescPtr); - - Value two = - createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); - Value offsetGep = builder.create( - loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); - offsetGep = builder.create( - loc, LLVM::LLVMPointerType::get(typeConverter.getIndexType()), offsetGep); - builder.create(loc, offset, offsetGep); -} - -Value UnrankedMemRefDescriptor::sizeBasePtr( - OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, - Value memRefDescPtr, LLVM::LLVMPointerType elemPtrPtrType) { - Type elemPtrTy = elemPtrPtrType.getElementType(); - Type indexTy = typeConverter.getIndexType(); - Type structPtrTy = - LLVM::LLVMPointerType::get(LLVM::LLVMStructType::getLiteral( - indexTy.getContext(), {elemPtrTy, elemPtrTy, indexTy, indexTy})); - Value structPtr = - builder.create(loc, structPtrTy, memRefDescPtr); - - Type int32_type = unwrap(typeConverter.convertType(builder.getI32Type())); - Value zero = - createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0); - Value three = builder.create(loc, int32_type, - builder.getI32IntegerAttr(3)); - return builder.create(loc, LLVM::LLVMPointerType::get(indexTy), - structPtr, ValueRange({zero, three})); -} - -Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, - LLVMTypeConverter typeConverter, - Value sizeBasePtr, Value index) { - Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); - Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, - ValueRange({index})); - return builder.create(loc, sizeStoreGep); -} - -void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, - LLVMTypeConverter typeConverter, - Value sizeBasePtr, Value index, - Value size) { - Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); - Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, - ValueRange({index})); - builder.create(loc, size, sizeStoreGep); -} - -Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, - LLVMTypeConverter &typeConverter, - Value sizeBasePtr, Value rank) { - Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); - return builder.create(loc, indexPtrTy, sizeBasePtr, - ValueRange({rank})); -} - -Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, - LLVMTypeConverter typeConverter, - Value strideBasePtr, Value index, - Value stride) { - Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); - Value strideStoreGep = builder.create( - loc, indexPtrTy, strideBasePtr, ValueRange({index})); - return builder.create(loc, strideStoreGep); -} - -void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, - LLVMTypeConverter typeConverter, - Value strideBasePtr, Value index, - Value stride) { - Type indexPtrTy = LLVM::LLVMPointerType::get(typeConverter.getIndexType()); - Value strideStoreGep = builder.create( - loc, indexPtrTy, strideBasePtr, ValueRange({index})); - builder.create(loc, stride, strideStoreGep); -} LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { return static_cast( @@ -1062,6 +80,14 @@ IntegerType::get(&getTypeConverter()->getContext(), 8)); } +// Creates a constant Op producing a value of `resultType` from an index-typed +// integer attribute. +static Value createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value) { + return builder.create( + loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); +} + Value ConvertToLLVMPattern::createIndexConstant( ConversionPatternRewriter &builder, Location loc, uint64_t value) const { return createIndexAttrConstant(builder, loc, getIndexType(), value); @@ -1116,7 +142,7 @@ Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); - auto structElementType = unwrap(typeConverter->convertType(elementType)); + auto structElementType = typeConverter->convertType(elementType); return LLVM::LLVMPointerType::get(structElementType, type.getMemorySpaceAsInt()); } @@ -2276,7 +1302,7 @@ // this to a 1D array. However, for memref.global's with an initial value, // we do not intend to flatten the ElementsAttribute when going from std -> // LLVM dialect, so the LLVM type needs to me a multi-dimension array. - Type elementType = unwrap(typeConverter.convertType(type.getElementType())); + Type elementType = typeConverter.convertType(type.getElementType()); Type arrayTy = elementType; // Shape has the outermost dim at index 0, so need to walk it backwards for (int64_t dim : llvm::reverse(type.getShape())) @@ -2342,8 +1368,7 @@ // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. - Type elementType = - unwrap(typeConverter->convertType(type.getElementType())); + Type elementType = typeConverter->convertType(type.getElementType()); Type elementPtrType = LLVM::LLVMPointerType::get(elementType, memSpace); SmallVector operands = {addressOf}; @@ -2703,7 +1728,7 @@ unsigned memorySpace = operandType.cast().getMemorySpaceAsInt(); Type elementType = operandType.cast().getElementType(); - Type llvmElementType = unwrap(typeConverter.convertType(elementType)); + Type llvmElementType = typeConverter.convertType(elementType); Type elementPtrPtrType = LLVM::LLVMPointerType::get( LLVM::LLVMPointerType::get(llvmElementType, memorySpace)); @@ -2836,7 +1861,7 @@ // Create the unranked memref descriptor that holds the ranked one. The // inner descriptor is allocated on stack. auto targetDesc = UnrankedMemRefDescriptor::undef( - rewriter, loc, unwrap(typeConverter->convertType(targetType))); + rewriter, loc, typeConverter->convertType(targetType)); targetDesc.setRank(rewriter, loc, resultRank); SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), @@ -2852,7 +1877,7 @@ &allocatedPtr, &alignedPtr, &offset); // Set pointers and offset. - Type llvmElementType = unwrap(typeConverter->convertType(elementType)); + Type llvmElementType = typeConverter->convertType(elementType); auto elementPtrPtrType = LLVM::LLVMPointerType::get( LLVM::LLVMPointerType::get(llvmElementType, addressSpace)); UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, @@ -4102,82 +3127,6 @@ populateStdToLLVMMemoryConversionPatterns(converter, patterns); } -/// Convert a non-empty list of types to be returned from a function into a -/// supported LLVM IR type. In particular, if more than one value is returned, -/// create an LLVM IR structure type with elements that correspond to each of -/// the MLIR types converted with `convertType`. -Type LLVMTypeConverter::packFunctionResults(TypeRange types) { - assert(!types.empty() && "expected non-empty list of type"); - - if (types.size() == 1) - return convertCallingConventionType(types.front()); - - SmallVector resultTypes; - resultTypes.reserve(types.size()); - for (auto t : types) { - auto converted = convertCallingConventionType(t); - if (!converted || !LLVM::isCompatibleType(converted)) - return {}; - resultTypes.push_back(converted); - } - - return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); -} - -Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, - OpBuilder &builder) { - auto *context = builder.getContext(); - auto int64Ty = IntegerType::get(builder.getContext(), 64); - auto indexType = IndexType::get(context); - // Alloca with proper alignment. We do not expect optimizations of this - // alloca op and so we omit allocating at the entry block. - auto ptrType = LLVM::LLVMPointerType::get(operand.getType()); - Value one = builder.create(loc, int64Ty, - IntegerAttr::get(indexType, 1)); - Value allocated = - builder.create(loc, ptrType, one, /*alignment=*/0); - // Store into the alloca'ed descriptor. - builder.create(loc, operand, allocated); - return allocated; -} - -SmallVector LLVMTypeConverter::promoteOperands(Location loc, - ValueRange opOperands, - ValueRange operands, - OpBuilder &builder) { - SmallVector promotedOperands; - promotedOperands.reserve(operands.size()); - for (auto it : llvm::zip(opOperands, operands)) { - auto operand = std::get<0>(it); - auto llvmOperand = std::get<1>(it); - - if (options.useBarePtrCallConv) { - // For the bare-ptr calling convention, we only have to extract the - // aligned pointer of a memref. - if (auto memrefType = operand.getType().dyn_cast()) { - MemRefDescriptor desc(llvmOperand); - llvmOperand = desc.alignedPtr(builder, loc); - } else if (operand.getType().isa()) { - llvm_unreachable("Unranked memrefs are not supported"); - } - } else { - if (operand.getType().isa()) { - UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, - promotedOperands); - continue; - } - if (auto memrefType = operand.getType().dyn_cast()) { - MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, - promotedOperands); - continue; - } - } - - promotedOperands.push_back(llvmOperand); - } - return promotedOperands; -} - namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ConvertStandardToLLVMBase { @@ -4304,10 +3253,3 @@ options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout); } -mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx) - : LowerToLLVMOptions(ctx, DataLayout()) {} - -mlir::LowerToLLVMOptions::LowerToLLVMOptions(MLIRContext *ctx, - const DataLayout &dl) { - indexBitwidth = dl.getTypeSizeInBits(IndexType::get(ctx)); -} diff --git a/mlir/tools/mlir-vulkan-runner/CMakeLists.txt b/mlir/tools/mlir-vulkan-runner/CMakeLists.txt --- a/mlir/tools/mlir-vulkan-runner/CMakeLists.txt +++ b/mlir/tools/mlir-vulkan-runner/CMakeLists.txt @@ -61,6 +61,7 @@ MLIRIR MLIRJitRunner MLIRLLVMIR + MLIRLLVMCommonConversion MLIRLLVMToLLVMIRTranslation MLIRMemRef MLIRParser diff --git a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp --- a/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp +++ b/mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp @@ -14,6 +14,7 @@ #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h" #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h" +#include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h" #include "mlir/Dialect/GPU/GPUDialect.h"