diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index 36734f809175..c52de63224b1 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -1,578 +1,637 @@ //===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Provides a dialect conversion targeting the LLVM IR dialect. By default, it // converts Standard ops and types and provides hooks for dialect-specific // extensions to the conversion. // //===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Transforms/DialectConversion.h" namespace llvm { class IntegerType; class LLVMContext; class Module; class Type; } // namespace llvm namespace mlir { class BaseMemRefType; class ComplexType; class LLVMTypeConverter; class UnrankedMemRefType; namespace LLVM { class LLVMDialect; class LLVMType; class LLVMPointerType; } // namespace LLVM /// Callback to convert function argument types. It converts a MemRef function /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing /// the rank and a pointer to a descriptor struct. LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result); /// Callback to convert function argument types. It converts MemRef function /// arguments to bare pointers to the MemRef element type. LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result); /// Conversion from types in the Standard dialect to the LLVM IR dialect. class LLVMTypeConverter : public TypeConverter { /// Give structFuncArgTypeConverter access to memref-specific functions. friend LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result); public: using TypeConverter::convertType; /// Create an LLVMTypeConverter using the default LowerToLLVMOptions. LLVMTypeConverter(MLIRContext *ctx); /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options); /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` /// is populated with argument mapping. LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic, SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a /// supported LLVM IR type. In particular, if more than one value is /// returned, create an LLVM IR structure type with elements that correspond /// to each of the MLIR types converted with `convertType`. Type packFunctionResults(ArrayRef types); /// Convert a type in the context of the default or bare pointer calling /// convention. Calling convention sensitive types, such as MemRefType and /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. Type convertCallingConventionType(Type type); /// Promote the bare pointers in 'values' that resulted from memrefs to /// descriptors. 'stdTypes' holds the types of 'values' before the conversion /// to the LLVM-IR dialect (i.e., MemRefType, or any other Standard type). void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, SmallVectorImpl &values); /// Returns the MLIR context. MLIRContext &getContext(); /// Returns the LLVM dialect. LLVM::LLVMDialect *getDialect() { return llvmDialect; } const LowerToLLVMOptions &getOptions() const { return options; } /// Promote the LLVM representation of all operands including promoting MemRef /// descriptors to stack and use pointers to struct to avoid the complexity /// of the platform-specific C/C++ ABI lowering related to struct argument /// passing. SmallVector promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder); /// Promote the LLVM struct representation of one MemRef descriptor to stack /// and use pointer to struct to avoid the complexity of the platform-specific /// C/C++ ABI lowering related to struct argument passing. Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder); /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type); /// Returns the data layout to use during and after conversion. const llvm::DataLayout &getDataLayout() { return options.dataLayout; } /// Gets the LLVM representation of the index type. The returned type is an /// integer type with the size configured for this type converter. LLVM::LLVMType getIndexType(); /// Gets the bitwidth of the index type when converted to LLVM. unsigned getIndexTypeBitwidth() { return options.indexBitwidth; } /// Gets the pointer bitwidth. unsigned getPointerBitwidth(unsigned addressSpace = 0); protected: /// Pointer to the LLVM dialect. LLVM::LLVMDialect *llvmDialect; private: /// Convert a function type. The arguments and results are converted one by /// one. Additionally, if the function returns more than one value, pack the /// results into an LLVM IR structure type so that the converted function type /// returns at most one result. Type convertFunctionType(FunctionType type); /// Convert the index type. Uses llvmModule data layout to create an integer /// of the pointer bitwidth. Type convertIndexType(IndexType type); /// Convert an integer type `i*` to `!llvm<"i*">`. Type convertIntegerType(IntegerType type); /// Convert a floating point type: `f16` to `!llvm.half`, `f32` to /// `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported /// by LLVM. Type convertFloatType(FloatType type); /// Convert complex number type: `complex` to `!llvm<"{ half, half }">`, /// `complex` to `!llvm<"{ float, float }">`, and `complex` to /// `!llvm<"{ double, double }">`. `complex` is not supported. Type convertComplexType(ComplexType type); /// Convert a memref type into an LLVM type that captures the relevant data. Type convertMemRefType(MemRefType type); /// Convert a memref type into a list of non-aggregate LLVM IR types that /// contain all the relevant data. In particular, the list will contain: /// - two pointers to the memref element type, followed by /// - an integer offset, followed by /// - one integer size per dimension of the memref, followed by /// - one integer stride per dimension of the memref. /// For example, memref is converted to the following list: /// - `!llvm<"float*">` (allocated pointer), /// - `!llvm<"float*">` (aligned pointer), /// - `!llvm.i64` (offset), /// - `!llvm.i64`, `!llvm.i64` (sizes), /// - `!llvm.i64`, `!llvm.i64` (strides). /// These types can be recomposed to a memref descriptor struct. SmallVector convertMemRefSignature(MemRefType type); /// Convert an unranked memref type into a list of non-aggregate LLVM IR types /// that contain all the relevant data. In particular, this list contains: /// - an integer rank, followed by /// - a pointer to the memref descriptor struct. /// For example, memref<*xf32> is converted to the following list: /// !llvm.i64 (rank) /// !llvm<"i8*"> (type-erased pointer). /// These types can be recomposed to a unranked memref descriptor struct. SmallVector convertUnrankedMemRefSignature(); // Convert an unranked memref type to an LLVM type that captures the // runtime rank and a pointer to the static ranked memref desc Type convertUnrankedMemRefType(UnrankedMemRefType type); /// Convert a memref type to a bare pointer to the memref element type. Type convertMemRefToBarePtr(BaseMemRefType type); // Convert a 1D vector type into an LLVM vector type. Type convertVectorType(VectorType type); /// Options for customizing the llvm lowering. LowerToLLVMOptions options; }; /// Helper class to produce LLVM dialect operations extracting or inserting /// values to a struct. class StructBuilder { public: /// Construct a helper for the given value. explicit StructBuilder(Value v); /// Builds IR creating an `undef` value of the descriptor type. static StructBuilder undef(OpBuilder &builder, Location loc, Type descriptorType); /*implicit*/ operator Value() { return value; } protected: // LLVM value Value value; // Cached struct type. Type structType; protected: /// Builds IR to extract a value from the struct at position pos Value extractPtr(OpBuilder &builder, Location loc, unsigned pos); /// Builds IR to set a value in the struct at position pos void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); }; class ComplexStructBuilder : public StructBuilder { public: /// Construct a helper for the given complex number value. using StructBuilder::StructBuilder; /// Build IR creating an `undef` value of the complex number type. static ComplexStructBuilder undef(OpBuilder &builder, Location loc, Type type); // Build IR extracting the real value from the complex number struct. Value real(OpBuilder &builder, Location loc); // Build IR inserting the real value into the complex number struct. void setReal(OpBuilder &builder, Location loc, Value real); // Build IR extracting the imaginary value from the complex number struct. Value imaginary(OpBuilder &builder, Location loc); // Build IR inserting the imaginary value into the complex number struct. void setImaginary(OpBuilder &builder, Location loc, Value imaginary); }; /// Helper class to produce LLVM dialect operations extracting or inserting /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. /// The Value may be null, in which case none of the operations are valid. class MemRefDescriptor : public StructBuilder { public: /// Construct a helper for the given descriptor value. explicit MemRefDescriptor(Value descriptor); /// Builds IR creating an `undef` value of the descriptor type. static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType); /// Builds IR creating a MemRef descriptor that represents `type` and /// populates it with static shape and stride information extracted from the /// type. static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory); /// Builds IR extracting the allocated pointer from the descriptor. Value allocatedPtr(OpBuilder &builder, Location loc); /// Builds IR inserting the allocated pointer into the descriptor. void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); /// Builds IR extracting the aligned pointer from the descriptor. Value alignedPtr(OpBuilder &builder, Location loc); /// Builds IR inserting the aligned pointer into the descriptor. void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); /// Builds IR extracting the offset from the descriptor. Value offset(OpBuilder &builder, Location loc); /// Builds IR inserting the offset into the descriptor. void setOffset(OpBuilder &builder, Location loc, Value offset); void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); /// Builds IR extracting the pos-th size from the descriptor. Value size(OpBuilder &builder, Location loc, unsigned pos); Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank); /// Builds IR inserting the pos-th size into the descriptor void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size); /// Builds IR extracting the pos-th size from the descriptor. Value stride(OpBuilder &builder, Location loc, unsigned pos); /// Builds IR inserting the pos-th stride into the descriptor void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride); /// Returns the (LLVM) pointer type this descriptor contains. LLVM::LLVMPointerType getElementPtrType(); /// Builds IR populating a MemRef descriptor structure from a list of /// individual values composing that descriptor, in the following order: /// - allocated pointer; /// - aligned pointer; /// - offset; /// - sizes; /// - shapes; /// where is the MemRef rank as provided in `type`. static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, MemRefType type, ValueRange values); /// Builds IR extracting individual elements of a MemRef descriptor structure /// and returning them as `results` list. static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl &results); /// Returns the number of non-aggregate values that would be produced by /// `unpack`. static unsigned getNumUnpackedValues(MemRefType type); private: // Cached index type. Type indexType; }; /// Helper class allowing the user to access a range of Values that correspond /// to an unpacked memref descriptor using named accessors. This does not own /// the values. class MemRefDescriptorView { public: /// Constructs the view from a range of values. Infers the rank from the size /// of the range. explicit MemRefDescriptorView(ValueRange range); /// Returns the allocated pointer Value. Value allocatedPtr(); /// Returns the aligned pointer Value. Value alignedPtr(); /// Returns the offset Value. Value offset(); /// Returns the pos-th size Value. Value size(unsigned pos); /// Returns the pos-th stride Value. Value stride(unsigned pos); private: /// Rank of the memref the descriptor is pointing to. int rank; /// Underlying range of Values. ValueRange elements; }; class UnrankedMemRefDescriptor : public StructBuilder { public: /// Construct a helper for the given descriptor value. explicit UnrankedMemRefDescriptor(Value descriptor); /// Builds IR creating an `undef` value of the descriptor type. static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType); /// Builds IR extracting the rank from the descriptor Value rank(OpBuilder &builder, Location loc); /// Builds IR setting the rank in the descriptor void setRank(OpBuilder &builder, Location loc, Value value); /// Builds IR extracting ranked memref descriptor ptr Value memRefDescPtr(OpBuilder &builder, Location loc); /// Builds IR setting ranked memref descriptor ptr void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); /// Builds IR populating an unranked MemRef descriptor structure from a list /// of individual constituent values in the following order: /// - rank of the memref; /// - pointer to the memref descriptor. static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values); /// Builds IR extracting individual elements that compose an unranked memref /// descriptor and returns them as `results` list. static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl &results); /// Returns the number of non-aggregate values that would be produced by /// `unpack`. static unsigned getNumUnpackedValues() { return 2; } /// Builds IR computing the sizes in bytes (suitable for opaque allocation) /// and appends the corresponding values into `sizes`. static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef values, SmallVectorImpl &sizes); + + /// TODO: The following accessors don't take alignment rules between elements + /// of the descriptor struct into account. For some architectures, it might be + /// necessary to extend them and to use `llvm::DataLayout` contained in + /// `LLVMTypeConverter`. + + /// Builds IR extracting the allocated pointer from the descriptor. + static Value allocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); + /// Builds IR inserting the allocated pointer into the descriptor. + static void setAllocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType, + Value allocatedPtr); + + /// Builds IR extracting the aligned pointer from the descriptor. + static Value alignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType); + /// Builds IR inserting the aligned pointer into the descriptor. + static void setAlignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, + Value alignedPtr); + + /// Builds IR extracting the offset from the descriptor. + static Value offset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType); + /// Builds IR inserting the offset into the descriptor. + static void setOffset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType, Value offset); + + /// Builds IR extracting the pointer to the first element of the size array. + static Value sizeBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); + /// Builds IR extracting the size[index] from the descriptor. + static Value size(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value sizeBasePtr, + Value index); + /// Builds IR inserting the size[index] into the descriptor. + static void setSize(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value sizeBasePtr, + Value index, Value size); + + /// Builds IR extracting the pointer to the first element of the stride array. + static Value strideBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value sizeBasePtr, Value rank); + /// Builds IR extracting the stride[index] from the descriptor. + static Value stride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value strideBasePtr, + Value index, Value stride); + /// Builds IR inserting the stride[index] into the descriptor. + static void setStride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, Value strideBasePtr, + Value index, Value stride); }; /// Base class for operation conversions targeting the LLVM IR dialect. It /// provides the conversion patterns with access to the LLVMTypeConverter and /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the /// LowerToLLVMOptions by reference meaning the references have to remain alive /// during the entire pattern lifetime. class ConvertToLLVMPattern : public ConversionPattern { public: ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1); protected: /// Returns the LLVM dialect. LLVM::LLVMDialect &getDialect() const; /// Gets the MLIR type wrapping the LLVM integer type whose bit width is /// defined by the used type converter. LLVM::LLVMType getIndexType() const; /// Gets the MLIR type wrapping the LLVM integer type whose bit width /// corresponds to that of a LLVM pointer type. LLVM::LLVMType getIntPtrType(unsigned addressSpace = 0) const; /// Gets the MLIR type wrapping the LLVM void type. LLVM::LLVMType getVoidType() const; /// Get the MLIR type wrapping the LLVM i8* type. LLVM::LLVMType getVoidPtrType() const; /// Create an LLVM dialect operation defining the given index constant. Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const; // Given subscript indices and array sizes in row-major order, // i_n, i_{n-1}, ..., i_1 // s_n, s_{n-1}, ..., s_1 // obtain a value that corresponds to the linearized subscript // \sum_k i_k * \prod_{j=1}^{k-1} s_j // by accumulating the running linearized value. // Note that `indices` and `allocSizes` are passed in the same order as they // appear in load/store operations and memref type declarations. Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, ArrayRef indices, ArrayRef allocSizes) const; // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. Value getStridedElementPtr(Location loc, Type elementTypePtr, Value descriptor, ValueRange indices, ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const; Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const; /// Returns the type of a pointer to an element of the memref. Type getElementPtrType(MemRefType type) const; /// Determines sizes to be used in the memref descriptor. void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ArrayRef dynSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes) const; /// Computes the size of type in bytes. Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const; /// Computes total number of elements for the given shape. Value getNumElements(Location loc, ArrayRef shape, ConversionPatternRewriter &rewriter) const; /// Computes total size in bytes of to store the given shape. Value getCumulativeSizeInBytes(Location loc, Type elementType, ArrayRef shape, ConversionPatternRewriter &rewriter) const; /// Creates and populates the memref descriptor struct given all its fields. /// 'strides' can be either dynamic (kDynamicStrideOrOffset) or static, but /// not a mix of the two. MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, uint64_t offset, ArrayRef strides, ArrayRef sizes, ConversionPatternRewriter &rewriter) const; protected: /// Reference to the type converter, with potential extensions. LLVMTypeConverter &typeConverter; }; /// Utility class for operation conversions targeting the LLVM dialect that /// match exactly one source operation. template class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ConvertToLLVMPattern(OpTy::getOperationName(), &typeConverter.getContext(), typeConverter, benefit) {} }; namespace LLVM { namespace detail { /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); } // namespace detail } // namespace LLVM /// Generic implementation of one-to-one conversion from "SourceOp" to /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent. /// Upholds a convention that multi-result operations get converted into an /// operation returning the LLVM IR structure type, in which case individual /// values must be extracted from using LLVM::ExtractValueOp before being used. template class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = OneToOneConvertToLLVMPattern; /// Converts the type of the result to an LLVM type, pass operands as is, /// preserve attributes. LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), operands, this->typeConverter, rewriter); } }; /// Basic lowering implementation to rewrite Ops with just one result to the /// LLVM Dialect. This supports higher-dimensional vector types. template class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = VectorConvertToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); static_assert(std::is_base_of, SourceOp>::value, "expected same operands and result type"); return LLVM::detail::vectorOneToOneRewrite(op, TargetOp::getOperationName(), operands, this->typeConverter, rewriter); } }; /// Derived class that automatically populates legalization information for /// different LLVM ops. class LLVMConversionTarget : public ConversionTarget { public: explicit LLVMConversionTarget(MLIRContext &ctx); }; } // namespace mlir #endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 152dbd1e990e..b6dc2cad4d37 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1,3813 +1,4116 @@ //===- StandardToLLVM.cpp - Standard to LLVM dialect conversion -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert MLIR standard and builtin dialects // into the LLVM IR dialect. // //===----------------------------------------------------------------------===// #include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include using namespace mlir; #define PASS_NAME "convert-std-to-llvm" // Extract an LLVM IR type from the LLVM IR dialect type. static LLVM::LLVMType unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); auto wrappedLLVMType = type.dyn_cast(); if (!wrappedLLVMType) emitError(UnknownLoc::get(mlirContext), "conversion resulted in a non-LLVM type"); return wrappedLLVMType; } /// Callback to convert function argument types. It converts a MemRef function /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing /// the rank and a pointer to a descriptor struct. LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { if (auto memref = type.dyn_cast()) { auto converted = converter.convertMemRefSignature(memref); if (converted.empty()) return failure(); result.append(converted.begin(), converted.end()); return success(); } if (type.isa()) { auto converted = converter.convertUnrankedMemRefSignature(); if (converted.empty()) return failure(); result.append(converted.begin(), converted.end()); return success(); } auto converted = converter.convertType(type); if (!converted) return failure(); result.push_back(converted); return success(); } /// Callback to convert function argument types. It converts MemRef function /// arguments to bare pointers to the MemRef element type. LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { auto llvmTy = converter.convertCallingConventionType(type); if (!llvmTy) return failure(); result.push_back(llvmTy); return success(); } /// Create an LLVMTypeConverter using default LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : LLVMTypeConverter(ctx, LowerToLLVMOptions::getDefaultOptions()) {} /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options) : llvmDialect(ctx->getOrLoadDialect()), options(options) { assert(llvmDialect && "LLVM IR dialect is not registered"); if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout) this->options.indexBitwidth = options.dataLayout.getPointerSizeInBits(); // Register conversions for the standard types. addConversion([&](ComplexType type) { return convertComplexType(type); }); addConversion([&](FloatType type) { return convertFloatType(type); }); addConversion([&](FunctionType type) { return convertFunctionType(type); }); addConversion([&](IndexType type) { return convertIndexType(type); }); addConversion([&](IntegerType type) { return convertIntegerType(type); }); addConversion([&](MemRefType type) { return convertMemRefType(type); }); addConversion( [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); addConversion([&](VectorType type) { return convertVectorType(type); }); // LLVMType is legal, so add a pass-through conversion. addConversion([](LLVM::LLVMType type) { return type; }); // Materialization for memrefs creates descriptor structs from individual // values constituting them, when descriptors are used, i.e. more than one // value represents a memref. addArgumentMaterialization( [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() == 1) return llvm::None; return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs); }); addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() == 1) return llvm::None; return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); }); // Add generic source and target materializations to handle cases where // non-LLVM types persist after an LLVM conversion. addSourceMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() != 1) return llvm::None; // FIXME: These should check LLVM::DialectCastOp can actually be constructed // from the input and result. return builder.create(loc, resultType, inputs[0]) .getResult(); }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() != 1) return llvm::None; // FIXME: These should check LLVM::DialectCastOp can actually be constructed // from the input and result. return builder.create(loc, resultType, inputs[0]) .getResult(); }); } /// Returns the MLIR context. MLIRContext &LLVMTypeConverter::getContext() { return *getDialect()->getContext(); } LLVM::LLVMType LLVMTypeConverter::getIndexType() { return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth()); } unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { return options.dataLayout.getPointerSizeInBits(addressSpace); } Type LLVMTypeConverter::convertIndexType(IndexType type) { return getIndexType(); } Type LLVMTypeConverter::convertIntegerType(IntegerType type) { return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth()); } Type LLVMTypeConverter::convertFloatType(FloatType type) { if (type.isa()) return LLVM::LLVMType::getFloatTy(&getContext()); if (type.isa()) return LLVM::LLVMType::getDoubleTy(&getContext()); if (type.isa()) return LLVM::LLVMType::getHalfTy(&getContext()); if (type.isa()) return LLVM::LLVMType::getBFloatTy(&getContext()); llvm_unreachable("non-float type in convertFloatType"); } // Convert a `ComplexType` to an LLVM type. The result is a complex number // struct with entries for the // 1. real part and for the // 2. imaginary part. static constexpr unsigned kRealPosInComplexNumberStruct = 0; static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; Type LLVMTypeConverter::convertComplexType(ComplexType type) { auto elementType = convertType(type.getElementType()).cast(); return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType}); } // Except for signatures, MLIR function types are converted into LLVM // pointer-to-function types. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { SignatureConversion conversion(type.getNumInputs()); LLVM::LLVMType converted = convertFunctionSignature(type, /*isVariadic=*/false, conversion); return converted.getPointerTo(); } /// In signatures, MemRef descriptors are expanded into lists of non-aggregate /// values. SmallVector LLVMTypeConverter::convertMemRefSignature(MemRefType type) { SmallVector results; assert(isStrided(type) && "Non-strided layout maps must have been normalized away"); LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto indexTy = getIndexType(); results.insert(results.begin(), 2, elementType.getPointerTo(type.getMemorySpace())); results.push_back(indexTy); auto rank = type.getRank(); results.insert(results.end(), 2 * rank, indexTy); return results; } /// In signatures, unranked MemRef descriptors are expanded into a pair "rank, /// pointer to descriptor". SmallVector LLVMTypeConverter::convertUnrankedMemRefSignature() { return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())}; } // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( FunctionType funcTy, bool isVariadic, LLVMTypeConverter::SignatureConversion &result) { // Select the argument converter depending on the calling convention. auto funcArgConverter = options.useBarePtrCallConv ? barePtrFuncArgTypeConverter : structFuncArgTypeConverter; // Convert argument types one by one and check for errors. for (auto &en : llvm::enumerate(funcTy.getInputs())) { Type type = en.value(); SmallVector converted; if (failed(funcArgConverter(*this, type, converted))) return {}; result.addInputs(en.index(), converted); } SmallVector argTypes; argTypes.reserve(llvm::size(result.getConvertedTypes())); for (Type type : result.getConvertedTypes()) argTypes.push_back(unwrap(type)); // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. LLVM::LLVMType resultType = funcTy.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(&getContext()) : unwrap(packFunctionResults(funcTy.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); } /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. LLVM::LLVMType LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { SmallVector inputs; for (Type t : type.getInputs()) { auto converted = convertType(t).dyn_cast_or_null(); if (!converted) return {}; if (t.isa()) converted = converted.getPointerTo(); inputs.push_back(converted); } LLVM::LLVMType resultType = type.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(&getContext()) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, inputs, false); } // Convert a MemRef to an LLVM type. The result is a MemRef descriptor which // contains: // 1. the pointer to the data buffer, followed by // 2. a lowered `index`-type integer containing the distance between the // beginning of the buffer and the first element to be accessed through the // view, followed by // 3. an array containing as many `index`-type integers as the rank of the // MemRef: the array represents the size, in number of elements, of the memref // along the given dimension. For constant MemRef dimensions, the // corresponding size entry is a constant whose runtime value must match the // static value, followed by // 4. a second array containing as many `index`-type integers as the rank of // the MemRef: the second array represents the "stride" (in tensor abstraction // sense), i.e. the number of consecutive elements of the underlying buffer. // TODO: add assertions for the static cases. // // template // struct { // Elem *allocatedPtr; // Elem *alignedPtr; // int64_t offset; // int64_t sizes[Rank]; // omitted when rank == 0 // int64_t strides[Rank]; // omitted when rank == 0 // }; static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0; static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1; static constexpr unsigned kOffsetPosInMemRefDescriptor = 2; static constexpr unsigned kSizePosInMemRefDescriptor = 3; static constexpr unsigned kStridePosInMemRefDescriptor = 4; Type LLVMTypeConverter::convertMemRefType(MemRefType type) { int64_t offset; SmallVector strides; bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset)); assert(strideSuccess && "Non-strided layout maps must have been normalized away"); (void)strideSuccess; LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); auto indexTy = getIndexType(); auto rank = type.getRank(); if (rank > 0) { auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, type.getRank()); return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy, arrayTy, arrayTy); } return LLVM::LLVMType::getStructTy(ptrTy, ptrTy, indexTy); } // Converts UnrankedMemRefType to LLVMType. The result is a descriptor which // contains: // 1. int64_t rank, the dynamic rank of this MemRef // 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be // stack allocated (alloca) copy of a MemRef descriptor that got casted to // be unranked. static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0; static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1; Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { auto rankTy = getIndexType(); auto ptrTy = LLVM::LLVMType::getInt8PtrTy(&getContext()); return LLVM::LLVMType::getStructTy(rankTy, ptrTy); } /// Convert a memref type to a bare pointer to the memref element type. Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { if (type.isa()) // Unranked memref is not supported in the bare pointer calling convention. return {}; // Check that the memref has static shape, strides and offset. Otherwise, it // cannot be lowered to a bare pointer. auto memrefTy = type.cast(); if (!memrefTy.hasStaticShape()) return {}; int64_t offset = 0; SmallVector strides; if (failed(getStridesAndOffset(memrefTy, strides, offset))) return {}; for (int64_t stride : strides) if (ShapedType::isDynamicStrideOrOffset(stride)) return {}; if (ShapedType::isDynamicStrideOrOffset(offset)) return {}; LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; return elementType.getPointerTo(type.getMemorySpace()); } // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when // n > 1. // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. Type LLVMTypeConverter::convertVectorType(VectorType type) { auto elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto vectorType = LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); return vectorType; } /// Convert a type in the context of the default or bare pointer calling /// convention. Calling convention sensitive types, such as MemRefType and /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. Type LLVMTypeConverter::convertCallingConventionType(Type type) { if (options.useBarePtrCallConv) if (auto memrefTy = type.dyn_cast()) return convertMemRefToBarePtr(memrefTy); return convertType(type); } /// Promote the bare pointers in 'values' that resulted from memrefs to /// descriptors. 'stdTypes' holds they types of 'values' before the conversion /// to the LLVM-IR dialect (i.e., MemRefType, or any other Standard type). void LLVMTypeConverter::promoteBarePtrsToDescriptors( ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, SmallVectorImpl &values) { assert(stdTypes.size() == values.size() && "The number of types and values doesn't match"); for (unsigned i = 0, end = values.size(); i < end; ++i) if (auto memrefTy = stdTypes[i].dyn_cast()) values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, memrefTy, values[i]); } ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, typeConverter, context), typeConverter(typeConverter) {} /*============================================================================*/ /* StructBuilder implementation */ /*============================================================================*/ StructBuilder::StructBuilder(Value v) : value(v) { assert(value != nullptr && "value cannot be null"); structType = value.getType().dyn_cast(); assert(structType && "expected llvm type"); } Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, unsigned pos) { Type type = structType.cast().getStructElementType(pos); return builder.create(loc, type, value, builder.getI64ArrayAttr(pos)); } void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr) { value = builder.create(loc, structType, value, ptr, builder.getI64ArrayAttr(pos)); } /*============================================================================*/ /* ComplexStructBuilder implementation */ /*============================================================================*/ ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, Location loc, Type type) { Value val = builder.create(loc, type.cast()); return ComplexStructBuilder(val); } void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, Value real) { setPtr(builder, loc, kRealPosInComplexNumberStruct, real); } Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRealPosInComplexNumberStruct); } void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, Value imaginary) { setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); } Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); } /*============================================================================*/ /* MemRefDescriptor implementation */ /*============================================================================*/ /// Construct a helper for the given descriptor value. MemRefDescriptor::MemRefDescriptor(Value descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); indexType = value.getType().cast().getStructElementType( kOffsetPosInMemRefDescriptor); } /// Builds IR creating an `undef` value of the descriptor type. MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { Value descriptor = builder.create(loc, descriptorType.cast()); return MemRefDescriptor(descriptor); } /// Builds IR creating a MemRef descriptor that represents `type` and /// populates it with static shape and stride information extracted from the /// type. MemRefDescriptor MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory) { assert(type.hasStaticShape() && "unexpected dynamic shape"); // Extract all strides and offsets and verify they are static. int64_t offset; SmallVector strides; auto result = getStridesAndOffset(type, strides, offset); (void)result; assert(succeeded(result) && "unexpected failure in stride computation"); assert(offset != MemRefType::getDynamicStrideOrOffset() && "expected static offset"); assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) && "expected static strides"); auto convertedType = typeConverter.convertType(type); assert(convertedType && "unexpected failure in memref type conversion"); auto descr = MemRefDescriptor::undef(builder, loc, convertedType); descr.setAllocatedPtr(builder, loc, memory); descr.setAlignedPtr(builder, loc, memory); descr.setConstantOffset(builder, loc, offset); // Fill in sizes and strides for (unsigned i = 0, e = type.getRank(); i != e; ++i) { descr.setConstantSize(builder, loc, i, type.getDimSize(i)); descr.setConstantStride(builder, loc, i, strides[i]); } return descr; } /// Builds IR extracting the allocated pointer from the descriptor. Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); } /// Builds IR inserting the allocated pointer into the descriptor. void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr) { setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); } /// Builds IR extracting the aligned pointer from the descriptor. Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); } /// Builds IR inserting the aligned pointer into the descriptor. void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, Value ptr) { setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } // Creates a constant Op producing a value of `resultType` from an index-typed // integer attribute. static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { return builder.create( loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); } /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { value = builder.create( loc, structType, value, offset, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset) { setOffset(builder, loc, createIndexAttrConstant(builder, loc, indexType, offset)); } /// Builds IR extracting the pos-th size from the descriptor. Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, int64_t rank) { auto indexTy = indexType.cast(); auto indexPtrTy = indexTy.getPointerTo(); auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank); auto arrayPtrTy = arrayTy.getPointerTo(); // Copy size values to stack-allocated memory. auto zero = createIndexAttrConstant(builder, loc, indexType, 0); auto one = createIndexAttrConstant(builder, loc, indexType, 1); auto sizes = builder.create( loc, arrayTy, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor})); auto sizesPtr = builder.create(loc, arrayPtrTy, one, /*alignment=*/0); builder.create(loc, sizes, sizesPtr); // Load an return size value of interest. auto resultPtr = builder.create(loc, indexPtrTy, sizesPtr, ValueRange({zero, pos})); return builder.create(loc, resultPtr); } /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { value = builder.create( loc, structType, value, size, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size) { setSize(builder, loc, pos, createIndexAttrConstant(builder, loc, indexType, size)); } /// Builds IR extracting the pos-th stride from the descriptor. Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride) { value = builder.create( loc, structType, value, stride, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride) { setStride(builder, loc, pos, createIndexAttrConstant(builder, loc, indexType, stride)); } LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { return value.getType() .cast() .getStructElementType(kAlignedPtrPosInMemRefDescriptor) .cast(); } /// Creates a MemRef descriptor structure from a list of individual values /// composing that descriptor, in the following order: /// - allocated pointer; /// - aligned pointer; /// - offset; /// - sizes; /// - shapes; /// where is the MemRef rank as provided in `type`. Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, MemRefType type, ValueRange values) { Type llvmType = converter.convertType(type); auto d = MemRefDescriptor::undef(builder, loc, llvmType); d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); int64_t rank = type.getRank(); for (unsigned i = 0; i < rank; ++i) { d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); } return d; } /// Builds IR extracting individual elements of a MemRef descriptor structure /// and returning them as `results` list. void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl &results) { int64_t rank = type.getRank(); results.reserve(results.size() + getNumUnpackedValues(type)); MemRefDescriptor d(packed); results.push_back(d.allocatedPtr(builder, loc)); results.push_back(d.alignedPtr(builder, loc)); results.push_back(d.offset(builder, loc)); for (int64_t i = 0; i < rank; ++i) results.push_back(d.size(builder, loc, i)); for (int64_t i = 0; i < rank; ++i) results.push_back(d.stride(builder, loc, i)); } /// Returns the number of non-aggregate values that would be produced by /// `unpack`. unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { // Two pointers, offset, sizes, shapes. return 3 + 2 * type.getRank(); } /*============================================================================*/ /* MemRefDescriptorView implementation. */ /*============================================================================*/ MemRefDescriptorView::MemRefDescriptorView(ValueRange range) : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} Value MemRefDescriptorView::allocatedPtr() { return elements[kAllocatedPtrPosInMemRefDescriptor]; } Value MemRefDescriptorView::alignedPtr() { return elements[kAlignedPtrPosInMemRefDescriptor]; } Value MemRefDescriptorView::offset() { return elements[kOffsetPosInMemRefDescriptor]; } Value MemRefDescriptorView::size(unsigned pos) { return elements[kSizePosInMemRefDescriptor + pos]; } Value MemRefDescriptorView::stride(unsigned pos) { return elements[kSizePosInMemRefDescriptor + rank + pos]; } /*============================================================================*/ /* UnrankedMemRefDescriptor implementation */ /*============================================================================*/ /// Construct a helper for the given descriptor value. UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) : StructBuilder(descriptor) {} /// Builds IR creating an `undef` value of the descriptor type. UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { Value descriptor = builder.create(loc, descriptorType.cast()); return UnrankedMemRefDescriptor(descriptor); } Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, Value v) { setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); } Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, Location loc, Value v) { setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); } /// Builds IR populating an unranked MemRef descriptor structure from a list /// of individual constituent values in the following order: /// - rank of the memref; /// - pointer to the memref descriptor. Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values) { Type llvmType = converter.convertType(type); auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType); d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); return d; } /// Builds IR extracting individual elements that compose an unranked memref /// descriptor and returns them as `results` list. void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl &results) { UnrankedMemRefDescriptor d(packed); results.reserve(results.size() + 2); results.push_back(d.rank(builder, loc)); results.push_back(d.memRefDescPtr(builder, loc)); } void UnrankedMemRefDescriptor::computeSizes( OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef values, SmallVectorImpl &sizes) { if (values.empty()) return; // Cache the index type. LLVM::LLVMType indexType = typeConverter.getIndexType(); // Initialize shared constants. Value one = createIndexAttrConstant(builder, loc, indexType, 1); Value two = createIndexAttrConstant(builder, loc, indexType, 2); Value pointerSize = createIndexAttrConstant( builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8)); Value indexSize = createIndexAttrConstant(builder, loc, indexType, ceilDiv(typeConverter.getIndexTypeBitwidth(), 8)); sizes.reserve(sizes.size() + values.size()); for (UnrankedMemRefDescriptor desc : values) { // Emit IR computing the memory necessary to store the descriptor. This // assumes the descriptor to be // { type*, type*, index, index[rank], index[rank] } // and densely packed, so the total size is // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). // TODO: consider including the actual size (including eventual padding due // to data layout) into the unranked descriptor. Value doublePointerSize = builder.create(loc, indexType, two, pointerSize); // (1 + 2 * rank) * sizeof(index) Value rank = desc.rank(builder, loc); Value doubleRank = builder.create(loc, indexType, two, rank); Value doubleRankIncremented = builder.create(loc, indexType, doubleRank, one); Value rankIndexSize = builder.create( loc, indexType, doubleRankIncremented, indexSize); // Total allocation size. Value allocationSize = builder.create( loc, indexType, doublePointerSize, rankIndexSize); sizes.push_back(allocationSize); } } +Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType) { + + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + return builder.create(loc, elementPtrPtr); +} + +void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType, + Value allocatedPtr) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + builder.create(loc, allocatedPtr, elementPtrPtr); +} + +Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value one = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); + Value alignedGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); + return builder.create(loc, alignedGep); +} + +void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType, + Value alignedPtr) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value one = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); + Value alignedGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); + builder.create(loc, alignedPtr, alignedGep); +} + +Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value two = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); + Value offsetGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); + offsetGep = builder.create( + loc, typeConverter.getIndexType().getPointerTo(), offsetGep); + return builder.create(loc, offsetGep); +} + +void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType, + Value offset) { + Value elementPtrPtr = + builder.create(loc, elemPtrPtrType, memRefDescPtr); + + Value two = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); + Value offsetGep = builder.create( + loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); + offsetGep = builder.create( + loc, typeConverter.getIndexType().getPointerTo(), offsetGep); + builder.create(loc, offset, offsetGep); +} + +Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value memRefDescPtr, + LLVM::LLVMType elemPtrPtrType) { + LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy(); + LLVM::LLVMType indexTy = typeConverter.getIndexType(); + LLVM::LLVMType structPtrTy = + LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy) + .getPointerTo(); + Value structPtr = + builder.create(loc, structPtrTy, memRefDescPtr); + + LLVM::LLVMType int32_type = + unwrap(typeConverter.convertType(builder.getI32Type())); + Value zero = + createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0); + Value three = builder.create(loc, int32_type, + builder.getI32IntegerAttr(3)); + return builder.create(loc, indexTy.getPointerTo(), structPtr, + ValueRange({zero, three})); +} + +Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value sizeBasePtr, Value index) { + LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, + ValueRange({index})); + return builder.create(loc, sizeStoreGep); +} + +void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value sizeBasePtr, Value index, + Value size) { + LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, + ValueRange({index})); + builder.create(loc, size, sizeStoreGep); +} + +Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, + LLVMTypeConverter &typeConverter, + Value sizeBasePtr, Value rank) { + LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + return builder.create(loc, indexPtrTy, sizeBasePtr, + ValueRange({rank})); +} + +Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value strideBasePtr, Value index, + Value stride) { + LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + Value strideStoreGep = builder.create( + loc, indexPtrTy, strideBasePtr, ValueRange({index})); + return builder.create(loc, strideStoreGep); +} + +void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, + LLVMTypeConverter typeConverter, + Value strideBasePtr, Value index, + Value stride) { + LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); + Value strideStoreGep = builder.create( + loc, indexPtrTy, strideBasePtr, ValueRange({index})); + builder.create(loc, stride, strideStoreGep); +} + LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { return *typeConverter.getDialect(); } LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { return typeConverter.getIndexType(); } LLVM::LLVMType ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { return LLVM::LLVMType::getIntNTy( &typeConverter.getContext(), typeConverter.getPointerBitwidth(addressSpace)); } LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const { return LLVM::LLVMType::getVoidTy(&typeConverter.getContext()); } LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const { return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext()); } Value ConvertToLLVMPattern::createIndexConstant( ConversionPatternRewriter &builder, Location loc, uint64_t value) const { return createIndexAttrConstant(builder, loc, getIndexType(), value); } Value ConvertToLLVMPattern::linearizeSubscripts( ConversionPatternRewriter &builder, Location loc, ArrayRef indices, ArrayRef allocSizes) const { assert(indices.size() == allocSizes.size() && "mismatching number of indices and allocation sizes"); assert(!indices.empty() && "cannot linearize a 0-dimensional access"); Value linearized = indices.front(); for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) { linearized = builder.create( loc, this->getIndexType(), ArrayRef{linearized, allocSizes[i]}); linearized = builder.create( loc, this->getIndexType(), ArrayRef{linearized, indices[i]}); } return linearized; } Value ConvertToLLVMPattern::getStridedElementPtr( Location loc, Type elementTypePtr, Value descriptor, ValueRange indices, ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const { MemRefDescriptor memRefDescriptor(descriptor); Value base = memRefDescriptor.alignedPtr(rewriter, loc); Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.offset(rewriter, loc) : createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset() ? memRefDescriptor.stride(rewriter, loc, i) : createIndexConstant(rewriter, loc, strides[i]); Value additionalOffset = rewriter.create(loc, indices[i], stride); offsetValue = rewriter.create(loc, offsetValue, additionalOffset); } return rewriter.create(loc, elementTypePtr, base, offsetValue); } Value ConvertToLLVMPattern::getDataPtr( Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementPtrType(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(type, strides, offset); assert(succeeded(successStrides) && "unexpected non-strided memref"); (void)successStrides; return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides, offset, rewriter); } Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); auto structElementType = typeConverter.convertType(elementType); return structElementType.cast().getPointerTo( type.getMemorySpace()); } void ConvertToLLVMPattern::getMemRefDescriptorSizes( Location loc, MemRefType memRefType, ArrayRef dynSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes) const { sizes.reserve(memRefType.getRank()); unsigned i = 0; for (int64_t s : memRefType.getShape()) sizes.push_back(s == ShapedType::kDynamicSize ? dynSizes[i++] : createIndexConstant(rewriter, loc, s)); } Value ConvertToLLVMPattern::getSizeInBytes( Location loc, Type type, ConversionPatternRewriter &rewriter) const { // Compute the size of an individual element. This emits the MLIR equivalent // of the following sizeof(...) implementation in LLVM IR: // %0 = getelementptr %elementType* null, %indexType 1 // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. auto convertedPtrType = typeConverter.convertType(type).cast().getPointerTo(); auto nullPtr = rewriter.create(loc, convertedPtrType); auto gep = rewriter.create( loc, convertedPtrType, ArrayRef{nullPtr, createIndexConstant(rewriter, loc, 1)}); return rewriter.create(loc, getIndexType(), gep); } Value ConvertToLLVMPattern::getNumElements( Location loc, ArrayRef shape, ConversionPatternRewriter &rewriter) const { // Compute the total number of memref elements. Value numElements = shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); for (unsigned i = 1, e = shape.size(); i < e; ++i) numElements = rewriter.create(loc, numElements, shape[i]); return numElements; } Value ConvertToLLVMPattern::getCumulativeSizeInBytes( Location loc, Type elementType, ArrayRef shape, ConversionPatternRewriter &rewriter) const { Value numElements = this->getNumElements(loc, shape, rewriter); Value elementSize = this->getSizeInBytes(loc, elementType, rewriter); return rewriter.create(loc, numElements, elementSize); } /// Creates and populates the memref descriptor struct given all its fields. MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, uint64_t offset, ArrayRef strides, ArrayRef sizes, ConversionPatternRewriter &rewriter) const { auto structType = typeConverter.convertType(memRefType); auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); // Field 1: Allocated pointer, used for malloc/free. memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); // Field 2: Actual aligned pointer to payload. memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); // Field 3: Offset in aligned pointer. memRefDescriptor.setOffset(rewriter, loc, createIndexConstant(rewriter, loc, offset)); if (memRefType.getRank() == 0) // No size/stride descriptor in memref, return the descriptor value. return memRefDescriptor; // Fields 4 and 5: sizes and strides of the strided MemRef. // Store all sizes in the descriptor. Only dynamic sizes are passed in as // operands to AllocOp. Value runningStride = nullptr; // Iterate strides in reverse order, compute runningStride and strideValues. auto nStrides = strides.size(); SmallVector strideValues(nStrides, nullptr); for (unsigned i = 0; i < nStrides; ++i) { int64_t index = nStrides - 1 - i; if (strides[index] == MemRefType::getDynamicStrideOrOffset()) // Identity layout map is enforced in the match function, so we compute: // `runningStride *= sizes[index + 1]` runningStride = runningStride ? rewriter.create( loc, runningStride, sizes[index + 1]) : createIndexConstant(rewriter, loc, 1); else runningStride = createIndexConstant(rewriter, loc, strides[index]); strideValues[index] = runningStride; } // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(sizes)) { int64_t index = indexedSize.index(); memRefDescriptor.setSize(rewriter, loc, index, indexedSize.value()); memRefDescriptor.setStride(rewriter, loc, index, strideValues[index]); } return memRefDescriptor; } /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. static void filterFuncAttributes(ArrayRef attrs, bool filterArgAttrs, SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.first == SymbolTable::getSymbolAttrName() || attr.first == impl::getTypeAttrName() || attr.first == "std.varargs" || (filterArgAttrs && impl::isArgAttrName(attr.first.strref()))) continue; result.push_back(attr); } } /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. This function can be called from C /// by passing a pointer to a C struct corresponding to a memref descriptor. /// Internally, the auxiliary function unpacks the descriptor into individual /// components and forwards them to `newFuncOp`. static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, LLVMTypeConverter &typeConverter, FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { auto type = funcOp.getType(); SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External, attributes); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); SmallVector args; for (auto &en : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(en.index()); if (auto memrefType = en.value().dyn_cast()) { Value loaded = rewriter.create(loc, arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); continue; } if (en.value().isa()) { Value loaded = rewriter.create(loc, arg); UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); continue; } args.push_back(wrapperFuncOp.getArgument(en.index())); } auto call = rewriter.create(loc, newFuncOp, args); rewriter.create(loc, call.getResults()); } /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. Creates a body for the (external) /// `newFuncOp` that allocates a memref descriptor on stack, packs the /// individual arguments into this descriptor and passes a pointer to it into /// the auxiliary function. This auxiliary external function is now compatible /// with functions defined in C using pointers to C structs corresponding to a /// memref descriptor. static void wrapExternalFunction(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { OpBuilder::InsertionGuard guard(builder); LLVM::LLVMType wrapperType = typeConverter.convertFunctionTypeCWrapper(funcOp.getType()); // This conversion can only fail if it could not convert one of the argument // types. But since it has been applies to a non-wrapper function before, it // should have failed earlier and not reach this point at all. assert(wrapperType && "unexpected type conversion failure"); SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes); // Create the auxiliary function. auto wrapperFunc = builder.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperType, LLVM::Linkage::External, attributes); builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); // Get a ValueRange containing arguments. FunctionType type = funcOp.getType(); SmallVector args; args.reserve(type.getNumInputs()); ValueRange wrapperArgsRange(newFuncOp.getArguments()); // Iterate over the inputs of the original function and pack values into // memref descriptors if the original type is a memref. for (auto &en : llvm::enumerate(type.getInputs())) { Value arg; int numToDrop = 1; auto memRefType = en.value().dyn_cast(); auto unrankedMemRefType = en.value().dyn_cast(); if (memRefType || unrankedMemRefType) { numToDrop = memRefType ? MemRefDescriptor::getNumUnpackedValues(memRefType) : UnrankedMemRefDescriptor::getNumUnpackedValues(); Value packed = memRefType ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, wrapperArgsRange.take_front(numToDrop)) : UnrankedMemRefDescriptor::pack( builder, loc, typeConverter, unrankedMemRefType, wrapperArgsRange.take_front(numToDrop)); auto ptrTy = packed.getType().cast().getPointerTo(); Value one = builder.create( loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); Value allocated = builder.create(loc, ptrTy, one, /*alignment=*/0); builder.create(loc, packed, allocated); arg = allocated; } else { arg = wrapperArgsRange[0]; } args.push_back(arg); wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop); } assert(wrapperArgsRange.empty() && "did not map some of the arguments"); auto call = builder.create(loc, wrapperFunc, args); builder.create(loc, call.getResults()); } namespace { struct FuncOpConversionBase : public ConvertOpToLLVMPattern { protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided // to this legalization pattern. LLVM::LLVMFuncOp convertFuncOpToLLVMFuncOp(FuncOp funcOp, ConversionPatternRewriter &rewriter) const { // Convert the original function arguments. They are converted using the // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp.getAttrOfType("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = typeConverter.convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); if (!llvmType) return nullptr; // Propagate argument attributes to all converted arguments obtained after // converting a given original argument. SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/true, attributes); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { auto attr = impl::getArgAttrDict(funcOp, i); if (!attr) continue; auto mapping = result.getInputMapping(i); assert(mapping.hasValue() && "unexpected deletion of function argument"); SmallString<8> name; for (size_t j = 0; j < mapping->size; ++j) { impl::getArgAttrName(mapping->inputNo + j, name); attributes.push_back(rewriter.getNamedAttr(name, attr)); } } // Create an LLVM function, use external linkage by default until MLIR // functions have linkage. auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, &result))) return nullptr; return newFuncOp; } }; /// FuncOp legalization pattern that converts MemRef arguments to pointers to /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { FuncOpConversion(LLVMTypeConverter &converter) : FuncOpConversionBase(converter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) return failure(); if (typeConverter.getOptions().emitCWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp, newFuncOp); else wrapForExternalCallers(rewriter, op->getLoc(), typeConverter, funcOp, newFuncOp); } rewriter.eraseOp(op); return success(); } }; /// FuncOp legalization pattern that converts MemRef arguments to bare pointers /// to the MemRef element type. This will impact the calling convention and ABI. struct BarePtrFuncOpConversion : public FuncOpConversionBase { using FuncOpConversionBase::FuncOpConversionBase; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); // Store the type of memref-typed arguments before the conversion so that we // can promote them to MemRef descriptor at the beginning of the function. SmallVector oldArgTypes = llvm::to_vector<8>(funcOp.getType().getInputs()); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) return failure(); if (newFuncOp.getBody().empty()) { rewriter.eraseOp(op); return success(); } // Promote bare pointers from memref arguments to memref descriptors at the // beginning of the function so that all the memrefs in the function have a // uniform representation. Block *entryBlock = &newFuncOp.getBody().front(); auto blockArgs = entryBlock->getArguments(); assert(blockArgs.size() == oldArgTypes.size() && "The number of arguments and types doesn't match"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(entryBlock); for (auto it : llvm::zip(blockArgs, oldArgTypes)) { BlockArgument arg = std::get<0>(it); Type argTy = std::get<1>(it); // Unranked memrefs are not supported in the bare pointer calling // convention. We should have bailed out before in the presence of // unranked memrefs. assert(!argTy.isa() && "Unranked memref is not supported"); auto memrefTy = argTy.dyn_cast(); if (!memrefTy) continue; // Replace barePtr with a placeholder (undef), promote barePtr to a ranked // or unranked memref descriptor and replace placeholder with the last // instruction of the memref descriptor. // TODO: The placeholder is needed to avoid replacing barePtr uses in the // MemRef descriptor instructions. We may want to have a utility in the // rewriter to properly handle this use case. Location loc = op->getLoc(); auto placeholder = rewriter.create(loc, memrefTy); rewriter.replaceUsesOfBlockArgument(arg, placeholder); Value desc = MemRefDescriptor::fromStaticShape( rewriter, loc, typeConverter, memrefTy, arg); rewriter.replaceOp(placeholder, {desc}); } rewriter.eraseOp(op); return success(); } }; //////////////// Support for Lowering operations on n-D vectors //////////////// // Helper struct to "unroll" operations on n-D vectors in terms of operations on // 1-D LLVM vectors. struct NDVectorTypeInfo { // LLVM array struct which encodes n-D vectors. LLVM::LLVMType llvmArrayTy; // LLVM vector type which encodes the inner 1-D vector type. LLVM::LLVMType llvmVectorTy; // Multiplicity of llvmArrayTy to llvmVectorTy. SmallVector arraySizes; }; } // namespace // For >1-D vector types, extracts the necessary information to iterate over all // 1-D subvectors in the underlying llrepresentation of the n-D vector // Iterates on the llvm array type until we hit a non-array type (which is // asserted to be an llvm vector type). static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, LLVMTypeConverter &converter) { assert(vectorType.getRank() > 1 && "expected >1D vector type"); NDVectorTypeInfo info; info.llvmArrayTy = converter.convertType(vectorType).dyn_cast(); if (!info.llvmArrayTy) return info; info.arraySizes.reserve(vectorType.getRank() - 1); auto llvmTy = info.llvmArrayTy; while (llvmTy.isArrayTy()) { info.arraySizes.push_back(llvmTy.getArrayNumElements()); llvmTy = llvmTy.getArrayElementType(); } if (!llvmTy.isVectorTy()) return info; info.llvmVectorTy = llvmTy; return info; } // Express `linearIndex` in terms of coordinates of `basis`. // Returns the empty vector when linearIndex is out of the range [0, P] where // P is the product of all the basis coordinates. // // Prerequisites: // Basis is an array of nonnegative integers (signed type inherited from // vector shape type). static SmallVector getCoordinates(ArrayRef basis, unsigned linearIndex) { SmallVector res; res.reserve(basis.size()); for (unsigned basisElement : llvm::reverse(basis)) { res.push_back(linearIndex % basisElement); linearIndex = linearIndex / basisElement; } if (linearIndex > 0) return {}; std::reverse(res.begin(), res.end()); return res; } // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. template void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, Lambda fun) { unsigned ub = 1; for (auto s : info.arraySizes) ub *= s; for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { auto coords = getCoordinates(info.arraySizes, linearIndex); // Linear index is out of bounds, we are done. if (coords.empty()) break; assert(coords.size() == info.arraySizes.size()); auto position = builder.getI64ArrayAttr(coords); fun(position); } } ////////////// End Support for Lowering operations on n-D vectors ////////////// /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. LogicalResult LLVM::detail::oneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); Type packedType; if (numResults != 0) { packedType = typeConverter.packFunctionResults(op->getResultTypes()); if (!packedType) return failure(); } // Create the operation through state since we don't know its C++ type. OperationState state(op->getLoc(), targetOp); state.addTypes(packedType); state.addOperands(operands); state.addAttributes(op->getAttrs()); Operation *newOp = rewriter.createOperation(state); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) return rewriter.eraseOp(op), success(); if (numResults == 1) return rewriter.replaceOp(op, newOp->getResult(0)), success(); // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = typeConverter.convertType(op->getResult(i).getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); } rewriter.replaceOp(op, results); return success(); } static LogicalResult handleMultidimensionalVectors( Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { auto vectorType = op->getResult(0).getType().dyn_cast(); if (!vectorType) return failure(); auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter); auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; auto llvmArrayTy = operands[0].getType().cast(); if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) return failure(); auto loc = op->getLoc(); Value desc = rewriter.create(loc, llvmArrayTy); nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors SmallVector extractedOperands; for (auto operand : operands) extractedOperands.push_back(rewriter.create( loc, llvmVectorTy, operand, position)); Value newVal = createOperand(llvmVectorTy, extractedOperands); desc = rewriter.create(loc, llvmArrayTy, desc, newVal, position); }); rewriter.replaceOp(op, desc); return success(); } LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. if (!llvm::all_of(operands.getTypes(), [](Type t) { return t.isa(); })) return failure(); auto llvmArrayTy = operands[0].getType().cast(); if (!llvmArrayTy.isArrayTy()) return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy, ValueRange operands) { OperationState state(op->getLoc(), targetOp); state.addTypes(llvmVectorTy); state.addOperands(operands); state.addAttributes(op->getAttrs()); return rewriter.createOperation(state)->getResult(0); }; return handleMultidimensionalVectors(op, operands, typeConverter, callback, rewriter); } namespace { // Straightforward lowerings. using AbsFOpLowering = VectorConvertToLLVMPattern; using AddFOpLowering = VectorConvertToLLVMPattern; using AddIOpLowering = VectorConvertToLLVMPattern; using AndOpLowering = VectorConvertToLLVMPattern; using CeilFOpLowering = VectorConvertToLLVMPattern; using CopySignOpLowering = VectorConvertToLLVMPattern; using CosOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = VectorConvertToLLVMPattern; using ExpOpLowering = VectorConvertToLLVMPattern; using Exp2OpLowering = VectorConvertToLLVMPattern; using FloorFOpLowering = VectorConvertToLLVMPattern; using Log10OpLowering = VectorConvertToLLVMPattern; using Log2OpLowering = VectorConvertToLLVMPattern; using LogOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = VectorConvertToLLVMPattern; using OrOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = OneToOneConvertToLLVMPattern; using ShiftLeftOpLowering = OneToOneConvertToLLVMPattern; using SignedDivIOpLowering = VectorConvertToLLVMPattern; using SignedRemIOpLowering = VectorConvertToLLVMPattern; using SignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; using SinOpLowering = VectorConvertToLLVMPattern; using SqrtOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using UnsignedDivIOpLowering = VectorConvertToLLVMPattern; using UnsignedRemIOpLowering = VectorConvertToLLVMPattern; using UnsignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; /// Lower `std.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is /// ignored by the default lowering but should be propagated by any custom /// lowering. struct AssertOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); AssertOp::Adaptor transformed(operands); // Insert the `abort` declaration if necessary. auto module = op->getParentOfType(); auto abortFunc = module.lookupSymbol("abort"); if (!abortFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); auto abortFuncTy = LLVM::LLVMType::getFunctionTy(getVoidType(), {}, /*isVarArg=*/false); abortFunc = rewriter.create(rewriter.getUnknownLoc(), "abort", abortFuncTy); } // Split block at `assert` operation. Block *opBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); // Generate IR to call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); rewriter.create(loc, abortFunc, llvm::None); rewriter.create(loc); // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); rewriter.replaceOpWithNewOp( op, transformed.arg(), continuationBlock, failureBlock); return success(); } }; // Lowerings for operations on complex numbers. struct CreateComplexOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto complexOp = cast(op); CreateComplexOp::Adaptor transformed(operands); // Pack real and imaginary part in a complex number struct. auto loc = op->getLoc(); auto structType = typeConverter.convertType(complexOp.getType()); auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); complexStruct.setReal(rewriter, loc, transformed.real()); complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); rewriter.replaceOp(op, {complexStruct}); return success(); } }; struct ReOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ReOp::Adaptor transformed(operands); // Extract real part from the complex number struct. ComplexStructBuilder complexStruct(transformed.complex()); Value real = complexStruct.real(rewriter, op->getLoc()); rewriter.replaceOp(op, real); return success(); } }; struct ImOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ImOp::Adaptor transformed(operands); // Extract imaginary part from the complex number struct. ComplexStructBuilder complexStruct(transformed.complex()); Value imaginary = complexStruct.imaginary(rewriter, op->getLoc()); rewriter.replaceOp(op, imaginary); return success(); } }; struct BinaryComplexOperands { std::complex lhs, rhs; }; template BinaryComplexOperands unpackBinaryComplexOperands(OpTy op, ArrayRef operands, ConversionPatternRewriter &rewriter) { auto bop = cast(op); auto loc = bop.getLoc(); typename OpTy::Adaptor transformed(operands); // Extract real and imaginary values from operands. BinaryComplexOperands unpacked; ComplexStructBuilder lhs(transformed.lhs()); unpacked.lhs.real(lhs.real(rewriter, loc)); unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); ComplexStructBuilder rhs(transformed.rhs()); unpacked.rhs.real(rhs.real(rewriter, loc)); unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); return unpacked; } struct AddCFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = cast(operation); auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, operands, rewriter); // Initialize complex number struct for result. auto structType = this->typeConverter.convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real()); Value imag = rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag()); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); rewriter.replaceOp(op, {result}); return success(); } }; struct SubCFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = cast(operation); auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, operands, rewriter); // Initialize complex number struct for result. auto structType = this->typeConverter.convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real()); Value imag = rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag()); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); rewriter.replaceOp(op, {result}); return success(); } }; struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = cast(operation); // If constant refers to a function, convert it to "addressof". if (auto symbolRef = op.getValue().dyn_cast()) { auto type = typeConverter.convertType(op.getResult().getType()) .dyn_cast_or_null(); if (!type) return rewriter.notifyMatchFailure(op, "failed to convert result type"); MutableDictionaryAttr attrs(op.getAttrs()); attrs.remove(rewriter.getIdentifier("value")); rewriter.replaceOpWithNewOp( op, type.cast(), symbolRef.getValue(), attrs.getAttrs()); return success(); } // Calling into other scopes (non-flat reference) is not supported in LLVM. if (op.getValue().isa()) return rewriter.notifyMatchFailure( op, "referring to a symbol outside of the current module"); return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), operands, typeConverter, rewriter); } }; // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. static bool isSupportedMemRefType(MemRefType type) { return type.getAffineMaps().empty() || llvm::all_of(type.getAffineMaps(), [](AffineMap map) { return map.isIdentity(); }); } /// Lowering for AllocOp and AllocaOp. struct AllocLikeOpLowering : public ConvertToLLVMPattern { using ConvertToLLVMPattern::createIndexConstant; using ConvertToLLVMPattern::getIndexType; using ConvertToLLVMPattern::getVoidPtrType; using ConvertToLLVMPattern::typeConverter; explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter) : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} protected: // Returns 'input' aligned up to 'alignment'. Computes // bumped = input + alignement - 1 // aligned = bumped - bumped % alignment static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment) { Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); Value bump = rewriter.create(loc, alignment, one); Value bumped = rewriter.create(loc, input, bump); Value mod = rewriter.create(loc, bumped, alignment); return rewriter.create(loc, bumped, mod); } // Creates a call to an allocation function with params and casts the // resulting void pointer to ptrType. Value createAllocCall(Location loc, StringRef name, Type ptrType, ArrayRef params, ModuleOp module, ConversionPatternRewriter &rewriter) const { SmallVector paramTypes; auto allocFuncOp = module.lookupSymbol(name); if (!allocFuncOp) { for (Value param : params) paramTypes.push_back(param.getType().cast()); auto allocFuncType = LLVM::LLVMType::getFunctionTy(getVoidPtrType(), paramTypes, /*isVarArg=*/false); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); allocFuncOp = rewriter.create(rewriter.getUnknownLoc(), name, allocFuncType); } auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp); auto allocatedPtr = rewriter .create(loc, getVoidPtrType(), allocFuncSymbol, params) .getResult(0); return rewriter.create(loc, ptrType, allocatedPtr); } /// Allocates the underlying buffer. Returns the allocated pointer and the /// aligned pointer. virtual std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value cumulativeSize, Operation *op) const = 0; private: static MemRefType getMemRefResultType(Operation *op) { return op->getResult(0).getType().cast(); } LogicalResult match(Operation *op) const override { MemRefType memRefType = getMemRefResultType(op); if (isSupportedMemRefType(memRefType)) return success(); int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(memRefType, strides, offset))) return failure(); // Dynamic strides are ok if they can be deduced from dynamic sizes (which // is guaranteed when getStridesAndOffset succeeded. Dynamic offset however // can never be alloc'ed. if (offset == MemRefType::getDynamicStrideOrOffset()) return failure(); return success(); } // An `alloc` is converted into a definition of a memref descriptor value and // a call to `malloc` to allocate the underlying data buffer. The memref // descriptor is of the LLVM structure type where: // 1. the first element is a pointer to the allocated (typed) data buffer, // 2. the second element is a pointer to the (typed) payload, aligned to the // specified alignment, // 3. the remaining elements serve to store all the sizes and strides of the // memref using LLVM-converted `index` type. // // Alignment is performed by allocating `alignment` more bytes than // requested and shifting the aligned pointer relative to the allocated // memory. Note: `alignment - ` would actually be // sufficient. If alignment is unspecified, the two pointers are equal. // An `alloca` is converted into a definition of a memref descriptor value and // an llvm.alloca to allocate the underlying data buffer. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType memRefType = getMemRefResultType(op); auto loc = op->getLoc(); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes); Value cumulativeSize = this->getCumulativeSizeInBytes( loc, memRefType.getElementType(), sizes, rewriter); // Allocate the underlying buffer. Value allocatedPtr; Value alignedPtr; std::tie(allocatedPtr, alignedPtr) = this->allocateBuffer(rewriter, loc, cumulativeSize, op); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(memRefType, strides, offset); (void)successStrides; assert(succeeded(successStrides) && "unexpected non-strided memref"); assert(offset != MemRefType::getDynamicStrideOrOffset() && "unexpected dynamic offset"); // 0-D memref corner case: they have size 1. assert( ((memRefType.getRank() == 0 && strides.empty() && sizes.size() == 1) || (strides.size() == sizes.size())) && "unexpected number of strides"); // Create the MemRef descriptor. auto memRefDescriptor = this->createMemRefDescriptor(loc, memRefType, allocatedPtr, alignedPtr, offset, strides, sizes, rewriter); // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } }; struct AllocOpLowering : public AllocLikeOpLowering { AllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {} std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value cumulativeSize, Operation *op) const override { // Heap allocations. AllocOp allocOp = cast(op); MemRefType memRefType = allocOp.getType(); Value alignment; if (auto alignmentAttr = allocOp.alignment()) { alignment = createIndexConstant(rewriter, loc, *alignmentAttr); } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { // In the case where no alignment is specified, we may want to override // `malloc's` behavior. `malloc` typically aligns at the size of the // biggest scalar on a target HW. For non-scalars, use the natural // alignment of the LLVM type given by the LLVM DataLayout. alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); } if (alignment) { // Adjust the allocation size to consider alignment. cumulativeSize = rewriter.create(loc, cumulativeSize, alignment); } // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); Value allocatedPtr = createAllocCall(loc, "malloc", elementPtrType, {cumulativeSize}, allocOp.getParentOfType(), rewriter); Value alignedPtr = allocatedPtr; if (alignment) { auto intPtrType = getIntPtrType(memRefType.getMemorySpace()); // Compute the aligned type pointer. Value allocatedInt = rewriter.create(loc, intPtrType, allocatedPtr); Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); alignedPtr = rewriter.create(loc, elementPtrType, alignmentInt); } return std::make_tuple(allocatedPtr, alignedPtr); } }; struct AlignedAllocOpLowering : public AllocLikeOpLowering { AlignedAllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {} /// Returns the memref's element size in bytes. // TODO: there are other places where this is used. Expose publicly? static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { auto vectorType = elementType.cast(); sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); } return llvm::divideCeil(sizeInBits, 8); } /// Returns true if the memref size in bytes is known to be a multiple of /// factor. static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) { uint64_t sizeDivisor = getMemRefEltSizeInBytes(type); for (unsigned i = 0, e = type.getRank(); i < e; i++) { if (type.isDynamic(type.getDimSize(i))) continue; sizeDivisor = sizeDivisor * type.getDimSize(i); } return sizeDivisor % factor == 0; } /// Returns the alignment to be used for the allocation call itself. /// aligned_alloc requires the allocation size to be a power of two, and the /// allocation size to be a multiple of alignment, int64_t getAllocationAlignment(AllocOp allocOp) const { if (Optional alignment = allocOp.alignment()) return *alignment; // Whenever we don't have alignment set, we will use an alignment // consistent with the element type; since the allocation size has to be a // power of two, we will bump to the next power of two if it already isn't. auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType()); return std::max(kMinAlignedAllocAlignment, llvm::PowerOf2Ceil(eltSizeBytes)); } std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value cumulativeSize, Operation *op) const override { // Heap allocations. AllocOp allocOp = cast(op); MemRefType memRefType = allocOp.getType(); int64_t alignment = getAllocationAlignment(allocOp); Value allocAlignment = createIndexConstant(rewriter, loc, alignment); // aligned_alloc requires size to be a multiple of alignment; we will pad // the size to the next multiple if necessary. if (!isMemRefSizeMultipleOf(memRefType, alignment)) cumulativeSize = createAligned(rewriter, loc, cumulativeSize, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); Value allocatedPtr = createAllocCall( loc, "aligned_alloc", elementPtrType, {allocAlignment, cumulativeSize}, allocOp.getParentOfType(), rewriter); return std::make_tuple(allocatedPtr, allocatedPtr); } /// The minimum alignment to use with aligned_alloc (has to be a power of 2). static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; }; // Out of line definition, required till C++17. constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; struct AllocaOpLowering : public AllocLikeOpLowering { AllocaOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(AllocaOp::getOperationName(), converter) {} /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is needed post allocation (for eg. in conjunction with malloc). std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value cumulativeSize, Operation *op) const override { // With alloca, one gets a pointer to the element type right away. // For stack allocations. auto allocaOp = cast(op); auto elementPtrType = this->getElementPtrType(allocaOp.getType()); auto allocatedElementPtr = rewriter.create( loc, elementPtrType, cumulativeSize, allocaOp.alignment() ? *allocaOp.alignment() : 0); return std::make_tuple(allocatedElementPtr, allocatedElementPtr); } }; /// Copies the shaped descriptor part to (if `toDynamic` is set) or from /// (otherwise) the dynamically allocated memory for any operands that were /// unranked descriptors originally. static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, TypeRange origTypes, SmallVectorImpl &operands, bool toDynamic) { assert(origTypes.size() == operands.size() && "expected as may original types as operands"); // Find operands of unranked memref type and store them. SmallVector unrankedMemrefs; for (unsigned i = 0, e = operands.size(); i < e; ++i) if (origTypes[i].isa()) unrankedMemrefs.emplace_back(operands[i]); if (unrankedMemrefs.empty()) return success(); // Compute allocation sizes. SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter, unrankedMemrefs, sizes); // Get frequently used types. MLIRContext *context = builder.getContext(); auto voidType = LLVM::LLVMType::getVoidTy(context); auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context); auto i1Type = LLVM::LLVMType::getInt1Ty(context); LLVM::LLVMType indexType = typeConverter.getIndexType(); // Find the malloc and free, or declare them if necessary. auto module = builder.getInsertionPoint()->getParentOfType(); auto mallocFunc = module.lookupSymbol("malloc"); if (!mallocFunc && toDynamic) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(module.getBody()); mallocFunc = builder.create( builder.getUnknownLoc(), "malloc", LLVM::LLVMType::getFunctionTy( voidPtrType, llvm::makeArrayRef(indexType), /*isVarArg=*/false)); } auto freeFunc = module.lookupSymbol("free"); if (!freeFunc && !toDynamic) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(module.getBody()); freeFunc = builder.create( builder.getUnknownLoc(), "free", LLVM::LLVMType::getFunctionTy(voidType, llvm::makeArrayRef(voidPtrType), /*isVarArg=*/false)); } // Initialize shared constants. Value zero = builder.create(loc, i1Type, builder.getBoolAttr(false)); unsigned unrankedMemrefPos = 0; for (unsigned i = 0, e = operands.size(); i < e; ++i) { Type type = origTypes[i]; if (!type.isa()) continue; Value allocationSize = sizes[unrankedMemrefPos++]; UnrankedMemRefDescriptor desc(operands[i]); // Allocate memory, copy, and free the source if necessary. Value memory = toDynamic ? builder.create(loc, mallocFunc, allocationSize) .getResult(0) : builder.create(loc, voidPtrType, allocationSize, /*alignment=*/0); Value source = desc.memRefDescPtr(builder, loc); builder.create(loc, memory, source, allocationSize, zero); if (!toDynamic) builder.create(loc, freeFunc, source); // Create a new descriptor. The same descriptor can be returned multiple // times, attempting to modify its pointer can lead to memory leaks // (allocated twice and overwritten) or double frees (the caller does not // know if the descriptor points to the same memory). Type descriptorType = typeConverter.convertType(type); if (!descriptorType) return failure(); auto updatedDesc = UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); Value rank = desc.rank(builder, loc); updatedDesc.setRank(builder, loc, rank); updatedDesc.setMemRefDescPtr(builder, loc, memory); operands[i] = updatedDesc; } return success(); } // A CallOp automatically promotes MemRefType to a sequence of alloca/store and // passes the pointer to the MemRef across function boundaries. template struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = CallOpInterfaceLowering; using Base = ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { typename CallOpType::Adaptor transformed(operands); auto callOp = cast(op); // Pack the result types into a struct. Type packedResult; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); if (numResults != 0) { if (!(packedResult = this->typeConverter.packFunctionResults(resultTypes))) return failure(); } auto promoted = this->typeConverter.promoteOperands( op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter); auto newOp = rewriter.create(op->getLoc(), packedResult, promoted, op->getAttrs()); SmallVector results; if (numResults < 2) { // If < 2 results, packing did not do anything and we can just return. results.append(newOp.result_begin(), newOp.result_end()); } else { // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->typeConverter.convertType(op->getResult(i).getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), rewriter.getI64ArrayAttr(i))); } } if (this->typeConverter.getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, promote memref results to // descriptors. assert(results.size() == resultTypes.size() && "The number of arguments and types doesn't match"); this->typeConverter.promoteBarePtrsToDescriptors(rewriter, op->getLoc(), resultTypes, results); } else if (failed(copyUnrankedDescriptors(rewriter, op->getLoc(), this->typeConverter, resultTypes, results, /*toDynamic=*/false))) { return failure(); } rewriter.replaceOp(op, results); return success(); } }; struct CallOpLowering : public CallOpInterfaceLowering { using Super::Super; }; struct CallIndirectOpLowering : public CallOpInterfaceLowering { using Super::Super; }; // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. struct DeallocOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit DeallocOpLowering(LLVMTypeConverter &converter) : ConvertOpToLLVMPattern(converter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1 && "dealloc takes one operand"); DeallocOp::Adaptor transformed(operands); // Insert the `free` declaration if it is not already present. auto freeFunc = op->getParentOfType().lookupSymbol("free"); if (!freeFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart( op->getParentOfType().getBody()); freeFunc = rewriter.create( rewriter.getUnknownLoc(), "free", LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), /*isVarArg=*/false)); } MemRefDescriptor memref(transformed.memref()); Value casted = rewriter.create( op->getLoc(), getVoidPtrType(), memref.allocatedPtr(rewriter, op->getLoc())); rewriter.replaceOpWithNewOp( op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted); return success(); } }; // A `rsqrt` is converted into `1 / sqrt`. struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RsqrtOp::Adaptor transformed(operands); auto operandType = transformed.operand().getType().dyn_cast(); if (!operandType) return failure(); auto loc = op->getLoc(); auto resultType = *op->result_type_begin(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); if (!operandType.isArrayTy()) { LLVM::ConstantOp one; if (operandType.isVectorTy()) { one = rewriter.create( loc, operandType, SplatElementsAttr::get(resultType.cast(), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } auto sqrt = rewriter.create(loc, transformed.operand()); rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); return success(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) return failure(); return handleMultidimensionalVectors( op, operands, typeConverter, [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get({llvmVectorTy.getVectorNumElements()}, floatType), floatOne); auto one = rewriter.create(loc, llvmVectorTy, splatAttr); auto sqrt = rewriter.create(loc, llvmVectorTy, operands[0]); return rewriter.create(loc, llvmVectorTy, one, sqrt); }, rewriter); } }; struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult match(Operation *op) const override { auto memRefCastOp = cast(op); Type srcType = memRefCastOp.getOperand().getType(); Type dstType = memRefCastOp.getType(); // MemRefCastOp reduce to bitcast in the ranked MemRef case and can be used // for type erasure. For now they must preserve underlying element type and // require source and result type to have the same rank. Therefore, perform // a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. if (srcType.isa() && dstType.isa()) return success(typeConverter.convertType(srcType) == typeConverter.convertType(dstType)); // At least one of the operands is unranked type assert(srcType.isa() || dstType.isa()); // Unranked to unranked cast is disallowed return !(srcType.isa() && dstType.isa()) ? success() : failure(); } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast(op); MemRefCastOp::Adaptor transformed(operands); auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); auto targetStructType = typeConverter.convertType(memRefCastOp.getType()); auto loc = op->getLoc(); // For ranked/ranked case, just keep the original descriptor. if (srcType.isa() && dstType.isa()) return rewriter.replaceOp(op, {transformed.source()}); if (srcType.isa() && dstType.isa()) { // Casting ranked to unranked memref type // Set the rank in the destination from the memref type // Allocate space on the stack and copy the src memref descriptor // Set the ptr in the destination to the stack space auto srcMemRefType = srcType.cast(); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = typeConverter.promoteOneMemRefDescriptor( loc, transformed.source(), rewriter); // voidptr = BitCastOp srcType* to void* auto voidPtr = rewriter.create(loc, getVoidPtrType(), ptr) .getResult(); // rank = ConstantOp srcRank auto rankVal = rewriter.create( loc, typeConverter.convertType(rewriter.getIntegerType(64)), rewriter.getI64IntegerAttr(rank)); // undef = UndefOp UnrankedMemRefDescriptor memRefDesc = UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); // d1 = InsertValueOp undef, rank, 0 memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, voidptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); rewriter.replaceOp(op, (Value)memRefDesc); } else if (srcType.isa() && dstType.isa()) { // Casting from unranked type to ranked. // The operation is assumed to be doing a correct cast. If the destination // type mismatches the unranked the type, it is undefined behavior. UnrankedMemRefDescriptor memRefDesc(transformed.source()); // ptr = ExtractValueOp src, 1 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // castPtr = BitCastOp i8* to structTy* auto castPtr = rewriter .create( loc, targetStructType.cast().getPointerTo(), ptr) .getResult(); // struct = LoadOp castPtr auto loadOp = rewriter.create(loc, castPtr); rewriter.replaceOp(op, loadOp.getResult()); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); } } }; +/// Extracts allocated, aligned pointers and offset from a ranked or unranked +/// memref type. In unranked case, the fields are extracted from the underlying +/// ranked descriptor. +static void extractPointersAndOffset(Location loc, + ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, + Value originalOperand, + Value convertedOperand, + Value *allocatedPtr, Value *alignedPtr, + Value *offset = nullptr) { + Type operandType = originalOperand.getType(); + if (operandType.isa()) { + MemRefDescriptor desc(convertedOperand); + *allocatedPtr = desc.allocatedPtr(rewriter, loc); + *alignedPtr = desc.alignedPtr(rewriter, loc); + if (offset != nullptr) + *offset = desc.offset(rewriter, loc); + return; + } + + unsigned memorySpace = + operandType.cast().getMemorySpace(); + Type elementType = operandType.cast().getElementType(); + LLVM::LLVMType llvmElementType = + unwrap(typeConverter.convertType(elementType)); + LLVM::LLVMType elementPtrPtrType = + llvmElementType.getPointerTo(memorySpace).getPointerTo(); + + // Extract pointer to the underlying ranked memref descriptor and cast it to + // ElemType**. + UnrankedMemRefDescriptor unrankedDesc(convertedOperand); + Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); + + *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( + rewriter, loc, underlyingDescPtr, elementPtrPtrType); + *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( + rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); + if (offset != nullptr) { + *offset = UnrankedMemRefDescriptor::offset( + rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); + } +} + struct MemRefReinterpretCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto castOp = cast(op); MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary()); Type srcType = castOp.source().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, adaptor, &descriptor))) return failure(); rewriter.replaceOp(op, {descriptor}); return success(); } private: LogicalResult convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, Type srcType, MemRefReinterpretCastOp castOp, MemRefReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { MemRefType targetMemRefType = castOp.getResult().getType().cast(); auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return failure(); // Create descriptor. Location loc = castOp.getLoc(); auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; - extractPointers(loc, rewriter, castOp.source(), adaptor.source(), - &allocatedPtr, &alignedPtr); + extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(), + adaptor.source(), &allocatedPtr, &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); // Set offset. if (castOp.isDynamicOffset(0)) desc.setOffset(rewriter, loc, adaptor.offsets()[0]); else desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); // Set sizes and strides. unsigned dynSizeId = 0; unsigned dynStrideId = 0; for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { if (castOp.isDynamicSize(i)) desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); else desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); if (castOp.isDynamicStride(i)) desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); else desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); } *descriptor = desc; return success(); } +}; - void extractPointers(Location loc, ConversionPatternRewriter &rewriter, - Value originalOperand, Value convertedOperand, - Value *allocatedPtr, Value *alignedPtr) const { - Type operandType = originalOperand.getType(); - if (operandType.isa()) { - MemRefDescriptor desc(convertedOperand); - *allocatedPtr = desc.allocatedPtr(rewriter, loc); - *alignedPtr = desc.alignedPtr(rewriter, loc); - return; - } +struct MemRefReshapeOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - unsigned memorySpace = - operandType.cast().getMemorySpace(); - Type elementType = operandType.cast().getElementType(); - LLVM::LLVMType llvmElementType = - typeConverter.convertType(elementType).cast(); - LLVM::LLVMType elementPtrPtrType = - llvmElementType.getPointerTo(memorySpace).getPointerTo(); + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto reshapeOp = cast(op); - // Extract pointer to the underlying ranked memref descriptor and cast it to - // ElemType**. - UnrankedMemRefDescriptor unrankedDesc(convertedOperand); - Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); - Value elementPtrPtr = rewriter.create( - loc, elementPtrPtrType, underlyingDescPtr); + MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary()); + Type srcType = reshapeOp.source().getType(); - LLVM::LLVMType int32Type = - typeConverter.convertType(rewriter.getI32Type()).cast(); + Value descriptor; + if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, + adaptor, &descriptor))) + return failure(); + rewriter.replaceOp(op, {descriptor}); + return success(); + } + +private: + LogicalResult + convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, + Type srcType, MemRefReshapeOp reshapeOp, + MemRefReshapeOp::Adaptor adaptor, + Value *descriptor) const { + // Conversion for statically-known shape args is performed via + // `memref_reinterpret_cast`. + auto shapeMemRefType = reshapeOp.shape().getType().cast(); + if (shapeMemRefType.hasStaticShape()) + return failure(); - // Extract and set allocated pointer. - *allocatedPtr = rewriter.create(loc, elementPtrPtr); + // The shape is a rank-1 tensor with unknown length. + Location loc = reshapeOp.getLoc(); + MemRefDescriptor shapeDesc(adaptor.shape()); + Value resultRank = shapeDesc.size(rewriter, loc, 0); - // Extract and set aligned pointer. - Value one = rewriter.create( - loc, int32Type, rewriter.getI32IntegerAttr(1)); - Value alignedGep = rewriter.create( - loc, elementPtrPtrType, elementPtrPtr, ValueRange({one})); - *alignedPtr = rewriter.create(loc, alignedGep); + // Extract address space and element type. + auto targetType = + reshapeOp.getResult().getType().cast(); + unsigned addressSpace = targetType.getMemorySpace(); + Type elementType = targetType.getElementType(); + + // Create the unranked memref descriptor that holds the ranked one. The + // inner descriptor is allocated on stack. + auto targetDesc = UnrankedMemRefDescriptor::undef( + rewriter, loc, unwrap(typeConverter.convertType(targetType))); + targetDesc.setRank(rewriter, loc, resultRank); + SmallVector sizes; + UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, + targetDesc, sizes); + Value underlyingDescPtr = rewriter.create( + loc, getVoidPtrType(), sizes.front(), llvm::None); + targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); + + // Extract pointers and offset from the source memref. + Value allocatedPtr, alignedPtr, offset; + extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(), + adaptor.source(), &allocatedPtr, &alignedPtr, + &offset); + + // Set pointers and offset. + LLVM::LLVMType llvmElementType = + unwrap(typeConverter.convertType(elementType)); + LLVM::LLVMType elementPtrPtrType = + llvmElementType.getPointerTo(addressSpace).getPointerTo(); + UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, + elementPtrPtrType, allocatedPtr); + UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, typeConverter, + underlyingDescPtr, + elementPtrPtrType, alignedPtr); + UnrankedMemRefDescriptor::setOffset(rewriter, loc, typeConverter, + underlyingDescPtr, elementPtrPtrType, + offset); + + // Use the offset pointer as base for further addressing. Copy over the new + // shape and compute strides. For this, we create a loop from rank-1 to 0. + Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( + rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); + Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( + rewriter, loc, typeConverter, targetSizesBase, resultRank); + Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); + Value oneIndex = createIndexConstant(rewriter, loc, 1); + Value resultRankMinusOne = + rewriter.create(loc, resultRank, oneIndex); + + Block *initBlock = rewriter.getInsertionBlock(); + LLVM::LLVMType indexType = typeConverter.getIndexType(); + Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); + + Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, + {indexType, indexType}); + + // Iterate over the remaining ops in initBlock and move them to condBlock. + BlockAndValueMapping map; + for (auto it = remainingOpsIt, e = initBlock->end(); it != e; ++it) { + rewriter.clone(*it, map); + rewriter.eraseOp(&*it); + } + + rewriter.setInsertionPointToEnd(initBlock); + rewriter.create(loc, ValueRange({resultRankMinusOne, oneIndex}), + condBlock); + rewriter.setInsertionPointToStart(condBlock); + Value indexArg = condBlock->getArgument(0); + Value strideArg = condBlock->getArgument(1); + + Value zeroIndex = createIndexConstant(rewriter, loc, 0); + Value pred = rewriter.create( + loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), + LLVM::ICmpPredicate::sge, indexArg, zeroIndex); + + Block *bodyBlock = + rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); + rewriter.setInsertionPointToStart(bodyBlock); + + // Copy size from shape to descriptor. + LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo(); + Value sizeLoadGep = rewriter.create( + loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); + Value size = rewriter.create(loc, sizeLoadGep); + UnrankedMemRefDescriptor::setSize(rewriter, loc, typeConverter, + targetSizesBase, indexArg, size); + + // Write stride value and compute next one. + UnrankedMemRefDescriptor::setStride(rewriter, loc, typeConverter, + targetStridesBase, indexArg, strideArg); + Value nextStride = rewriter.create(loc, strideArg, size); + + // Decrement loop counter and branch back. + Value decrement = rewriter.create(loc, indexArg, oneIndex); + rewriter.create(loc, ValueRange({decrement, nextStride}), + condBlock); + + Block *remainder = + rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); + + // Hook up the cond exit to the remainder. + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create(loc, pred, bodyBlock, llvm::None, remainder, + llvm::None); + + // Reset position to beginning of new remainder block. + rewriter.setInsertionPointToStart(remainder); + + *descriptor = targetDesc; + return success(); } }; struct DialectCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto castOp = cast(op); LLVM::DialectCastOp::Adaptor transformed(operands); if (transformed.in().getType() != typeConverter.convertType(castOp.getType())) { return failure(); } rewriter.replaceOp(op, transformed.in()); return success(); } }; // A `dim` is converted to a constant for static sizes and to an access to the // size stored in the memref descriptor for dynamic sizes. struct DimOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); Type operandType = dimOp.memrefOrTensor().getType(); if (operandType.isa()) { rewriter.replaceOp(op, {extractSizeOfUnrankedMemRef(operandType, dimOp, operands, rewriter)}); return success(); } if (operandType.isa()) { rewriter.replaceOp(op, {extractSizeOfRankedMemRef(operandType, dimOp, operands, rewriter)}); return success(); } return failure(); } private: Value extractSizeOfUnrankedMemRef(Type operandType, DimOp dimOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); DimOp::Adaptor transformed(operands); auto unrankedMemRefType = operandType.cast(); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); unsigned addressSpace = unrankedMemRefType.getMemorySpace(); // Extract pointer to the underlying ranked descriptor and bitcast it to a // memref descriptor pointer to minimize the number of GEP // operations. UnrankedMemRefDescriptor unrankedDesc(transformed.memrefOrTensor()); Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); Value scalarMemRefDescPtr = rewriter.create( loc, typeConverter.convertType(scalarMemRefType) .cast() .getPointerTo(addressSpace), underlyingRankedDesc); // Get pointer to offset field of memref descriptor. Type indexPtrTy = typeConverter.getIndexType().getPointerTo(addressSpace); Value two = rewriter.create( loc, typeConverter.convertType(rewriter.getI32Type()), rewriter.getI32IntegerAttr(2)); Value offsetPtr = rewriter.create( loc, indexPtrTy, scalarMemRefDescPtr, ValueRange({createIndexConstant(rewriter, loc, 0), two})); // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. Value idxPlusOne = rewriter.create( loc, createIndexConstant(rewriter, loc, 1), transformed.index()); Value sizePtr = rewriter.create(loc, indexPtrTy, offsetPtr, ValueRange({idxPlusOne})); return rewriter.create(loc, sizePtr); } Value extractSizeOfRankedMemRef(Type operandType, DimOp dimOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); DimOp::Adaptor transformed(operands); // Take advantage if index is constant. MemRefType memRefType = operandType.cast(); if (Optional index = dimOp.getConstantIndex()) { int64_t i = index.getValue(); if (memRefType.isDynamicDim(i)) { // extract dynamic size from the memref descriptor. MemRefDescriptor descriptor(transformed.memrefOrTensor()); return descriptor.size(rewriter, loc, i); } // Use constant for static size. int64_t dimSize = memRefType.getDimSize(i); return createIndexConstant(rewriter, loc, dimSize); } Value index = dimOp.index(); int64_t rank = memRefType.getRank(); MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor()); return memrefDescriptor.size(rewriter, loc, index, rank); } }; struct RankOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Type operandType = cast(op).memrefOrTensor().getType(); if (auto unrankedMemRefType = operandType.dyn_cast()) { UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } if (auto rankedMemRefType = operandType.dyn_cast()) { rewriter.replaceOp( op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); return success(); } return failure(); } }; // Common base for load and store operations on MemRefs. Restricts the match // to supported MemRef types. Provides functionality to emit code accessing a // specific element of the underlying data buffer. template struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Base = LoadStoreOpLowering; LogicalResult match(Operation *op) const override { MemRefType type = cast(op).getMemRefType(); return isSupportedMemRefType(type) ? success() : failure(); } }; // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loadOp = cast(op); LoadOp::Adaptor transformed(operands); auto type = loadOp.getMemRefType(); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(op, dataPtr); return success(); } }; // Store operation is lowered to obtaining a pointer to the indexed element, // and storing the given value to it. struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); StoreOp::Adaptor transformed(operands); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); return success(); } }; // The prefetch operation is lowered in a way similar to the load operation // except that the llvm.prefetch operation is used for replacement. struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto prefetchOp = cast(op); PrefetchOp::Adaptor transformed(operands); auto type = prefetchOp.getMemRefType(); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter); // Replace with llvm.prefetch. auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); auto isWrite = rewriter.create( op->getLoc(), llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); auto localityHint = rewriter.create( op->getLoc(), llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.localityHint())); auto isData = rewriter.create( op->getLoc(), llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); rewriter.replaceOpWithNewOp(op, dataPtr, isWrite, localityHint, isData); return success(); } }; // The lowering of index_cast becomes an integer conversion since index becomes // an integer. If the bit width of the source and target integer types is the // same, just erase the cast. If the target type is wider, sign-extend the // value, otherwise truncate it. struct IndexCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpAdaptor transformed(operands); auto indexCastOp = cast(op); auto targetType = this->typeConverter.convertType(indexCastOp.getResult().getType()) .cast(); auto sourceType = transformed.in().getType().cast(); unsigned targetBits = targetType.getIntegerBitWidth(); unsigned sourceBits = sourceType.getIntegerBitWidth(); if (targetBits == sourceBits) rewriter.replaceOp(op, transformed.in()); else if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); else rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); return success(); } }; // Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two // enums share the numerical values so just cast. template static LLVMPredType convertCmpPredicate(StdPredType pred) { return static_cast(pred); } struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast(op); CmpIOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(cmpiOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return success(); } }; struct CmpFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpfOp = cast(op); CmpFOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return success(); } }; struct SIToFPLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct UIToFPLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPExtLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPToSILowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPToUILowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPTruncLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct SignExtendIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct TruncateIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct ZeroExtendIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = OneToOneLLVMTerminatorLowering; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands, op->getSuccessors(), op->getAttrs()); return success(); } }; // Special lowering pattern for `ReturnOps`. Unlike all other operations, // `ReturnOp` interacts with the function signature and must have as many // operands as the function has return values. Because in LLVM IR, functions // can only return 0 or 1 value, we pack multiple values into a structure type. // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if // necessary before returning it struct ReturnOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); unsigned numArguments = op->getNumOperands(); SmallVector updatedOperands; if (typeConverter.getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. for (auto it : llvm::zip(op->getOperands(), operands)) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); if (oldTy.isa()) { MemRefDescriptor memrefDesc(newOperand); newOperand = memrefDesc.alignedPtr(rewriter, loc); } else if (oldTy.isa()) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); } updatedOperands.push_back(newOperand); } } else { updatedOperands = llvm::to_vector<4>(operands); copyUnrankedDescriptors(rewriter, loc, typeConverter, op->getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); } // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), op->getAttrs()); return success(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( op, TypeRange(), updatedOperands, op->getAttrs()); return success(); } // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. auto packedType = typeConverter.packFunctionResults( llvm::to_vector<4>(op->getOperandTypes())); Value packed = rewriter.create(loc, packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( loc, packedType, packed, updatedOperands[i], rewriter.getI64ArrayAttr(i)); } rewriter.replaceOpWithNewOp(op, TypeRange(), packed, op->getAttrs()); return success(); } }; // FIXME: this should be tablegen'ed as well. struct BranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; struct CondBranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; // The Splat operation is lowered to an insertelement + a shufflevector // operation. Splat to only 1-d vector result types are lowered. struct SplatOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() != 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto vectorType = typeConverter.convertType(splatOp.getType()); Value undef = rewriter.create(op->getLoc(), vectorType); auto zero = rewriter.create( op->getLoc(), typeConverter.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); auto v = rewriter.create( op->getLoc(), vectorType, undef, splatOp.getOperand(), zero); int64_t width = splatOp.getType().cast().getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); rewriter.replaceOpWithNewOp(op, v, undef, zeroAttrs); return success(); } }; // The Splat operation is lowered to an insertelement + a shufflevector // operation. Splat to only 2+-d vector result types are lowered by the // SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. struct SplatNdOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); SplatOp::Adaptor adaptor(operands); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() == 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto loc = op->getLoc(); auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter); auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmArrayTy || !llvmVectorTy) return failure(); // Construct returned value. Value desc = rewriter.create(loc, llvmArrayTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. Value vdesc = rewriter.create(loc, llvmVectorTy); auto zero = rewriter.create( loc, typeConverter.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create(loc, llvmVectorTy, vdesc, adaptor.input(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); SmallVector zeroValues(width, 0); ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); v = rewriter.create(loc, v, v, zeroAttrs); // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { desc = rewriter.create(loc, llvmArrayTy, desc, v, position); }); rewriter.replaceOp(op, desc); return success(); } }; /// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. static SmallVector extractFromI64ArrayAttr(Attribute attr) { return llvm::to_vector<4>( llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); } /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The subview op is replaced by the descriptor. struct SubViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto subViewOp = cast(op); auto sourceMemRefType = subViewOp.source().getType().cast(); auto sourceElementTy = typeConverter.convertType(sourceMemRefType.getElementType()) .dyn_cast_or_null(); auto viewMemRefType = subViewOp.getType(); auto inferredType = SubViewOp::inferResultType( subViewOp.getSourceType(), extractFromI64ArrayAttr(subViewOp.static_offsets()), extractFromI64ArrayAttr(subViewOp.static_sizes()), extractFromI64ArrayAttr(subViewOp.static_strides())) .cast(); auto targetElementTy = typeConverter.convertType(viewMemRefType.getElementType()) .dyn_cast(); auto targetDescTy = typeConverter.convertType(viewMemRefType) .dyn_cast_or_null(); if (!sourceElementTy || !targetDescTy) return failure(); // Extract the offset and strides from the type. int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(inferredType, strides, offset); if (failed(successStrides)) return failure(); // Create the descriptor. if (!operands.front().getType().isa()) return failure(); MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); Value bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Copy the buffer pointer from the old descriptor to the new one. extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); auto shape = viewMemRefType.getShape(); auto inferredShape = inferredType.getShape(); size_t inferredShapeRank = inferredShape.size(); size_t resultShapeRank = shape.size(); SmallVector mask = computeRankReductionMask(inferredShape, shape).getValue(); // Extract strides needed to compute offset. SmallVector strideValues; strideValues.reserve(inferredShapeRank); for (unsigned i = 0; i < inferredShapeRank; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Offset. auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType()); if (!ShapedType::isDynamicStrideOrOffset(offset)) { targetMemRef.setConstantOffset(rewriter, loc, offset); } else { Value baseOffset = sourceMemRef.offset(rewriter, loc); for (unsigned i = 0; i < inferredShapeRank; ++i) { Value offset = subViewOp.isDynamicOffset(i) ? operands[subViewOp.getIndexOfDynamicOffset(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); Value mul = rewriter.create(loc, offset, strideValues[i]); baseOffset = rewriter.create(loc, baseOffset, mul); } targetMemRef.setOffset(rewriter, loc, baseOffset); } // Update sizes and strides. for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; i >= 0 && j >= 0; --i) { if (!mask[i]) continue; Value size = subViewOp.isDynamicSize(i) ? operands[subViewOp.getIndexOfDynamicSize(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); targetMemRef.setSize(rewriter, loc, j, size); Value stride; if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { stride = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); } else { stride = subViewOp.isDynamicStride(i) ? operands[subViewOp.getIndexOfDynamicStride(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i))); stride = rewriter.create(loc, stride, strideValues[i]); } targetMemRef.setStride(rewriter, loc, j, stride); j--; } rewriter.replaceOp(op, {targetMemRef}); return success(); } }; /// Conversion pattern that transforms a transpose op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size /// and stride. Size and stride are permutations of the original values. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The transpose op is replaced by the alloca'ed pointer. class TransposeOpLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); TransposeOpAdaptor adaptor(operands); MemRefDescriptor viewMemRef(adaptor.in()); auto transposeOp = cast(op); // No permutation, early exit. if (transposeOp.permutation().isIdentity()) return rewriter.replaceOp(op, {viewMemRef}), success(); auto targetMemRef = MemRefDescriptor::undef( rewriter, loc, typeConverter.convertType(transposeOp.getShapedType())); // Copy the base and aligned pointers from the old descriptor to the new // one. targetMemRef.setAllocatedPtr(rewriter, loc, viewMemRef.allocatedPtr(rewriter, loc)); targetMemRef.setAlignedPtr(rewriter, loc, viewMemRef.alignedPtr(rewriter, loc)); // Copy the offset pointer from the old descriptor to the new one. targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); // Iterate over the dimensions and apply size/stride permutation. for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { int sourcePos = en.index(); int targetPos = en.value().cast().getPosition(); targetMemRef.setSize(rewriter, loc, targetPos, viewMemRef.size(rewriter, loc, sourcePos)); targetMemRef.setStride(rewriter, loc, targetPos, viewMemRef.stride(rewriter, loc, sourcePos)); } rewriter.replaceOp(op, {targetMemRef}); return success(); } }; /// Conversion pattern that transforms an op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The view op is replaced by the descriptor. struct ViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. Value getSize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef shape, ValueRange dynamicSizes, unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); // Count the number of dynamic dims in range [0, idx] unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { return ShapedType::isDynamic(v); }); return dynamicSizes[nDynamic]; } // Build and return the idx^th stride, either by returning the constant stride // or by computing the dynamic stride from the current `runningStride` and // `nextSize`. The caller should keep a running stride and update it with the // result returned by this function. Value getStride(ConversionPatternRewriter &rewriter, Location loc, ArrayRef strides, Value nextSize, Value runningStride, unsigned idx) const { assert(idx < strides.size()); if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) return createIndexConstant(rewriter, loc, strides[idx]); if (nextSize) return runningStride ? rewriter.create(loc, runningStride, nextSize) : nextSize; assert(!runningStride); return createIndexConstant(rewriter, loc, 1); } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); ViewOpAdaptor adaptor(operands); auto viewMemRefType = viewOp.getType(); auto targetElementTy = typeConverter.convertType(viewMemRefType.getElementType()) .dyn_cast(); auto targetDescTy = typeConverter.convertType(viewMemRefType).dyn_cast(); if (!targetDescTy) return op->emitWarning("Target descriptor type not converted to LLVM"), failure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) return op->emitWarning("cannot cast to non-strided shape"), failure(); assert(offset == 0 && "expected offset to be 0"); // Create the descriptor. MemRefDescriptor sourceMemRef(adaptor.source()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); auto srcMemRefType = viewOp.source().getType().cast(); Value bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()), allocatedPtr); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); alignedPtr = rewriter.create(loc, alignedPtr.getType(), alignedPtr, adaptor.byte_shift()); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()), alignedPtr); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Field 3: The offset in the resulting type must be 0. This is because of // the type change: an offset on srcType* may not be expressible as an // offset on dstType*. targetMemRef.setOffset(rewriter, loc, createIndexConstant(rewriter, loc, offset)); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) return rewriter.replaceOp(op, {targetMemRef}), success(); // Fields 4 and 5: Update sizes and strides. if (strides.back() != 1) return op->emitWarning("cannot cast to non-contiguous shape"), failure(); Value stride = nullptr, nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. Value size = getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i); targetMemRef.setStride(rewriter, loc, i, stride); nextSize = size; } rewriter.replaceOp(op, {targetMemRef}); return success(); } }; struct AssumeAlignmentOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { AssumeAlignmentOp::Adaptor transformed(operands); Value memref = transformed.memref(); unsigned alignment = cast(op).alignment(); MemRefDescriptor memRefDescriptor(memref); Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that // the asserted memref.alignedPtr isn't used anywhere else, as the real // users like load/store/views always re-extract memref.alignedPtr as they // get lowered. // // This relies on LLVM's CSE optimization (potentially after SROA), since // after CSE all memref.alignedPtr instances get de-duplicated into the same // pointer SSA value. auto intPtrType = getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); Value zero = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0); Value mask = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, alignment - 1); Value ptrValue = rewriter.create(op->getLoc(), intPtrType, ptr); rewriter.create( op->getLoc(), rewriter.create( op->getLoc(), LLVM::ICmpPredicate::eq, rewriter.create(op->getLoc(), ptrValue, mask), zero)); rewriter.eraseOp(op); return success(); } }; } // namespace /// Try to match the kind of a std.atomic_rmw to determine whether to use a /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. static Optional matchSimpleAtomicOp(AtomicRMWOp atomicOp) { switch (atomicOp.kind()) { case AtomicRMWKind::addf: return LLVM::AtomicBinOp::fadd; case AtomicRMWKind::addi: return LLVM::AtomicBinOp::add; case AtomicRMWKind::assign: return LLVM::AtomicBinOp::xchg; case AtomicRMWKind::maxs: return LLVM::AtomicBinOp::max; case AtomicRMWKind::maxu: return LLVM::AtomicBinOp::umax; case AtomicRMWKind::mins: return LLVM::AtomicBinOp::min; case AtomicRMWKind::minu: return LLVM::AtomicBinOp::umin; default: return llvm::None; } llvm_unreachable("Invalid AtomicRMWKind"); } namespace { struct AtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto atomicOp = cast(op); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); AtomicRMWOp::Adaptor adaptor(operands); auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(), adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp( op, resultType, *maybeKind, dataPtr, adaptor.value(), LLVM::AtomicOrdering::acq_rel); return success(); } }; /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be /// retried until it succeeds in atomically storing a new value into memory. /// /// +---------------------------------+ /// | | /// | | /// | br loop(%loaded) | /// +---------------------------------+ /// | /// -------| | /// | v v /// | +--------------------------------+ /// | | loop(%loaded): | /// | | | /// | | %pair = cmpxchg | /// | | %ok = %pair[0] | /// | | %new = %pair[1] | /// | | cond_br %ok, end, loop(%new) | /// | +--------------------------------+ /// | | | /// |----------- | /// v /// +--------------------------------+ /// | end: | /// | | /// +--------------------------------+ /// struct GenericAtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto atomicOp = cast(op); auto loc = op->getLoc(); GenericAtomicRMWOp::Adaptor adaptor(operands); LLVM::LLVMType valueType = typeConverter.convertType(atomicOp.getResult().getType()) .cast(); // Split the block into initial, loop, and ending parts. auto *initBlock = rewriter.getInsertionBlock(); auto *loopBlock = rewriter.createBlock(initBlock->getParent(), std::next(Region::iterator(initBlock)), valueType); auto *endBlock = rewriter.createBlock( loopBlock->getParent(), std::next(Region::iterator(loopBlock))); // Operations range to be moved to `endBlock`. auto opsToMoveStart = atomicOp.getOperation()->getIterator(); auto opsToMoveEnd = initBlock->back().getIterator(); // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); auto memRefType = atomicOp.memref().getType().cast(); auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), adaptor.indices(), rewriter); Value init = rewriter.create(loc, dataPtr); rewriter.create(loc, init, loopBlock); // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); // Clone the GenericAtomicRMWOp region and extract the result. auto loopArgument = loopBlock->getArgument(0); BlockAndValueMapping mapping; mapping.map(atomicOp.getCurrentValue(), loopArgument); Block &entryBlock = atomicOp.body().front(); for (auto &nestedOp : entryBlock.without_terminator()) { Operation *clone = rewriter.clone(nestedOp, mapping); mapping.map(nestedOp.getResults(), clone->getResults()); } Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); // Prepare the epilog of the loop block. // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext()); auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); auto cmpxchg = rewriter.create( loc, pairType, dataPtr, loopArgument, result, successOrdering, failureOrdering); // Extract the %new_loaded and %ok values from the pair. Value newLoaded = rewriter.create( loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); Value ok = rewriter.create( loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); // Conditionally branch to the end or back to the loop depending on %ok. rewriter.create(loc, ok, endBlock, ArrayRef(), loopBlock, newLoaded); rewriter.setInsertionPointToEnd(endBlock); moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), std::next(opsToMoveEnd), rewriter); // The 'result' of the atomic_rmw op is the newly loaded value. rewriter.replaceOp(op, {newLoaded}); return success(); } private: // Clones a segment of ops [start, end) and erases the original. void moveOpsRange(ValueRange oldResult, ValueRange newResult, Block::iterator start, Block::iterator end, ConversionPatternRewriter &rewriter) const { BlockAndValueMapping mapping; mapping.map(oldResult, newResult); SmallVector opsToErase; for (auto it = start; it != end; ++it) { rewriter.clone(*it, mapping); opsToErase.push_back(&*it); } for (auto *it : opsToErase) rewriter.eraseOp(it); } }; } // namespace /// Collect a set of patterns to convert from the Standard dialect to LLVM. void mlir::populateStdToLLVMNonMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed // clang-format off patterns.insert< AbsFOpLowering, AddCFOpLowering, AddFOpLowering, AddIOpLowering, AllocaOpLowering, AndOpLowering, AssertOpLowering, AtomicRMWOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CeilFOpLowering, CmpFOpLowering, CmpIOpLowering, CondBranchOpLowering, CopySignOpLowering, CosOpLowering, ConstantOpLowering, CreateComplexOpLowering, DialectCastOpLowering, DivFOpLowering, ExpOpLowering, Exp2OpLowering, FloorFOpLowering, GenericAtomicRMWOpLowering, LogOpLowering, Log10OpLowering, Log2OpLowering, FPExtLowering, FPToSILowering, FPToUILowering, FPTruncLowering, ImOpLowering, IndexCastOpLowering, MulFOpLowering, MulIOpLowering, NegFOpLowering, OrOpLowering, PrefetchOpLowering, ReOpLowering, RemFOpLowering, ReturnOpLowering, RsqrtOpLowering, SIToFPLowering, SelectOpLowering, ShiftLeftOpLowering, SignExtendIOpLowering, SignedDivIOpLowering, SignedRemIOpLowering, SignedShiftRightOpLowering, SinOpLowering, SplatOpLowering, SplatNdOpLowering, SqrtOpLowering, SubCFOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering, UIToFPLowering, UnsignedDivIOpLowering, UnsignedRemIOpLowering, UnsignedShiftRightOpLowering, XOrOpLowering, ZeroExtendIOpLowering>(converter); // clang-format on } void mlir::populateStdToLLVMMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // clang-format off patterns.insert< AssumeAlignmentOpLowering, DeallocOpLowering, DimOpLowering, LoadOpLowering, MemRefCastOpLowering, MemRefReinterpretCastOpLowering, + MemRefReshapeOpLowering, RankOpLowering, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, ViewOpLowering>(converter); // clang-format on if (converter.getOptions().useAlignedAlloc) patterns.insert(converter); else patterns.insert(converter); } void mlir::populateStdToLLVMFuncOpConversionPattern( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { if (converter.getOptions().useBarePtrCallConv) patterns.insert(converter); else patterns.insert(converter); } void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { populateStdToLLVMFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatterns(converter, patterns); populateExpandMemRefReshapePattern(patterns, &converter.getContext()); } /// Convert a non-empty list of types to be returned from a function into a /// supported LLVM IR type. In particular, if more than one value is returned, /// create an LLVM IR structure type with elements that correspond to each of /// the MLIR types converted with `convertType`. Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { assert(!types.empty() && "expected non-empty list of type"); if (types.size() == 1) return convertCallingConventionType(types.front()); SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { auto converted = convertCallingConventionType(t).dyn_cast_or_null(); if (!converted) return {}; resultTypes.push_back(converted); } return LLVM::LLVMType::getStructTy(&getContext(), resultTypes); } Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) { auto *context = builder.getContext(); auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext()); auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = operand.getType().cast().getPointerTo(); Value one = builder.create(loc, int64Ty, IntegerAttr::get(indexType, 1)); Value allocated = builder.create(loc, ptrType, one, /*alignment=*/0); // Store into the alloca'ed descriptor. builder.create(loc, operand, allocated); return allocated; } SmallVector LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder) { SmallVector promotedOperands; promotedOperands.reserve(operands.size()); for (auto it : llvm::zip(opOperands, operands)) { auto operand = std::get<0>(it); auto llvmOperand = std::get<1>(it); if (options.useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. if (auto memrefType = operand.getType().dyn_cast()) { MemRefDescriptor desc(llvmOperand); llvmOperand = desc.alignedPtr(builder, loc); } else if (operand.getType().isa()) { llvm_unreachable("Unranked memrefs are not supported"); } } else { if (operand.getType().isa()) { UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, promotedOperands); continue; } if (auto memrefType = operand.getType().dyn_cast()) { MemRefDescriptor::unpack(builder, loc, llvmOperand, operand.getType().cast(), promotedOperands); continue; } } promotedOperands.push_back(llvmOperand); } return promotedOperands; } namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ConvertStandardToLLVMBase { LLVMLoweringPass() = default; LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, unsigned indexBitwidth, bool useAlignedAlloc, const llvm::DataLayout &dataLayout) { this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; this->indexBitwidth = indexBitwidth; this->useAlignedAlloc = useAlignedAlloc; this->dataLayout = dataLayout.getStringRepresentation(); } /// Run the dialect converter on the module. void runOnOperation() override { if (useBarePtrCallConv && emitCWrappers) { getOperation().emitError() << "incompatible conversion options: bare-pointer calling convention " "and C wrapper emission"; signalPassFailure(); return; } if (failed(LLVM::LLVMDialect::verifyDataLayoutString( this->dataLayout, [this](const Twine &message) { getOperation().emitError() << message.str(); }))) { signalPassFailure(); return; } ModuleOp m = getOperation(); LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers, indexBitwidth, useAlignedAlloc, llvm::DataLayout(this->dataLayout)}; LLVMTypeConverter typeConverter(&getContext(), options); OwningRewritePatternList patterns; populateStdToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); m.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), StringAttr::get(this->dataLayout, m.getContext())); } }; } // end namespace mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { this->addLegalDialect(); this->addIllegalOp(); this->addIllegalOp(); } std::unique_ptr> mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) { return std::make_unique( options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth, options.useAlignedAlloc, options.dataLayout); } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir index 8447474484e2..de59472f0713 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,491 +1,562 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s // RUN: mlir-opt -convert-std-to-llvm='use-aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC // CHECK-LABEL: func @check_strided_memref_arguments( // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 func @check_strided_memref_arguments(%static: memref<10x20xf32, affine_map<(i,j)->(20 * i + j + 1)>>, %dynamic : memref(M * i + j + 1)>>, %mixed : memref<10x?xf32, affine_map<(i,j)[M]->(M * i + j + 1)>>) { return } // CHECK-LABEL: func @check_arguments // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref, %mixed : memref<10x?xf32>) { return } // CHECK-LABEL: func @mixed_alloc( // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> { func @mixed_alloc(%arg0: index, %arg1: index) -> memref { // CHECK: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: llvm.mul %[[M]], %[[c42]] : !llvm.i64 // CHECK-NEXT: %[[sz:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: %[[st2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[st1:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[c42]] : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[c42]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[st1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[st2]], %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> %0 = alloc(%arg0, %arg1) : memref // CHECK-NEXT: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> return %0 : memref } // CHECK-LABEL: func @mixed_dealloc func @mixed_dealloc(%arg0: memref) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm.ptr) -> () dealloc %arg0 : memref // CHECK-NEXT: llvm.return return } // CHECK-LABEL: func @dynamic_alloc( // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { func @dynamic_alloc(%arg0: index, %arg1: index) -> memref { // CHECK: %[[sz:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[sz]], %[[sizeof]] : !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[st1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %0 = alloc(%arg0, %arg1) : memref // CHECK-NEXT: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> return %0 : memref } // ----- // CHECK-LABEL: func @dynamic_alloca // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { func @dynamic_alloca(%arg0: index, %arg1: index) -> memref { // CHECK: %[[num_elems:.*]] = llvm.mul %[[M]], %[[N]] : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x !llvm.float : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.mul %{{.*}}, %[[N]] : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[st1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %0 = alloca(%arg0, %arg1) : memref // Test with explicitly specified alignment. llvm.alloca takes care of the // alignment. The same pointer is thus used for allocation and aligned // accesses. // CHECK: %[[alloca_aligned:.*]] = llvm.alloca %{{.*}} x !llvm.float {alignment = 32 : i64} : (!llvm.i64) -> !llvm.ptr // CHECK: %[[desc:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[desc1:.*]] = llvm.insertvalue %[[alloca_aligned]], %[[desc]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[alloca_aligned]], %[[desc1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> alloca(%arg0, %arg1) {alignment = 32} : memref return %0 : memref } // CHECK-LABEL: func @dynamic_dealloc func @dynamic_dealloc(%arg0: memref) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm.ptr) -> () dealloc %arg0 : memref return } // CHECK-LABEL: func @stdlib_aligned_alloc({{.*}}) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { // ALIGNED-ALLOC-LABEL: func @stdlib_aligned_alloc({{.*}}) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> { // ALIGNED-ALLOC-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[num_elems:.*]] = llvm.mul %0, %1 : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // ALIGNED-ALLOC-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // ALIGNED-ALLOC-NEXT: %[[sizeof:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[bytes:.*]] = llvm.mul %[[num_elems]], %[[sizeof]] : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[allocated:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[bytes]]) : (!llvm.i64, !llvm.i64) -> !llvm.ptr // ALIGNED-ALLOC-NEXT: llvm.bitcast %[[allocated]] : !llvm.ptr to !llvm.ptr %0 = alloc() {alignment = 32} : memref<32x18xf32> // Do another alloc just to test that we have a unique declaration for // aligned_alloc. // ALIGNED-ALLOC: llvm.call @aligned_alloc %1 = alloc() {alignment = 64} : memref<4096xf32> // Alignment is to element type boundaries (minimum 16 bytes). // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c32]] %2 = alloc() : memref<4096xvector<8xf32>> // The minimum alignment is 16 bytes unless explicitly specified. // ALIGNED-ALLOC: %[[c16:.*]] = llvm.mlir.constant(16 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c16]], %3 = alloc() : memref<4096xvector<2xf32>> // ALIGNED-ALLOC: %[[c8:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c8]], %4 = alloc() {alignment = 8} : memref<1024xvector<4xf32>> // Bump the memref allocation size if its size is not a multiple of alignment. // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.sub // ALIGNED-ALLOC-NEXT: llvm.add // ALIGNED-ALLOC-NEXT: llvm.urem // ALIGNED-ALLOC-NEXT: %[[SIZE_ALIGNED:.*]] = llvm.sub // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c32]], %[[SIZE_ALIGNED]]) %5 = alloc() {alignment = 32} : memref<100xf32> // Bump alignment to the next power of two if it isn't. // ALIGNED-ALLOC: %[[c128:.*]] = llvm.mlir.constant(128 : index) : !llvm.i64 // ALIGNED-ALLOC: llvm.call @aligned_alloc(%[[c128]] %6 = alloc(%N) : memref> return %0 : memref<32x18xf32> } // CHECK-LABEL: func @mixed_load( // CHECK-COUNT-2: !llvm.ptr, // CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64 // CHECK: %[[I:.*]]: !llvm.i64, // CHECK: %[[J:.*]]: !llvm.i64) func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 // CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr %0 = load %mixed[%i, %j] : memref<42x?xf32> return } // CHECK-LABEL: func @dynamic_load( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[I:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @dynamic_load(%dynamic : memref, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 // CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr %0 = load %dynamic[%i, %j] : memref return } // CHECK-LABEL: func @prefetch // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[I:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @prefetch(%A : memref, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 // CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK-NEXT: [[C3:%.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32 // CHECK-NEXT: [[C1_1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK-NEXT: "llvm.intr.prefetch"(%[[addr]], [[C1]], [[C3]], [[C1_1]]) : (!llvm.ptr, !llvm.i32, !llvm.i32, !llvm.i32) -> () prefetch %A[%i, %j], write, locality<3>, data : memref // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: [[C0_1:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: [[C1_2:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK: "llvm.intr.prefetch"(%{{.*}}, [[C0]], [[C0_1]], [[C1_2]]) : (!llvm.ptr, !llvm.i32, !llvm.i32, !llvm.i32) -> () prefetch %A[%i, %j], read, locality<0>, data : memref // CHECK: [[C0_2:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : !llvm.i32 // CHECK: [[C0_3:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: "llvm.intr.prefetch"(%{{.*}}, [[C0_2]], [[C2]], [[C0_3]]) : (!llvm.ptr, !llvm.i32, !llvm.i32, !llvm.i32) -> () prefetch %A[%i, %j], read, locality<2>, instr : memref return } // CHECK-LABEL: func @dynamic_store // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[I:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @dynamic_store(%dynamic : memref, %i : index, %j : index, %val : f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 // CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr store %val, %dynamic[%i, %j] : memref return } // CHECK-LABEL: func @mixed_store // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[I:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 // CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 // CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr store %val, %mixed[%i, %j] : memref<42x?xf32> return } // CHECK-LABEL: func @memref_cast_static_to_dynamic func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %static : memref<10x42xf32> to memref return } // CHECK-LABEL: func @memref_cast_static_to_mixed func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %static : memref<10x42xf32> to memref return } // CHECK-LABEL: func @memref_cast_dynamic_to_static func @memref_cast_dynamic_to_static(%dynamic : memref) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %dynamic : memref to memref<10x12xf32> return } // CHECK-LABEL: func @memref_cast_dynamic_to_mixed func @memref_cast_dynamic_to_mixed(%dynamic : memref) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %dynamic : memref to memref return } // CHECK-LABEL: func @memref_cast_mixed_to_dynamic func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %mixed : memref<42x?xf32> to memref return } // CHECK-LABEL: func @memref_cast_mixed_to_static func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %mixed : memref<42x?xf32> to memref<42x1xf32> return } // CHECK-LABEL: func @memref_cast_mixed_to_mixed func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %mixed : memref<42x?xf32> to memref return } // CHECK-LABEL: func @memref_cast_ranked_to_unranked func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) { // CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-DAG: %[[p:.*]] = llvm.alloca %[[c]] x !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> : (!llvm.i64) -> !llvm.ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>> // CHECK-DAG: llvm.store %{{.*}}, %[[p]] : !llvm.ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>> // CHECK-DAG: %[[p2:.*]] = llvm.bitcast %[[p]] : !llvm.ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>> to !llvm.ptr // CHECK-DAG: %[[r:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64 // CHECK : llvm.mlir.undef : !llvm.struct<(i64, ptr)> // CHECK-DAG: llvm.insertvalue %[[r]], %{{.*}}[0] : !llvm.struct<(i64, ptr)> // CHECK-DAG: llvm.insertvalue %[[p2]], %{{.*}}[1] : !llvm.struct<(i64, ptr)> %0 = memref_cast %arg : memref<42x2x?xf32> to memref<*xf32> return } // CHECK-LABEL: func @memref_cast_unranked_to_ranked func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) { // CHECK: %[[p:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)> // CHECK-NEXT: llvm.bitcast %[[p]] : !llvm.ptr to !llvm.ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>> %0 = memref_cast %arg : memref<*xf32> to memref return } // CHECK-LABEL: func @mixed_memref_dim func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) { // CHECK: llvm.mlir.constant(42 : index) : !llvm.i64 %c0 = constant 0 : index %0 = dim %mixed, %c0 : memref<42x?x?x13x?xf32> // CHECK: llvm.extractvalue %[[ld:.*]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> %c1 = constant 1 : index %1 = dim %mixed, %c1 : memref<42x?x?x13x?xf32> // CHECK: llvm.extractvalue %[[ld]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> %c2 = constant 2 : index %2 = dim %mixed, %c2 : memref<42x?x?x13x?xf32> // CHECK: llvm.mlir.constant(13 : index) : !llvm.i64 %c3 = constant 3 : index %3 = dim %mixed, %c3 : memref<42x?x?x13x?xf32> // CHECK: llvm.extractvalue %[[ld]][3, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> %c4 = constant 4 : index %4 = dim %mixed, %c4 : memref<42x?x?x13x?xf32> return } // CHECK-LABEL: @memref_dim_with_dyn_index // CHECK-SAME: %[[ALLOC_PTR:.*]]: !llvm.ptr, %[[ALIGN_PTR:.*]]: !llvm.ptr, %[[OFFSET:.*]]: !llvm.i64, %[[SIZE0:.*]]: !llvm.i64, %[[SIZE1:.*]]: !llvm.i64, %[[STRIDE0:.*]]: !llvm.i64, %[[STRIDE1:.*]]: !llvm.i64, %[[IDX:.*]]: !llvm.i64) -> !llvm.i64 func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index { // CHECK-NEXT: %[[DESCR0:.*]] = llvm.mlir.undef : [[DESCR_TY:!llvm.struct<\(ptr, ptr, i64, array<2 x i64>, array<2 x i64>\)>]] // CHECK-NEXT: %[[DESCR1:.*]] = llvm.insertvalue %[[ALLOC_PTR]], %[[DESCR0]][0] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR2:.*]] = llvm.insertvalue %[[ALIGN_PTR]], %[[DESCR1]][1] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESCR2]][2] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR4:.*]] = llvm.insertvalue %[[SIZE0]], %[[DESCR3]][3, 0] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESCR4]][4, 0] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR6:.*]] = llvm.insertvalue %[[SIZE1]], %[[DESCR5]][3, 1] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR7:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESCR6]][4, 1] : [[DESCR_TY]] // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-DAG: %[[SIZES:.*]] = llvm.extractvalue %[[DESCR7]][3] : [[DESCR_TY]] // CHECK-DAG: %[[SIZES_PTR:.*]] = llvm.alloca %[[C1]] x !llvm.array<2 x i64> : (!llvm.i64) -> !llvm.ptr> // CHECK-DAG: llvm.store %[[SIZES]], %[[SIZES_PTR]] : !llvm.ptr> // CHECK-DAG: %[[RESULT_PTR:.*]] = llvm.getelementptr %[[SIZES_PTR]][%[[C0]], %[[IDX]]] : (!llvm.ptr>, !llvm.i64, !llvm.i64) -> !llvm.ptr // CHECK-DAG: %[[RESULT:.*]] = llvm.load %[[RESULT_PTR]] : !llvm.ptr // CHECK-DAG: llvm.return %[[RESULT]] : !llvm.i64 %result = dim %arg, %idx : memref<3x?xf32> return %result : index } // CHECK-LABEL: @memref_reinterpret_cast_ranked_to_static_shape func @memref_reinterpret_cast_ranked_to_static_shape(%input : memref<2x3xf32>) { %output = memref_reinterpret_cast %input to offset: [0], sizes: [6, 1], strides: [1, 1] : memref<2x3xf32> to memref<6x1xf32> return } // CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[TY:!.*]] // CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY]] // CHECK: [[BASE_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[TY]] // CHECK: [[ALIGNED_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[TY]] // CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]] // CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]] // CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]] // CHECK: [[SIZE_0:%.*]] = llvm.mlir.constant(6 : index) : !llvm.i64 // CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]] // CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_4]][4, 0] : [[TY]] // CHECK: [[STRIDE_0:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_5]][3, 1] : [[TY]] // CHECK: [[STRIDE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]] // CHECK-LABEL: @memref_reinterpret_cast_unranked_to_dynamic_shape func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index, %size_0 : index, %size_1 : index, %stride_0 : index, %stride_1 : index, %input : memref<*xf32>) { %output = memref_reinterpret_cast %input to offset: [%offset], sizes: [%size_0, %size_1], strides: [%stride_0, %stride_1] : memref<*xf32> to memref return } // CHECK-SAME: ([[OFFSET:%[a-z,0-9]+]]: !llvm.i64, // CHECK-SAME: [[SIZE_0:%[a-z,0-9]+]]: !llvm.i64, [[SIZE_1:%[a-z,0-9]+]]: !llvm.i64, // CHECK-SAME: [[STRIDE_0:%[a-z,0-9]+]]: !llvm.i64, [[STRIDE_1:%[a-z,0-9]+]]: !llvm.i64, // CHECK: [[INPUT:%.*]] = llvm.insertvalue {{.*}}[1] : !llvm.struct<(i64, ptr)> // CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY:!.*]] // CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr)> // CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr to !llvm.ptr> // CHECK: [[BASE_PTR:%.*]] = llvm.load [[BASE_PTR_PTR]] : !llvm.ptr> -// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 -// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR]]{{\[}}[[C1]]] -// CHECK-SAME: : (!llvm.ptr>, !llvm.i32) -> !llvm.ptr> +// CHECK: [[BASE_PTR_PTR_:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR_]]{{\[}}[[C1]]] +// CHECK-SAME: : (!llvm.ptr>, !llvm.i64) -> !llvm.ptr> // CHECK: [[ALIGNED_PTR:%.*]] = llvm.load [[ALIGNED_PTR_PTR]] : !llvm.ptr> // CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]] // CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]] // CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]] // CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]] // CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]] // CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]] // CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]] + +// CHECK-LABEL: @memref_reshape +func @memref_reshape(%input : memref<2x3xf32>, %shape : memref) { + %output = memref_reshape %input(%shape) + : (memref<2x3xf32>, memref) -> memref<*xf32> + return +} +// CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[INPUT_TY:!.*]] +// CHECK: [[SHAPE:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : [[SHAPE_TY:!.*]] +// CHECK: [[RANK:%.*]] = llvm.extractvalue [[SHAPE]][3, 0] : [[SHAPE_TY]] +// CHECK: [[UNRANKED_OUT_O:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> +// CHECK: [[UNRANKED_OUT_1:%.*]] = llvm.insertvalue [[RANK]], [[UNRANKED_OUT_O]][0] : !llvm.struct<(i64, ptr)> + +// Compute size in bytes to allocate result ranked descriptor +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 +// CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : !llvm.i64 +// CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}} +// CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x !llvm.i8 +// CHECK: llvm.insertvalue [[UNDERLYING_DESC]], [[UNRANKED_OUT_1]][1] + +// Set allocated, aligned pointers and offset. +// CHECK: [[ALLOC_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[INPUT_TY]] +// CHECK: [[ALIGN_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[INPUT_TY]] +// CHECK: [[OFFSET:%.*]] = llvm.extractvalue [[INPUT]][2] : [[INPUT_TY]] +// CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] +// CHECK-SAME: !llvm.ptr to !llvm.ptr> +// CHECK: llvm.store [[ALLOC_PTR]], [[BASE_PTR_PTR]] : !llvm.ptr> +// CHECK: [[BASE_PTR_PTR_:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR_]]{{\[}}[[C1]]] +// CHECK: llvm.store [[ALIGN_PTR]], [[ALIGNED_PTR_PTR]] : !llvm.ptr> +// CHECK: [[BASE_PTR_PTR__:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] : !llvm.ptr to !llvm.ptr> +// CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 +// CHECK: [[OFFSET_PTR_:%.*]] = llvm.getelementptr [[BASE_PTR_PTR__]]{{\[}}[[C2]]] +// CHECK: [[OFFSET_PTR:%.*]] = llvm.bitcast [[OFFSET_PTR_]] +// CHECK: llvm.store [[OFFSET]], [[OFFSET_PTR]] : !llvm.ptr + +// Iterate over shape operand in reverse order and set sizes and strides. +// CHECK: [[STRUCT_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] +// CHECK-SAME: !llvm.ptr to !llvm.ptr, ptr, i64, i64)>> +// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: [[C3_I32:%.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32 +// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], [[C3_I32]]] +// CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]] +// CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]] +// CHECK: [[C1_:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 +// CHECK: [[RANK_MIN_1:%.*]] = llvm.sub [[RANK]], [[C1_]] : !llvm.i64 +// CHECK: llvm.br ^bb1([[RANK_MIN_1]], [[C1_]] : !llvm.i64, !llvm.i64) + +// CHECK: ^bb1([[DIM:%.*]]: !llvm.i64, [[CUR_STRIDE:%.*]]: !llvm.i64): +// CHECK: [[C0_:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 +// CHECK: [[COND:%.*]] = llvm.icmp "sge" [[DIM]], [[C0_]] : !llvm.i64 +// CHECK: llvm.cond_br [[COND]], ^bb2, ^bb3 + +// CHECK: ^bb2: +// CHECK: [[SIZE_PTR:%.*]] = llvm.getelementptr [[SHAPE_IN_PTR]]{{\[}}[[DIM]]] +// CHECK: [[SIZE:%.*]] = llvm.load [[SIZE_PTR]] : !llvm.ptr +// CHECK: [[TARGET_SIZE_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[DIM]]] +// CHECK: llvm.store [[SIZE]], [[TARGET_SIZE_PTR]] : !llvm.ptr +// CHECK: [[TARGET_STRIDE_PTR:%.*]] = llvm.getelementptr [[STRIDES_PTR]]{{\[}}[[DIM]]] +// CHECK: llvm.store [[CUR_STRIDE]], [[TARGET_STRIDE_PTR]] : !llvm.ptr +// CHECK: [[UPDATE_STRIDE:%.*]] = llvm.mul [[CUR_STRIDE]], [[SIZE]] : !llvm.i64 +// CHECK: [[STRIDE_COND:%.*]] = llvm.sub [[DIM]], [[C1_]] : !llvm.i64 +// CHECK: llvm.br ^bb1([[STRIDE_COND]], [[UPDATE_STRIDE]] : !llvm.i64, !llvm.i64) + +// CHECK: ^bb3: +// CHECK: llvm.return diff --git a/mlir/test/mlir-cpu-runner/memref_reshape.mlir b/mlir/test/mlir-cpu-runner/memref_reshape.mlir index 96a8ae16ae6d..5ec3b31a76b2 100644 --- a/mlir/test/mlir-cpu-runner/memref_reshape.mlir +++ b/mlir/test/mlir-cpu-runner/memref_reshape.mlir @@ -1,72 +1,105 @@ -// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm --print-ir-after-all \ +// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm \ // RUN: | mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ // RUN: | FileCheck %s func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } func @main() -> () { %c0 = constant 0 : index %c1 = constant 1 : index // Initialize input. %input = alloc() : memref<2x3xf32> %dim_x = dim %input, %c0 : memref<2x3xf32> %dim_y = dim %input, %c1 : memref<2x3xf32> scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { %prod = muli %i, %dim_y : index %val = addi %prod, %j : index %val_i64 = index_cast %val : index to i64 %val_f32 = sitofp %val_i64 : i64 to f32 store %val_f32, %input[%i, %j] : memref<2x3xf32> } %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] // CHECK-NEXT: [0, 1, 2] // CHECK-NEXT: [3, 4, 5] // Initialize shape. %shape = alloc() : memref<2xindex> %c2 = constant 2 : index %c3 = constant 3 : index store %c3, %shape[%c0] : memref<2xindex> store %c2, %shape[%c1] : memref<2xindex> // Test cases. call @reshape_ranked_memref_to_ranked(%input, %shape) : (memref<2x3xf32>, memref<2xindex>) -> () call @reshape_unranked_memref_to_ranked(%input, %shape) : (memref<2x3xf32>, memref<2xindex>) -> () + call @reshape_ranked_memref_to_unranked(%input, %shape) + : (memref<2x3xf32>, memref<2xindex>) -> () + call @reshape_unranked_memref_to_unranked(%input, %shape) + : (memref<2x3xf32>, memref<2xindex>) -> () return } func @reshape_ranked_memref_to_ranked(%input : memref<2x3xf32>, %shape : memref<2xindex>) { %output = memref_reshape %input(%shape) : (memref<2x3xf32>, memref<2xindex>) -> memref %unranked_output = memref_cast %output : memref to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = - // CHECK: [0, 1], - // CHECK: [2, 3], - // CHECK: [4, 5] + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] return } func @reshape_unranked_memref_to_ranked(%input : memref<2x3xf32>, %shape : memref<2xindex>) { %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> %output = memref_reshape %input(%shape) : (memref<2x3xf32>, memref<2xindex>) -> memref %unranked_output = memref_cast %output : memref to memref<*xf32> call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = - // CHECK: [0, 1], - // CHECK: [2, 3], - // CHECK: [4, 5] + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] + return +} + +func @reshape_ranked_memref_to_unranked(%input : memref<2x3xf32>, + %shape : memref<2xindex>) { + %dyn_size_shape = memref_cast %shape : memref<2xindex> to memref + %output = memref_reshape %input(%dyn_size_shape) + : (memref<2x3xf32>, memref) -> memref<*xf32> + + call @print_memref_f32(%output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] + return +} + +func @reshape_unranked_memref_to_unranked(%input : memref<2x3xf32>, + %shape : memref<2xindex>) { + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %dyn_size_shape = memref_cast %shape : memref<2xindex> to memref + %output = memref_reshape %input(%dyn_size_shape) + : (memref<2x3xf32>, memref) -> memref<*xf32> + + call @print_memref_f32(%output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [3, 2] strides = [2, 1] data = + // CHECK: [0, 1], + // CHECK: [2, 3], + // CHECK: [4, 5] return }