diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h index dbf38de56a0a..919a93ac84a2 100644 --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -1,647 +1,647 @@ //===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Provides a dialect conversion targeting the LLVM IR dialect. By default, it // converts Standard ops and types and provides hooks for dialect-specific // extensions to the conversion. // //===----------------------------------------------------------------------===// #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Transforms/DialectConversion.h" namespace llvm { class IntegerType; class LLVMContext; class Module; class Type; } // namespace llvm namespace mlir { class BaseMemRefType; class ComplexType; class LLVMTypeConverter; class UnrankedMemRefType; namespace LLVM { class LLVMDialect; class LLVMType; class LLVMPointerType; } // namespace LLVM /// Callback to convert function argument types. It converts a MemRef function /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing /// the rank and a pointer to a descriptor struct. LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result); /// Callback to convert function argument types. It converts MemRef function /// arguments to bare pointers to the MemRef element type. LogicalResult barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result); /// Conversion from types in the Standard dialect to the LLVM IR dialect. class LLVMTypeConverter : public TypeConverter { /// Give structFuncArgTypeConverter access to memref-specific functions. friend LogicalResult structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result); public: using TypeConverter::convertType; /// Create an LLVMTypeConverter using the default LowerToLLVMOptions. LLVMTypeConverter(MLIRContext *ctx); /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options); /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` /// is populated with argument mapping. LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic, SignatureConversion &result); /// Convert a non-empty list of types to be returned from a function into a /// supported LLVM IR type. In particular, if more than one value is /// returned, create an LLVM IR structure type with elements that correspond /// to each of the MLIR types converted with `convertType`. Type packFunctionResults(ArrayRef types); /// Convert a type in the context of the default or bare pointer calling /// convention. Calling convention sensitive types, such as MemRefType and /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. Type convertCallingConventionType(Type type); /// Promote the bare pointers in 'values' that resulted from memrefs to /// descriptors. 'stdTypes' holds the types of 'values' before the conversion /// to the LLVM-IR dialect (i.e., MemRefType, or any other Standard type). void promoteBarePtrsToDescriptors(ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, SmallVectorImpl &values); /// Returns the MLIR context. MLIRContext &getContext(); /// Returns the LLVM dialect. LLVM::LLVMDialect *getDialect() { return llvmDialect; } const LowerToLLVMOptions &getOptions() const { return options; } /// Promote the LLVM representation of all operands including promoting MemRef /// descriptors to stack and use pointers to struct to avoid the complexity /// of the platform-specific C/C++ ABI lowering related to struct argument /// passing. SmallVector promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder); /// Promote the LLVM struct representation of one MemRef descriptor to stack /// and use pointer to struct to avoid the complexity of the platform-specific /// C/C++ ABI lowering related to struct argument passing. Value promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder); /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type); /// Returns the data layout to use during and after conversion. const llvm::DataLayout &getDataLayout() { return options.dataLayout; } /// Gets the LLVM representation of the index type. The returned type is an /// integer type with the size configured for this type converter. LLVM::LLVMType getIndexType(); /// Gets the bitwidth of the index type when converted to LLVM. unsigned getIndexTypeBitwidth() { return options.indexBitwidth; } /// Gets the pointer bitwidth. unsigned getPointerBitwidth(unsigned addressSpace = 0); protected: /// Pointer to the LLVM dialect. LLVM::LLVMDialect *llvmDialect; private: /// Convert a function type. The arguments and results are converted one by /// one. Additionally, if the function returns more than one value, pack the /// results into an LLVM IR structure type so that the converted function type /// returns at most one result. Type convertFunctionType(FunctionType type); /// Convert the index type. Uses llvmModule data layout to create an integer /// of the pointer bitwidth. Type convertIndexType(IndexType type); /// Convert an integer type `i*` to `!llvm<"i*">`. Type convertIntegerType(IntegerType type); /// Convert a floating point type: `f16` to `!llvm.half`, `f32` to /// `!llvm.float` and `f64` to `!llvm.double`. `bf16` is not supported /// by LLVM. Type convertFloatType(FloatType type); /// Convert complex number type: `complex` to `!llvm<"{ half, half }">`, /// `complex` to `!llvm<"{ float, float }">`, and `complex` to /// `!llvm<"{ double, double }">`. `complex` is not supported. Type convertComplexType(ComplexType type); /// Convert a memref type into an LLVM type that captures the relevant data. Type convertMemRefType(MemRefType type); /// Convert a memref type into a list of LLVM IR types that will form the /// memref descriptor. If `unpackAggregates` is true the `sizes` and `strides` /// arrays in the descriptors are unpacked to individual index-typed elements, /// else they are are kept as rank-sized arrays of index type. In particular, /// the list will contain: /// - two pointers to the memref element type, followed by /// - an index-typed offset, followed by /// - (if unpackAggregates = true) /// - one index-typed size per dimension of the memref, followed by /// - one index-typed stride per dimension of the memref. /// - (if unpackArrregates = false) /// - one rank-sized array of index-type for the size of each dimension /// - one rank-sized array of index-type for the stride of each dimension /// /// For example, memref is converted to the following list: /// - `!llvm<"float*">` (allocated pointer), /// - `!llvm<"float*">` (aligned pointer), /// - `!llvm.i64` (offset), /// - `!llvm.i64`, `!llvm.i64` (sizes), /// - `!llvm.i64`, `!llvm.i64` (strides). /// These types can be recomposed to a memref descriptor struct. SmallVector getMemRefDescriptorFields(MemRefType type, bool unpackAggregates); /// Convert an unranked memref type into a list of non-aggregate LLVM IR types /// that will form the unranked memref descriptor. In particular, this list /// contains: /// - an integer rank, followed by /// - a pointer to the memref descriptor struct. /// For example, memref<*xf32> is converted to the following list: /// !llvm.i64 (rank) /// !llvm<"i8*"> (type-erased pointer). /// These types can be recomposed to a unranked memref descriptor struct. SmallVector getUnrankedMemRefDescriptorFields(); // Convert an unranked memref type to an LLVM type that captures the // runtime rank and a pointer to the static ranked memref desc Type convertUnrankedMemRefType(UnrankedMemRefType type); /// Convert a memref type to a bare pointer to the memref element type. Type convertMemRefToBarePtr(BaseMemRefType type); // Convert a 1D vector type into an LLVM vector type. Type convertVectorType(VectorType type); /// Options for customizing the llvm lowering. LowerToLLVMOptions options; }; /// Helper class to produce LLVM dialect operations extracting or inserting /// values to a struct. class StructBuilder { public: /// Construct a helper for the given value. explicit StructBuilder(Value v); /// Builds IR creating an `undef` value of the descriptor type. static StructBuilder undef(OpBuilder &builder, Location loc, Type descriptorType); /*implicit*/ operator Value() { return value; } protected: // LLVM value Value value; // Cached struct type. Type structType; protected: /// Builds IR to extract a value from the struct at position pos Value extractPtr(OpBuilder &builder, Location loc, unsigned pos); /// Builds IR to set a value in the struct at position pos void setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr); }; class ComplexStructBuilder : public StructBuilder { public: /// Construct a helper for the given complex number value. using StructBuilder::StructBuilder; /// Build IR creating an `undef` value of the complex number type. static ComplexStructBuilder undef(OpBuilder &builder, Location loc, Type type); // Build IR extracting the real value from the complex number struct. Value real(OpBuilder &builder, Location loc); // Build IR inserting the real value into the complex number struct. void setReal(OpBuilder &builder, Location loc, Value real); // Build IR extracting the imaginary value from the complex number struct. Value imaginary(OpBuilder &builder, Location loc); // Build IR inserting the imaginary value into the complex number struct. void setImaginary(OpBuilder &builder, Location loc, Value imaginary); }; /// Helper class to produce LLVM dialect operations extracting or inserting /// elements of a MemRef descriptor. Wraps a Value pointing to the descriptor. /// The Value may be null, in which case none of the operations are valid. class MemRefDescriptor : public StructBuilder { public: /// Construct a helper for the given descriptor value. explicit MemRefDescriptor(Value descriptor); /// Builds IR creating an `undef` value of the descriptor type. static MemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType); /// Builds IR creating a MemRef descriptor that represents `type` and /// populates it with static shape and stride information extracted from the /// type. static MemRefDescriptor fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory); /// Builds IR extracting the allocated pointer from the descriptor. Value allocatedPtr(OpBuilder &builder, Location loc); /// Builds IR inserting the allocated pointer into the descriptor. void setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr); /// Builds IR extracting the aligned pointer from the descriptor. Value alignedPtr(OpBuilder &builder, Location loc); /// Builds IR inserting the aligned pointer into the descriptor. void setAlignedPtr(OpBuilder &builder, Location loc, Value ptr); /// Builds IR extracting the offset from the descriptor. Value offset(OpBuilder &builder, Location loc); /// Builds IR inserting the offset into the descriptor. void setOffset(OpBuilder &builder, Location loc, Value offset); void setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset); /// Builds IR extracting the pos-th size from the descriptor. Value size(OpBuilder &builder, Location loc, unsigned pos); Value size(OpBuilder &builder, Location loc, Value pos, int64_t rank); /// Builds IR inserting the pos-th size into the descriptor void setSize(OpBuilder &builder, Location loc, unsigned pos, Value size); void setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size); /// Builds IR extracting the pos-th size from the descriptor. Value stride(OpBuilder &builder, Location loc, unsigned pos); /// Builds IR inserting the pos-th stride into the descriptor void setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride); void setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride); /// Returns the (LLVM) pointer type this descriptor contains. LLVM::LLVMPointerType getElementPtrType(); /// Builds IR populating a MemRef descriptor structure from a list of /// individual values composing that descriptor, in the following order: /// - allocated pointer; /// - aligned pointer; /// - offset; /// - sizes; /// - shapes; /// where is the MemRef rank as provided in `type`. static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, MemRefType type, ValueRange values); /// Builds IR extracting individual elements of a MemRef descriptor structure /// and returning them as `results` list. static void unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl &results); /// Returns the number of non-aggregate values that would be produced by /// `unpack`. static unsigned getNumUnpackedValues(MemRefType type); private: // Cached index type. Type indexType; }; /// Helper class allowing the user to access a range of Values that correspond /// to an unpacked memref descriptor using named accessors. This does not own /// the values. class MemRefDescriptorView { public: /// Constructs the view from a range of values. Infers the rank from the size /// of the range. explicit MemRefDescriptorView(ValueRange range); /// Returns the allocated pointer Value. Value allocatedPtr(); /// Returns the aligned pointer Value. Value alignedPtr(); /// Returns the offset Value. Value offset(); /// Returns the pos-th size Value. Value size(unsigned pos); /// Returns the pos-th stride Value. Value stride(unsigned pos); private: /// Rank of the memref the descriptor is pointing to. int rank; /// Underlying range of Values. ValueRange elements; }; class UnrankedMemRefDescriptor : public StructBuilder { public: /// Construct a helper for the given descriptor value. explicit UnrankedMemRefDescriptor(Value descriptor); /// Builds IR creating an `undef` value of the descriptor type. static UnrankedMemRefDescriptor undef(OpBuilder &builder, Location loc, Type descriptorType); /// Builds IR extracting the rank from the descriptor Value rank(OpBuilder &builder, Location loc); /// Builds IR setting the rank in the descriptor void setRank(OpBuilder &builder, Location loc, Value value); /// Builds IR extracting ranked memref descriptor ptr Value memRefDescPtr(OpBuilder &builder, Location loc); /// Builds IR setting ranked memref descriptor ptr void setMemRefDescPtr(OpBuilder &builder, Location loc, Value value); /// Builds IR populating an unranked MemRef descriptor structure from a list /// of individual constituent values in the following order: /// - rank of the memref; /// - pointer to the memref descriptor. static Value pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values); /// Builds IR extracting individual elements that compose an unranked memref /// descriptor and returns them as `results` list. static void unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl &results); /// Returns the number of non-aggregate values that would be produced by /// `unpack`. static unsigned getNumUnpackedValues() { return 2; } /// Builds IR computing the sizes in bytes (suitable for opaque allocation) /// and appends the corresponding values into `sizes`. static void computeSizes(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef values, SmallVectorImpl &sizes); /// TODO: The following accessors don't take alignment rules between elements /// of the descriptor struct into account. For some architectures, it might be /// necessary to extend them and to use `llvm::DataLayout` contained in /// `LLVMTypeConverter`. /// Builds IR extracting the allocated pointer from the descriptor. static Value allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); /// Builds IR inserting the allocated pointer into the descriptor. static void setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, Value allocatedPtr); /// Builds IR extracting the aligned pointer from the descriptor. static Value alignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); /// Builds IR inserting the aligned pointer into the descriptor. static void setAlignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, Value alignedPtr); /// Builds IR extracting the offset from the descriptor. static Value offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); /// Builds IR inserting the offset into the descriptor. static void setOffset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, Value offset); /// Builds IR extracting the pointer to the first element of the size array. static Value sizeBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType); /// Builds IR extracting the size[index] from the descriptor. static Value size(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value sizeBasePtr, Value index); /// Builds IR inserting the size[index] into the descriptor. static void setSize(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value sizeBasePtr, Value index, Value size); /// Builds IR extracting the pointer to the first element of the stride array. static Value strideBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank); /// Builds IR extracting the stride[index] from the descriptor. static Value stride(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value strideBasePtr, Value index, Value stride); /// Builds IR inserting the stride[index] into the descriptor. static void setStride(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value strideBasePtr, Value index, Value stride); }; /// Base class for operation conversions targeting the LLVM IR dialect. It /// provides the conversion patterns with access to the LLVMTypeConverter and /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the /// LowerToLLVMOptions by reference meaning the references have to remain alive /// during the entire pattern lifetime. class ConvertToLLVMPattern : public ConversionPattern { public: ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1); protected: /// Returns the LLVM dialect. LLVM::LLVMDialect &getDialect() const; /// Gets the MLIR type wrapping the LLVM integer type whose bit width is /// defined by the used type converter. LLVM::LLVMType getIndexType() const; /// Gets the MLIR type wrapping the LLVM integer type whose bit width /// corresponds to that of a LLVM pointer type. LLVM::LLVMType getIntPtrType(unsigned addressSpace = 0) const; /// Gets the MLIR type wrapping the LLVM void type. LLVM::LLVMType getVoidType() const; /// Get the MLIR type wrapping the LLVM i8* type. LLVM::LLVMType getVoidPtrType() const; /// Create an LLVM dialect operation defining the given index constant. Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const; // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. - Value getStridedElementPtr(Location loc, Type elementTypePtr, - Value descriptor, ValueRange indices, - ArrayRef strides, int64_t offset, + Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, + ValueRange indices, ConversionPatternRewriter &rewriter) const; - /// Returns if the givem memref type is supported. - bool isSupportedMemRefType(MemRefType type) const; - + // Forwards to getStridedElementPtr. TODO: remove. Value getDataPtr(Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const; + /// Returns if the givem memref type is supported. + bool isSupportedMemRefType(MemRefType type) const; + /// Returns the type of a pointer to an element of the memref. Type getElementPtrType(MemRefType type) const; /// Computes sizes, strides and buffer size in bytes of `memRefType` with /// identity layout. Emits constant ops for the static sizes of `memRefType`, /// and uses `dynamicSizes` for the others. Emits instructions to compute /// strides and buffer size from these sizes. /// /// For example, memref<4x?xf32> emits: /// `sizes[0]` = llvm.mlir.constant(4 : index) : !llvm.i64 /// `sizes[1]` = `dynamicSizes[0]` /// `strides[1]` = llvm.mlir.constant(1 : index) : !llvm.i64 /// `strides[0]` = `sizes[0]` /// %size = llvm.mul `sizes[0]`, `sizes[1]` : !llvm.i64 /// %nullptr = llvm.mlir.null : !llvm.ptr /// %gep = llvm.getelementptr %nullptr[%size] /// : (!llvm.ptr, !llvm.i64) -> !llvm.ptr /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to !llvm.i64 void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, ArrayRef dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, SmallVectorImpl &strides, Value &sizeBytes) const; /// Computes the size of type in bytes. Value getSizeInBytes(Location loc, Type type, ConversionPatternRewriter &rewriter) const; /// Computes total number of elements for the given shape. Value getNumElements(Location loc, ArrayRef shape, ConversionPatternRewriter &rewriter) const; /// Creates and populates a canonical memref descriptor struct. MemRefDescriptor createMemRefDescriptor(Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef sizes, ArrayRef strides, ConversionPatternRewriter &rewriter) const; protected: /// Reference to the type converter, with potential extensions. LLVMTypeConverter &typeConverter; }; /// Utility class for operation conversions targeting the LLVM dialect that /// match exactly one source operation. template class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ConvertToLLVMPattern(OpTy::getOperationName(), &typeConverter.getContext(), typeConverter, benefit) {} }; namespace LLVM { namespace detail { /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter); } // namespace detail } // namespace LLVM /// Generic implementation of one-to-one conversion from "SourceOp" to /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent. /// Upholds a convention that multi-result operations get converted into an /// operation returning the LLVM IR structure type, in which case individual /// values must be extracted from using LLVM::ExtractValueOp before being used. template class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = OneToOneConvertToLLVMPattern; /// Converts the type of the result to an LLVM type, pass operands as is, /// preserve attributes. LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), operands, this->typeConverter, rewriter); } }; /// Basic lowering implementation to rewrite Ops with just one result to the /// LLVM Dialect. This supports higher-dimensional vector types. template class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = VectorConvertToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { static_assert( std::is_base_of, SourceOp>::value, "expected single result op"); static_assert(std::is_base_of, SourceOp>::value, "expected same operands and result type"); return LLVM::detail::vectorOneToOneRewrite(op, TargetOp::getOperationName(), operands, this->typeConverter, rewriter); } }; /// Derived class that automatically populates legalization information for /// different LLVM ops. class LLVMConversionTarget : public ConversionTarget { public: explicit LLVMConversionTarget(MLIRContext &ctx); }; } // namespace mlir #endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 870b91034361..17187e933cfa 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1,4177 +1,4187 @@ //===- StandardToLLVM.cpp - Standard to LLVM dialect conversion -----------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to convert MLIR standard and builtin dialects // into the LLVM IR dialect. // //===----------------------------------------------------------------------===// #include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/Utils.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Type.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" #include using namespace mlir; #define PASS_NAME "convert-std-to-llvm" // Extract an LLVM IR type from the LLVM IR dialect type. static LLVM::LLVMType unwrap(Type type) { if (!type) return nullptr; auto *mlirContext = type.getContext(); auto wrappedLLVMType = type.dyn_cast(); if (!wrappedLLVMType) emitError(UnknownLoc::get(mlirContext), "conversion resulted in a non-LLVM type"); return wrappedLLVMType; } /// Callback to convert function argument types. It converts a MemRef function /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing /// the rank and a pointer to a descriptor struct. LogicalResult mlir::structFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { if (auto memref = type.dyn_cast()) { // In signatures, Memref descriptors are expanded into lists of // non-aggregate values. auto converted = converter.getMemRefDescriptorFields(memref, /*unpackAggregates=*/true); if (converted.empty()) return failure(); result.append(converted.begin(), converted.end()); return success(); } if (type.isa()) { auto converted = converter.getUnrankedMemRefDescriptorFields(); if (converted.empty()) return failure(); result.append(converted.begin(), converted.end()); return success(); } auto converted = converter.convertType(type); if (!converted) return failure(); result.push_back(converted); return success(); } /// Callback to convert function argument types. It converts MemRef function /// arguments to bare pointers to the MemRef element type. LogicalResult mlir::barePtrFuncArgTypeConverter(LLVMTypeConverter &converter, Type type, SmallVectorImpl &result) { auto llvmTy = converter.convertCallingConventionType(type); if (!llvmTy) return failure(); result.push_back(llvmTy); return success(); } /// Create an LLVMTypeConverter using default LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) : LLVMTypeConverter(ctx, LowerToLLVMOptions::getDefaultOptions()) {} /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options) : llvmDialect(ctx->getOrLoadDialect()), options(options) { assert(llvmDialect && "LLVM IR dialect is not registered"); if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout) this->options.indexBitwidth = options.dataLayout.getPointerSizeInBits(); // Register conversions for the standard types. addConversion([&](ComplexType type) { return convertComplexType(type); }); addConversion([&](FloatType type) { return convertFloatType(type); }); addConversion([&](FunctionType type) { return convertFunctionType(type); }); addConversion([&](IndexType type) { return convertIndexType(type); }); addConversion([&](IntegerType type) { return convertIntegerType(type); }); addConversion([&](MemRefType type) { return convertMemRefType(type); }); addConversion( [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); addConversion([&](VectorType type) { return convertVectorType(type); }); // LLVMType is legal, so add a pass-through conversion. addConversion([](LLVM::LLVMType type) { return type; }); // Materialization for memrefs creates descriptor structs from individual // values constituting them, when descriptors are used, i.e. more than one // value represents a memref. addArgumentMaterialization( [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() == 1) return llvm::None; return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs); }); addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() == 1) return llvm::None; return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs); }); // Add generic source and target materializations to handle cases where // non-LLVM types persist after an LLVM conversion. addSourceMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() != 1) return llvm::None; // FIXME: These should check LLVM::DialectCastOp can actually be constructed // from the input and result. return builder.create(loc, resultType, inputs[0]) .getResult(); }); addTargetMaterialization([&](OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) -> Optional { if (inputs.size() != 1) return llvm::None; // FIXME: These should check LLVM::DialectCastOp can actually be constructed // from the input and result. return builder.create(loc, resultType, inputs[0]) .getResult(); }); } /// Returns the MLIR context. MLIRContext &LLVMTypeConverter::getContext() { return *getDialect()->getContext(); } LLVM::LLVMType LLVMTypeConverter::getIndexType() { return LLVM::LLVMType::getIntNTy(&getContext(), getIndexTypeBitwidth()); } unsigned LLVMTypeConverter::getPointerBitwidth(unsigned addressSpace) { return options.dataLayout.getPointerSizeInBits(addressSpace); } Type LLVMTypeConverter::convertIndexType(IndexType type) { return getIndexType(); } Type LLVMTypeConverter::convertIntegerType(IntegerType type) { return LLVM::LLVMType::getIntNTy(&getContext(), type.getWidth()); } Type LLVMTypeConverter::convertFloatType(FloatType type) { if (type.isa()) return LLVM::LLVMType::getFloatTy(&getContext()); if (type.isa()) return LLVM::LLVMType::getDoubleTy(&getContext()); if (type.isa()) return LLVM::LLVMType::getHalfTy(&getContext()); if (type.isa()) return LLVM::LLVMType::getBFloatTy(&getContext()); llvm_unreachable("non-float type in convertFloatType"); } // Convert a `ComplexType` to an LLVM type. The result is a complex number // struct with entries for the // 1. real part and for the // 2. imaginary part. static constexpr unsigned kRealPosInComplexNumberStruct = 0; static constexpr unsigned kImaginaryPosInComplexNumberStruct = 1; Type LLVMTypeConverter::convertComplexType(ComplexType type) { auto elementType = convertType(type.getElementType()).cast(); return LLVM::LLVMType::getStructTy(&getContext(), {elementType, elementType}); } // Except for signatures, MLIR function types are converted into LLVM // pointer-to-function types. Type LLVMTypeConverter::convertFunctionType(FunctionType type) { SignatureConversion conversion(type.getNumInputs()); LLVM::LLVMType converted = convertFunctionSignature(type, /*isVariadic=*/false, conversion); return converted.getPointerTo(); } // Function types are converted to LLVM Function types by recursively converting // argument and result types. If MLIR Function has zero results, the LLVM // Function has one VoidType result. If MLIR Function has more than one result, // they are into an LLVM StructType in their order of appearance. LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( FunctionType funcTy, bool isVariadic, LLVMTypeConverter::SignatureConversion &result) { // Select the argument converter depending on the calling convention. auto funcArgConverter = options.useBarePtrCallConv ? barePtrFuncArgTypeConverter : structFuncArgTypeConverter; // Convert argument types one by one and check for errors. for (auto &en : llvm::enumerate(funcTy.getInputs())) { Type type = en.value(); SmallVector converted; if (failed(funcArgConverter(*this, type, converted))) return {}; result.addInputs(en.index(), converted); } SmallVector argTypes; argTypes.reserve(llvm::size(result.getConvertedTypes())); for (Type type : result.getConvertedTypes()) argTypes.push_back(unwrap(type)); // If function does not return anything, create the void result type, // if it returns on element, convert it, otherwise pack the result types into // a struct. LLVM::LLVMType resultType = funcTy.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(&getContext()) : unwrap(packFunctionResults(funcTy.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic); } /// Converts the function type to a C-compatible format, in particular using /// pointers to memref descriptors for arguments. LLVM::LLVMType LLVMTypeConverter::convertFunctionTypeCWrapper(FunctionType type) { SmallVector inputs; for (Type t : type.getInputs()) { auto converted = convertType(t).dyn_cast_or_null(); if (!converted) return {}; if (t.isa()) converted = converted.getPointerTo(); inputs.push_back(converted); } LLVM::LLVMType resultType = type.getNumResults() == 0 ? LLVM::LLVMType::getVoidTy(&getContext()) : unwrap(packFunctionResults(type.getResults())); if (!resultType) return {}; return LLVM::LLVMType::getFunctionTy(resultType, inputs, false); } static constexpr unsigned kAllocatedPtrPosInMemRefDescriptor = 0; static constexpr unsigned kAlignedPtrPosInMemRefDescriptor = 1; static constexpr unsigned kOffsetPosInMemRefDescriptor = 2; static constexpr unsigned kSizePosInMemRefDescriptor = 3; static constexpr unsigned kStridePosInMemRefDescriptor = 4; /// Convert a memref type into a list of LLVM IR types that will form the /// memref descriptor. The result contains the following types: /// 1. The pointer to the allocated data buffer, followed by /// 2. The pointer to the aligned data buffer, followed by /// 3. A lowered `index`-type integer containing the distance between the /// beginning of the buffer and the first element to be accessed through the /// view, followed by /// 4. An array containing as many `index`-type integers as the rank of the /// MemRef: the array represents the size, in number of elements, of the memref /// along the given dimension. For constant MemRef dimensions, the /// corresponding size entry is a constant whose runtime value must match the /// static value, followed by /// 5. A second array containing as many `index`-type integers as the rank of /// the MemRef: the second array represents the "stride" (in tensor abstraction /// sense), i.e. the number of consecutive elements of the underlying buffer. /// TODO: add assertions for the static cases. /// /// If `unpackAggregates` is set to true, the arrays described in (4) and (5) /// are expanded into individual index-type elements. /// /// template /// struct { /// Elem *allocatedPtr; /// Elem *alignedPtr; /// Index offset; /// Index sizes[Rank]; // omitted when rank == 0 /// Index strides[Rank]; // omitted when rank == 0 /// }; SmallVector LLVMTypeConverter::getMemRefDescriptorFields(MemRefType type, bool unpackAggregates) { assert(isStrided(type) && "Non-strided layout maps must have been normalized away"); LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto ptrTy = elementType.getPointerTo(type.getMemorySpace()); auto indexTy = getIndexType(); SmallVector results = {ptrTy, ptrTy, indexTy}; auto rank = type.getRank(); if (rank == 0) return results; if (unpackAggregates) results.insert(results.end(), 2 * rank, indexTy); else results.insert(results.end(), 2, LLVM::LLVMType::getArrayTy(indexTy, rank)); return results; } /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that /// packs the descriptor fields as defined by `getMemRefDescriptorFields`. Type LLVMTypeConverter::convertMemRefType(MemRefType type) { // When converting a MemRefType to a struct with descriptor fields, do not // unpack the `sizes` and `strides` arrays. SmallVector types = getMemRefDescriptorFields(type, /*unpackAggregates=*/false); return LLVM::LLVMType::getStructTy(&getContext(), types); } static constexpr unsigned kRankInUnrankedMemRefDescriptor = 0; static constexpr unsigned kPtrInUnrankedMemRefDescriptor = 1; /// Convert an unranked memref type into a list of non-aggregate LLVM IR types /// that will form the unranked memref descriptor. In particular, the fields /// for an unranked memref descriptor are: /// 1. index-typed rank, the dynamic rank of this MemRef /// 2. void* ptr, pointer to the static ranked MemRef descriptor. This will be /// stack allocated (alloca) copy of a MemRef descriptor that got casted to /// be unranked. SmallVector LLVMTypeConverter::getUnrankedMemRefDescriptorFields() { return {getIndexType(), LLVM::LLVMType::getInt8PtrTy(&getContext())}; } Type LLVMTypeConverter::convertUnrankedMemRefType(UnrankedMemRefType type) { return LLVM::LLVMType::getStructTy(&getContext(), getUnrankedMemRefDescriptorFields()); } /// Convert a memref type to a bare pointer to the memref element type. Type LLVMTypeConverter::convertMemRefToBarePtr(BaseMemRefType type) { if (type.isa()) // Unranked memref is not supported in the bare pointer calling convention. return {}; // Check that the memref has static shape, strides and offset. Otherwise, it // cannot be lowered to a bare pointer. auto memrefTy = type.cast(); if (!memrefTy.hasStaticShape()) return {}; int64_t offset = 0; SmallVector strides; if (failed(getStridesAndOffset(memrefTy, strides, offset))) return {}; for (int64_t stride : strides) if (ShapedType::isDynamicStrideOrOffset(stride)) return {}; if (ShapedType::isDynamicStrideOrOffset(offset)) return {}; LLVM::LLVMType elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; return elementType.getPointerTo(type.getMemorySpace()); } // Convert an n-D vector type to an LLVM vector type via (n-1)-D array type when // n > 1. // For example, `vector<4 x f32>` converts to `!llvm.type<"<4 x float>">` and // `vector<4 x 8 x 16 f32>` converts to `!llvm<"[4 x [8 x <16 x float>]]">`. Type LLVMTypeConverter::convertVectorType(VectorType type) { auto elementType = unwrap(convertType(type.getElementType())); if (!elementType) return {}; auto vectorType = LLVM::LLVMType::getVectorTy(elementType, type.getShape().back()); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMType::getArrayTy(vectorType, shape[i]); return vectorType; } /// Convert a type in the context of the default or bare pointer calling /// convention. Calling convention sensitive types, such as MemRefType and /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. Type LLVMTypeConverter::convertCallingConventionType(Type type) { if (options.useBarePtrCallConv) if (auto memrefTy = type.dyn_cast()) return convertMemRefToBarePtr(memrefTy); return convertType(type); } /// Promote the bare pointers in 'values' that resulted from memrefs to /// descriptors. 'stdTypes' holds they types of 'values' before the conversion /// to the LLVM-IR dialect (i.e., MemRefType, or any other Standard type). void LLVMTypeConverter::promoteBarePtrsToDescriptors( ConversionPatternRewriter &rewriter, Location loc, ArrayRef stdTypes, SmallVectorImpl &values) { assert(stdTypes.size() == values.size() && "The number of types and values doesn't match"); for (unsigned i = 0, end = values.size(); i < end; ++i) if (auto memrefTy = stdTypes[i].dyn_cast()) values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, memrefTy, values[i]); } ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, typeConverter, context), typeConverter(typeConverter) {} //===----------------------------------------------------------------------===// // StructBuilder implementation //===----------------------------------------------------------------------===// StructBuilder::StructBuilder(Value v) : value(v) { assert(value != nullptr && "value cannot be null"); structType = value.getType().dyn_cast(); assert(structType && "expected llvm type"); } Value StructBuilder::extractPtr(OpBuilder &builder, Location loc, unsigned pos) { Type type = structType.cast().getStructElementType(pos); return builder.create(loc, type, value, builder.getI64ArrayAttr(pos)); } void StructBuilder::setPtr(OpBuilder &builder, Location loc, unsigned pos, Value ptr) { value = builder.create(loc, structType, value, ptr, builder.getI64ArrayAttr(pos)); } //===----------------------------------------------------------------------===// // ComplexStructBuilder implementation //===----------------------------------------------------------------------===// ComplexStructBuilder ComplexStructBuilder::undef(OpBuilder &builder, Location loc, Type type) { Value val = builder.create(loc, type.cast()); return ComplexStructBuilder(val); } void ComplexStructBuilder::setReal(OpBuilder &builder, Location loc, Value real) { setPtr(builder, loc, kRealPosInComplexNumberStruct, real); } Value ComplexStructBuilder::real(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRealPosInComplexNumberStruct); } void ComplexStructBuilder::setImaginary(OpBuilder &builder, Location loc, Value imaginary) { setPtr(builder, loc, kImaginaryPosInComplexNumberStruct, imaginary); } Value ComplexStructBuilder::imaginary(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kImaginaryPosInComplexNumberStruct); } //===----------------------------------------------------------------------===// // MemRefDescriptor implementation //===----------------------------------------------------------------------===// /// Construct a helper for the given descriptor value. MemRefDescriptor::MemRefDescriptor(Value descriptor) : StructBuilder(descriptor) { assert(value != nullptr && "value cannot be null"); indexType = value.getType().cast().getStructElementType( kOffsetPosInMemRefDescriptor); } /// Builds IR creating an `undef` value of the descriptor type. MemRefDescriptor MemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { Value descriptor = builder.create(loc, descriptorType.cast()); return MemRefDescriptor(descriptor); } /// Builds IR creating a MemRef descriptor that represents `type` and /// populates it with static shape and stride information extracted from the /// type. MemRefDescriptor MemRefDescriptor::fromStaticShape(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, MemRefType type, Value memory) { assert(type.hasStaticShape() && "unexpected dynamic shape"); // Extract all strides and offsets and verify they are static. int64_t offset; SmallVector strides; auto result = getStridesAndOffset(type, strides, offset); (void)result; assert(succeeded(result) && "unexpected failure in stride computation"); assert(offset != MemRefType::getDynamicStrideOrOffset() && "expected static offset"); assert(!llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) && "expected static strides"); auto convertedType = typeConverter.convertType(type); assert(convertedType && "unexpected failure in memref type conversion"); auto descr = MemRefDescriptor::undef(builder, loc, convertedType); descr.setAllocatedPtr(builder, loc, memory); descr.setAlignedPtr(builder, loc, memory); descr.setConstantOffset(builder, loc, offset); // Fill in sizes and strides for (unsigned i = 0, e = type.getRank(); i != e; ++i) { descr.setConstantSize(builder, loc, i, type.getDimSize(i)); descr.setConstantStride(builder, loc, i, strides[i]); } return descr; } /// Builds IR extracting the allocated pointer from the descriptor. Value MemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor); } /// Builds IR inserting the allocated pointer into the descriptor. void MemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, Value ptr) { setPtr(builder, loc, kAllocatedPtrPosInMemRefDescriptor, ptr); } /// Builds IR extracting the aligned pointer from the descriptor. Value MemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor); } /// Builds IR inserting the aligned pointer into the descriptor. void MemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, Value ptr) { setPtr(builder, loc, kAlignedPtrPosInMemRefDescriptor, ptr); } // Creates a constant Op producing a value of `resultType` from an index-typed // integer attribute. static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value) { return builder.create( loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); } /// Builds IR extracting the offset from the descriptor. Value MemRefDescriptor::offset(OpBuilder &builder, Location loc) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setOffset(OpBuilder &builder, Location loc, Value offset) { value = builder.create( loc, structType, value, offset, builder.getI64ArrayAttr(kOffsetPosInMemRefDescriptor)); } /// Builds IR inserting the offset into the descriptor. void MemRefDescriptor::setConstantOffset(OpBuilder &builder, Location loc, uint64_t offset) { setOffset(builder, loc, createIndexAttrConstant(builder, loc, indexType, offset)); } /// Builds IR extracting the pos-th size from the descriptor. Value MemRefDescriptor::size(OpBuilder &builder, Location loc, unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos, int64_t rank) { auto indexTy = indexType.cast(); auto indexPtrTy = indexTy.getPointerTo(); auto arrayTy = LLVM::LLVMType::getArrayTy(indexTy, rank); auto arrayPtrTy = arrayTy.getPointerTo(); // Copy size values to stack-allocated memory. auto zero = createIndexAttrConstant(builder, loc, indexType, 0); auto one = createIndexAttrConstant(builder, loc, indexType, 1); auto sizes = builder.create( loc, arrayTy, value, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor})); auto sizesPtr = builder.create(loc, arrayPtrTy, one, /*alignment=*/0); builder.create(loc, sizes, sizesPtr); // Load an return size value of interest. auto resultPtr = builder.create(loc, indexPtrTy, sizesPtr, ValueRange({zero, pos})); return builder.create(loc, resultPtr); } /// Builds IR inserting the pos-th size into the descriptor void MemRefDescriptor::setSize(OpBuilder &builder, Location loc, unsigned pos, Value size) { value = builder.create( loc, structType, value, size, builder.getI64ArrayAttr({kSizePosInMemRefDescriptor, pos})); } void MemRefDescriptor::setConstantSize(OpBuilder &builder, Location loc, unsigned pos, uint64_t size) { setSize(builder, loc, pos, createIndexAttrConstant(builder, loc, indexType, size)); } /// Builds IR extracting the pos-th stride from the descriptor. Value MemRefDescriptor::stride(OpBuilder &builder, Location loc, unsigned pos) { return builder.create( loc, indexType, value, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } /// Builds IR inserting the pos-th stride into the descriptor void MemRefDescriptor::setStride(OpBuilder &builder, Location loc, unsigned pos, Value stride) { value = builder.create( loc, structType, value, stride, builder.getI64ArrayAttr({kStridePosInMemRefDescriptor, pos})); } void MemRefDescriptor::setConstantStride(OpBuilder &builder, Location loc, unsigned pos, uint64_t stride) { setStride(builder, loc, pos, createIndexAttrConstant(builder, loc, indexType, stride)); } LLVM::LLVMPointerType MemRefDescriptor::getElementPtrType() { return value.getType() .cast() .getStructElementType(kAlignedPtrPosInMemRefDescriptor) .cast(); } /// Creates a MemRef descriptor structure from a list of individual values /// composing that descriptor, in the following order: /// - allocated pointer; /// - aligned pointer; /// - offset; /// - sizes; /// - shapes; /// where is the MemRef rank as provided in `type`. Value MemRefDescriptor::pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, MemRefType type, ValueRange values) { Type llvmType = converter.convertType(type); auto d = MemRefDescriptor::undef(builder, loc, llvmType); d.setAllocatedPtr(builder, loc, values[kAllocatedPtrPosInMemRefDescriptor]); d.setAlignedPtr(builder, loc, values[kAlignedPtrPosInMemRefDescriptor]); d.setOffset(builder, loc, values[kOffsetPosInMemRefDescriptor]); int64_t rank = type.getRank(); for (unsigned i = 0; i < rank; ++i) { d.setSize(builder, loc, i, values[kSizePosInMemRefDescriptor + i]); d.setStride(builder, loc, i, values[kSizePosInMemRefDescriptor + rank + i]); } return d; } /// Builds IR extracting individual elements of a MemRef descriptor structure /// and returning them as `results` list. void MemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, MemRefType type, SmallVectorImpl &results) { int64_t rank = type.getRank(); results.reserve(results.size() + getNumUnpackedValues(type)); MemRefDescriptor d(packed); results.push_back(d.allocatedPtr(builder, loc)); results.push_back(d.alignedPtr(builder, loc)); results.push_back(d.offset(builder, loc)); for (int64_t i = 0; i < rank; ++i) results.push_back(d.size(builder, loc, i)); for (int64_t i = 0; i < rank; ++i) results.push_back(d.stride(builder, loc, i)); } /// Returns the number of non-aggregate values that would be produced by /// `unpack`. unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) { // Two pointers, offset, sizes, shapes. return 3 + 2 * type.getRank(); } //===----------------------------------------------------------------------===// // MemRefDescriptorView implementation. //===----------------------------------------------------------------------===// MemRefDescriptorView::MemRefDescriptorView(ValueRange range) : rank((range.size() - kSizePosInMemRefDescriptor) / 2), elements(range) {} Value MemRefDescriptorView::allocatedPtr() { return elements[kAllocatedPtrPosInMemRefDescriptor]; } Value MemRefDescriptorView::alignedPtr() { return elements[kAlignedPtrPosInMemRefDescriptor]; } Value MemRefDescriptorView::offset() { return elements[kOffsetPosInMemRefDescriptor]; } Value MemRefDescriptorView::size(unsigned pos) { return elements[kSizePosInMemRefDescriptor + pos]; } Value MemRefDescriptorView::stride(unsigned pos) { return elements[kSizePosInMemRefDescriptor + rank + pos]; } //===----------------------------------------------------------------------===// // UnrankedMemRefDescriptor implementation //===----------------------------------------------------------------------===// /// Construct a helper for the given descriptor value. UnrankedMemRefDescriptor::UnrankedMemRefDescriptor(Value descriptor) : StructBuilder(descriptor) {} /// Builds IR creating an `undef` value of the descriptor type. UnrankedMemRefDescriptor UnrankedMemRefDescriptor::undef(OpBuilder &builder, Location loc, Type descriptorType) { Value descriptor = builder.create(loc, descriptorType.cast()); return UnrankedMemRefDescriptor(descriptor); } Value UnrankedMemRefDescriptor::rank(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kRankInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setRank(OpBuilder &builder, Location loc, Value v) { setPtr(builder, loc, kRankInUnrankedMemRefDescriptor, v); } Value UnrankedMemRefDescriptor::memRefDescPtr(OpBuilder &builder, Location loc) { return extractPtr(builder, loc, kPtrInUnrankedMemRefDescriptor); } void UnrankedMemRefDescriptor::setMemRefDescPtr(OpBuilder &builder, Location loc, Value v) { setPtr(builder, loc, kPtrInUnrankedMemRefDescriptor, v); } /// Builds IR populating an unranked MemRef descriptor structure from a list /// of individual constituent values in the following order: /// - rank of the memref; /// - pointer to the memref descriptor. Value UnrankedMemRefDescriptor::pack(OpBuilder &builder, Location loc, LLVMTypeConverter &converter, UnrankedMemRefType type, ValueRange values) { Type llvmType = converter.convertType(type); auto d = UnrankedMemRefDescriptor::undef(builder, loc, llvmType); d.setRank(builder, loc, values[kRankInUnrankedMemRefDescriptor]); d.setMemRefDescPtr(builder, loc, values[kPtrInUnrankedMemRefDescriptor]); return d; } /// Builds IR extracting individual elements that compose an unranked memref /// descriptor and returns them as `results` list. void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, Value packed, SmallVectorImpl &results) { UnrankedMemRefDescriptor d(packed); results.reserve(results.size() + 2); results.push_back(d.rank(builder, loc)); results.push_back(d.memRefDescPtr(builder, loc)); } void UnrankedMemRefDescriptor::computeSizes( OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, ArrayRef values, SmallVectorImpl &sizes) { if (values.empty()) return; // Cache the index type. LLVM::LLVMType indexType = typeConverter.getIndexType(); // Initialize shared constants. Value one = createIndexAttrConstant(builder, loc, indexType, 1); Value two = createIndexAttrConstant(builder, loc, indexType, 2); Value pointerSize = createIndexAttrConstant( builder, loc, indexType, ceilDiv(typeConverter.getPointerBitwidth(), 8)); Value indexSize = createIndexAttrConstant(builder, loc, indexType, ceilDiv(typeConverter.getIndexTypeBitwidth(), 8)); sizes.reserve(sizes.size() + values.size()); for (UnrankedMemRefDescriptor desc : values) { // Emit IR computing the memory necessary to store the descriptor. This // assumes the descriptor to be // { type*, type*, index, index[rank], index[rank] } // and densely packed, so the total size is // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). // TODO: consider including the actual size (including eventual padding due // to data layout) into the unranked descriptor. Value doublePointerSize = builder.create(loc, indexType, two, pointerSize); // (1 + 2 * rank) * sizeof(index) Value rank = desc.rank(builder, loc); Value doubleRank = builder.create(loc, indexType, two, rank); Value doubleRankIncremented = builder.create(loc, indexType, doubleRank, one); Value rankIndexSize = builder.create( loc, indexType, doubleRankIncremented, indexSize); // Total allocation size. Value allocationSize = builder.create( loc, indexType, doublePointerSize, rankIndexSize); sizes.push_back(allocationSize); } } Value UnrankedMemRefDescriptor::allocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); return builder.create(loc, elementPtrPtr); } void UnrankedMemRefDescriptor::setAllocatedPtr(OpBuilder &builder, Location loc, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, Value allocatedPtr) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); builder.create(loc, allocatedPtr, elementPtrPtr); } Value UnrankedMemRefDescriptor::alignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); Value one = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); Value alignedGep = builder.create( loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); return builder.create(loc, alignedGep); } void UnrankedMemRefDescriptor::setAlignedPtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, Value alignedPtr) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); Value one = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 1); Value alignedGep = builder.create( loc, elemPtrPtrType, elementPtrPtr, ValueRange({one})); builder.create(loc, alignedPtr, alignedGep); } Value UnrankedMemRefDescriptor::offset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); Value two = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); Value offsetGep = builder.create( loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); offsetGep = builder.create( loc, typeConverter.getIndexType().getPointerTo(), offsetGep); return builder.create(loc, offsetGep); } void UnrankedMemRefDescriptor::setOffset(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType, Value offset) { Value elementPtrPtr = builder.create(loc, elemPtrPtrType, memRefDescPtr); Value two = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 2); Value offsetGep = builder.create( loc, elemPtrPtrType, elementPtrPtr, ValueRange({two})); offsetGep = builder.create( loc, typeConverter.getIndexType().getPointerTo(), offsetGep); builder.create(loc, offset, offsetGep); } Value UnrankedMemRefDescriptor::sizeBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value memRefDescPtr, LLVM::LLVMType elemPtrPtrType) { LLVM::LLVMType elemPtrTy = elemPtrPtrType.getPointerElementTy(); LLVM::LLVMType indexTy = typeConverter.getIndexType(); LLVM::LLVMType structPtrTy = LLVM::LLVMType::getStructTy(elemPtrTy, elemPtrTy, indexTy, indexTy) .getPointerTo(); Value structPtr = builder.create(loc, structPtrTy, memRefDescPtr); LLVM::LLVMType int32_type = unwrap(typeConverter.convertType(builder.getI32Type())); Value zero = createIndexAttrConstant(builder, loc, typeConverter.getIndexType(), 0); Value three = builder.create(loc, int32_type, builder.getI32IntegerAttr(3)); return builder.create(loc, indexTy.getPointerTo(), structPtr, ValueRange({zero, three})); } Value UnrankedMemRefDescriptor::size(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value sizeBasePtr, Value index) { LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({index})); return builder.create(loc, sizeStoreGep); } void UnrankedMemRefDescriptor::setSize(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value sizeBasePtr, Value index, Value size) { LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); Value sizeStoreGep = builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({index})); builder.create(loc, size, sizeStoreGep); } Value UnrankedMemRefDescriptor::strideBasePtr(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, Value sizeBasePtr, Value rank) { LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); return builder.create(loc, indexPtrTy, sizeBasePtr, ValueRange({rank})); } Value UnrankedMemRefDescriptor::stride(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value strideBasePtr, Value index, Value stride) { LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); Value strideStoreGep = builder.create( loc, indexPtrTy, strideBasePtr, ValueRange({index})); return builder.create(loc, strideStoreGep); } void UnrankedMemRefDescriptor::setStride(OpBuilder &builder, Location loc, LLVMTypeConverter typeConverter, Value strideBasePtr, Value index, Value stride) { LLVM::LLVMType indexPtrTy = typeConverter.getIndexType().getPointerTo(); Value strideStoreGep = builder.create( loc, indexPtrTy, strideBasePtr, ValueRange({index})); builder.create(loc, stride, strideStoreGep); } LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { return *typeConverter.getDialect(); } LLVM::LLVMType ConvertToLLVMPattern::getIndexType() const { return typeConverter.getIndexType(); } LLVM::LLVMType ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { return LLVM::LLVMType::getIntNTy( &typeConverter.getContext(), typeConverter.getPointerBitwidth(addressSpace)); } LLVM::LLVMType ConvertToLLVMPattern::getVoidType() const { return LLVM::LLVMType::getVoidTy(&typeConverter.getContext()); } LLVM::LLVMType ConvertToLLVMPattern::getVoidPtrType() const { return LLVM::LLVMType::getInt8PtrTy(&typeConverter.getContext()); } Value ConvertToLLVMPattern::createIndexConstant( ConversionPatternRewriter &builder, Location loc, uint64_t value) const { return createIndexAttrConstant(builder, loc, getIndexType(), value); } Value ConvertToLLVMPattern::getStridedElementPtr( - Location loc, Type elementTypePtr, Value descriptor, ValueRange indices, - ArrayRef strides, int64_t offset, + Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { - MemRefDescriptor memRefDescriptor(descriptor); + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + assert(succeeded(successStrides) && "unexpected non-strided memref"); + (void)successStrides; + + MemRefDescriptor memRefDescriptor(memRefDesc); Value base = memRefDescriptor.alignedPtr(rewriter, loc); - Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset() - ? memRefDescriptor.offset(rewriter, loc) - : createIndexConstant(rewriter, loc, offset); + + Value index; + if (offset != 0) // Skip if offset is zero. + index = offset == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.offset(rewriter, loc) + : createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { - Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset() - ? memRefDescriptor.stride(rewriter, loc, i) - : createIndexConstant(rewriter, loc, strides[i]); - Value additionalOffset = - rewriter.create(loc, indices[i], stride); - offsetValue = - rewriter.create(loc, offsetValue, additionalOffset); + Value increment = indices[i]; + if (strides[i] != 1) { // Skip if stride is 1. + Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset() + ? memRefDescriptor.stride(rewriter, loc, i) + : createIndexConstant(rewriter, loc, strides[i]); + increment = rewriter.create(loc, increment, stride); + } + index = + index ? rewriter.create(loc, index, increment) : increment; } - return rewriter.create(loc, elementTypePtr, base, offsetValue); + + LLVM::LLVMType elementPtrType = memRefDescriptor.getElementPtrType(); + return index ? rewriter.create(loc, elementPtrType, base, index) + : base; } Value ConvertToLLVMPattern::getDataPtr( Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { - LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementPtrType(); - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - assert(succeeded(successStrides) && "unexpected non-strided memref"); - (void)successStrides; - return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides, - offset, rewriter); + return getStridedElementPtr(loc, type, memRefDesc, indices, rewriter); } // Check if the MemRefType `type` is supported by the lowering. We currently // only support memrefs with identity maps. bool ConvertToLLVMPattern::isSupportedMemRefType(MemRefType type) const { if (!typeConverter.convertType(type.getElementType())) return false; return type.getAffineMaps().empty() || llvm::all_of(type.getAffineMaps(), [](AffineMap map) { return map.isIdentity(); }); } Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { auto elementType = type.getElementType(); auto structElementType = unwrap(typeConverter.convertType(elementType)); return structElementType.getPointerTo(type.getMemorySpace()); } void ConvertToLLVMPattern::getMemRefDescriptorSizes( Location loc, MemRefType memRefType, ArrayRef dynamicSizes, ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, SmallVectorImpl &strides, Value &sizeBytes) const { assert(isSupportedMemRefType(memRefType) && "layout maps must have been normalized away"); sizes.reserve(memRefType.getRank()); unsigned dynamicIndex = 0; for (int64_t size : memRefType.getShape()) { sizes.push_back(size == ShapedType::kDynamicSize ? dynamicSizes[dynamicIndex++] : createIndexConstant(rewriter, loc, size)); } // Strides: iterate sizes in reverse order and multiply. int64_t stride = 1; Value runningStride = createIndexConstant(rewriter, loc, 1); strides.resize(memRefType.getRank()); for (auto i = memRefType.getRank(); i-- > 0;) { strides[i] = runningStride; int64_t size = memRefType.getShape()[i]; if (size == 0) continue; bool useSizeAsStride = stride == 1; if (size == ShapedType::kDynamicSize) stride = ShapedType::kDynamicSize; if (stride != ShapedType::kDynamicSize) stride *= size; if (useSizeAsStride) runningStride = sizes[i]; else if (stride == ShapedType::kDynamicSize) runningStride = rewriter.create(loc, runningStride, sizes[i]); else runningStride = createIndexConstant(rewriter, loc, stride); } // Buffer size in bytes. Type elementPtrType = getElementPtrType(memRefType); Value nullPtr = rewriter.create(loc, elementPtrType); Value gepPtr = rewriter.create( loc, elementPtrType, ArrayRef{nullPtr, runningStride}); sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); } Value ConvertToLLVMPattern::getSizeInBytes( Location loc, Type type, ConversionPatternRewriter &rewriter) const { // Compute the size of an individual element. This emits the MLIR equivalent // of the following sizeof(...) implementation in LLVM IR: // %0 = getelementptr %elementType* null, %indexType 1 // %1 = ptrtoint %elementType* %0 to %indexType // which is a common pattern of getting the size of a type in bytes. auto convertedPtrType = typeConverter.convertType(type).cast().getPointerTo(); auto nullPtr = rewriter.create(loc, convertedPtrType); auto gep = rewriter.create( loc, convertedPtrType, ArrayRef{nullPtr, createIndexConstant(rewriter, loc, 1)}); return rewriter.create(loc, getIndexType(), gep); } Value ConvertToLLVMPattern::getNumElements( Location loc, ArrayRef shape, ConversionPatternRewriter &rewriter) const { // Compute the total number of memref elements. Value numElements = shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); for (unsigned i = 1, e = shape.size(); i < e; ++i) numElements = rewriter.create(loc, numElements, shape[i]); return numElements; } /// Creates and populates the memref descriptor struct given all its fields. MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, ArrayRef sizes, ArrayRef strides, ConversionPatternRewriter &rewriter) const { auto structType = typeConverter.convertType(memRefType); auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); // Field 1: Allocated pointer, used for malloc/free. memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); // Field 2: Actual aligned pointer to payload. memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); // Field 3: Offset in aligned pointer. memRefDescriptor.setOffset(rewriter, loc, createIndexConstant(rewriter, loc, 0)); // Fields 4: Sizes. for (auto en : llvm::enumerate(sizes)) memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); // Field 5: Strides. for (auto en : llvm::enumerate(strides)) memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); return memRefDescriptor; } /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. static void filterFuncAttributes(ArrayRef attrs, bool filterArgAttrs, SmallVectorImpl &result) { for (const auto &attr : attrs) { if (attr.first == SymbolTable::getSymbolAttrName() || attr.first == impl::getTypeAttrName() || attr.first == "std.varargs" || (filterArgAttrs && impl::isArgAttrName(attr.first.strref()))) continue; result.push_back(attr); } } /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. This function can be called from C /// by passing a pointer to a C struct corresponding to a memref descriptor. /// Internally, the auxiliary function unpacks the descriptor into individual /// components and forwards them to `newFuncOp`. static void wrapForExternalCallers(OpBuilder &rewriter, Location loc, LLVMTypeConverter &typeConverter, FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { auto type = funcOp.getType(); SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes); auto wrapperFuncOp = rewriter.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), typeConverter.convertFunctionTypeCWrapper(type), LLVM::Linkage::External, attributes); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(wrapperFuncOp.addEntryBlock()); SmallVector args; for (auto &en : llvm::enumerate(type.getInputs())) { Value arg = wrapperFuncOp.getArgument(en.index()); if (auto memrefType = en.value().dyn_cast()) { Value loaded = rewriter.create(loc, arg); MemRefDescriptor::unpack(rewriter, loc, loaded, memrefType, args); continue; } if (en.value().isa()) { Value loaded = rewriter.create(loc, arg); UnrankedMemRefDescriptor::unpack(rewriter, loc, loaded, args); continue; } args.push_back(wrapperFuncOp.getArgument(en.index())); } auto call = rewriter.create(loc, newFuncOp, args); rewriter.create(loc, call.getResults()); } /// Creates an auxiliary function with pointer-to-memref-descriptor-struct /// arguments instead of unpacked arguments. Creates a body for the (external) /// `newFuncOp` that allocates a memref descriptor on stack, packs the /// individual arguments into this descriptor and passes a pointer to it into /// the auxiliary function. This auxiliary external function is now compatible /// with functions defined in C using pointers to C structs corresponding to a /// memref descriptor. static void wrapExternalFunction(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, FuncOp funcOp, LLVM::LLVMFuncOp newFuncOp) { OpBuilder::InsertionGuard guard(builder); LLVM::LLVMType wrapperType = typeConverter.convertFunctionTypeCWrapper(funcOp.getType()); // This conversion can only fail if it could not convert one of the argument // types. But since it has been applies to a non-wrapper function before, it // should have failed earlier and not reach this point at all. assert(wrapperType && "unexpected type conversion failure"); SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/false, attributes); // Create the auxiliary function. auto wrapperFunc = builder.create( loc, llvm::formatv("_mlir_ciface_{0}", funcOp.getName()).str(), wrapperType, LLVM::Linkage::External, attributes); builder.setInsertionPointToStart(newFuncOp.addEntryBlock()); // Get a ValueRange containing arguments. FunctionType type = funcOp.getType(); SmallVector args; args.reserve(type.getNumInputs()); ValueRange wrapperArgsRange(newFuncOp.getArguments()); // Iterate over the inputs of the original function and pack values into // memref descriptors if the original type is a memref. for (auto &en : llvm::enumerate(type.getInputs())) { Value arg; int numToDrop = 1; auto memRefType = en.value().dyn_cast(); auto unrankedMemRefType = en.value().dyn_cast(); if (memRefType || unrankedMemRefType) { numToDrop = memRefType ? MemRefDescriptor::getNumUnpackedValues(memRefType) : UnrankedMemRefDescriptor::getNumUnpackedValues(); Value packed = memRefType ? MemRefDescriptor::pack(builder, loc, typeConverter, memRefType, wrapperArgsRange.take_front(numToDrop)) : UnrankedMemRefDescriptor::pack( builder, loc, typeConverter, unrankedMemRefType, wrapperArgsRange.take_front(numToDrop)); auto ptrTy = packed.getType().cast().getPointerTo(); Value one = builder.create( loc, typeConverter.convertType(builder.getIndexType()), builder.getIntegerAttr(builder.getIndexType(), 1)); Value allocated = builder.create(loc, ptrTy, one, /*alignment=*/0); builder.create(loc, packed, allocated); arg = allocated; } else { arg = wrapperArgsRange[0]; } args.push_back(arg); wrapperArgsRange = wrapperArgsRange.drop_front(numToDrop); } assert(wrapperArgsRange.empty() && "did not map some of the arguments"); auto call = builder.create(loc, wrapperFunc, args); builder.create(loc, call.getResults()); } namespace { struct FuncOpConversionBase : public ConvertOpToLLVMPattern { protected: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided // to this legalization pattern. LLVM::LLVMFuncOp convertFuncOpToLLVMFuncOp(FuncOp funcOp, ConversionPatternRewriter &rewriter) const { // Convert the original function arguments. They are converted using the // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp.getAttrOfType("std.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = typeConverter.convertFunctionSignature( funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); if (!llvmType) return nullptr; // Propagate argument attributes to all converted arguments obtained after // converting a given original argument. SmallVector attributes; filterFuncAttributes(funcOp.getAttrs(), /*filterArgAttrs=*/true, attributes); for (unsigned i = 0, e = funcOp.getNumArguments(); i < e; ++i) { auto attr = impl::getArgAttrDict(funcOp, i); if (!attr) continue; auto mapping = result.getInputMapping(i); assert(mapping.hasValue() && "unexpected deletion of function argument"); SmallString<8> name; for (size_t j = 0; j < mapping->size; ++j) { impl::getArgAttrName(mapping->inputNo + j, name); attributes.push_back(rewriter.getNamedAttr(name, attr)); } } // Create an LLVM function, use external linkage by default until MLIR // functions have linkage. auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, LLVM::Linkage::External, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter, &result))) return nullptr; return newFuncOp; } }; /// FuncOp legalization pattern that converts MemRef arguments to pointers to /// MemRef descriptors (LLVM struct data types) containing all the MemRef type /// information. static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { FuncOpConversion(LLVMTypeConverter &converter) : FuncOpConversionBase(converter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) return failure(); if (typeConverter.getOptions().emitCWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp, newFuncOp); else wrapForExternalCallers(rewriter, op->getLoc(), typeConverter, funcOp, newFuncOp); } rewriter.eraseOp(op); return success(); } }; /// FuncOp legalization pattern that converts MemRef arguments to bare pointers /// to the MemRef element type. This will impact the calling convention and ABI. struct BarePtrFuncOpConversion : public FuncOpConversionBase { using FuncOpConversionBase::FuncOpConversionBase; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto funcOp = cast(op); // Store the type of memref-typed arguments before the conversion so that we // can promote them to MemRef descriptor at the beginning of the function. SmallVector oldArgTypes = llvm::to_vector<8>(funcOp.getType().getInputs()); auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) return failure(); if (newFuncOp.getBody().empty()) { rewriter.eraseOp(op); return success(); } // Promote bare pointers from memref arguments to memref descriptors at the // beginning of the function so that all the memrefs in the function have a // uniform representation. Block *entryBlock = &newFuncOp.getBody().front(); auto blockArgs = entryBlock->getArguments(); assert(blockArgs.size() == oldArgTypes.size() && "The number of arguments and types doesn't match"); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(entryBlock); for (auto it : llvm::zip(blockArgs, oldArgTypes)) { BlockArgument arg = std::get<0>(it); Type argTy = std::get<1>(it); // Unranked memrefs are not supported in the bare pointer calling // convention. We should have bailed out before in the presence of // unranked memrefs. assert(!argTy.isa() && "Unranked memref is not supported"); auto memrefTy = argTy.dyn_cast(); if (!memrefTy) continue; // Replace barePtr with a placeholder (undef), promote barePtr to a ranked // or unranked memref descriptor and replace placeholder with the last // instruction of the memref descriptor. // TODO: The placeholder is needed to avoid replacing barePtr uses in the // MemRef descriptor instructions. We may want to have a utility in the // rewriter to properly handle this use case. Location loc = op->getLoc(); auto placeholder = rewriter.create(loc, memrefTy); rewriter.replaceUsesOfBlockArgument(arg, placeholder); Value desc = MemRefDescriptor::fromStaticShape( rewriter, loc, typeConverter, memrefTy, arg); rewriter.replaceOp(placeholder, {desc}); } rewriter.eraseOp(op); return success(); } }; //////////////// Support for Lowering operations on n-D vectors //////////////// // Helper struct to "unroll" operations on n-D vectors in terms of operations on // 1-D LLVM vectors. struct NDVectorTypeInfo { // LLVM array struct which encodes n-D vectors. LLVM::LLVMType llvmArrayTy; // LLVM vector type which encodes the inner 1-D vector type. LLVM::LLVMType llvmVectorTy; // Multiplicity of llvmArrayTy to llvmVectorTy. SmallVector arraySizes; }; } // namespace // For >1-D vector types, extracts the necessary information to iterate over all // 1-D subvectors in the underlying llrepresentation of the n-D vector // Iterates on the llvm array type until we hit a non-array type (which is // asserted to be an llvm vector type). static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, LLVMTypeConverter &converter) { assert(vectorType.getRank() > 1 && "expected >1D vector type"); NDVectorTypeInfo info; info.llvmArrayTy = converter.convertType(vectorType).dyn_cast(); if (!info.llvmArrayTy) return info; info.arraySizes.reserve(vectorType.getRank() - 1); auto llvmTy = info.llvmArrayTy; while (llvmTy.isArrayTy()) { info.arraySizes.push_back(llvmTy.getArrayNumElements()); llvmTy = llvmTy.getArrayElementType(); } if (!llvmTy.isVectorTy()) return info; info.llvmVectorTy = llvmTy; return info; } // Express `linearIndex` in terms of coordinates of `basis`. // Returns the empty vector when linearIndex is out of the range [0, P] where // P is the product of all the basis coordinates. // // Prerequisites: // Basis is an array of nonnegative integers (signed type inherited from // vector shape type). static SmallVector getCoordinates(ArrayRef basis, unsigned linearIndex) { SmallVector res; res.reserve(basis.size()); for (unsigned basisElement : llvm::reverse(basis)) { res.push_back(linearIndex % basisElement); linearIndex = linearIndex / basisElement; } if (linearIndex > 0) return {}; std::reverse(res.begin(), res.end()); return res; } // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. template void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, Lambda fun) { unsigned ub = 1; for (auto s : info.arraySizes) ub *= s; for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { auto coords = getCoordinates(info.arraySizes, linearIndex); // Linear index is out of bounds, we are done. if (coords.empty()) break; assert(coords.size() == info.arraySizes.size()); auto position = builder.getI64ArrayAttr(coords); fun(position); } } ////////////// End Support for Lowering operations on n-D vectors ////////////// /// Replaces the given operation "op" with a new operation of type "targetOp" /// and given operands. LogicalResult LLVM::detail::oneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { unsigned numResults = op->getNumResults(); Type packedType; if (numResults != 0) { packedType = typeConverter.packFunctionResults(op->getResultTypes()); if (!packedType) return failure(); } // Create the operation through state since we don't know its C++ type. OperationState state(op->getLoc(), targetOp); state.addTypes(packedType); state.addOperands(operands); state.addAttributes(op->getAttrs()); Operation *newOp = rewriter.createOperation(state); // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) return rewriter.eraseOp(op), success(); if (numResults == 1) return rewriter.replaceOp(op, newOp->getResult(0)), success(); // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. SmallVector results; results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = typeConverter.convertType(op->getResult(i).getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); } rewriter.replaceOp(op, results); return success(); } static LogicalResult handleMultidimensionalVectors( Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, std::function createOperand, ConversionPatternRewriter &rewriter) { auto vectorType = op->getResult(0).getType().dyn_cast(); if (!vectorType) return failure(); auto vectorTypeInfo = extractNDVectorTypeInfo(vectorType, typeConverter); auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; auto llvmArrayTy = operands[0].getType().cast(); if (!llvmVectorTy || llvmArrayTy != vectorTypeInfo.llvmArrayTy) return failure(); auto loc = op->getLoc(); Value desc = rewriter.create(loc, llvmArrayTy); nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { // For this unrolled `position` corresponding to the `linearIndex`^th // element, extract operand vectors SmallVector extractedOperands; for (auto operand : operands) extractedOperands.push_back(rewriter.create( loc, llvmVectorTy, operand, position)); Value newVal = createOperand(llvmVectorTy, extractedOperands); desc = rewriter.create(loc, llvmArrayTy, desc, newVal, position); }); rewriter.replaceOp(op, desc); return success(); } LogicalResult LLVM::detail::vectorOneToOneRewrite( Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { assert(!operands.empty()); // Cannot convert ops if their operands are not of LLVM type. if (!llvm::all_of(operands.getTypes(), [](Type t) { return t.isa(); })) return failure(); auto llvmArrayTy = operands[0].getType().cast(); if (!llvmArrayTy.isArrayTy()) return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); auto callback = [op, targetOp, &rewriter](LLVM::LLVMType llvmVectorTy, ValueRange operands) { OperationState state(op->getLoc(), targetOp); state.addTypes(llvmVectorTy); state.addOperands(operands); state.addAttributes(op->getAttrs()); return rewriter.createOperation(state)->getResult(0); }; return handleMultidimensionalVectors(op, operands, typeConverter, callback, rewriter); } namespace { // Straightforward lowerings. using AbsFOpLowering = VectorConvertToLLVMPattern; using AddFOpLowering = VectorConvertToLLVMPattern; using AddIOpLowering = VectorConvertToLLVMPattern; using AndOpLowering = VectorConvertToLLVMPattern; using CeilFOpLowering = VectorConvertToLLVMPattern; using CopySignOpLowering = VectorConvertToLLVMPattern; using CosOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = VectorConvertToLLVMPattern; using ExpOpLowering = VectorConvertToLLVMPattern; using Exp2OpLowering = VectorConvertToLLVMPattern; using FloorFOpLowering = VectorConvertToLLVMPattern; using Log10OpLowering = VectorConvertToLLVMPattern; using Log2OpLowering = VectorConvertToLLVMPattern; using LogOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; using MulIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = VectorConvertToLLVMPattern; using OrOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; using SelectOpLowering = OneToOneConvertToLLVMPattern; using ShiftLeftOpLowering = OneToOneConvertToLLVMPattern; using SignedDivIOpLowering = VectorConvertToLLVMPattern; using SignedRemIOpLowering = VectorConvertToLLVMPattern; using SignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; using SinOpLowering = VectorConvertToLLVMPattern; using SqrtOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; using SubIOpLowering = VectorConvertToLLVMPattern; using UnsignedDivIOpLowering = VectorConvertToLLVMPattern; using UnsignedRemIOpLowering = VectorConvertToLLVMPattern; using UnsignedShiftRightOpLowering = OneToOneConvertToLLVMPattern; using XOrOpLowering = VectorConvertToLLVMPattern; /// Lower `std.assert`. The default lowering calls the `abort` function if the /// assertion is violated and has no effect otherwise. The failure message is /// ignored by the default lowering but should be propagated by any custom /// lowering. struct AssertOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); AssertOp::Adaptor transformed(operands); // Insert the `abort` declaration if necessary. auto module = op->getParentOfType(); auto abortFunc = module.lookupSymbol("abort"); if (!abortFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); auto abortFuncTy = LLVM::LLVMType::getFunctionTy(getVoidType(), {}, /*isVarArg=*/false); abortFunc = rewriter.create(rewriter.getUnknownLoc(), "abort", abortFuncTy); } // Split block at `assert` operation. Block *opBlock = rewriter.getInsertionBlock(); auto opPosition = rewriter.getInsertionPoint(); Block *continuationBlock = rewriter.splitBlock(opBlock, opPosition); // Generate IR to call `abort`. Block *failureBlock = rewriter.createBlock(opBlock->getParent()); rewriter.create(loc, abortFunc, llvm::None); rewriter.create(loc); // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); rewriter.replaceOpWithNewOp( op, transformed.arg(), continuationBlock, failureBlock); return success(); } }; // Lowerings for operations on complex numbers. struct CreateComplexOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto complexOp = cast(op); CreateComplexOp::Adaptor transformed(operands); // Pack real and imaginary part in a complex number struct. auto loc = op->getLoc(); auto structType = typeConverter.convertType(complexOp.getType()); auto complexStruct = ComplexStructBuilder::undef(rewriter, loc, structType); complexStruct.setReal(rewriter, loc, transformed.real()); complexStruct.setImaginary(rewriter, loc, transformed.imaginary()); rewriter.replaceOp(op, {complexStruct}); return success(); } }; struct ReOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ReOp::Adaptor transformed(operands); // Extract real part from the complex number struct. ComplexStructBuilder complexStruct(transformed.complex()); Value real = complexStruct.real(rewriter, op->getLoc()); rewriter.replaceOp(op, real); return success(); } }; struct ImOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { ImOp::Adaptor transformed(operands); // Extract imaginary part from the complex number struct. ComplexStructBuilder complexStruct(transformed.complex()); Value imaginary = complexStruct.imaginary(rewriter, op->getLoc()); rewriter.replaceOp(op, imaginary); return success(); } }; struct BinaryComplexOperands { std::complex lhs, rhs; }; template BinaryComplexOperands unpackBinaryComplexOperands(OpTy op, ArrayRef operands, ConversionPatternRewriter &rewriter) { auto bop = cast(op); auto loc = bop.getLoc(); typename OpTy::Adaptor transformed(operands); // Extract real and imaginary values from operands. BinaryComplexOperands unpacked; ComplexStructBuilder lhs(transformed.lhs()); unpacked.lhs.real(lhs.real(rewriter, loc)); unpacked.lhs.imag(lhs.imaginary(rewriter, loc)); ComplexStructBuilder rhs(transformed.rhs()); unpacked.rhs.real(rhs.real(rewriter, loc)); unpacked.rhs.imag(rhs.imaginary(rewriter, loc)); return unpacked; } struct AddCFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = cast(operation); auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, operands, rewriter); // Initialize complex number struct for result. auto structType = this->typeConverter.convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to add complex numbers. Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real()); Value imag = rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag()); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); rewriter.replaceOp(op, {result}); return success(); } }; struct SubCFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = cast(operation); auto loc = op.getLoc(); BinaryComplexOperands arg = unpackBinaryComplexOperands(op, operands, rewriter); // Initialize complex number struct for result. auto structType = this->typeConverter.convertType(op.getType()); auto result = ComplexStructBuilder::undef(rewriter, loc, structType); // Emit IR to substract complex numbers. Value real = rewriter.create(loc, arg.lhs.real(), arg.rhs.real()); Value imag = rewriter.create(loc, arg.lhs.imag(), arg.rhs.imag()); result.setReal(rewriter, loc, real); result.setImaginary(rewriter, loc, imag); rewriter.replaceOp(op, {result}); return success(); } }; struct ConstantOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto op = cast(operation); // If constant refers to a function, convert it to "addressof". if (auto symbolRef = op.getValue().dyn_cast()) { auto type = typeConverter.convertType(op.getResult().getType()) .dyn_cast_or_null(); if (!type) return rewriter.notifyMatchFailure(op, "failed to convert result type"); MutableDictionaryAttr attrs(op.getAttrs()); attrs.remove(rewriter.getIdentifier("value")); rewriter.replaceOpWithNewOp( op, type.cast(), symbolRef.getValue(), attrs.getAttrs()); return success(); } // Calling into other scopes (non-flat reference) is not supported in LLVM. if (op.getValue().isa()) return rewriter.notifyMatchFailure( op, "referring to a symbol outside of the current module"); return LLVM::detail::oneToOneRewrite(op, LLVM::ConstantOp::getOperationName(), operands, typeConverter, rewriter); } }; /// Lowering for AllocOp and AllocaOp. struct AllocLikeOpLowering : public ConvertToLLVMPattern { using ConvertToLLVMPattern::createIndexConstant; using ConvertToLLVMPattern::getIndexType; using ConvertToLLVMPattern::getVoidPtrType; using ConvertToLLVMPattern::typeConverter; explicit AllocLikeOpLowering(StringRef opName, LLVMTypeConverter &converter) : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} protected: // Returns 'input' aligned up to 'alignment'. Computes // bumped = input + alignement - 1 // aligned = bumped - bumped % alignment static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, Value input, Value alignment) { Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); Value bump = rewriter.create(loc, alignment, one); Value bumped = rewriter.create(loc, input, bump); Value mod = rewriter.create(loc, bumped, alignment); return rewriter.create(loc, bumped, mod); } // Creates a call to an allocation function with params and casts the // resulting void pointer to ptrType. Value createAllocCall(Location loc, StringRef name, Type ptrType, ArrayRef params, ModuleOp module, ConversionPatternRewriter &rewriter) const { SmallVector paramTypes; auto allocFuncOp = module.lookupSymbol(name); if (!allocFuncOp) { for (Value param : params) paramTypes.push_back(param.getType().cast()); auto allocFuncType = LLVM::LLVMType::getFunctionTy(getVoidPtrType(), paramTypes, /*isVarArg=*/false); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); allocFuncOp = rewriter.create(rewriter.getUnknownLoc(), name, allocFuncType); } auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp); auto allocatedPtr = rewriter .create(loc, getVoidPtrType(), allocFuncSymbol, params) .getResult(0); return rewriter.create(loc, ptrType, allocatedPtr); } /// Allocates the underlying buffer. Returns the allocated pointer and the /// aligned pointer. virtual std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const = 0; private: static MemRefType getMemRefResultType(Operation *op) { return op->getResult(0).getType().cast(); } LogicalResult match(Operation *op) const override { MemRefType memRefType = getMemRefResultType(op); return success(isSupportedMemRefType(memRefType)); } // An `alloc` is converted into a definition of a memref descriptor value and // a call to `malloc` to allocate the underlying data buffer. The memref // descriptor is of the LLVM structure type where: // 1. the first element is a pointer to the allocated (typed) data buffer, // 2. the second element is a pointer to the (typed) payload, aligned to the // specified alignment, // 3. the remaining elements serve to store all the sizes and strides of the // memref using LLVM-converted `index` type. // // Alignment is performed by allocating `alignment` more bytes than // requested and shifting the aligned pointer relative to the allocated // memory. Note: `alignment - ` would actually be // sufficient. If alignment is unspecified, the two pointers are equal. // An `alloca` is converted into a definition of a memref descriptor value and // an llvm.alloca to allocate the underlying data buffer. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { MemRefType memRefType = getMemRefResultType(op); auto loc = op->getLoc(); // Get actual sizes of the memref as values: static sizes are constant // values and dynamic sizes are passed to 'alloc' as operands. In case of // zero-dimensional memref, assume a scalar (size 1). SmallVector sizes; SmallVector strides; Value sizeBytes; this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes, strides, sizeBytes); // Allocate the underlying buffer. Value allocatedPtr; Value alignedPtr; std::tie(allocatedPtr, alignedPtr) = this->allocateBuffer(rewriter, loc, sizeBytes, op); // Create the MemRef descriptor. auto memRefDescriptor = this->createMemRefDescriptor( loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); // Return the final value of the descriptor. rewriter.replaceOp(op, {memRefDescriptor}); } }; struct AllocOpLowering : public AllocLikeOpLowering { AllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {} std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { // Heap allocations. AllocOp allocOp = cast(op); MemRefType memRefType = allocOp.getType(); Value alignment; if (auto alignmentAttr = allocOp.alignment()) { alignment = createIndexConstant(rewriter, loc, *alignmentAttr); } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { // In the case where no alignment is specified, we may want to override // `malloc's` behavior. `malloc` typically aligns at the size of the // biggest scalar on a target HW. For non-scalars, use the natural // alignment of the LLVM type given by the LLVM DataLayout. alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); } if (alignment) { // Adjust the allocation size to consider alignment. sizeBytes = rewriter.create(loc, sizeBytes, alignment); } // Allocate the underlying buffer and store a pointer to it in the MemRef // descriptor. Type elementPtrType = this->getElementPtrType(memRefType); Value allocatedPtr = createAllocCall(loc, "malloc", elementPtrType, {sizeBytes}, allocOp.getParentOfType(), rewriter); Value alignedPtr = allocatedPtr; if (alignment) { auto intPtrType = getIntPtrType(memRefType.getMemorySpace()); // Compute the aligned type pointer. Value allocatedInt = rewriter.create(loc, intPtrType, allocatedPtr); Value alignmentInt = createAligned(rewriter, loc, allocatedInt, alignment); alignedPtr = rewriter.create(loc, elementPtrType, alignmentInt); } return std::make_tuple(allocatedPtr, alignedPtr); } }; struct AlignedAllocOpLowering : public AllocLikeOpLowering { AlignedAllocOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(AllocOp::getOperationName(), converter) {} /// Returns the memref's element size in bytes. // TODO: there are other places where this is used. Expose publicly? static unsigned getMemRefEltSizeInBytes(MemRefType memRefType) { auto elementType = memRefType.getElementType(); unsigned sizeInBits; if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { auto vectorType = elementType.cast(); sizeInBits = vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); } return llvm::divideCeil(sizeInBits, 8); } /// Returns true if the memref size in bytes is known to be a multiple of /// factor. static bool isMemRefSizeMultipleOf(MemRefType type, uint64_t factor) { uint64_t sizeDivisor = getMemRefEltSizeInBytes(type); for (unsigned i = 0, e = type.getRank(); i < e; i++) { if (type.isDynamic(type.getDimSize(i))) continue; sizeDivisor = sizeDivisor * type.getDimSize(i); } return sizeDivisor % factor == 0; } /// Returns the alignment to be used for the allocation call itself. /// aligned_alloc requires the allocation size to be a power of two, and the /// allocation size to be a multiple of alignment, int64_t getAllocationAlignment(AllocOp allocOp) const { if (Optional alignment = allocOp.alignment()) return *alignment; // Whenever we don't have alignment set, we will use an alignment // consistent with the element type; since the allocation size has to be a // power of two, we will bump to the next power of two if it already isn't. auto eltSizeBytes = getMemRefEltSizeInBytes(allocOp.getType()); return std::max(kMinAlignedAllocAlignment, llvm::PowerOf2Ceil(eltSizeBytes)); } std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { // Heap allocations. AllocOp allocOp = cast(op); MemRefType memRefType = allocOp.getType(); int64_t alignment = getAllocationAlignment(allocOp); Value allocAlignment = createIndexConstant(rewriter, loc, alignment); // aligned_alloc requires size to be a multiple of alignment; we will pad // the size to the next multiple if necessary. if (!isMemRefSizeMultipleOf(memRefType, alignment)) sizeBytes = createAligned(rewriter, loc, sizeBytes, allocAlignment); Type elementPtrType = this->getElementPtrType(memRefType); Value allocatedPtr = createAllocCall( loc, "aligned_alloc", elementPtrType, {allocAlignment, sizeBytes}, allocOp.getParentOfType(), rewriter); return std::make_tuple(allocatedPtr, allocatedPtr); } /// The minimum alignment to use with aligned_alloc (has to be a power of 2). static constexpr uint64_t kMinAlignedAllocAlignment = 16UL; }; // Out of line definition, required till C++17. constexpr uint64_t AlignedAllocOpLowering::kMinAlignedAllocAlignment; struct AllocaOpLowering : public AllocLikeOpLowering { AllocaOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(AllocaOp::getOperationName(), converter) {} /// Allocates the underlying buffer using the right call. `allocatedBytePtr` /// is set to null for stack allocations. `accessAlignment` is set if /// alignment is needed post allocation (for eg. in conjunction with malloc). std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { // With alloca, one gets a pointer to the element type right away. // For stack allocations. auto allocaOp = cast(op); auto elementPtrType = this->getElementPtrType(allocaOp.getType()); auto allocatedElementPtr = rewriter.create( loc, elementPtrType, sizeBytes, allocaOp.alignment() ? *allocaOp.alignment() : 0); return std::make_tuple(allocatedElementPtr, allocatedElementPtr); } }; /// Copies the shaped descriptor part to (if `toDynamic` is set) or from /// (otherwise) the dynamically allocated memory for any operands that were /// unranked descriptors originally. static LogicalResult copyUnrankedDescriptors(OpBuilder &builder, Location loc, LLVMTypeConverter &typeConverter, TypeRange origTypes, SmallVectorImpl &operands, bool toDynamic) { assert(origTypes.size() == operands.size() && "expected as may original types as operands"); // Find operands of unranked memref type and store them. SmallVector unrankedMemrefs; for (unsigned i = 0, e = operands.size(); i < e; ++i) if (origTypes[i].isa()) unrankedMemrefs.emplace_back(operands[i]); if (unrankedMemrefs.empty()) return success(); // Compute allocation sizes. SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(builder, loc, typeConverter, unrankedMemrefs, sizes); // Get frequently used types. MLIRContext *context = builder.getContext(); auto voidType = LLVM::LLVMType::getVoidTy(context); auto voidPtrType = LLVM::LLVMType::getInt8PtrTy(context); auto i1Type = LLVM::LLVMType::getInt1Ty(context); LLVM::LLVMType indexType = typeConverter.getIndexType(); // Find the malloc and free, or declare them if necessary. auto module = builder.getInsertionPoint()->getParentOfType(); auto mallocFunc = module.lookupSymbol("malloc"); if (!mallocFunc && toDynamic) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(module.getBody()); mallocFunc = builder.create( builder.getUnknownLoc(), "malloc", LLVM::LLVMType::getFunctionTy( voidPtrType, llvm::makeArrayRef(indexType), /*isVarArg=*/false)); } auto freeFunc = module.lookupSymbol("free"); if (!freeFunc && !toDynamic) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(module.getBody()); freeFunc = builder.create( builder.getUnknownLoc(), "free", LLVM::LLVMType::getFunctionTy(voidType, llvm::makeArrayRef(voidPtrType), /*isVarArg=*/false)); } // Initialize shared constants. Value zero = builder.create(loc, i1Type, builder.getBoolAttr(false)); unsigned unrankedMemrefPos = 0; for (unsigned i = 0, e = operands.size(); i < e; ++i) { Type type = origTypes[i]; if (!type.isa()) continue; Value allocationSize = sizes[unrankedMemrefPos++]; UnrankedMemRefDescriptor desc(operands[i]); // Allocate memory, copy, and free the source if necessary. Value memory = toDynamic ? builder.create(loc, mallocFunc, allocationSize) .getResult(0) : builder.create(loc, voidPtrType, allocationSize, /*alignment=*/0); Value source = desc.memRefDescPtr(builder, loc); builder.create(loc, memory, source, allocationSize, zero); if (!toDynamic) builder.create(loc, freeFunc, source); // Create a new descriptor. The same descriptor can be returned multiple // times, attempting to modify its pointer can lead to memory leaks // (allocated twice and overwritten) or double frees (the caller does not // know if the descriptor points to the same memory). Type descriptorType = typeConverter.convertType(type); if (!descriptorType) return failure(); auto updatedDesc = UnrankedMemRefDescriptor::undef(builder, loc, descriptorType); Value rank = desc.rank(builder, loc); updatedDesc.setRank(builder, loc, rank); updatedDesc.setMemRefDescPtr(builder, loc, memory); operands[i] = updatedDesc; } return success(); } // A CallOp automatically promotes MemRefType to a sequence of alloca/store and // passes the pointer to the MemRef across function boundaries. template struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = CallOpInterfaceLowering; using Base = ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { typename CallOpType::Adaptor transformed(operands); auto callOp = cast(op); // Pack the result types into a struct. Type packedResult; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); if (numResults != 0) { if (!(packedResult = this->typeConverter.packFunctionResults(resultTypes))) return failure(); } auto promoted = this->typeConverter.promoteOperands( op->getLoc(), /*opOperands=*/op->getOperands(), operands, rewriter); auto newOp = rewriter.create(op->getLoc(), packedResult, promoted, op->getAttrs()); SmallVector results; if (numResults < 2) { // If < 2 results, packing did not do anything and we can just return. results.append(newOp.result_begin(), newOp.result_end()); } else { // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. results.reserve(numResults); for (unsigned i = 0; i < numResults; ++i) { auto type = this->typeConverter.convertType(op->getResult(i).getType()); results.push_back(rewriter.create( op->getLoc(), type, newOp.getOperation()->getResult(0), rewriter.getI64ArrayAttr(i))); } } if (this->typeConverter.getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, promote memref results to // descriptors. assert(results.size() == resultTypes.size() && "The number of arguments and types doesn't match"); this->typeConverter.promoteBarePtrsToDescriptors(rewriter, op->getLoc(), resultTypes, results); } else if (failed(copyUnrankedDescriptors(rewriter, op->getLoc(), this->typeConverter, resultTypes, results, /*toDynamic=*/false))) { return failure(); } rewriter.replaceOp(op, results); return success(); } }; struct CallOpLowering : public CallOpInterfaceLowering { using Super::Super; }; struct CallIndirectOpLowering : public CallOpInterfaceLowering { using Super::Super; }; // A `dealloc` is converted into a call to `free` on the underlying data buffer. // The memref descriptor being an SSA value, there is no need to clean it up // in any way. struct DeallocOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; explicit DeallocOpLowering(LLVMTypeConverter &converter) : ConvertOpToLLVMPattern(converter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1 && "dealloc takes one operand"); DeallocOp::Adaptor transformed(operands); // Insert the `free` declaration if it is not already present. auto freeFunc = op->getParentOfType().lookupSymbol("free"); if (!freeFunc) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart( op->getParentOfType().getBody()); freeFunc = rewriter.create( rewriter.getUnknownLoc(), "free", LLVM::LLVMType::getFunctionTy(getVoidType(), getVoidPtrType(), /*isVarArg=*/false)); } MemRefDescriptor memref(transformed.memref()); Value casted = rewriter.create( op->getLoc(), getVoidPtrType(), memref.allocatedPtr(rewriter, op->getLoc())); rewriter.replaceOpWithNewOp( op, TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted); return success(); } }; /// Returns the LLVM type of the global variable given the memref type `type`. static LLVM::LLVMType convertGlobalMemrefTypeToLLVM(MemRefType type, LLVMTypeConverter &typeConverter) { // LLVM type for a global memref will be a multi-dimension array. For // declarations or uninitialized global memrefs, we can potentially flatten // this to a 1D array. However, for global_memref's with an initial value, // we do not intend to flatten the ElementsAttribute when going from std -> // LLVM dialect, so the LLVM type needs to me a multi-dimension array. LLVM::LLVMType elementType = unwrap(typeConverter.convertType(type.getElementType())); LLVM::LLVMType arrayTy = elementType; // Shape has the outermost dim at index 0, so need to walk it backwards for (int64_t dim : llvm::reverse(type.getShape())) arrayTy = LLVM::LLVMType::getArrayTy(arrayTy, dim); return arrayTy; } /// GlobalMemrefOp is lowered to a LLVM Global Variable. struct GlobalMemrefOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto global = cast(op); MemRefType type = global.type().cast(); if (!isSupportedMemRefType(type)) return failure(); LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter); LLVM::Linkage linkage = global.isPublic() ? LLVM::Linkage::External : LLVM::Linkage::Private; Attribute initialValue = nullptr; if (!global.isExternal() && !global.isUninitialized()) { auto elementsAttr = global.initial_value()->cast(); initialValue = elementsAttr; // For scalar memrefs, the global variable created is of the element type, // so unpack the elements attribute to extract the value. if (type.getRank() == 0) initialValue = elementsAttr.getValue({}); } rewriter.replaceOpWithNewOp( op, arrayTy, global.constant(), linkage, global.sym_name(), initialValue, type.getMemorySpace()); return success(); } }; /// GetGlobalMemrefOp is lowered into a Memref descriptor with the pointer to /// the first element stashed into the descriptor. This reuses /// `AllocLikeOpLowering` to reuse the Memref descriptor construction. struct GetGlobalMemrefOpLowering : public AllocLikeOpLowering { GetGlobalMemrefOpLowering(LLVMTypeConverter &converter) : AllocLikeOpLowering(GetGlobalMemrefOp::getOperationName(), converter) {} /// Buffer "allocation" for get_global_memref op is getting the address of /// the global variable referenced. std::tuple allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op) const override { auto getGlobalOp = cast(op); MemRefType type = getGlobalOp.result().getType().cast(); unsigned memSpace = type.getMemorySpace(); LLVM::LLVMType arrayTy = convertGlobalMemrefTypeToLLVM(type, typeConverter); auto addressOf = rewriter.create( loc, arrayTy.getPointerTo(memSpace), getGlobalOp.name()); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. LLVM::LLVMType elementType = unwrap(typeConverter.convertType(type.getElementType())); LLVM::LLVMType elementPtrType = elementType.getPointerTo(memSpace); SmallVector operands = {addressOf}; operands.insert(operands.end(), type.getRank() + 1, createIndexConstant(rewriter, loc, 0)); auto gep = rewriter.create(loc, elementPtrType, operands); // We do not expect the memref obtained using `get_global_memref` to be // ever deallocated. Set the allocated pointer to be known bad value to // help debug if that ever happens. auto intPtrType = getIntPtrType(memSpace); Value deadBeefConst = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef); auto deadBeefPtr = rewriter.create(loc, elementPtrType, deadBeefConst); // Both allocated and aligned pointers are same. We could potentially stash // a nullptr for the allocated pointer since we do not expect any dealloc. return std::make_tuple(deadBeefPtr, gep); } }; // A `rsqrt` is converted into `1 / sqrt`. struct RsqrtOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { RsqrtOp::Adaptor transformed(operands); auto operandType = transformed.operand().getType().dyn_cast(); if (!operandType) return failure(); auto loc = op->getLoc(); auto resultType = *op->result_type_begin(); auto floatType = getElementTypeOrSelf(resultType).cast(); auto floatOne = rewriter.getFloatAttr(floatType, 1.0); if (!operandType.isArrayTy()) { LLVM::ConstantOp one; if (operandType.isVectorTy()) { one = rewriter.create( loc, operandType, SplatElementsAttr::get(resultType.cast(), floatOne)); } else { one = rewriter.create(loc, operandType, floatOne); } auto sqrt = rewriter.create(loc, transformed.operand()); rewriter.replaceOpWithNewOp(op, operandType, one, sqrt); return success(); } auto vectorType = resultType.dyn_cast(); if (!vectorType) return failure(); return handleMultidimensionalVectors( op, operands, typeConverter, [&](LLVM::LLVMType llvmVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( mlir::VectorType::get({llvmVectorTy.getVectorNumElements()}, floatType), floatOne); auto one = rewriter.create(loc, llvmVectorTy, splatAttr); auto sqrt = rewriter.create(loc, llvmVectorTy, operands[0]); return rewriter.create(loc, llvmVectorTy, one, sqrt); }, rewriter); } }; struct MemRefCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult match(Operation *op) const override { auto memRefCastOp = cast(op); Type srcType = memRefCastOp.getOperand().getType(); Type dstType = memRefCastOp.getType(); // MemRefCastOp reduce to bitcast in the ranked MemRef case and can be used // for type erasure. For now they must preserve underlying element type and // require source and result type to have the same rank. Therefore, perform // a sanity check that the underlying structs are the same. Once op // semantics are relaxed we can revisit. if (srcType.isa() && dstType.isa()) return success(typeConverter.convertType(srcType) == typeConverter.convertType(dstType)); // At least one of the operands is unranked type assert(srcType.isa() || dstType.isa()); // Unranked to unranked cast is disallowed return !(srcType.isa() && dstType.isa()) ? success() : failure(); } void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast(op); MemRefCastOp::Adaptor transformed(operands); auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); auto targetStructType = typeConverter.convertType(memRefCastOp.getType()); auto loc = op->getLoc(); // For ranked/ranked case, just keep the original descriptor. if (srcType.isa() && dstType.isa()) return rewriter.replaceOp(op, {transformed.source()}); if (srcType.isa() && dstType.isa()) { // Casting ranked to unranked memref type // Set the rank in the destination from the memref type // Allocate space on the stack and copy the src memref descriptor // Set the ptr in the destination to the stack space auto srcMemRefType = srcType.cast(); int64_t rank = srcMemRefType.getRank(); // ptr = AllocaOp sizeof(MemRefDescriptor) auto ptr = typeConverter.promoteOneMemRefDescriptor( loc, transformed.source(), rewriter); // voidptr = BitCastOp srcType* to void* auto voidPtr = rewriter.create(loc, getVoidPtrType(), ptr) .getResult(); // rank = ConstantOp srcRank auto rankVal = rewriter.create( loc, typeConverter.convertType(rewriter.getIntegerType(64)), rewriter.getI64IntegerAttr(rank)); // undef = UndefOp UnrankedMemRefDescriptor memRefDesc = UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType); // d1 = InsertValueOp undef, rank, 0 memRefDesc.setRank(rewriter, loc, rankVal); // d2 = InsertValueOp d1, voidptr, 1 memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr); rewriter.replaceOp(op, (Value)memRefDesc); } else if (srcType.isa() && dstType.isa()) { // Casting from unranked type to ranked. // The operation is assumed to be doing a correct cast. If the destination // type mismatches the unranked the type, it is undefined behavior. UnrankedMemRefDescriptor memRefDesc(transformed.source()); // ptr = ExtractValueOp src, 1 auto ptr = memRefDesc.memRefDescPtr(rewriter, loc); // castPtr = BitCastOp i8* to structTy* auto castPtr = rewriter .create( loc, targetStructType.cast().getPointerTo(), ptr) .getResult(); // struct = LoadOp castPtr auto loadOp = rewriter.create(loc, castPtr); rewriter.replaceOp(op, loadOp.getResult()); } else { llvm_unreachable("Unsupported unranked memref to unranked memref cast"); } } }; /// Extracts allocated, aligned pointers and offset from a ranked or unranked /// memref type. In unranked case, the fields are extracted from the underlying /// ranked descriptor. static void extractPointersAndOffset(Location loc, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Value originalOperand, Value convertedOperand, Value *allocatedPtr, Value *alignedPtr, Value *offset = nullptr) { Type operandType = originalOperand.getType(); if (operandType.isa()) { MemRefDescriptor desc(convertedOperand); *allocatedPtr = desc.allocatedPtr(rewriter, loc); *alignedPtr = desc.alignedPtr(rewriter, loc); if (offset != nullptr) *offset = desc.offset(rewriter, loc); return; } unsigned memorySpace = operandType.cast().getMemorySpace(); Type elementType = operandType.cast().getElementType(); LLVM::LLVMType llvmElementType = unwrap(typeConverter.convertType(elementType)); LLVM::LLVMType elementPtrPtrType = llvmElementType.getPointerTo(memorySpace).getPointerTo(); // Extract pointer to the underlying ranked memref descriptor and cast it to // ElemType**. UnrankedMemRefDescriptor unrankedDesc(convertedOperand); Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); *allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr( rewriter, loc, underlyingDescPtr, elementPtrPtrType); *alignedPtr = UnrankedMemRefDescriptor::alignedPtr( rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); if (offset != nullptr) { *offset = UnrankedMemRefDescriptor::offset( rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); } } struct MemRefReinterpretCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto castOp = cast(op); MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary()); Type srcType = castOp.source().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, castOp, adaptor, &descriptor))) return failure(); rewriter.replaceOp(op, {descriptor}); return success(); } private: LogicalResult convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, Type srcType, MemRefReinterpretCastOp castOp, MemRefReinterpretCastOp::Adaptor adaptor, Value *descriptor) const { MemRefType targetMemRefType = castOp.getResult().getType().cast(); auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return failure(); // Create descriptor. Location loc = castOp.getLoc(); auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); // Set allocated and aligned pointers. Value allocatedPtr, alignedPtr; extractPointersAndOffset(loc, rewriter, typeConverter, castOp.source(), adaptor.source(), &allocatedPtr, &alignedPtr); desc.setAllocatedPtr(rewriter, loc, allocatedPtr); desc.setAlignedPtr(rewriter, loc, alignedPtr); // Set offset. if (castOp.isDynamicOffset(0)) desc.setOffset(rewriter, loc, adaptor.offsets()[0]); else desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); // Set sizes and strides. unsigned dynSizeId = 0; unsigned dynStrideId = 0; for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { if (castOp.isDynamicSize(i)) desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); else desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); if (castOp.isDynamicStride(i)) desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); else desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); } *descriptor = desc; return success(); } }; struct MemRefReshapeOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto reshapeOp = cast(op); MemRefReshapeOp::Adaptor adaptor(operands, op->getAttrDictionary()); Type srcType = reshapeOp.source().getType(); Value descriptor; if (failed(convertSourceMemRefToDescriptor(rewriter, srcType, reshapeOp, adaptor, &descriptor))) return failure(); rewriter.replaceOp(op, {descriptor}); return success(); } private: LogicalResult convertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, Type srcType, MemRefReshapeOp reshapeOp, MemRefReshapeOp::Adaptor adaptor, Value *descriptor) const { // Conversion for statically-known shape args is performed via // `memref_reinterpret_cast`. auto shapeMemRefType = reshapeOp.shape().getType().cast(); if (shapeMemRefType.hasStaticShape()) return failure(); // The shape is a rank-1 tensor with unknown length. Location loc = reshapeOp.getLoc(); MemRefDescriptor shapeDesc(adaptor.shape()); Value resultRank = shapeDesc.size(rewriter, loc, 0); // Extract address space and element type. auto targetType = reshapeOp.getResult().getType().cast(); unsigned addressSpace = targetType.getMemorySpace(); Type elementType = targetType.getElementType(); // Create the unranked memref descriptor that holds the ranked one. The // inner descriptor is allocated on stack. auto targetDesc = UnrankedMemRefDescriptor::undef( rewriter, loc, unwrap(typeConverter.convertType(targetType))); targetDesc.setRank(rewriter, loc, resultRank); SmallVector sizes; UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter, targetDesc, sizes); Value underlyingDescPtr = rewriter.create( loc, getVoidPtrType(), sizes.front(), llvm::None); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); // Extract pointers and offset from the source memref. Value allocatedPtr, alignedPtr, offset; extractPointersAndOffset(loc, rewriter, typeConverter, reshapeOp.source(), adaptor.source(), &allocatedPtr, &alignedPtr, &offset); // Set pointers and offset. LLVM::LLVMType llvmElementType = unwrap(typeConverter.convertType(elementType)); LLVM::LLVMType elementPtrPtrType = llvmElementType.getPointerTo(addressSpace).getPointerTo(); UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr, elementPtrPtrType, allocatedPtr); UnrankedMemRefDescriptor::setAlignedPtr(rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType, alignedPtr); UnrankedMemRefDescriptor::setOffset(rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType, offset); // Use the offset pointer as base for further addressing. Copy over the new // shape and compute strides. For this, we create a loop from rank-1 to 0. Value targetSizesBase = UnrankedMemRefDescriptor::sizeBasePtr( rewriter, loc, typeConverter, underlyingDescPtr, elementPtrPtrType); Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( rewriter, loc, typeConverter, targetSizesBase, resultRank); Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); Value oneIndex = createIndexConstant(rewriter, loc, 1); Value resultRankMinusOne = rewriter.create(loc, resultRank, oneIndex); Block *initBlock = rewriter.getInsertionBlock(); LLVM::LLVMType indexType = typeConverter.getIndexType(); Block::iterator remainingOpsIt = std::next(rewriter.getInsertionPoint()); Block *condBlock = rewriter.createBlock(initBlock->getParent(), {}, {indexType, indexType}); // Iterate over the remaining ops in initBlock and move them to condBlock. BlockAndValueMapping map; for (auto it = remainingOpsIt, e = initBlock->end(); it != e; ++it) { rewriter.clone(*it, map); rewriter.eraseOp(&*it); } rewriter.setInsertionPointToEnd(initBlock); rewriter.create(loc, ValueRange({resultRankMinusOne, oneIndex}), condBlock); rewriter.setInsertionPointToStart(condBlock); Value indexArg = condBlock->getArgument(0); Value strideArg = condBlock->getArgument(1); Value zeroIndex = createIndexConstant(rewriter, loc, 0); Value pred = rewriter.create( loc, LLVM::LLVMType::getInt1Ty(rewriter.getContext()), LLVM::ICmpPredicate::sge, indexArg, zeroIndex); Block *bodyBlock = rewriter.splitBlock(condBlock, rewriter.getInsertionPoint()); rewriter.setInsertionPointToStart(bodyBlock); // Copy size from shape to descriptor. LLVM::LLVMType llvmIndexPtrType = indexType.getPointerTo(); Value sizeLoadGep = rewriter.create( loc, llvmIndexPtrType, shapeOperandPtr, ValueRange{indexArg}); Value size = rewriter.create(loc, sizeLoadGep); UnrankedMemRefDescriptor::setSize(rewriter, loc, typeConverter, targetSizesBase, indexArg, size); // Write stride value and compute next one. UnrankedMemRefDescriptor::setStride(rewriter, loc, typeConverter, targetStridesBase, indexArg, strideArg); Value nextStride = rewriter.create(loc, strideArg, size); // Decrement loop counter and branch back. Value decrement = rewriter.create(loc, indexArg, oneIndex); rewriter.create(loc, ValueRange({decrement, nextStride}), condBlock); Block *remainder = rewriter.splitBlock(bodyBlock, rewriter.getInsertionPoint()); // Hook up the cond exit to the remainder. rewriter.setInsertionPointToEnd(condBlock); rewriter.create(loc, pred, bodyBlock, llvm::None, remainder, llvm::None); // Reset position to beginning of new remainder block. rewriter.setInsertionPointToStart(remainder); *descriptor = targetDesc; return success(); } }; struct DialectCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto castOp = cast(op); LLVM::DialectCastOp::Adaptor transformed(operands); if (transformed.in().getType() != typeConverter.convertType(castOp.getType())) { return failure(); } rewriter.replaceOp(op, transformed.in()); return success(); } }; // A `dim` is converted to a constant for static sizes and to an access to the // size stored in the memref descriptor for dynamic sizes. struct DimOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); Type operandType = dimOp.memrefOrTensor().getType(); if (operandType.isa()) { rewriter.replaceOp(op, {extractSizeOfUnrankedMemRef(operandType, dimOp, operands, rewriter)}); return success(); } if (operandType.isa()) { rewriter.replaceOp(op, {extractSizeOfRankedMemRef(operandType, dimOp, operands, rewriter)}); return success(); } return failure(); } private: Value extractSizeOfUnrankedMemRef(Type operandType, DimOp dimOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); DimOp::Adaptor transformed(operands); auto unrankedMemRefType = operandType.cast(); auto scalarMemRefType = MemRefType::get({}, unrankedMemRefType.getElementType()); unsigned addressSpace = unrankedMemRefType.getMemorySpace(); // Extract pointer to the underlying ranked descriptor and bitcast it to a // memref descriptor pointer to minimize the number of GEP // operations. UnrankedMemRefDescriptor unrankedDesc(transformed.memrefOrTensor()); Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc); Value scalarMemRefDescPtr = rewriter.create( loc, typeConverter.convertType(scalarMemRefType) .cast() .getPointerTo(addressSpace), underlyingRankedDesc); // Get pointer to offset field of memref descriptor. Type indexPtrTy = typeConverter.getIndexType().getPointerTo(addressSpace); Value two = rewriter.create( loc, typeConverter.convertType(rewriter.getI32Type()), rewriter.getI32IntegerAttr(2)); Value offsetPtr = rewriter.create( loc, indexPtrTy, scalarMemRefDescPtr, ValueRange({createIndexConstant(rewriter, loc, 0), two})); // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. Value idxPlusOne = rewriter.create( loc, createIndexConstant(rewriter, loc, 1), transformed.index()); Value sizePtr = rewriter.create(loc, indexPtrTy, offsetPtr, ValueRange({idxPlusOne})); return rewriter.create(loc, sizePtr); } Value extractSizeOfRankedMemRef(Type operandType, DimOp dimOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { Location loc = dimOp.getLoc(); DimOp::Adaptor transformed(operands); // Take advantage if index is constant. MemRefType memRefType = operandType.cast(); if (Optional index = dimOp.getConstantIndex()) { int64_t i = index.getValue(); if (memRefType.isDynamicDim(i)) { // extract dynamic size from the memref descriptor. MemRefDescriptor descriptor(transformed.memrefOrTensor()); return descriptor.size(rewriter, loc, i); } // Use constant for static size. int64_t dimSize = memRefType.getDimSize(i); return createIndexConstant(rewriter, loc, dimSize); } Value index = dimOp.index(); int64_t rank = memRefType.getRank(); MemRefDescriptor memrefDescriptor(transformed.memrefOrTensor()); return memrefDescriptor.size(rewriter, loc, index, rank); } }; struct RankOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); Type operandType = cast(op).memrefOrTensor().getType(); if (auto unrankedMemRefType = operandType.dyn_cast()) { UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } if (auto rankedMemRefType = operandType.dyn_cast()) { rewriter.replaceOp( op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); return success(); } return failure(); } }; // Common base for load and store operations on MemRefs. Restricts the match // to supported MemRef types. Provides functionality to emit code accessing a // specific element of the underlying data buffer. template struct LoadStoreOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using ConvertOpToLLVMPattern::isSupportedMemRefType; using Base = LoadStoreOpLowering; LogicalResult match(Operation *op) const override { MemRefType type = cast(op).getMemRefType(); return isSupportedMemRefType(type) ? success() : failure(); } }; // Load operation is lowered to obtaining a pointer to the indexed element // and loading it. struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loadOp = cast(op); LoadOp::Adaptor transformed(operands); auto type = loadOp.getMemRefType(); - Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter); + Value dataPtr = + getStridedElementPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(op, dataPtr); return success(); } }; // Store operation is lowered to obtaining a pointer to the indexed element, // and storing the given value to it. struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); StoreOp::Adaptor transformed(operands); - Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter); + Value dataPtr = + getStridedElementPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); return success(); } }; // The prefetch operation is lowered in a way similar to the load operation // except that the llvm.prefetch operation is used for replacement. struct PrefetchOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto prefetchOp = cast(op); PrefetchOp::Adaptor transformed(operands); auto type = prefetchOp.getMemRefType(); - Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), - transformed.indices(), rewriter); + Value dataPtr = + getStridedElementPtr(op->getLoc(), type, transformed.memref(), + transformed.indices(), rewriter); // Replace with llvm.prefetch. auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); auto isWrite = rewriter.create( op->getLoc(), llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isWrite())); auto localityHint = rewriter.create( op->getLoc(), llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.localityHint())); auto isData = rewriter.create( op->getLoc(), llvmI32Type, rewriter.getI32IntegerAttr(prefetchOp.isDataCache())); rewriter.replaceOpWithNewOp(op, dataPtr, isWrite, localityHint, isData); return success(); } }; // The lowering of index_cast becomes an integer conversion since index becomes // an integer. If the bit width of the source and target integer types is the // same, just erase the cast. If the target type is wider, sign-extend the // value, otherwise truncate it. struct IndexCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { IndexCastOpAdaptor transformed(operands); auto indexCastOp = cast(op); auto targetType = this->typeConverter.convertType(indexCastOp.getResult().getType()) .cast(); auto sourceType = transformed.in().getType().cast(); unsigned targetBits = targetType.getIntegerBitWidth(); unsigned sourceBits = sourceType.getIntegerBitWidth(); if (targetBits == sourceBits) rewriter.replaceOp(op, transformed.in()); else if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); else rewriter.replaceOpWithNewOp(op, targetType, transformed.in()); return success(); } }; // Convert std.cmp predicate into the LLVM dialect CmpPredicate. The two // enums share the numerical values so just cast. template static LLVMPredType convertCmpPredicate(StdPredType pred) { return static_cast(pred); } struct CmpIOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast(op); CmpIOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(cmpiOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpiOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return success(); } }; struct CmpFOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpfOp = cast(op); CmpFOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(cmpfOp.getResult().getType()), rewriter.getI64IntegerAttr(static_cast( convertCmpPredicate(cmpfOp.getPredicate()))), transformed.lhs(), transformed.rhs()); return success(); } }; struct SIToFPLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct UIToFPLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPExtLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPToSILowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPToUILowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct FPTruncLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct SignExtendIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct TruncateIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; struct ZeroExtendIOpLowering : public OneToOneConvertToLLVMPattern { using Super::Super; }; // Base class for LLVM IR lowering terminator operations with successors. template struct OneToOneLLVMTerminatorLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; using Super = OneToOneLLVMTerminatorLowering; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, operands, op->getSuccessors(), op->getAttrs()); return success(); } }; // Special lowering pattern for `ReturnOps`. Unlike all other operations, // `ReturnOp` interacts with the function signature and must have as many // operands as the function has return values. Because in LLVM IR, functions // can only return 0 or 1 value, we pack multiple values into a structure type. // Emit `UndefOp` followed by `InsertValueOp`s to create such structure if // necessary before returning it struct ReturnOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); unsigned numArguments = op->getNumOperands(); SmallVector updatedOperands; if (typeConverter.getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. for (auto it : llvm::zip(op->getOperands(), operands)) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); if (oldTy.isa()) { MemRefDescriptor memrefDesc(newOperand); newOperand = memrefDesc.alignedPtr(rewriter, loc); } else if (oldTy.isa()) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); } updatedOperands.push_back(newOperand); } } else { updatedOperands = llvm::to_vector<4>(operands); copyUnrankedDescriptors(rewriter, loc, typeConverter, op->getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); } // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), op->getAttrs()); return success(); } if (numArguments == 1) { rewriter.replaceOpWithNewOp( op, TypeRange(), updatedOperands, op->getAttrs()); return success(); } // Otherwise, we need to pack the arguments into an LLVM struct type before // returning. auto packedType = typeConverter.packFunctionResults( llvm::to_vector<4>(op->getOperandTypes())); Value packed = rewriter.create(loc, packedType); for (unsigned i = 0; i < numArguments; ++i) { packed = rewriter.create( loc, packedType, packed, updatedOperands[i], rewriter.getI64ArrayAttr(i)); } rewriter.replaceOpWithNewOp(op, TypeRange(), packed, op->getAttrs()); return success(); } }; // FIXME: this should be tablegen'ed as well. struct BranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; struct CondBranchOpLowering : public OneToOneLLVMTerminatorLowering { using Super::Super; }; // The Splat operation is lowered to an insertelement + a shufflevector // operation. Splat to only 1-d vector result types are lowered. struct SplatOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() != 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto vectorType = typeConverter.convertType(splatOp.getType()); Value undef = rewriter.create(op->getLoc(), vectorType); auto zero = rewriter.create( op->getLoc(), typeConverter.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); auto v = rewriter.create( op->getLoc(), vectorType, undef, splatOp.getOperand(), zero); int64_t width = splatOp.getType().cast().getDimSize(0); SmallVector zeroValues(width, 0); // Shuffle the value across the desired number of elements. ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); rewriter.replaceOpWithNewOp(op, v, undef, zeroAttrs); return success(); } }; // The Splat operation is lowered to an insertelement + a shufflevector // operation. Splat to only 2+-d vector result types are lowered by the // SplatNdOpLowering, the 1-d case is handled by SplatOpLowering. struct SplatNdOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); SplatOp::Adaptor adaptor(operands); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() == 1) return failure(); // First insert it into an undef vector so we can shuffle it. auto loc = op->getLoc(); auto vectorTypeInfo = extractNDVectorTypeInfo(resultType, typeConverter); auto llvmArrayTy = vectorTypeInfo.llvmArrayTy; auto llvmVectorTy = vectorTypeInfo.llvmVectorTy; if (!llvmArrayTy || !llvmVectorTy) return failure(); // Construct returned value. Value desc = rewriter.create(loc, llvmArrayTy); // Construct a 1-D vector with the splatted value that we insert in all the // places within the returned descriptor. Value vdesc = rewriter.create(loc, llvmVectorTy); auto zero = rewriter.create( loc, typeConverter.convertType(rewriter.getIntegerType(32)), rewriter.getZeroAttr(rewriter.getIntegerType(32))); Value v = rewriter.create(loc, llvmVectorTy, vdesc, adaptor.input(), zero); // Shuffle the value across the desired number of elements. int64_t width = resultType.getDimSize(resultType.getRank() - 1); SmallVector zeroValues(width, 0); ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues); v = rewriter.create(loc, v, v, zeroAttrs); // Iterate of linear index, convert to coords space and insert splatted 1-D // vector in each position. nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayAttr position) { desc = rewriter.create(loc, llvmArrayTy, desc, v, position); }); rewriter.replaceOp(op, desc); return success(); } }; /// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. static SmallVector extractFromI64ArrayAttr(Attribute attr) { return llvm::to_vector<4>( llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); } /// Conversion pattern that transforms a subview op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The subview op is replaced by the descriptor. struct SubViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto subViewOp = cast(op); auto sourceMemRefType = subViewOp.source().getType().cast(); auto sourceElementTy = typeConverter.convertType(sourceMemRefType.getElementType()) .dyn_cast_or_null(); auto viewMemRefType = subViewOp.getType(); auto inferredType = SubViewOp::inferResultType( subViewOp.getSourceType(), extractFromI64ArrayAttr(subViewOp.static_offsets()), extractFromI64ArrayAttr(subViewOp.static_sizes()), extractFromI64ArrayAttr(subViewOp.static_strides())) .cast(); auto targetElementTy = typeConverter.convertType(viewMemRefType.getElementType()) .dyn_cast(); auto targetDescTy = typeConverter.convertType(viewMemRefType) .dyn_cast_or_null(); if (!sourceElementTy || !targetDescTy) return failure(); // Extract the offset and strides from the type. int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(inferredType, strides, offset); if (failed(successStrides)) return failure(); // Create the descriptor. if (!operands.front().getType().isa()) return failure(); MemRefDescriptor sourceMemRef(operands.front()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Copy the buffer pointer from the old descriptor to the new one. Value extracted = sourceMemRef.allocatedPtr(rewriter, loc); Value bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()), extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Copy the buffer pointer from the old descriptor to the new one. extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()), extracted); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); auto shape = viewMemRefType.getShape(); auto inferredShape = inferredType.getShape(); size_t inferredShapeRank = inferredShape.size(); size_t resultShapeRank = shape.size(); SmallVector mask = computeRankReductionMask(inferredShape, shape).getValue(); // Extract strides needed to compute offset. SmallVector strideValues; strideValues.reserve(inferredShapeRank); for (unsigned i = 0; i < inferredShapeRank; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); // Offset. auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType()); if (!ShapedType::isDynamicStrideOrOffset(offset)) { targetMemRef.setConstantOffset(rewriter, loc, offset); } else { Value baseOffset = sourceMemRef.offset(rewriter, loc); for (unsigned i = 0; i < inferredShapeRank; ++i) { Value offset = subViewOp.isDynamicOffset(i) ? operands[subViewOp.getIndexOfDynamicOffset(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); Value mul = rewriter.create(loc, offset, strideValues[i]); baseOffset = rewriter.create(loc, baseOffset, mul); } targetMemRef.setOffset(rewriter, loc, baseOffset); } // Update sizes and strides. for (int i = inferredShapeRank - 1, j = resultShapeRank - 1; i >= 0 && j >= 0; --i) { if (!mask[i]) continue; Value size = subViewOp.isDynamicSize(i) ? operands[subViewOp.getIndexOfDynamicSize(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); targetMemRef.setSize(rewriter, loc, j, size); Value stride; if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { stride = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); } else { stride = subViewOp.isDynamicStride(i) ? operands[subViewOp.getIndexOfDynamicStride(i)] : rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i))); stride = rewriter.create(loc, stride, strideValues[i]); } targetMemRef.setStride(rewriter, loc, j, stride); j--; } rewriter.replaceOp(op, {targetMemRef}); return success(); } }; /// Conversion pattern that transforms a transpose op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. /// 3. Updates to the ViewDescriptor to introduce the data ptr, offset, size /// and stride. Size and stride are permutations of the original values. /// 4. A store of the resulting ViewDescriptor to the alloca'ed pointer. /// The transpose op is replaced by the alloca'ed pointer. class TransposeOpLowering : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); TransposeOpAdaptor adaptor(operands); MemRefDescriptor viewMemRef(adaptor.in()); auto transposeOp = cast(op); // No permutation, early exit. if (transposeOp.permutation().isIdentity()) return rewriter.replaceOp(op, {viewMemRef}), success(); auto targetMemRef = MemRefDescriptor::undef( rewriter, loc, typeConverter.convertType(transposeOp.getShapedType())); // Copy the base and aligned pointers from the old descriptor to the new // one. targetMemRef.setAllocatedPtr(rewriter, loc, viewMemRef.allocatedPtr(rewriter, loc)); targetMemRef.setAlignedPtr(rewriter, loc, viewMemRef.alignedPtr(rewriter, loc)); // Copy the offset pointer from the old descriptor to the new one. targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc)); // Iterate over the dimensions and apply size/stride permutation. for (auto en : llvm::enumerate(transposeOp.permutation().getResults())) { int sourcePos = en.index(); int targetPos = en.value().cast().getPosition(); targetMemRef.setSize(rewriter, loc, targetPos, viewMemRef.size(rewriter, loc, sourcePos)); targetMemRef.setStride(rewriter, loc, targetPos, viewMemRef.stride(rewriter, loc, sourcePos)); } rewriter.replaceOp(op, {targetMemRef}); return success(); } }; /// Conversion pattern that transforms an op into: /// 1. An `llvm.mlir.undef` operation to create a memref descriptor /// 2. Updates to the descriptor to introduce the data ptr, offset, size /// and stride. /// The view op is replaced by the descriptor. struct ViewOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. Value getSize(ConversionPatternRewriter &rewriter, Location loc, ArrayRef shape, ValueRange dynamicSizes, unsigned idx) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) return createIndexConstant(rewriter, loc, shape[idx]); // Count the number of dynamic dims in range [0, idx] unsigned nDynamic = llvm::count_if(shape.take_front(idx), [](int64_t v) { return ShapedType::isDynamic(v); }); return dynamicSizes[nDynamic]; } // Build and return the idx^th stride, either by returning the constant stride // or by computing the dynamic stride from the current `runningStride` and // `nextSize`. The caller should keep a running stride and update it with the // result returned by this function. Value getStride(ConversionPatternRewriter &rewriter, Location loc, ArrayRef strides, Value nextSize, Value runningStride, unsigned idx) const { assert(idx < strides.size()); if (strides[idx] != MemRefType::getDynamicStrideOrOffset()) return createIndexConstant(rewriter, loc, strides[idx]); if (nextSize) return runningStride ? rewriter.create(loc, runningStride, nextSize) : nextSize; assert(!runningStride); return createIndexConstant(rewriter, loc, 1); } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); ViewOpAdaptor adaptor(operands); auto viewMemRefType = viewOp.getType(); auto targetElementTy = typeConverter.convertType(viewMemRefType.getElementType()) .dyn_cast(); auto targetDescTy = typeConverter.convertType(viewMemRefType).dyn_cast(); if (!targetDescTy) return op->emitWarning("Target descriptor type not converted to LLVM"), failure(); int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) return op->emitWarning("cannot cast to non-strided shape"), failure(); assert(offset == 0 && "expected offset to be 0"); // Create the descriptor. MemRefDescriptor sourceMemRef(adaptor.source()); auto targetMemRef = MemRefDescriptor::undef(rewriter, loc, targetDescTy); // Field 1: Copy the allocated pointer, used for malloc/free. Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc); auto srcMemRefType = viewOp.source().getType().cast(); Value bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()), allocatedPtr); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); // Field 2: Copy the actual aligned pointer to payload. Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc); alignedPtr = rewriter.create(loc, alignedPtr.getType(), alignedPtr, adaptor.byte_shift()); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(srcMemRefType.getMemorySpace()), alignedPtr); targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); // Field 3: The offset in the resulting type must be 0. This is because of // the type change: an offset on srcType* may not be expressible as an // offset on dstType*. targetMemRef.setOffset(rewriter, loc, createIndexConstant(rewriter, loc, offset)); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) return rewriter.replaceOp(op, {targetMemRef}), success(); // Fields 4 and 5: Update sizes and strides. if (strides.back() != 1) return op->emitWarning("cannot cast to non-contiguous shape"), failure(); Value stride = nullptr, nextSize = nullptr; for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. Value size = getSize(rewriter, loc, viewMemRefType.getShape(), adaptor.sizes(), i); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. stride = getStride(rewriter, loc, strides, nextSize, stride, i); targetMemRef.setStride(rewriter, loc, i, stride); nextSize = size; } rewriter.replaceOp(op, {targetMemRef}); return success(); } }; struct AssumeAlignmentOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { AssumeAlignmentOp::Adaptor transformed(operands); Value memref = transformed.memref(); unsigned alignment = cast(op).alignment(); MemRefDescriptor memRefDescriptor(memref); Value ptr = memRefDescriptor.alignedPtr(rewriter, memref.getLoc()); // Emit llvm.assume(memref.alignedPtr & (alignment - 1) == 0). Notice that // the asserted memref.alignedPtr isn't used anywhere else, as the real // users like load/store/views always re-extract memref.alignedPtr as they // get lowered. // // This relies on LLVM's CSE optimization (potentially after SROA), since // after CSE all memref.alignedPtr instances get de-duplicated into the same // pointer SSA value. auto intPtrType = getIntPtrType(memRefDescriptor.getElementPtrType().getAddressSpace()); Value zero = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0); Value mask = createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, alignment - 1); Value ptrValue = rewriter.create(op->getLoc(), intPtrType, ptr); rewriter.create( op->getLoc(), rewriter.create( op->getLoc(), LLVM::ICmpPredicate::eq, rewriter.create(op->getLoc(), ptrValue, mask), zero)); rewriter.eraseOp(op); return success(); } }; } // namespace /// Try to match the kind of a std.atomic_rmw to determine whether to use a /// lowering to llvm.atomicrmw or fallback to llvm.cmpxchg. static Optional matchSimpleAtomicOp(AtomicRMWOp atomicOp) { switch (atomicOp.kind()) { case AtomicRMWKind::addf: return LLVM::AtomicBinOp::fadd; case AtomicRMWKind::addi: return LLVM::AtomicBinOp::add; case AtomicRMWKind::assign: return LLVM::AtomicBinOp::xchg; case AtomicRMWKind::maxs: return LLVM::AtomicBinOp::max; case AtomicRMWKind::maxu: return LLVM::AtomicBinOp::umax; case AtomicRMWKind::mins: return LLVM::AtomicBinOp::min; case AtomicRMWKind::minu: return LLVM::AtomicBinOp::umin; default: return llvm::None; } llvm_unreachable("Invalid AtomicRMWKind"); } namespace { struct AtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto atomicOp = cast(op); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); AtomicRMWOp::Adaptor adaptor(operands); auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); - auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + auto dataPtr = + getStridedElementPtr(op->getLoc(), memRefType, adaptor.memref(), + adaptor.indices(), rewriter); rewriter.replaceOpWithNewOp( op, resultType, *maybeKind, dataPtr, adaptor.value(), LLVM::AtomicOrdering::acq_rel); return success(); } }; /// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be /// retried until it succeeds in atomically storing a new value into memory. /// /// +---------------------------------+ /// | | /// | | /// | br loop(%loaded) | /// +---------------------------------+ /// | /// -------| | /// | v v /// | +--------------------------------+ /// | | loop(%loaded): | /// | | | /// | | %pair = cmpxchg | /// | | %ok = %pair[0] | /// | | %new = %pair[1] | /// | | cond_br %ok, end, loop(%new) | /// | +--------------------------------+ /// | | | /// |----------- | /// v /// +--------------------------------+ /// | end: | /// | | /// +--------------------------------+ /// struct GenericAtomicRMWOpLowering : public LoadStoreOpLowering { using Base::Base; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto atomicOp = cast(op); auto loc = op->getLoc(); GenericAtomicRMWOp::Adaptor adaptor(operands); LLVM::LLVMType valueType = typeConverter.convertType(atomicOp.getResult().getType()) .cast(); // Split the block into initial, loop, and ending parts. auto *initBlock = rewriter.getInsertionBlock(); auto *loopBlock = rewriter.createBlock(initBlock->getParent(), std::next(Region::iterator(initBlock)), valueType); auto *endBlock = rewriter.createBlock( loopBlock->getParent(), std::next(Region::iterator(loopBlock))); // Operations range to be moved to `endBlock`. auto opsToMoveStart = atomicOp.getOperation()->getIterator(); auto opsToMoveEnd = initBlock->back().getIterator(); // Compute the loaded value and branch to the loop block. rewriter.setInsertionPointToEnd(initBlock); auto memRefType = atomicOp.memref().getType().cast(); - auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter); Value init = rewriter.create(loc, dataPtr); rewriter.create(loc, init, loopBlock); // Prepare the body of the loop block. rewriter.setInsertionPointToStart(loopBlock); // Clone the GenericAtomicRMWOp region and extract the result. auto loopArgument = loopBlock->getArgument(0); BlockAndValueMapping mapping; mapping.map(atomicOp.getCurrentValue(), loopArgument); Block &entryBlock = atomicOp.body().front(); for (auto &nestedOp : entryBlock.without_terminator()) { Operation *clone = rewriter.clone(nestedOp, mapping); mapping.map(nestedOp.getResults(), clone->getResults()); } Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0)); // Prepare the epilog of the loop block. // Append the cmpxchg op to the end of the loop block. auto successOrdering = LLVM::AtomicOrdering::acq_rel; auto failureOrdering = LLVM::AtomicOrdering::monotonic; auto boolType = LLVM::LLVMType::getInt1Ty(rewriter.getContext()); auto pairType = LLVM::LLVMType::getStructTy(valueType, boolType); auto cmpxchg = rewriter.create( loc, pairType, dataPtr, loopArgument, result, successOrdering, failureOrdering); // Extract the %new_loaded and %ok values from the pair. Value newLoaded = rewriter.create( loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0})); Value ok = rewriter.create( loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1})); // Conditionally branch to the end or back to the loop depending on %ok. rewriter.create(loc, ok, endBlock, ArrayRef(), loopBlock, newLoaded); rewriter.setInsertionPointToEnd(endBlock); moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart), std::next(opsToMoveEnd), rewriter); // The 'result' of the atomic_rmw op is the newly loaded value. rewriter.replaceOp(op, {newLoaded}); return success(); } private: // Clones a segment of ops [start, end) and erases the original. void moveOpsRange(ValueRange oldResult, ValueRange newResult, Block::iterator start, Block::iterator end, ConversionPatternRewriter &rewriter) const { BlockAndValueMapping mapping; mapping.map(oldResult, newResult); SmallVector opsToErase; for (auto it = start; it != end; ++it) { rewriter.clone(*it, mapping); opsToErase.push_back(&*it); } for (auto *it : opsToErase) rewriter.eraseOp(it); } }; } // namespace /// Collect a set of patterns to convert from the Standard dialect to LLVM. void mlir::populateStdToLLVMNonMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // FIXME: this should be tablegen'ed // clang-format off patterns.insert< AbsFOpLowering, AddCFOpLowering, AddFOpLowering, AddIOpLowering, AllocaOpLowering, AndOpLowering, AssertOpLowering, AtomicRMWOpLowering, BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CeilFOpLowering, CmpFOpLowering, CmpIOpLowering, CondBranchOpLowering, CopySignOpLowering, CosOpLowering, ConstantOpLowering, CreateComplexOpLowering, DialectCastOpLowering, DivFOpLowering, ExpOpLowering, Exp2OpLowering, FloorFOpLowering, GenericAtomicRMWOpLowering, LogOpLowering, Log10OpLowering, Log2OpLowering, FPExtLowering, FPToSILowering, FPToUILowering, FPTruncLowering, ImOpLowering, IndexCastOpLowering, MulFOpLowering, MulIOpLowering, NegFOpLowering, OrOpLowering, PrefetchOpLowering, ReOpLowering, RemFOpLowering, ReturnOpLowering, RsqrtOpLowering, SIToFPLowering, SelectOpLowering, ShiftLeftOpLowering, SignExtendIOpLowering, SignedDivIOpLowering, SignedRemIOpLowering, SignedShiftRightOpLowering, SinOpLowering, SplatOpLowering, SplatNdOpLowering, SqrtOpLowering, SubCFOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering, UIToFPLowering, UnsignedDivIOpLowering, UnsignedRemIOpLowering, UnsignedShiftRightOpLowering, XOrOpLowering, ZeroExtendIOpLowering>(converter); // clang-format on } void mlir::populateStdToLLVMMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { // clang-format off patterns.insert< AssumeAlignmentOpLowering, DeallocOpLowering, DimOpLowering, GlobalMemrefOpLowering, GetGlobalMemrefOpLowering, LoadOpLowering, MemRefCastOpLowering, MemRefReinterpretCastOpLowering, MemRefReshapeOpLowering, RankOpLowering, StoreOpLowering, SubViewOpLowering, TransposeOpLowering, ViewOpLowering>(converter); // clang-format on if (converter.getOptions().useAlignedAlloc) patterns.insert(converter); else patterns.insert(converter); } void mlir::populateStdToLLVMFuncOpConversionPattern( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { if (converter.getOptions().useBarePtrCallConv) patterns.insert(converter); else patterns.insert(converter); } void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { populateStdToLLVMFuncOpConversionPattern(converter, patterns); populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); populateStdToLLVMMemoryConversionPatterns(converter, patterns); } /// Convert a non-empty list of types to be returned from a function into a /// supported LLVM IR type. In particular, if more than one value is returned, /// create an LLVM IR structure type with elements that correspond to each of /// the MLIR types converted with `convertType`. Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { assert(!types.empty() && "expected non-empty list of type"); if (types.size() == 1) return convertCallingConventionType(types.front()); SmallVector resultTypes; resultTypes.reserve(types.size()); for (auto t : types) { auto converted = convertCallingConventionType(t).dyn_cast_or_null(); if (!converted) return {}; resultTypes.push_back(converted); } return LLVM::LLVMType::getStructTy(&getContext(), resultTypes); } Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, OpBuilder &builder) { auto *context = builder.getContext(); auto int64Ty = LLVM::LLVMType::getInt64Ty(builder.getContext()); auto indexType = IndexType::get(context); // Alloca with proper alignment. We do not expect optimizations of this // alloca op and so we omit allocating at the entry block. auto ptrType = operand.getType().cast().getPointerTo(); Value one = builder.create(loc, int64Ty, IntegerAttr::get(indexType, 1)); Value allocated = builder.create(loc, ptrType, one, /*alignment=*/0); // Store into the alloca'ed descriptor. builder.create(loc, operand, allocated); return allocated; } SmallVector LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, ValueRange operands, OpBuilder &builder) { SmallVector promotedOperands; promotedOperands.reserve(operands.size()); for (auto it : llvm::zip(opOperands, operands)) { auto operand = std::get<0>(it); auto llvmOperand = std::get<1>(it); if (options.useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. if (auto memrefType = operand.getType().dyn_cast()) { MemRefDescriptor desc(llvmOperand); llvmOperand = desc.alignedPtr(builder, loc); } else if (operand.getType().isa()) { llvm_unreachable("Unranked memrefs are not supported"); } } else { if (operand.getType().isa()) { UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, promotedOperands); continue; } if (auto memrefType = operand.getType().dyn_cast()) { MemRefDescriptor::unpack(builder, loc, llvmOperand, operand.getType().cast(), promotedOperands); continue; } } promotedOperands.push_back(llvmOperand); } return promotedOperands; } namespace { /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ConvertStandardToLLVMBase { LLVMLoweringPass() = default; LLVMLoweringPass(bool useBarePtrCallConv, bool emitCWrappers, unsigned indexBitwidth, bool useAlignedAlloc, const llvm::DataLayout &dataLayout) { this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; this->indexBitwidth = indexBitwidth; this->useAlignedAlloc = useAlignedAlloc; this->dataLayout = dataLayout.getStringRepresentation(); } /// Run the dialect converter on the module. void runOnOperation() override { if (useBarePtrCallConv && emitCWrappers) { getOperation().emitError() << "incompatible conversion options: bare-pointer calling convention " "and C wrapper emission"; signalPassFailure(); return; } if (failed(LLVM::LLVMDialect::verifyDataLayoutString( this->dataLayout, [this](const Twine &message) { getOperation().emitError() << message.str(); }))) { signalPassFailure(); return; } ModuleOp m = getOperation(); LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers, indexBitwidth, useAlignedAlloc, llvm::DataLayout(this->dataLayout)}; LLVMTypeConverter typeConverter(&getContext(), options); OwningRewritePatternList patterns; populateStdToLLVMConversionPatterns(typeConverter, patterns); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); m.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), StringAttr::get(this->dataLayout, m.getContext())); } }; } // end namespace mlir::LLVMConversionTarget::LLVMConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) { this->addLegalDialect(); this->addIllegalOp(); this->addIllegalOp(); } std::unique_ptr> mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) { return std::make_unique( options.useBarePtrCallConv, options.emitCWrappers, options.indexBitwidth, options.useAlignedAlloc, options.dataLayout); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 6fe801fd0d94..42c66bcab9ab 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1,1643 +1,1643 @@ //===- VectorToLLVM.cpp - Conversion from Vector to the LLVM dialect ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Target/LLVMIR/TypeTranslation.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/ErrorHandling.h" using namespace mlir; using namespace mlir::vector; // Helper to reduce vector type by one rank at front. static VectorType reducedVectorTypeFront(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); return VectorType::get(tp.getShape().drop_front(), tp.getElementType()); } // Helper to reduce vector type by *all* but one rank at back. static VectorType reducedVectorTypeBack(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); return VectorType::get(tp.getShape().take_back(), tp.getElementType()); } // Helper that picks the proper sequence for inserting. static Value insertOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val1, Value val2, Type llvmType, int64_t rank, int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); return rewriter.create(loc, llvmType, val1, val2, constant); } return rewriter.create(loc, llvmType, val1, val2, rewriter.getI64ArrayAttr(pos)); } // Helper that picks the proper sequence for inserting. static Value insertOne(PatternRewriter &rewriter, Location loc, Value from, Value into, int64_t offset) { auto vectorType = into.getType().cast(); if (vectorType.getRank() > 1) return rewriter.create(loc, from, into, offset); return rewriter.create( loc, vectorType, from, into, rewriter.create(loc, offset)); } // Helper that picks the proper sequence for extracting. static Value extractOne(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, Value val, Type llvmType, int64_t rank, int64_t pos) { if (rank == 1) { auto idxType = rewriter.getIndexType(); auto constant = rewriter.create( loc, typeConverter.convertType(idxType), rewriter.getIntegerAttr(idxType, pos)); return rewriter.create(loc, llvmType, val, constant); } return rewriter.create(loc, llvmType, val, rewriter.getI64ArrayAttr(pos)); } // Helper that picks the proper sequence for extracting. static Value extractOne(PatternRewriter &rewriter, Location loc, Value vector, int64_t offset) { auto vectorType = vector.getType().cast(); if (vectorType.getRank() > 1) return rewriter.create(loc, vector, offset); return rewriter.create( loc, vectorType.getElementType(), vector, rewriter.create(loc, offset)); } // Helper that returns a subset of `arrayAttr` as a vector of int64_t. // TODO: Better support for attribute subtype forwarding + slicing. static SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, unsigned dropBack = 0) { assert(arrayAttr.size() > dropFront + dropBack && "Out of bounds"); auto range = arrayAttr.getAsRange(); SmallVector res; res.reserve(arrayAttr.size() - dropFront - dropBack); for (auto it = range.begin() + dropFront, eit = range.end() - dropBack; it != eit; ++it) res.push_back((*it).getValue().getSExtValue()); return res; } // Helper that returns a vector comparison that constructs a mask: // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] // // NOTE: The LLVM::GetActiveLaneMaskOp intrinsic would provide an alternative, // much more compact, IR for this operation, but LLVM eventually // generates more elaborate instructions for this intrinsic since it // is very conservative on the boundary conditions. static Value buildVectorComparison(ConversionPatternRewriter &rewriter, Operation *op, bool enableIndexOptimizations, int64_t dim, Value b, Value *off = nullptr) { auto loc = op->getLoc(); // If we can assume all indices fit in 32-bit, we perform the vector // comparison in 32-bit to get a higher degree of SIMD parallelism. // Otherwise we perform the vector comparison using 64-bit indices. Value indices; Type idxType; if (enableIndexOptimizations) { indices = rewriter.create( loc, rewriter.getI32VectorAttr( llvm::to_vector<4>(llvm::seq(0, dim)))); idxType = rewriter.getI32Type(); } else { indices = rewriter.create( loc, rewriter.getI64VectorAttr( llvm::to_vector<4>(llvm::seq(0, dim)))); idxType = rewriter.getI64Type(); } // Add in an offset if requested. if (off) { Value o = rewriter.create(loc, idxType, *off); Value ov = rewriter.create(loc, indices.getType(), o); indices = rewriter.create(loc, ov, indices); } // Construct the vector comparison. Value bound = rewriter.create(loc, idxType, b); Value bounds = rewriter.create(loc, indices.getType(), bound); return rewriter.create(loc, CmpIPredicate::slt, indices, bounds); } // Helper that returns data layout alignment of an operation with memref. template LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter, T op, unsigned &align) { Type elementTy = typeConverter.convertType(op.getMemRefType().getElementType()); if (!elementTy) return failure(); // TODO: this should use the MLIR data layout when it becomes available and // stop depending on translation. llvm::LLVMContext llvmContext; align = LLVM::TypeToLLVMIRTranslator(llvmContext) .getPreferredAlignment(elementTy.cast(), typeConverter.getDataLayout()); return success(); } // Helper that returns the base address of a memref. static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc, Value memref, MemRefType memRefType, Value &base) { // Inspect stride and offset structure. // // TODO: flat memory only for now, generalize // int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(memRefType, strides, offset); if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 || offset != 0 || memRefType.getMemorySpace() != 0) return failure(); base = MemRefDescriptor(memref).alignedPtr(rewriter, loc); return success(); } // Helper that returns a pointer given a memref base. static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc, Value memref, MemRefType memRefType, Value &ptr) { Value base; if (failed(getBase(rewriter, loc, memref, memRefType, base))) return failure(); auto pType = MemRefDescriptor(memref).getElementPtrType(); ptr = rewriter.create(loc, pType, base); return success(); } // Helper that returns a bit-casted pointer given a memref base. static LogicalResult getBasePtr(ConversionPatternRewriter &rewriter, Location loc, Value memref, MemRefType memRefType, Type type, Value &ptr) { Value base; if (failed(getBase(rewriter, loc, memref, memRefType, base))) return failure(); auto pType = type.template cast().getPointerTo(); base = rewriter.create(loc, pType, base); ptr = rewriter.create(loc, pType, base); return success(); } // Helper that returns vector of pointers given a memref base and an index // vector. static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter, Location loc, Value memref, Value indices, MemRefType memRefType, VectorType vType, Type iType, Value &ptrs) { Value base; if (failed(getBase(rewriter, loc, memref, memRefType, base))) return failure(); auto pType = MemRefDescriptor(memref).getElementPtrType(); auto ptrsType = LLVM::LLVMType::getVectorTy(pType, vType.getDimSize(0)); ptrs = rewriter.create(loc, ptrsType, base, indices); return success(); } static LogicalResult replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, ArrayRef operands, Value dataPtr) { unsigned align; if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); rewriter.replaceOpWithNewOp(xferOp, dataPtr, align); return success(); } static LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, ArrayRef operands, Value dataPtr, Value mask) { auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; VectorType fillType = xferOp.getVectorType(); Value fill = rewriter.create(loc, fillType, xferOp.padding()); fill = rewriter.create(loc, toLLVMTy(fillType), fill); Type vecTy = typeConverter.convertType(xferOp.getVectorType()); if (!vecTy) return failure(); unsigned align; if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); rewriter.replaceOpWithNewOp( xferOp, vecTy, dataPtr, mask, ValueRange{fill}, rewriter.getI32IntegerAttr(align)); return success(); } static LogicalResult replaceTransferOpWithLoadOrStore(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, ArrayRef operands, Value dataPtr) { unsigned align; if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); auto adaptor = TransferWriteOpAdaptor(operands); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr, align); return success(); } static LogicalResult replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, ArrayRef operands, Value dataPtr, Value mask) { unsigned align; if (failed(getMemRefAlignment(typeConverter, xferOp, align))) return failure(); auto adaptor = TransferWriteOpAdaptor(operands); rewriter.replaceOpWithNewOp( xferOp, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); return success(); } static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, ArrayRef operands) { return TransferReadOpAdaptor(operands); } static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef operands) { return TransferWriteOpAdaptor(operands); } namespace { /// Conversion pattern for a vector.matrix_multiply. /// This is lowered directly to the proper llvm.intr.matrix.multiply. class VectorMatmulOpConversion : public ConvertToLLVMPattern { public: explicit VectorMatmulOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto matmulOp = cast(op); auto adaptor = vector::MatmulOpAdaptor(operands); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), matmulOp.rhs_columns()); return success(); } }; /// Conversion pattern for a vector.flat_transpose. /// This is lowered directly to the proper llvm.intr.matrix.transpose. class VectorFlatTransposeOpConversion : public ConvertToLLVMPattern { public: explicit VectorFlatTransposeOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::FlatTransposeOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto transOp = cast(op); auto adaptor = vector::FlatTransposeOpAdaptor(operands); rewriter.replaceOpWithNewOp( transOp, typeConverter.convertType(transOp.res().getType()), adaptor.matrix(), transOp.rows(), transOp.columns()); return success(); } }; /// Conversion pattern for a vector.maskedload. class VectorMaskedLoadOpConversion : public ConvertToLLVMPattern { public: explicit VectorMaskedLoadOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::MaskedLoadOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto load = cast(op); auto adaptor = vector::MaskedLoadOpAdaptor(operands); // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(typeConverter, load, align))) return failure(); auto vtype = typeConverter.convertType(load.getResultVectorType()); Value ptr; if (failed(getBasePtr(rewriter, loc, adaptor.base(), load.getMemRefType(), vtype, ptr))) return failure(); rewriter.replaceOpWithNewOp( load, vtype, ptr, adaptor.mask(), adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); return success(); } }; /// Conversion pattern for a vector.maskedstore. class VectorMaskedStoreOpConversion : public ConvertToLLVMPattern { public: explicit VectorMaskedStoreOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::MaskedStoreOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto store = cast(op); auto adaptor = vector::MaskedStoreOpAdaptor(operands); // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(typeConverter, store, align))) return failure(); auto vtype = typeConverter.convertType(store.getValueVectorType()); Value ptr; if (failed(getBasePtr(rewriter, loc, adaptor.base(), store.getMemRefType(), vtype, ptr))) return failure(); rewriter.replaceOpWithNewOp( store, adaptor.value(), ptr, adaptor.mask(), rewriter.getI32IntegerAttr(align)); return success(); } }; /// Conversion pattern for a vector.gather. class VectorGatherOpConversion : public ConvertToLLVMPattern { public: explicit VectorGatherOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::GatherOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto gather = cast(op); auto adaptor = vector::GatherOpAdaptor(operands); // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(typeConverter, gather, align))) return failure(); // Get index ptrs. VectorType vType = gather.getResultVectorType(); Type iType = gather.getIndicesVectorType().getElementType(); Value ptrs; if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), gather.getMemRefType(), vType, iType, ptrs))) return failure(); // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp( gather, typeConverter.convertType(vType), ptrs, adaptor.mask(), adaptor.pass_thru(), rewriter.getI32IntegerAttr(align)); return success(); } }; /// Conversion pattern for a vector.scatter. class VectorScatterOpConversion : public ConvertToLLVMPattern { public: explicit VectorScatterOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::ScatterOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto scatter = cast(op); auto adaptor = vector::ScatterOpAdaptor(operands); // Resolve alignment. unsigned align; if (failed(getMemRefAlignment(typeConverter, scatter, align))) return failure(); // Get index ptrs. VectorType vType = scatter.getValueVectorType(); Type iType = scatter.getIndicesVectorType().getElementType(); Value ptrs; if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(), scatter.getMemRefType(), vType, iType, ptrs))) return failure(); // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp( scatter, adaptor.value(), ptrs, adaptor.mask(), rewriter.getI32IntegerAttr(align)); return success(); } }; /// Conversion pattern for a vector.expandload. class VectorExpandLoadOpConversion : public ConvertToLLVMPattern { public: explicit VectorExpandLoadOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::ExpandLoadOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto expand = cast(op); auto adaptor = vector::ExpandLoadOpAdaptor(operands); Value ptr; if (failed(getBasePtr(rewriter, loc, adaptor.base(), expand.getMemRefType(), ptr))) return failure(); auto vType = expand.getResultVectorType(); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(vType), ptr, adaptor.mask(), adaptor.pass_thru()); return success(); } }; /// Conversion pattern for a vector.compressstore. class VectorCompressStoreOpConversion : public ConvertToLLVMPattern { public: explicit VectorCompressStoreOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::CompressStoreOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto compress = cast(op); auto adaptor = vector::CompressStoreOpAdaptor(operands); Value ptr; if (failed(getBasePtr(rewriter, loc, adaptor.base(), compress.getMemRefType(), ptr))) return failure(); rewriter.replaceOpWithNewOp( op, adaptor.value(), ptr, adaptor.mask()); return success(); } }; /// Conversion pattern for all vector reductions. class VectorReductionOpConversion : public ConvertToLLVMPattern { public: explicit VectorReductionOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter, bool reassociateFPRed) : ConvertToLLVMPattern(vector::ReductionOp::getOperationName(), context, typeConverter), reassociateFPReductions(reassociateFPRed) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto reductionOp = cast(op); auto kind = reductionOp.kind(); Type eltType = reductionOp.dest().getType(); Type llvmType = typeConverter.convertType(eltType); if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. if (kind == "add") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else if (kind == "mul") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else if (kind == "min" && (eltType.isIndex() || eltType.isUnsignedInteger())) rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else if (kind == "min") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else if (kind == "max" && (eltType.isIndex() || eltType.isUnsignedInteger())) rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else if (kind == "max") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else if (kind == "and") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else if (kind == "or") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else if (kind == "xor") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else return failure(); return success(); } else if (eltType.isa()) { // Floating-point reductions: add/mul/min/max if (kind == "add") { // Optional accumulator (or zero). Value acc = operands.size() > 1 ? operands[1] : rewriter.create( op->getLoc(), llvmType, rewriter.getZeroAttr(eltType)); rewriter.replaceOpWithNewOp( op, llvmType, acc, operands[0], rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == "mul") { // Optional accumulator (or one). Value acc = operands.size() > 1 ? operands[1] : rewriter.create( op->getLoc(), llvmType, rewriter.getFloatAttr(eltType, 1.0)); rewriter.replaceOpWithNewOp( op, llvmType, acc, operands[0], rewriter.getBoolAttr(reassociateFPReductions)); } else if (kind == "min") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else if (kind == "max") rewriter.replaceOpWithNewOp( op, llvmType, operands[0]); else return failure(); return success(); } return failure(); } private: const bool reassociateFPReductions; }; /// Conversion pattern for a vector.create_mask (1-D only). class VectorCreateMaskOpConversion : public ConvertToLLVMPattern { public: explicit VectorCreateMaskOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter, bool enableIndexOpt) : ConvertToLLVMPattern(vector::CreateMaskOp::getOperationName(), context, typeConverter), enableIndexOptimizations(enableIndexOpt) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dstType = op->getResult(0).getType().cast(); int64_t rank = dstType.getRank(); if (rank == 1) { rewriter.replaceOp( op, buildVectorComparison(rewriter, op, enableIndexOptimizations, dstType.getDimSize(0), operands[0])); return success(); } return failure(); } private: const bool enableIndexOptimizations; }; class VectorShuffleOpConversion : public ConvertToLLVMPattern { public: explicit VectorShuffleOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::ShuffleOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ShuffleOpAdaptor(operands); auto shuffleOp = cast(op); auto v1Type = shuffleOp.getV1VectorType(); auto v2Type = shuffleOp.getV2VectorType(); auto vectorType = shuffleOp.getVectorType(); Type llvmType = typeConverter.convertType(vectorType); auto maskArrayAttr = shuffleOp.mask(); // Bail if result type cannot be lowered. if (!llvmType) return failure(); // Get rank and dimension sizes. int64_t rank = vectorType.getRank(); assert(v1Type.getRank() == rank); assert(v2Type.getRank() == rank); int64_t v1Dim = v1Type.getDimSize(0); // For rank 1, where both operands have *exactly* the same vector type, // there is direct shuffle support in LLVM. Use it! if (rank == 1 && v1Type == v2Type) { Value shuffle = rewriter.create( loc, adaptor.v1(), adaptor.v2(), maskArrayAttr); rewriter.replaceOp(op, shuffle); return success(); } // For all other cases, insert the individual values individually. Value insert = rewriter.create(loc, llvmType); int64_t insPos = 0; for (auto en : llvm::enumerate(maskArrayAttr)) { int64_t extPos = en.value().cast().getInt(); Value value = adaptor.v1(); if (extPos >= v1Dim) { extPos -= v1Dim; value = adaptor.v2(); } Value extract = extractOne(rewriter, typeConverter, loc, value, llvmType, rank, extPos); insert = insertOne(rewriter, typeConverter, loc, insert, extract, llvmType, rank, insPos++); } rewriter.replaceOp(op, insert); return success(); } }; class VectorExtractElementOpConversion : public ConvertToLLVMPattern { public: explicit VectorExtractElementOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::ExtractElementOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::ExtractElementOpAdaptor(operands); auto extractEltOp = cast(op); auto vectorType = extractEltOp.getVectorType(); auto llvmType = typeConverter.convertType(vectorType.getElementType()); // Bail if result type cannot be lowered. if (!llvmType) return failure(); rewriter.replaceOpWithNewOp( op, llvmType, adaptor.vector(), adaptor.position()); return success(); } }; class VectorExtractOpConversion : public ConvertToLLVMPattern { public: explicit VectorExtractOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::ExtractOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::ExtractOpAdaptor(operands); auto extractOp = cast(op); auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult().getType(); auto llvmResultType = typeConverter.convertType(resultType); auto positionArrayAttr = extractOp.position(); // Bail if result type cannot be lowered. if (!llvmResultType) return failure(); // One-shot extraction of vector from array (only requires extractvalue). if (resultType.isa()) { Value extracted = rewriter.create( loc, llvmResultType, adaptor.vector(), positionArrayAttr); rewriter.replaceOp(op, extracted); return success(); } // Potential extraction of 1-D vector from array. auto *context = op->getContext(); Value extracted = adaptor.vector(); auto positionAttrs = positionArrayAttr.getValue(); if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create( loc, typeConverter.convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } // Remaining extraction of element from 1-D LLVM vector auto position = positionAttrs.back().cast(); auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); auto constant = rewriter.create(loc, i64Type, position); extracted = rewriter.create(loc, extracted, constant); rewriter.replaceOp(op, extracted); return success(); } }; /// Conversion pattern that turns a vector.fma on a 1-D vector /// into an llvm.intr.fmuladd. This is a trivial 1-1 conversion. /// This does not match vectors of n >= 2 rank. /// /// Example: /// ``` /// vector.fma %a, %a, %a : vector<8xf32> /// ``` /// is converted to: /// ``` /// llvm.intr.fmuladd %va, %va, %va: /// (!llvm<"<8 x float>">, !llvm<"<8 x float>">, !llvm<"<8 x float>">) /// -> !llvm<"<8 x float>"> /// ``` class VectorFMAOp1DConversion : public ConvertToLLVMPattern { public: explicit VectorFMAOp1DConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::FMAOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::FMAOpAdaptor(operands); vector::FMAOp fmaOp = cast(op); VectorType vType = fmaOp.getVectorType(); if (vType.getRank() != 1) return failure(); rewriter.replaceOpWithNewOp(op, adaptor.lhs(), adaptor.rhs(), adaptor.acc()); return success(); } }; class VectorInsertElementOpConversion : public ConvertToLLVMPattern { public: explicit VectorInsertElementOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::InsertElementOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto adaptor = vector::InsertElementOpAdaptor(operands); auto insertEltOp = cast(op); auto vectorType = insertEltOp.getDestVectorType(); auto llvmType = typeConverter.convertType(vectorType); // Bail if result type cannot be lowered. if (!llvmType) return failure(); rewriter.replaceOpWithNewOp( op, llvmType, adaptor.dest(), adaptor.source(), adaptor.position()); return success(); } }; class VectorInsertOpConversion : public ConvertToLLVMPattern { public: explicit VectorInsertOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::InsertOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto adaptor = vector::InsertOpAdaptor(operands); auto insertOp = cast(op); auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); auto llvmResultType = typeConverter.convertType(destVectorType); auto positionArrayAttr = insertOp.position(); // Bail if result type cannot be lowered. if (!llvmResultType) return failure(); // One-shot insertion of a vector into an array (only requires insertvalue). if (sourceType.isa()) { Value inserted = rewriter.create( loc, llvmResultType, adaptor.dest(), adaptor.source(), positionArrayAttr); rewriter.replaceOp(op, inserted); return success(); } // Potential extraction of 1-D vector from array. auto *context = op->getContext(); Value extracted = adaptor.dest(); auto positionAttrs = positionArrayAttr.getValue(); auto position = positionAttrs.back().cast(); auto oneDVectorType = destVectorType; if (positionAttrs.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); extracted = rewriter.create( loc, typeConverter.convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); } // Insertion of an element into a 1-D LLVM vector. auto i64Type = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); auto constant = rewriter.create(loc, i64Type, position); Value inserted = rewriter.create( loc, typeConverter.convertType(oneDVectorType), extracted, adaptor.source(), constant); // Potential insertion of resulting 1-D vector into array. if (positionAttrs.size() > 1) { auto nMinusOnePositionAttrs = ArrayAttr::get(positionAttrs.drop_back(), context); inserted = rewriter.create(loc, llvmResultType, adaptor.dest(), inserted, nMinusOnePositionAttrs); } rewriter.replaceOp(op, inserted); return success(); } }; /// Rank reducing rewrite for n-D FMA into (n-1)-D FMA where n > 1. /// /// Example: /// ``` /// %d = vector.fma %a, %b, %c : vector<2x4xf32> /// ``` /// is rewritten into: /// ``` /// %r = splat %f0: vector<2x4xf32> /// %va = vector.extractvalue %a[0] : vector<2x4xf32> /// %vb = vector.extractvalue %b[0] : vector<2x4xf32> /// %vc = vector.extractvalue %c[0] : vector<2x4xf32> /// %vd = vector.fma %va, %vb, %vc : vector<4xf32> /// %r2 = vector.insertvalue %vd, %r[0] : vector<4xf32> into vector<2x4xf32> /// %va2 = vector.extractvalue %a2[1] : vector<2x4xf32> /// %vb2 = vector.extractvalue %b2[1] : vector<2x4xf32> /// %vc2 = vector.extractvalue %c2[1] : vector<2x4xf32> /// %vd2 = vector.fma %va2, %vb2, %vc2 : vector<4xf32> /// %r3 = vector.insertvalue %vd2, %r2[1] : vector<4xf32> into vector<2x4xf32> /// // %r3 holds the final value. /// ``` class VectorFMAOpNDRewritePattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(FMAOp op, PatternRewriter &rewriter) const override { auto vType = op.getVectorType(); if (vType.getRank() < 2) return failure(); auto loc = op.getLoc(); auto elemType = vType.getElementType(); Value zero = rewriter.create(loc, elemType, rewriter.getZeroAttr(elemType)); Value desc = rewriter.create(loc, vType, zero); for (int64_t i = 0, e = vType.getShape().front(); i != e; ++i) { Value extrLHS = rewriter.create(loc, op.lhs(), i); Value extrRHS = rewriter.create(loc, op.rhs(), i); Value extrACC = rewriter.create(loc, op.acc(), i); Value fma = rewriter.create(loc, extrLHS, extrRHS, extrACC); desc = rewriter.create(loc, fma, desc, i); } rewriter.replaceOp(op, desc); return success(); } }; // When ranks are different, InsertStridedSlice needs to extract a properly // ranked vector from the destination vector into which to insert. This pattern // only takes care of this part and forwards the rest of the conversion to // another pattern that converts InsertStridedSlice for operands of the same // rank. // // RewritePattern for InsertStridedSliceOp where source and destination vectors // have different ranks. In this case: // 1. the proper subvector is extracted from the destination vector // 2. a new InsertStridedSlice op is created to insert the source in the // destination subvector // 3. the destination subvector is inserted back in the proper place // 4. the op is replaced by the result of step 3. // The new InsertStridedSlice from step 2. will be picked up by a // `VectorInsertStridedSliceOpSameRankRewritePattern`. class VectorInsertStridedSliceOpDifferentRankRewritePattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override { auto srcType = op.getSourceVectorType(); auto dstType = op.getDestVectorType(); if (op.offsets().getValue().empty()) return failure(); auto loc = op.getLoc(); int64_t rankDiff = dstType.getRank() - srcType.getRank(); assert(rankDiff >= 0); if (rankDiff == 0) return failure(); int64_t rankRest = dstType.getRank() - rankDiff; // Extract / insert the subvector of matching rank and InsertStridedSlice // on it. Value extracted = rewriter.create(loc, op.dest(), getI64SubArray(op.offsets(), /*dropFront=*/0, /*dropFront=*/rankRest)); // A different pattern will kick in for InsertStridedSlice with matching // ranks. auto stridedSliceInnerOp = rewriter.create( loc, op.source(), extracted, getI64SubArray(op.offsets(), /*dropFront=*/rankDiff), getI64SubArray(op.strides(), /*dropFront=*/0)); rewriter.replaceOpWithNewOp( op, stridedSliceInnerOp.getResult(), op.dest(), getI64SubArray(op.offsets(), /*dropFront=*/0, /*dropFront=*/rankRest)); return success(); } }; // RewritePattern for InsertStridedSliceOp where source and destination vectors // have the same rank. In this case, we reduce // 1. the proper subvector is extracted from the destination vector // 2. a new InsertStridedSlice op is created to insert the source in the // destination subvector // 3. the destination subvector is inserted back in the proper place // 4. the op is replaced by the result of step 3. // The new InsertStridedSlice from step 2. will be picked up by a // `VectorInsertStridedSliceOpSameRankRewritePattern`. class VectorInsertStridedSliceOpSameRankRewritePattern : public OpRewritePattern { public: VectorInsertStridedSliceOpSameRankRewritePattern(MLIRContext *ctx) : OpRewritePattern(ctx) { // This pattern creates recursive InsertStridedSliceOp, but the recursion is // bounded as the rank is strictly decreasing. setHasBoundedRewriteRecursion(); } LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override { auto srcType = op.getSourceVectorType(); auto dstType = op.getDestVectorType(); if (op.offsets().getValue().empty()) return failure(); int64_t rankDiff = dstType.getRank() - srcType.getRank(); assert(rankDiff >= 0); if (rankDiff != 0) return failure(); if (srcType == dstType) { rewriter.replaceOp(op, op.source()); return success(); } int64_t offset = op.offsets().getValue().front().cast().getInt(); int64_t size = srcType.getShape().front(); int64_t stride = op.strides().getValue().front().cast().getInt(); auto loc = op.getLoc(); Value res = op.dest(); // For each slice of the source vector along the most major dimension. for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { // 1. extract the proper subvector (or element) from source Value extractedSource = extractOne(rewriter, loc, op.source(), idx); if (extractedSource.getType().isa()) { // 2. If we have a vector, extract the proper subvector from destination // Otherwise we are at the element level and no need to recurse. Value extractedDest = extractOne(rewriter, loc, op.dest(), off); // 3. Reduce the problem to lowering a new InsertStridedSlice op with // smaller rank. extractedSource = rewriter.create( loc, extractedSource, extractedDest, getI64SubArray(op.offsets(), /* dropFront=*/1), getI64SubArray(op.strides(), /* dropFront=*/1)); } // 4. Insert the extractedSource into the res vector. res = insertOne(rewriter, loc, extractedSource, res, off); } rewriter.replaceOp(op, res); return success(); } }; /// Returns the strides if the memory underlying `memRefType` has a contiguous /// static layout. static llvm::Optional> computeContiguousStrides(MemRefType memRefType) { int64_t offset; SmallVector strides; if (failed(getStridesAndOffset(memRefType, strides, offset))) return None; if (!strides.empty() && strides.back() != 1) return None; // If no layout or identity layout, this is contiguous by definition. if (memRefType.getAffineMaps().empty() || memRefType.getAffineMaps().front().isIdentity()) return strides; // Otherwise, we must determine contiguity form shapes. This can only ever // work in static cases because MemRefType is underspecified to represent // contiguous dynamic shapes in other ways than with just empty/identity // layout. auto sizes = memRefType.getShape(); for (int index = 0, e = strides.size() - 2; index < e; ++index) { if (ShapedType::isDynamic(sizes[index + 1]) || ShapedType::isDynamicStrideOrOffset(strides[index]) || ShapedType::isDynamicStrideOrOffset(strides[index + 1])) return None; if (strides[index] != strides[index + 1] * sizes[index + 1]) return None; } return strides; } class VectorTypeCastOpConversion : public ConvertToLLVMPattern { public: explicit VectorTypeCastOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::TypeCastOp::getOperationName(), context, typeConverter) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); vector::TypeCastOp castOp = cast(op); MemRefType sourceMemRefType = castOp.getOperand().getType().cast(); MemRefType targetMemRefType = castOp.getResult().getType().cast(); // Only static shape casts supported atm. if (!sourceMemRefType.hasStaticShape() || !targetMemRefType.hasStaticShape()) return failure(); auto llvmSourceDescriptorTy = operands[0].getType().dyn_cast(); if (!llvmSourceDescriptorTy || !llvmSourceDescriptorTy.isStructTy()) return failure(); MemRefDescriptor sourceMemRef(operands[0]); auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) .dyn_cast_or_null(); if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) return failure(); // Only contiguous source buffers supported atm. auto sourceStrides = computeContiguousStrides(sourceMemRefType); if (!sourceStrides) return failure(); auto targetStrides = computeContiguousStrides(targetMemRefType); if (!targetStrides) return failure(); // Only support static strides for now, regardless of contiguity. if (llvm::any_of(*targetStrides, [](int64_t stride) { return ShapedType::isDynamicStrideOrOffset(stride); })) return failure(); auto int64Ty = LLVM::LLVMType::getInt64Ty(rewriter.getContext()); // Create descriptor. auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); Type llvmTargetElementTy = desc.getElementPtrType(); // Set allocated ptr. Value allocated = sourceMemRef.allocatedPtr(rewriter, loc); allocated = rewriter.create(loc, llvmTargetElementTy, allocated); desc.setAllocatedPtr(rewriter, loc, allocated); // Set aligned ptr. Value ptr = sourceMemRef.alignedPtr(rewriter, loc); ptr = rewriter.create(loc, llvmTargetElementTy, ptr); desc.setAlignedPtr(rewriter, loc, ptr); // Fill offset 0. auto attr = rewriter.getIntegerAttr(rewriter.getIndexType(), 0); auto zero = rewriter.create(loc, int64Ty, attr); desc.setOffset(rewriter, loc, zero); // Fill size and stride descriptors in memref. for (auto indexedSize : llvm::enumerate(targetMemRefType.getShape())) { int64_t index = indexedSize.index(); auto sizeAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), indexedSize.value()); auto size = rewriter.create(loc, int64Ty, sizeAttr); desc.setSize(rewriter, loc, index, size); auto strideAttr = rewriter.getIntegerAttr(rewriter.getIndexType(), (*targetStrides)[index]); auto stride = rewriter.create(loc, int64Ty, strideAttr); desc.setStride(rewriter, loc, index, stride); } rewriter.replaceOp(op, {desc}); return success(); } }; /// Conversion pattern that converts a 1-D vector transfer read/write op in a /// sequence of: /// 1. Get the source/dst address as an LLVM vector pointer. /// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. /// 3. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. /// 4. Create a mask where offsetVector is compared against memref upper bound. /// 5. Rewrite op as a masked read or write. template class VectorTransferConversion : public ConvertToLLVMPattern { public: explicit VectorTransferConversion(MLIRContext *context, LLVMTypeConverter &typeConv, bool enableIndexOpt) : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv), enableIndexOptimizations(enableIndexOpt) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto xferOp = cast(op); auto adaptor = getTransferOpAdapter(xferOp, operands); if (xferOp.getVectorType().getRank() > 1 || llvm::size(xferOp.indices()) == 0) return failure(); if (xferOp.permutation_map() != AffineMap::getMinorIdentityMap(xferOp.permutation_map().getNumInputs(), xferOp.getVectorType().getRank(), op->getContext())) return failure(); // Only contiguous source tensors supported atm. auto strides = computeContiguousStrides(xferOp.getMemRefType()); if (!strides) return failure(); auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; Location loc = op->getLoc(); MemRefType memRefType = xferOp.getMemRefType(); if (auto memrefVectorElementType = memRefType.getElementType().dyn_cast()) { // Memref has vector element type. if (memrefVectorElementType.getElementType() != xferOp.getVectorType().getElementType()) return failure(); #ifndef NDEBUG // Check that memref vector type is a suffix of 'vectorType. unsigned memrefVecEltRank = memrefVectorElementType.getRank(); unsigned resultVecRank = xferOp.getVectorType().getRank(); assert(memrefVecEltRank <= resultVecRank); // TODO: Move this to isSuffix in Vector/Utils.h. unsigned rankOffset = resultVecRank - memrefVecEltRank; auto memrefVecEltShape = memrefVectorElementType.getShape(); auto resultVecShape = xferOp.getVectorType().getShape(); for (unsigned i = 0; i < memrefVecEltRank; ++i) assert(memrefVecEltShape[i] != resultVecShape[rankOffset + i] && "memref vector element shape should match suffix of vector " "result shape."); #endif // ifndef NDEBUG } // 1. Get the source/dst address as an LLVM vector pointer. // The vector pointer would always be on address space 0, therefore // addrspacecast shall be used when source/dst memrefs are not on // address space 0. // TODO: support alignment when possible. - Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter); auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast(); Value vectorDataPtr; if (memRefType.getMemorySpace() == 0) vectorDataPtr = rewriter.create(loc, vecTy.getPointerTo(), dataPtr); else vectorDataPtr = rewriter.create( loc, vecTy.getPointerTo(), dataPtr); if (!xferOp.isMaskedDim(0)) return replaceTransferOpWithLoadOrStore(rewriter, typeConverter, loc, xferOp, operands, vectorDataPtr); // 2. Create a vector with linear indices [ 0 .. vector_length - 1 ]. // 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ]. // 4. Let dim the memref dimension, compute the vector comparison mask: // [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ] // // TODO: when the leaf transfer rank is k > 1, we need the last `k` // dimensions here. unsigned vecWidth = vecTy.getVectorNumElements(); unsigned lastIndex = llvm::size(xferOp.indices()) - 1; Value off = xferOp.indices()[lastIndex]; Value dim = rewriter.create(loc, xferOp.memref(), lastIndex); Value mask = buildVectorComparison(rewriter, op, enableIndexOptimizations, vecWidth, dim, &off); // 5. Rewrite as a masked read / write. return replaceTransferOpWithMasked(rewriter, typeConverter, loc, xferOp, operands, vectorDataPtr, mask); } private: const bool enableIndexOptimizations; }; class VectorPrintOpConversion : public ConvertToLLVMPattern { public: explicit VectorPrintOpConversion(MLIRContext *context, LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(vector::PrintOp::getOperationName(), context, typeConverter) {} // Proof-of-concept lowering implementation that relies on a small // runtime support library, which only needs to provide a few // printing methods (single value for all data types, opening/closing // bracket, comma, newline). The lowering fully unrolls a vector // in terms of these elementary printing operations. The advantage // of this approach is that the library can remain unaware of all // low-level implementation details of vectors while still supporting // output of any shaped and dimensioned vector. Due to full unrolling, // this approach is less suited for very large vectors though. // // TODO: rely solely on libc in future? something else? // LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto printOp = cast(op); auto adaptor = vector::PrintOpAdaptor(operands); Type printType = printOp.getPrintType(); if (typeConverter.convertType(printType) == nullptr) return failure(); // Make sure element type has runtime support. PrintConversion conversion = PrintConversion::None; VectorType vectorType = printType.dyn_cast(); Type eltType = vectorType ? vectorType.getElementType() : printType; Operation *printer; if (eltType.isF32()) { printer = getPrintFloat(op); } else if (eltType.isF64()) { printer = getPrintDouble(op); } else if (eltType.isIndex()) { printer = getPrintU64(op); } else if (auto intTy = eltType.dyn_cast()) { // Integers need a zero or sign extension on the operand // (depending on the source type) as well as a signed or // unsigned print method. Up to 64-bit is supported. unsigned width = intTy.getWidth(); if (intTy.isUnsigned()) { if (width <= 64) { if (width < 64) conversion = PrintConversion::ZeroExt64; printer = getPrintU64(op); } else { return failure(); } } else { assert(intTy.isSignless() || intTy.isSigned()); if (width <= 64) { // Note that we *always* zero extend booleans (1-bit integers), // so that true/false is printed as 1/0 rather than -1/0. if (width == 1) conversion = PrintConversion::ZeroExt64; else if (width < 64) conversion = PrintConversion::SignExt64; printer = getPrintI64(op); } else { return failure(); } } } else { return failure(); } // Unroll vector into elementary print calls. int64_t rank = vectorType ? vectorType.getRank() : 0; emitRanks(rewriter, op, adaptor.source(), vectorType, printer, rank, conversion); emitCall(rewriter, op->getLoc(), getPrintNewline(op)); rewriter.eraseOp(op); return success(); } private: enum class PrintConversion { // clang-format off None, ZeroExt64, SignExt64 // clang-format on }; void emitRanks(ConversionPatternRewriter &rewriter, Operation *op, Value value, VectorType vectorType, Operation *printer, int64_t rank, PrintConversion conversion) const { Location loc = op->getLoc(); if (rank == 0) { switch (conversion) { case PrintConversion::ZeroExt64: value = rewriter.create( loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); break; case PrintConversion::SignExt64: value = rewriter.create( loc, value, LLVM::LLVMType::getInt64Ty(rewriter.getContext())); break; case PrintConversion::None: break; } emitCall(rewriter, loc, printer, value); return; } emitCall(rewriter, loc, getPrintOpen(op)); Operation *printComma = getPrintComma(op); int64_t dim = vectorType.getDimSize(0); for (int64_t d = 0; d < dim; ++d) { auto reducedType = rank > 1 ? reducedVectorTypeFront(vectorType) : nullptr; auto llvmType = typeConverter.convertType( rank > 1 ? reducedType : vectorType.getElementType()); Value nestedVal = extractOne(rewriter, typeConverter, loc, value, llvmType, rank, d); emitRanks(rewriter, op, nestedVal, reducedType, printer, rank - 1, conversion); if (d != dim - 1) emitCall(rewriter, loc, printComma); } emitCall(rewriter, loc, getPrintClose(op)); } // Helper to emit a call. static void emitCall(ConversionPatternRewriter &rewriter, Location loc, Operation *ref, ValueRange params = ValueRange()) { rewriter.create(loc, TypeRange(), rewriter.getSymbolRefAttr(ref), params); } // Helper for printer method declaration (first hit) and lookup. static Operation *getPrint(Operation *op, StringRef name, ArrayRef params) { auto module = op->getParentOfType(); auto func = module.lookupSymbol(name); if (func) return func; OpBuilder moduleBuilder(module.getBodyRegion()); return moduleBuilder.create( op->getLoc(), name, LLVM::LLVMType::getFunctionTy( LLVM::LLVMType::getVoidTy(op->getContext()), params, /*isVarArg=*/false)); } // Helpers for method names. Operation *getPrintI64(Operation *op) const { return getPrint(op, "printI64", LLVM::LLVMType::getInt64Ty(op->getContext())); } Operation *getPrintU64(Operation *op) const { return getPrint(op, "printU64", LLVM::LLVMType::getInt64Ty(op->getContext())); } Operation *getPrintFloat(Operation *op) const { return getPrint(op, "printF32", LLVM::LLVMType::getFloatTy(op->getContext())); } Operation *getPrintDouble(Operation *op) const { return getPrint(op, "printF64", LLVM::LLVMType::getDoubleTy(op->getContext())); } Operation *getPrintOpen(Operation *op) const { return getPrint(op, "printOpen", {}); } Operation *getPrintClose(Operation *op) const { return getPrint(op, "printClose", {}); } Operation *getPrintComma(Operation *op) const { return getPrint(op, "printComma", {}); } Operation *getPrintNewline(Operation *op) const { return getPrint(op, "printNewline", {}); } }; /// Progressive lowering of ExtractStridedSliceOp to either: /// 1. express single offset extract as a direct shuffle. /// 2. extract + lower rank strided_slice + insert for the n-D case. class VectorExtractStridedSliceOpConversion : public OpRewritePattern { public: VectorExtractStridedSliceOpConversion(MLIRContext *ctx) : OpRewritePattern(ctx) { // This pattern creates recursive ExtractStridedSliceOp, but the recursion // is bounded as the rank is strictly decreasing. setHasBoundedRewriteRecursion(); } LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { auto dstType = op.getResult().getType().cast(); assert(!op.offsets().getValue().empty() && "Unexpected empty offsets"); int64_t offset = op.offsets().getValue().front().cast().getInt(); int64_t size = op.sizes().getValue().front().cast().getInt(); int64_t stride = op.strides().getValue().front().cast().getInt(); auto loc = op.getLoc(); auto elemType = dstType.getElementType(); assert(elemType.isSignlessIntOrIndexOrFloat()); // Single offset can be more efficiently shuffled. if (op.offsets().getValue().size() == 1) { SmallVector offsets; offsets.reserve(size); for (int64_t off = offset, e = offset + size * stride; off < e; off += stride) offsets.push_back(off); rewriter.replaceOpWithNewOp(op, dstType, op.vector(), op.vector(), rewriter.getI64ArrayAttr(offsets)); return success(); } // Extract/insert on a lower ranked extract strided slice op. Value zero = rewriter.create(loc, elemType, rewriter.getZeroAttr(elemType)); Value res = rewriter.create(loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { Value one = extractOne(rewriter, loc, op.vector(), off); Value extracted = rewriter.create( loc, one, getI64SubArray(op.offsets(), /* dropFront=*/1), getI64SubArray(op.sizes(), /* dropFront=*/1), getI64SubArray(op.strides(), /* dropFront=*/1)); res = insertOne(rewriter, loc, extracted, res, idx); } rewriter.replaceOp(op, res); return success(); } }; } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. void mlir::populateVectorToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, bool reassociateFPReductions, bool enableIndexOptimizations) { MLIRContext *ctx = converter.getDialect()->getContext(); // clang-format off patterns.insert(ctx); patterns.insert( ctx, converter, reassociateFPReductions); patterns.insert, VectorTransferConversion>( ctx, converter, enableIndexOptimizations); patterns .insert(ctx, converter); // clang-format on } void mlir::populateVectorToLLVMMatrixConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.insert(ctx, converter); patterns.insert(ctx, converter); } namespace { struct LowerVectorToLLVMPass : public ConvertVectorToLLVMBase { LowerVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { this->reassociateFPReductions = options.reassociateFPReductions; this->enableIndexOptimizations = options.enableIndexOptimizations; } void runOnOperation() override; }; } // namespace void LowerVectorToLLVMPass::runOnOperation() { // Perform progressive lowering of operations on slices and // all contraction operations. Also applies folding and DCE. { OwningRewritePatternList patterns; populateVectorToVectorCanonicalizationPatterns(patterns, &getContext()); populateVectorSlicesLoweringPatterns(patterns, &getContext()); populateVectorContractLoweringPatterns(patterns, &getContext()); applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } // Convert to the LLVM IR dialect. LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateVectorToLLVMConversionPatterns( converter, patterns, reassociateFPReductions, enableIndexOptimizations); populateVectorToLLVMMatrixConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } std::unique_ptr> mlir::createConvertVectorToLLVMPass(const LowerVectorToLLVMOptions &options) { return std::make_unique(options); } diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp index 317cf322e3e0..26b8bec1f3fc 100644 --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -1,184 +1,184 @@ //===- VectorToROCDL.cpp - Vector to ROCDL lowering passes ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements a pass to generate ROCDLIR operations for higher-level // Vector operations. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/VectorToROCDL/VectorToROCDL.h" #include "../PassDetail.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; using namespace mlir::vector; static LogicalResult replaceTransferOpWithMubuf( ConversionPatternRewriter &rewriter, ArrayRef operands, LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes, Value &glc, Value &slc) { rewriter.replaceOpWithNewOp( xferOp, vecTy, dwordConfig, vindex, offsetSizeInBytes, glc, slc); return success(); } static LogicalResult replaceTransferOpWithMubuf( ConversionPatternRewriter &rewriter, ArrayRef operands, LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes, Value &glc, Value &slc) { auto adaptor = TransferWriteOpAdaptor(operands); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dwordConfig, vindex, offsetSizeInBytes, glc, slc); return success(); } namespace { /// Conversion pattern that converts a 1-D vector transfer read/write. /// Note that this conversion pass only converts vector x2 or x4 f32 /// types. For unsupported cases, they will fall back to the vector to /// llvm conversion pattern. template class VectorTransferConversion : public ConvertToLLVMPattern { public: explicit VectorTransferConversion(MLIRContext *context, LLVMTypeConverter &typeConv) : ConvertToLLVMPattern(ConcreteOp::getOperationName(), context, typeConv) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto xferOp = cast(op); typename ConcreteOp::Adaptor adaptor(operands); if (xferOp.getVectorType().getRank() > 1 || llvm::size(xferOp.indices()) == 0) return failure(); if (!xferOp.permutation_map().isMinorIdentity()) return failure(); // Have it handled in vector->llvm conversion pass. if (!xferOp.isMaskedDim(0)) return failure(); auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); }; LLVM::LLVMType vecTy = toLLVMTy(xferOp.getVectorType()).template cast(); unsigned vecWidth = vecTy.getVectorNumElements(); Location loc = op->getLoc(); // The backend result vector scalarization have trouble scalarize // <1 x ty> result, exclude the x1 width from the lowering. if (vecWidth != 2 && vecWidth != 4) return failure(); // Obtain dataPtr and elementType from the memref. MemRefType memRefType = xferOp.getMemRefType(); // MUBUF instruction operate only on addresspace 0(unified) or 1(global) // In case of 3(LDS): fall back to vector->llvm pass // In case of 5(VGPR): wrong if ((memRefType.getMemorySpace() != 0) && (memRefType.getMemorySpace() != 1)) return failure(); // Note that the dataPtr starts at the offset address specified by // indices, so no need to calculate offset size in bytes again in // the MUBUF instruction. - Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(), - adaptor.indices(), rewriter); + Value dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(), + adaptor.indices(), rewriter); // 1. Create and fill a <4 x i32> dwordConfig with: // 1st two elements holding the address of dataPtr. // 3rd element: -1. // 4th element: 0x27000. SmallVector constConfigAttr{0, 0, -1, 0x27000}; Type i32Ty = rewriter.getIntegerType(32); VectorType i32Vecx4 = VectorType::get(4, i32Ty); Value constConfig = rewriter.create( loc, toLLVMTy(i32Vecx4), DenseElementsAttr::get(i32Vecx4, ArrayRef(constConfigAttr))); // Treat first two element of <4 x i32> as i64, and save the dataPtr // to it. Type i64Ty = rewriter.getIntegerType(64); Value i64x2Ty = rewriter.create( loc, LLVM::LLVMType::getVectorTy( toLLVMTy(i64Ty).template cast(), 2), constConfig); Value dataPtrAsI64 = rewriter.create( loc, toLLVMTy(i64Ty).template cast(), dataPtr); Value zero = createIndexConstant(rewriter, loc, 0); Value dwordConfig = rewriter.create( loc, LLVM::LLVMType::getVectorTy( toLLVMTy(i64Ty).template cast(), 2), i64x2Ty, dataPtrAsI64, zero); dwordConfig = rewriter.create(loc, toLLVMTy(i32Vecx4), dwordConfig); // 2. Rewrite op as a buffer read or write. Value int1False = rewriter.create( loc, toLLVMTy(rewriter.getIntegerType(1)), rewriter.getIntegerAttr(rewriter.getIntegerType(1), 0)); Value int32Zero = rewriter.create( loc, toLLVMTy(i32Ty), rewriter.getIntegerAttr(rewriter.getIntegerType(32), 0)); return replaceTransferOpWithMubuf(rewriter, operands, typeConverter, loc, xferOp, vecTy, dwordConfig, int32Zero, int32Zero, int1False, int1False); } }; } // end anonymous namespace void mlir::populateVectorToROCDLConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { MLIRContext *ctx = converter.getDialect()->getContext(); patterns.insert, VectorTransferConversion>(ctx, converter); } namespace { struct LowerVectorToROCDLPass : public ConvertVectorToROCDLBase { void runOnOperation() override; }; } // namespace void LowerVectorToROCDLPass::runOnOperation() { LLVMTypeConverter converter(&getContext()); OwningRewritePatternList patterns; populateVectorToROCDLConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns); LLVMConversionTarget target(getContext()); target.addLegalDialect(); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); } std::unique_ptr> mlir::createConvertVectorToROCDLPass() { return std::make_unique(); } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir index 7db6ea568b51..b2708a562eab 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,551 +1,531 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s // RUN: mlir-opt -convert-std-to-llvm='use-aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC // CHECK-LABEL: func @check_strided_memref_arguments( // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 func @check_strided_memref_arguments(%static: memref<10x20xf32, affine_map<(i,j)->(20 * i + j + 1)>>, %dynamic : memref(M * i + j + 1)>>, %mixed : memref<10x?xf32, affine_map<(i,j)[M]->(M * i + j + 1)>>) { return } // CHECK-LABEL: func @check_arguments // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 func @check_arguments(%static: memref<10x20xf32>, %dynamic : memref, %mixed : memref<10x?xf32>) { return } // CHECK-LABEL: func @mixed_alloc( // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> { func @mixed_alloc(%arg0: index, %arg1: index) -> memref { // CHECK: %[[c42:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.mul %[[N]], %[[c42]] : !llvm.i64 // CHECK-NEXT: %[[sz:.*]] = llvm.mul %[[st0]], %[[M]] : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[sz]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[c42]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[3, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[st0]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[one]], %{{.*}}[4, 2] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> %0 = alloc(%arg0, %arg1) : memref // CHECK-NEXT: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> return %0 : memref } // CHECK-LABEL: func @mixed_dealloc func @mixed_dealloc(%arg0: memref) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> // CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm.ptr) -> () dealloc %arg0 : memref // CHECK-NEXT: llvm.return return } // CHECK-LABEL: func @dynamic_alloc( // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { func @dynamic_alloc(%arg0: index, %arg1: index) -> memref { // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[sz:.*]] = llvm.mul %[[N]], %[[M]] : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[sz]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%[[sz_bytes]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[one]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %0 = alloc(%arg0, %arg1) : memref // CHECK-NEXT: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> return %0 : memref } // ----- // CHECK-LABEL: func @dynamic_alloca // CHECK: %[[M:.*]]: !llvm.i64, %[[N:.*]]: !llvm.i64) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { func @dynamic_alloca(%arg0: index, %arg1: index) -> memref { // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %[[N]], %[[M]] : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[sz_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[sz_bytes]] x !llvm.float : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[allocated]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[off]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[M]], %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[N]], %{{.*}}[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[st1]], %{{.*}}[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> %0 = alloca(%arg0, %arg1) : memref // Test with explicitly specified alignment. llvm.alloca takes care of the // alignment. The same pointer is thus used for allocation and aligned // accesses. // CHECK: %[[alloca_aligned:.*]] = llvm.alloca %{{.*}} x !llvm.float {alignment = 32 : i64} : (!llvm.i64) -> !llvm.ptr // CHECK: %[[desc:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[desc1:.*]] = llvm.insertvalue %[[alloca_aligned]], %[[desc]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[alloca_aligned]], %[[desc1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> alloca(%arg0, %arg1) {alignment = 32} : memref return %0 : memref } // CHECK-LABEL: func @dynamic_dealloc func @dynamic_dealloc(%arg0: memref) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[ptri8:.*]] = llvm.bitcast %[[ptr]] : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.call @free(%[[ptri8]]) : (!llvm.ptr) -> () dealloc %arg0 : memref return } // CHECK-LABEL: func @stdlib_aligned_alloc({{.*}}) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { // ALIGNED-ALLOC-LABEL: func @stdlib_aligned_alloc({{.*}}) -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> { // ALIGNED-ALLOC-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[num_elems:.*]] = llvm.mlir.constant(576 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // ALIGNED-ALLOC-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // ALIGNED-ALLOC-NEXT: %[[bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[alignment:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: %[[allocated:.*]] = llvm.call @aligned_alloc(%[[alignment]], %[[bytes]]) : (!llvm.i64, !llvm.i64) -> !llvm.ptr // ALIGNED-ALLOC-NEXT: llvm.bitcast %[[allocated]] : !llvm.ptr to !llvm.ptr %0 = alloc() {alignment = 32} : memref<32x18xf32> // Do another alloc just to test that we have a unique declaration for // aligned_alloc. // ALIGNED-ALLOC: llvm.call @aligned_alloc %1 = alloc() {alignment = 64} : memref<4096xf32> // Alignment is to element type boundaries (minimum 16 bytes). // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c32]] %2 = alloc() : memref<4096xvector<8xf32>> // The minimum alignment is 16 bytes unless explicitly specified. // ALIGNED-ALLOC: %[[c16:.*]] = llvm.mlir.constant(16 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c16]], %3 = alloc() : memref<4096xvector<2xf32>> // ALIGNED-ALLOC: %[[c8:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c8]], %4 = alloc() {alignment = 8} : memref<1024xvector<4xf32>> // Bump the memref allocation size if its size is not a multiple of alignment. // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.mlir.constant(1 : index) : !llvm.i64 // ALIGNED-ALLOC-NEXT: llvm.sub // ALIGNED-ALLOC-NEXT: llvm.add // ALIGNED-ALLOC-NEXT: llvm.urem // ALIGNED-ALLOC-NEXT: %[[SIZE_ALIGNED:.*]] = llvm.sub // ALIGNED-ALLOC-NEXT: llvm.call @aligned_alloc(%[[c32]], %[[SIZE_ALIGNED]]) %5 = alloc() {alignment = 32} : memref<100xf32> // Bump alignment to the next power of two if it isn't. // ALIGNED-ALLOC: %[[c128:.*]] = llvm.mlir.constant(128 : index) : !llvm.i64 // ALIGNED-ALLOC: llvm.call @aligned_alloc(%[[c128]] %6 = alloc(%N) : memref> return %0 : memref<32x18xf32> } // CHECK-LABEL: func @mixed_load( // CHECK-COUNT-2: !llvm.ptr, // CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64 // CHECK: %[[I:.*]]: !llvm.i64, // CHECK: %[[J:.*]]: !llvm.i64) func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr %0 = load %mixed[%i, %j] : memref<42x?xf32> return } // CHECK-LABEL: func @dynamic_load( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[I:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @dynamic_load(%dynamic : memref, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr %0 = load %dynamic[%i, %j] : memref return } // CHECK-LABEL: func @prefetch // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[I:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @prefetch(%A : memref, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: [[C1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK-NEXT: [[C3:%.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32 // CHECK-NEXT: [[C1_1:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK-NEXT: "llvm.intr.prefetch"(%[[addr]], [[C1]], [[C3]], [[C1_1]]) : (!llvm.ptr, !llvm.i32, !llvm.i32, !llvm.i32) -> () prefetch %A[%i, %j], write, locality<3>, data : memref // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: [[C0_1:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: [[C1_2:%.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK: "llvm.intr.prefetch"(%{{.*}}, [[C0]], [[C0_1]], [[C1_2]]) : (!llvm.ptr, !llvm.i32, !llvm.i32, !llvm.i32) -> () prefetch %A[%i, %j], read, locality<0>, data : memref // CHECK: [[C0_2:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : i32) : !llvm.i32 // CHECK: [[C0_3:%.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: "llvm.intr.prefetch"(%{{.*}}, [[C0_2]], [[C2]], [[C0_3]]) : (!llvm.ptr, !llvm.i32, !llvm.i32, !llvm.i32) -> () prefetch %A[%i, %j], read, locality<2>, instr : memref return } // CHECK-LABEL: func @dynamic_store // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[I:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @dynamic_store(%dynamic : memref, %i : index, %j : index, %val : f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr store %val, %dynamic[%i, %j] : memref return } // CHECK-LABEL: func @mixed_store // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[I:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr store %val, %mixed[%i, %j] : memref<42x?xf32> return } // CHECK-LABEL: func @memref_cast_static_to_dynamic func @memref_cast_static_to_dynamic(%static : memref<10x42xf32>) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %static : memref<10x42xf32> to memref return } // CHECK-LABEL: func @memref_cast_static_to_mixed func @memref_cast_static_to_mixed(%static : memref<10x42xf32>) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %static : memref<10x42xf32> to memref return } // CHECK-LABEL: func @memref_cast_dynamic_to_static func @memref_cast_dynamic_to_static(%dynamic : memref) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %dynamic : memref to memref<10x12xf32> return } // CHECK-LABEL: func @memref_cast_dynamic_to_mixed func @memref_cast_dynamic_to_mixed(%dynamic : memref) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %dynamic : memref to memref return } // CHECK-LABEL: func @memref_cast_mixed_to_dynamic func @memref_cast_mixed_to_dynamic(%mixed : memref<42x?xf32>) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %mixed : memref<42x?xf32> to memref return } // CHECK-LABEL: func @memref_cast_mixed_to_static func @memref_cast_mixed_to_static(%mixed : memref<42x?xf32>) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %mixed : memref<42x?xf32> to memref<42x1xf32> return } // CHECK-LABEL: func @memref_cast_mixed_to_mixed func @memref_cast_mixed_to_mixed(%mixed : memref<42x?xf32>) { // CHECK-NOT: llvm.bitcast %0 = memref_cast %mixed : memref<42x?xf32> to memref return } // CHECK-LABEL: func @memref_cast_ranked_to_unranked func @memref_cast_ranked_to_unranked(%arg : memref<42x2x?xf32>) { // CHECK-DAG: %[[c:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-DAG: %[[p:.*]] = llvm.alloca %[[c]] x !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> : (!llvm.i64) -> !llvm.ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>> // CHECK-DAG: llvm.store %{{.*}}, %[[p]] : !llvm.ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>> // CHECK-DAG: %[[p2:.*]] = llvm.bitcast %[[p]] : !llvm.ptr, ptr, i64, array<3 x i64>, array<3 x i64>)>> to !llvm.ptr // CHECK-DAG: %[[r:.*]] = llvm.mlir.constant(3 : i64) : !llvm.i64 // CHECK : llvm.mlir.undef : !llvm.struct<(i64, ptr)> // CHECK-DAG: llvm.insertvalue %[[r]], %{{.*}}[0] : !llvm.struct<(i64, ptr)> // CHECK-DAG: llvm.insertvalue %[[p2]], %{{.*}}[1] : !llvm.struct<(i64, ptr)> %0 = memref_cast %arg : memref<42x2x?xf32> to memref<*xf32> return } // CHECK-LABEL: func @memref_cast_unranked_to_ranked func @memref_cast_unranked_to_ranked(%arg : memref<*xf32>) { // CHECK: %[[p:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(i64, ptr)> // CHECK-NEXT: llvm.bitcast %[[p]] : !llvm.ptr to !llvm.ptr, ptr, i64, array<4 x i64>, array<4 x i64>)>> %0 = memref_cast %arg : memref<*xf32> to memref return } // CHECK-LABEL: func @mixed_memref_dim func @mixed_memref_dim(%mixed : memref<42x?x?x13x?xf32>) { // CHECK: llvm.mlir.constant(42 : index) : !llvm.i64 %c0 = constant 0 : index %0 = dim %mixed, %c0 : memref<42x?x?x13x?xf32> // CHECK: llvm.extractvalue %[[ld:.*]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> %c1 = constant 1 : index %1 = dim %mixed, %c1 : memref<42x?x?x13x?xf32> // CHECK: llvm.extractvalue %[[ld]][3, 2] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> %c2 = constant 2 : index %2 = dim %mixed, %c2 : memref<42x?x?x13x?xf32> // CHECK: llvm.mlir.constant(13 : index) : !llvm.i64 %c3 = constant 3 : index %3 = dim %mixed, %c3 : memref<42x?x?x13x?xf32> // CHECK: llvm.extractvalue %[[ld]][3, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> %c4 = constant 4 : index %4 = dim %mixed, %c4 : memref<42x?x?x13x?xf32> return } // CHECK-LABEL: @memref_dim_with_dyn_index // CHECK-SAME: %[[ALLOC_PTR:.*]]: !llvm.ptr, %[[ALIGN_PTR:.*]]: !llvm.ptr, %[[OFFSET:.*]]: !llvm.i64, %[[SIZE0:.*]]: !llvm.i64, %[[SIZE1:.*]]: !llvm.i64, %[[STRIDE0:.*]]: !llvm.i64, %[[STRIDE1:.*]]: !llvm.i64, %[[IDX:.*]]: !llvm.i64) -> !llvm.i64 func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index { // CHECK-NEXT: %[[DESCR0:.*]] = llvm.mlir.undef : [[DESCR_TY:!llvm.struct<\(ptr, ptr, i64, array<2 x i64>, array<2 x i64>\)>]] // CHECK-NEXT: %[[DESCR1:.*]] = llvm.insertvalue %[[ALLOC_PTR]], %[[DESCR0]][0] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR2:.*]] = llvm.insertvalue %[[ALIGN_PTR]], %[[DESCR1]][1] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR3:.*]] = llvm.insertvalue %[[OFFSET]], %[[DESCR2]][2] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR4:.*]] = llvm.insertvalue %[[SIZE0]], %[[DESCR3]][3, 0] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR5:.*]] = llvm.insertvalue %[[STRIDE0]], %[[DESCR4]][4, 0] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR6:.*]] = llvm.insertvalue %[[SIZE1]], %[[DESCR5]][3, 1] : [[DESCR_TY]] // CHECK-NEXT: %[[DESCR7:.*]] = llvm.insertvalue %[[STRIDE1]], %[[DESCR6]][4, 1] : [[DESCR_TY]] // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-DAG: %[[SIZES:.*]] = llvm.extractvalue %[[DESCR7]][3] : [[DESCR_TY]] // CHECK-DAG: %[[SIZES_PTR:.*]] = llvm.alloca %[[C1]] x !llvm.array<2 x i64> : (!llvm.i64) -> !llvm.ptr> // CHECK-DAG: llvm.store %[[SIZES]], %[[SIZES_PTR]] : !llvm.ptr> // CHECK-DAG: %[[RESULT_PTR:.*]] = llvm.getelementptr %[[SIZES_PTR]][%[[C0]], %[[IDX]]] : (!llvm.ptr>, !llvm.i64, !llvm.i64) -> !llvm.ptr // CHECK-DAG: %[[RESULT:.*]] = llvm.load %[[RESULT_PTR]] : !llvm.ptr // CHECK-DAG: llvm.return %[[RESULT]] : !llvm.i64 %result = dim %arg, %idx : memref<3x?xf32> return %result : index } // CHECK-LABEL: @memref_reinterpret_cast_ranked_to_static_shape func @memref_reinterpret_cast_ranked_to_static_shape(%input : memref<2x3xf32>) { %output = memref_reinterpret_cast %input to offset: [0], sizes: [6, 1], strides: [1, 1] : memref<2x3xf32> to memref<6x1xf32> return } // CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[TY:!.*]] // CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY]] // CHECK: [[BASE_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[TY]] // CHECK: [[ALIGNED_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[TY]] // CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]] // CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]] // CHECK: [[OFFSET:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]] // CHECK: [[SIZE_0:%.*]] = llvm.mlir.constant(6 : index) : !llvm.i64 // CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]] // CHECK: [[SIZE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_4]][4, 0] : [[TY]] // CHECK: [[STRIDE_0:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_5]][3, 1] : [[TY]] // CHECK: [[STRIDE_1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]] // CHECK-LABEL: @memref_reinterpret_cast_unranked_to_dynamic_shape func @memref_reinterpret_cast_unranked_to_dynamic_shape(%offset: index, %size_0 : index, %size_1 : index, %stride_0 : index, %stride_1 : index, %input : memref<*xf32>) { %output = memref_reinterpret_cast %input to offset: [%offset], sizes: [%size_0, %size_1], strides: [%stride_0, %stride_1] : memref<*xf32> to memref return } // CHECK-SAME: ([[OFFSET:%[a-z,0-9]+]]: !llvm.i64, // CHECK-SAME: [[SIZE_0:%[a-z,0-9]+]]: !llvm.i64, [[SIZE_1:%[a-z,0-9]+]]: !llvm.i64, // CHECK-SAME: [[STRIDE_0:%[a-z,0-9]+]]: !llvm.i64, [[STRIDE_1:%[a-z,0-9]+]]: !llvm.i64, // CHECK: [[INPUT:%.*]] = llvm.insertvalue {{.*}}[1] : !llvm.struct<(i64, ptr)> // CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY:!.*]] // CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr)> // CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr to !llvm.ptr> // CHECK: [[BASE_PTR:%.*]] = llvm.load [[BASE_PTR_PTR]] : !llvm.ptr> // CHECK: [[BASE_PTR_PTR_:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr to !llvm.ptr> // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR_]]{{\[}}[[C1]]] // CHECK-SAME: : (!llvm.ptr>, !llvm.i64) -> !llvm.ptr> // CHECK: [[ALIGNED_PTR:%.*]] = llvm.load [[ALIGNED_PTR_PTR]] : !llvm.ptr> // CHECK: [[OUT_1:%.*]] = llvm.insertvalue [[BASE_PTR]], [[OUT_0]][0] : [[TY]] // CHECK: [[OUT_2:%.*]] = llvm.insertvalue [[ALIGNED_PTR]], [[OUT_1]][1] : [[TY]] // CHECK: [[OUT_3:%.*]] = llvm.insertvalue [[OFFSET]], [[OUT_2]][2] : [[TY]] // CHECK: [[OUT_4:%.*]] = llvm.insertvalue [[SIZE_0]], [[OUT_3]][3, 0] : [[TY]] // CHECK: [[OUT_5:%.*]] = llvm.insertvalue [[STRIDE_0]], [[OUT_4]][4, 0] : [[TY]] // CHECK: [[OUT_6:%.*]] = llvm.insertvalue [[SIZE_1]], [[OUT_5]][3, 1] : [[TY]] // CHECK: [[OUT_7:%.*]] = llvm.insertvalue [[STRIDE_1]], [[OUT_6]][4, 1] : [[TY]] // CHECK-LABEL: @memref_reshape func @memref_reshape(%input : memref<2x3xf32>, %shape : memref) { %output = memref_reshape %input(%shape) : (memref<2x3xf32>, memref) -> memref<*xf32> return } // CHECK: [[INPUT:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] : [[INPUT_TY:!.*]] // CHECK: [[SHAPE:%.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] : [[SHAPE_TY:!.*]] // CHECK: [[RANK:%.*]] = llvm.extractvalue [[SHAPE]][3, 0] : [[SHAPE_TY]] // CHECK: [[UNRANKED_OUT_O:%.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> // CHECK: [[UNRANKED_OUT_1:%.*]] = llvm.insertvalue [[RANK]], [[UNRANKED_OUT_O]][0] : !llvm.struct<(i64, ptr)> // Compute size in bytes to allocate result ranked descriptor // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 // CHECK: [[PTR_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // CHECK: [[INDEX_SIZE:%.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // CHECK: [[DOUBLE_PTR_SIZE:%.*]] = llvm.mul [[C2]], [[PTR_SIZE]] : !llvm.i64 // CHECK: [[DESC_ALLOC_SIZE:%.*]] = llvm.add [[DOUBLE_PTR_SIZE]], %{{.*}} // CHECK: [[UNDERLYING_DESC:%.*]] = llvm.alloca [[DESC_ALLOC_SIZE]] x !llvm.i8 // CHECK: llvm.insertvalue [[UNDERLYING_DESC]], [[UNRANKED_OUT_1]][1] // Set allocated, aligned pointers and offset. // CHECK: [[ALLOC_PTR:%.*]] = llvm.extractvalue [[INPUT]][0] : [[INPUT_TY]] // CHECK: [[ALIGN_PTR:%.*]] = llvm.extractvalue [[INPUT]][1] : [[INPUT_TY]] // CHECK: [[OFFSET:%.*]] = llvm.extractvalue [[INPUT]][2] : [[INPUT_TY]] // CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] // CHECK-SAME: !llvm.ptr to !llvm.ptr> // CHECK: llvm.store [[ALLOC_PTR]], [[BASE_PTR_PTR]] : !llvm.ptr> // CHECK: [[BASE_PTR_PTR_:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] : !llvm.ptr to !llvm.ptr> // CHECK: [[C1:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[ALIGNED_PTR_PTR:%.*]] = llvm.getelementptr [[BASE_PTR_PTR_]]{{\[}}[[C1]]] // CHECK: llvm.store [[ALIGN_PTR]], [[ALIGNED_PTR_PTR]] : !llvm.ptr> // CHECK: [[BASE_PTR_PTR__:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] : !llvm.ptr to !llvm.ptr> // CHECK: [[C2:%.*]] = llvm.mlir.constant(2 : index) : !llvm.i64 // CHECK: [[OFFSET_PTR_:%.*]] = llvm.getelementptr [[BASE_PTR_PTR__]]{{\[}}[[C2]]] // CHECK: [[OFFSET_PTR:%.*]] = llvm.bitcast [[OFFSET_PTR_]] // CHECK: llvm.store [[OFFSET]], [[OFFSET_PTR]] : !llvm.ptr // Iterate over shape operand in reverse order and set sizes and strides. // CHECK: [[STRUCT_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]] // CHECK-SAME: !llvm.ptr to !llvm.ptr, ptr, i64, i64)>> // CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: [[C3_I32:%.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32 // CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], [[C3_I32]]] // CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]] // CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]] // CHECK: [[C1_:%.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK: [[RANK_MIN_1:%.*]] = llvm.sub [[RANK]], [[C1_]] : !llvm.i64 // CHECK: llvm.br ^bb1([[RANK_MIN_1]], [[C1_]] : !llvm.i64, !llvm.i64) // CHECK: ^bb1([[DIM:%.*]]: !llvm.i64, [[CUR_STRIDE:%.*]]: !llvm.i64): // CHECK: [[C0_:%.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK: [[COND:%.*]] = llvm.icmp "sge" [[DIM]], [[C0_]] : !llvm.i64 // CHECK: llvm.cond_br [[COND]], ^bb2, ^bb3 // CHECK: ^bb2: // CHECK: [[SIZE_PTR:%.*]] = llvm.getelementptr [[SHAPE_IN_PTR]]{{\[}}[[DIM]]] // CHECK: [[SIZE:%.*]] = llvm.load [[SIZE_PTR]] : !llvm.ptr // CHECK: [[TARGET_SIZE_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[DIM]]] // CHECK: llvm.store [[SIZE]], [[TARGET_SIZE_PTR]] : !llvm.ptr // CHECK: [[TARGET_STRIDE_PTR:%.*]] = llvm.getelementptr [[STRIDES_PTR]]{{\[}}[[DIM]]] // CHECK: llvm.store [[CUR_STRIDE]], [[TARGET_STRIDE_PTR]] : !llvm.ptr // CHECK: [[UPDATE_STRIDE:%.*]] = llvm.mul [[CUR_STRIDE]], [[SIZE]] : !llvm.i64 // CHECK: [[STRIDE_COND:%.*]] = llvm.sub [[DIM]], [[C1_]] : !llvm.i64 // CHECK: llvm.br ^bb1([[STRIDE_COND]], [[UPDATE_STRIDE]] : !llvm.i64, !llvm.i64) // CHECK: ^bb3: // CHECK: llvm.return diff --git a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir index 9c32abc39f14..158fdcba7c92 100644 --- a/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-static-memref-ops.mlir @@ -1,428 +1,404 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s // RUN: mlir-opt -convert-std-to-llvm='use-bare-ptr-memref-call-conv=1' -split-input-file %s | FileCheck %s --check-prefix=BAREPTR // BAREPTR-LABEL: func @check_noalias // BAREPTR-SAME: %{{.*}}: !llvm.ptr {llvm.noalias = true}, %{{.*}}: !llvm.ptr {llvm.noalias = true} func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}, %other : memref<2xf32> {llvm.noalias = true}) { return } // ----- // CHECK-LABEL: func @check_static_return // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-LABEL: func @check_static_return // BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.ptr { func @check_static_return(%static : memref<32x18xf32>) -> memref<32x18xf32> { // CHECK: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[base0:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr return %static : memref<32x18xf32> } // ----- // CHECK-LABEL: func @check_static_return_with_offset // CHECK-COUNT-2: !llvm.ptr // CHECK-COUNT-5: !llvm.i64 // CHECK-SAME: -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-LABEL: func @check_static_return_with_offset // BAREPTR-SAME: (%[[arg:.*]]: !llvm.ptr) -> !llvm.ptr { func @check_static_return_with_offset(%static : memref<32x18xf32, offset:7, strides:[22,1]>) -> memref<32x18xf32, offset:7, strides:[22,1]> { // CHECK: llvm.return %{{.*}} : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR: %[[udf:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[base0:.*]] = llvm.insertvalue %[[arg]], %[[udf]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[aligned:.*]] = llvm.insertvalue %[[arg]], %[[base0]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val0:.*]] = llvm.mlir.constant(7 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins0:.*]] = llvm.insertvalue %[[val0]], %[[aligned]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins1:.*]] = llvm.insertvalue %[[val1]], %[[ins0]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val2:.*]] = llvm.mlir.constant(22 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins2:.*]] = llvm.insertvalue %[[val2]], %[[ins1]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val3:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins3:.*]] = llvm.insertvalue %[[val3]], %[[ins2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[val4:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[ins4:.*]] = llvm.insertvalue %[[val4]], %[[ins3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[base1:.*]] = llvm.extractvalue %[[ins4]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: llvm.return %[[base1]] : !llvm.ptr return %static : memref<32x18xf32, offset:7, strides:[22,1]> } // ----- // CHECK-LABEL: func @zero_d_alloc() -> !llvm.struct<(ptr, ptr, i64)> { // BAREPTR-LABEL: func @zero_d_alloc() -> !llvm.ptr { func @zero_d_alloc() -> memref { // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: llvm.call @malloc(%[[size_bytes]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> // CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> // CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> // CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> // BAREPTR-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[one]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // BAREPTR-NEXT: llvm.call @malloc(%[[size_bytes]]) : (!llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // BAREPTR-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64)> // BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> // BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> // BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64)> %0 = alloc() : memref return %0 : memref } // ----- // CHECK-LABEL: func @zero_d_dealloc // BAREPTR-LABEL: func @zero_d_dealloc(%{{.*}}: !llvm.ptr) { func @zero_d_dealloc(%arg0: memref) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> // CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm.ptr) -> () // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64)> // BAREPTR-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm.ptr to !llvm.ptr // BAREPTR-NEXT: llvm.call @free(%[[bc]]) : (!llvm.ptr) -> () dealloc %arg0 : memref return } // ----- // CHECK-LABEL: func @aligned_1d_alloc( // BAREPTR-LABEL: func @aligned_1d_alloc( func @aligned_1d_alloc() -> memref<42xf32> { // CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[sz1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // CHECK-NEXT: %[[allocsize:.*]] = llvm.add %[[size_bytes]], %[[alignment]] : !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // CHECK-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[ptr]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[one_1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[bump:.*]] = llvm.sub %[[alignment]], %[[one_1]] : !llvm.i64 // CHECK-NEXT: %[[bumped:.*]] = llvm.add %[[allocatedAsInt]], %[[bump]] : !llvm.i64 // CHECK-NEXT: %[[mod:.*]] = llvm.urem %[[bumped]], %[[alignment]] : !llvm.i64 // CHECK-NEXT: %[[aligned:.*]] = llvm.sub %[[bumped]], %[[mod]] : !llvm.i64 // CHECK-NEXT: %[[alignedBitCast:.*]] = llvm.inttoptr %[[aligned]] : !llvm.i64 to !llvm.ptr // CHECK-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: %[[sz1:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[sz1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // BAREPTR-NEXT: %[[alignment:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64 // BAREPTR-NEXT: %[[allocsize:.*]] = llvm.add %[[size_bytes]], %[[alignment]] : !llvm.i64 // BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[allocsize]]) : (!llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: %[[ptr:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr // BAREPTR-NEXT: %[[allocatedAsInt:.*]] = llvm.ptrtoint %[[ptr]] : !llvm.ptr to !llvm.i64 // BAREPTR-NEXT: %[[one_2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[bump:.*]] = llvm.sub %[[alignment]], %[[one_2]] : !llvm.i64 // BAREPTR-NEXT: %[[bumped:.*]] = llvm.add %[[allocatedAsInt]], %[[bump]] : !llvm.i64 // BAREPTR-NEXT: %[[mod:.*]] = llvm.urem %[[bumped]], %[[alignment]] : !llvm.i64 // BAREPTR-NEXT: %[[aligned:.*]] = llvm.sub %[[bumped]], %[[mod]] : !llvm.i64 // BAREPTR-NEXT: %[[alignedBitCast:.*]] = llvm.inttoptr %[[aligned]] : !llvm.i64 to !llvm.ptr // BAREPTR-NEXT: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: llvm.insertvalue %[[ptr]], %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: llvm.insertvalue %[[alignedBitCast]], %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: llvm.insertvalue %[[c0]], %{{.*}}[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %0 = alloc() {alignment = 8} : memref<42xf32> return %0 : memref<42xf32> } // ----- // CHECK-LABEL: func @static_alloc() -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { // BAREPTR-LABEL: func @static_alloc() -> !llvm.ptr { func @static_alloc() -> memref<32x18xf32> { // CHECK: %[[num_elems:.*]] = llvm.mlir.constant(576 : index) : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[size_bytes]]) : (!llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.bitcast %[[allocated]] : !llvm.ptr to !llvm.ptr // BAREPTR: %[[num_elems:.*]] = llvm.mlir.constant(576 : index) : !llvm.i64 // BAREPTR-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // BAREPTR-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // BAREPTR-NEXT: %[[allocated:.*]] = llvm.call @malloc(%[[size_bytes]]) : (!llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: llvm.bitcast %[[allocated]] : !llvm.ptr to !llvm.ptr %0 = alloc() : memref<32x18xf32> return %0 : memref<32x18xf32> } // ----- // CHECK-LABEL: func @static_alloca() -> !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> { func @static_alloca() -> memref<32x18xf32> { // CHECK-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : !llvm.i64 // CHECK-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : !llvm.i64 // CHECK-NEXT: %[[st2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: %[[num_elems:.*]] = llvm.mlir.constant(576 : index) : !llvm.i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[gep:.*]] = llvm.getelementptr %[[null]][%[[num_elems]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: %[[size_bytes:.*]] = llvm.ptrtoint %[[gep]] : !llvm.ptr to !llvm.i64 // CHECK-NEXT: %[[allocated:.*]] = llvm.alloca %[[size_bytes]] x !llvm.float : (!llvm.i64) -> !llvm.ptr %0 = alloca() : memref<32x18xf32> // Test with explicitly specified alignment. llvm.alloca takes care of the // alignment. The same pointer is thus used for allocation and aligned // accesses. // CHECK: %[[alloca_aligned:.*]] = llvm.alloca %{{.*}} x !llvm.float {alignment = 32 : i64} : (!llvm.i64) -> !llvm.ptr // CHECK: %[[desc:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[desc1:.*]] = llvm.insertvalue %[[alloca_aligned]], %[[desc]][0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[alloca_aligned]], %[[desc1]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> alloca() {alignment = 32} : memref<32x18xf32> return %0 : memref<32x18xf32> } // ----- // CHECK-LABEL: func @static_dealloc // BAREPTR-LABEL: func @static_dealloc(%{{.*}}: !llvm.ptr) { func @static_dealloc(%static: memref<10x8xf32>) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm.ptr to !llvm.ptr // CHECK-NEXT: llvm.call @free(%[[bc]]) : (!llvm.ptr) -> () // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // BAREPTR-NEXT: %[[bc:.*]] = llvm.bitcast %[[ptr]] : !llvm.ptr to !llvm.ptr // BAREPTR-NEXT: llvm.call @free(%[[bc]]) : (!llvm.ptr) -> () dealloc %static : memref<10x8xf32> return } // ----- // CHECK-LABEL: func @zero_d_load // BAREPTR-LABEL: func @zero_d_load(%{{.*}}: !llvm.ptr) -> !llvm.float func @zero_d_load(%arg0: memref) -> f32 { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> -// CHECK-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr -// CHECK-NEXT: %{{.*}} = llvm.load %[[addr]] : !llvm.ptr +// CHECK-NEXT: %{{.*}} = llvm.load %[[ptr]] : !llvm.ptr // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> -// BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[c0]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr -// BAREPTR-NEXT: llvm.load %[[addr:.*]] : !llvm.ptr +// BAREPTR-NEXT: llvm.load %[[ptr:.*]] : !llvm.ptr %0 = load %arg0[] : memref return %0 : f32 } // ----- // CHECK-LABEL: func @static_load( // CHECK-COUNT-2: !llvm.ptr, // CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64 // CHECK: %[[I:.*]]: !llvm.i64, // CHECK: %[[J:.*]]: !llvm.i64) // BAREPTR-LABEL: func @static_load // BAREPTR-SAME: (%[[A:.*]]: !llvm.ptr, %[[I:.*]]: !llvm.i64, %[[J:.*]]: !llvm.i64) { func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: llvm.load %[[addr]] : !llvm.ptr %0 = load %static[%i, %j] : memref<10x42xf32> return } // ----- // CHECK-LABEL: func @zero_d_store // BAREPTR-LABEL: func @zero_d_store // BAREPTR-SAME: (%[[A:.*]]: !llvm.ptr, %[[val:.*]]: !llvm.float) func @zero_d_store(%arg0: memref, %arg1: f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr -// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr +// CHECK-NEXT: llvm.store %{{.*}}, %[[ptr]] : !llvm.ptr // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64)> -// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr -// BAREPTR-NEXT: llvm.store %[[val]], %[[addr]] : !llvm.ptr +// BAREPTR-NEXT: llvm.store %[[val]], %[[ptr]] : !llvm.ptr store %arg1, %arg0[] : memref return } // ----- // CHECK-LABEL: func @static_store // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.ptr // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG3:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG4:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG5:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[ARG6:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[I:[a-zA-Z0-9]*]]: !llvm.i64 // CHECK-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 // BAREPTR-LABEL: func @static_store // BAREPTR-SAME: %[[A:.*]]: !llvm.ptr // BAREPTR-SAME: %[[I:[a-zA-Z0-9]*]]: !llvm.i64 // BAREPTR-SAME: %[[J:[a-zA-Z0-9]*]]: !llvm.i64 func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %val : f32) { // CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// CHECK-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// CHECK-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr // BAREPTR: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// BAREPTR-NEXT: %[[off:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: %[[st0:.*]] = llvm.mlir.constant(42 : index) : !llvm.i64 // BAREPTR-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : !llvm.i64 -// BAREPTR-NEXT: %[[off0:.*]] = llvm.add %[[off]], %[[offI]] : !llvm.i64 -// BAREPTR-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 -// BAREPTR-NEXT: %[[offJ:.*]] = llvm.mul %[[J]], %[[st1]] : !llvm.i64 -// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[off0]], %[[offJ]] : !llvm.i64 +// BAREPTR-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : !llvm.i64 // BAREPTR-NEXT: %[[addr:.*]] = llvm.getelementptr %[[ptr]][%[[off1]]] : (!llvm.ptr, !llvm.i64) -> !llvm.ptr // BAREPTR-NEXT: llvm.store %{{.*}}, %[[addr]] : !llvm.ptr store %val, %static[%i, %j] : memref<10x42xf32> return } // ----- // CHECK-LABEL: func @static_memref_dim // BAREPTR-LABEL: func @static_memref_dim(%{{.*}}: !llvm.ptr) { func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { // CHECK: llvm.mlir.constant(42 : index) : !llvm.i64 // BAREPTR: llvm.insertvalue %{{.*}}, %{{.*}}[4, 4] : !llvm.struct<(ptr, ptr, i64, array<5 x i64>, array<5 x i64>)> // BAREPTR: llvm.mlir.constant(42 : index) : !llvm.i64 %c0 = constant 0 : index %0 = dim %static, %c0 : memref<42x32x15x13x27xf32> // CHECK: llvm.mlir.constant(32 : index) : !llvm.i64 // BAREPTR: llvm.mlir.constant(32 : index) : !llvm.i64 %c1 = constant 1 : index %1 = dim %static, %c1 : memref<42x32x15x13x27xf32> // CHECK: llvm.mlir.constant(15 : index) : !llvm.i64 // BAREPTR: llvm.mlir.constant(15 : index) : !llvm.i64 %c2 = constant 2 : index %2 = dim %static, %c2 : memref<42x32x15x13x27xf32> // CHECK: llvm.mlir.constant(13 : index) : !llvm.i64 // BAREPTR: llvm.mlir.constant(13 : index) : !llvm.i64 %c3 = constant 3 : index %3 = dim %static, %c3 : memref<42x32x15x13x27xf32> // CHECK: llvm.mlir.constant(27 : index) : !llvm.i64 // BAREPTR: llvm.mlir.constant(27 : index) : !llvm.i64 %c4 = constant 4 : index %4 = dim %static, %c4 : memref<42x32x15x13x27xf32> return } // ----- // BAREPTR: llvm.func @foo(!llvm.ptr) -> !llvm.ptr func private @foo(memref<10xi8>) -> memref<20xi8> // BAREPTR-LABEL: func @check_memref_func_call // BAREPTR-SAME: %[[in:.*]]: !llvm.ptr) -> !llvm.ptr func @check_memref_func_call(%in : memref<10xi8>) -> memref<20xi8> { // BAREPTR: %[[inDesc:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 0] // BAREPTR-NEXT: %[[barePtr:.*]] = llvm.extractvalue %[[inDesc]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: %[[call:.*]] = llvm.call @foo(%[[barePtr]]) : (!llvm.ptr) -> !llvm.ptr // BAREPTR-NEXT: %[[desc0:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: %[[desc1:.*]] = llvm.insertvalue %[[call]], %[[desc0]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: %[[desc2:.*]] = llvm.insertvalue %[[call]], %[[desc1]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: %[[c0:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64 // BAREPTR-NEXT: %[[desc4:.*]] = llvm.insertvalue %[[c0]], %[[desc2]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: %[[c20:.*]] = llvm.mlir.constant(20 : index) : !llvm.i64 // BAREPTR-NEXT: %[[desc6:.*]] = llvm.insertvalue %[[c20]], %[[desc4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: %[[c1:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64 // BAREPTR-NEXT: %[[outDesc:.*]] = llvm.insertvalue %[[c1]], %[[desc6]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %res = call @foo(%in) : (memref<10xi8>) -> (memref<20xi8>) // BAREPTR-NEXT: %[[res:.*]] = llvm.extractvalue %[[outDesc]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // BAREPTR-NEXT: llvm.return %[[res]] : !llvm.ptr return %res : memref<20xi8> } // ----- // BAREPTR: llvm.func @goo(!llvm.float) -> !llvm.float func private @goo(f32) -> f32 // BAREPTR-LABEL: func @check_scalar_func_call // BAREPTR-SAME: %[[in:.*]]: !llvm.float) func @check_scalar_func_call(%in : f32) { // BAREPTR-NEXT: %[[call:.*]] = llvm.call @goo(%[[in]]) : (!llvm.float) -> !llvm.float %res = call @goo(%in) : (f32) -> (f32) return } // ----- // Unranked memrefs are currently not supported in the bare-ptr calling // convention. Check that the conversion to the LLVM-IR dialect doesn't happen // in the presence of unranked memrefs when using such a calling convention. // BAREPTR: func private @hoo(memref<*xi8>) -> memref<*xi8> func private @hoo(memref<*xi8>) -> memref<*xi8> // BAREPTR-LABEL: func @check_unranked_memref_func_call(%{{.*}}: memref<*xi8>) -> memref<*xi8> func @check_unranked_memref_func_call(%in: memref<*xi8>) -> memref<*xi8> { // BAREPTR-NEXT: call @hoo(%{{.*}}) : (memref<*xi8>) -> memref<*xi8> %res = call @hoo(%in) : (memref<*xi8>) -> memref<*xi8> // BAREPTR-NEXT: return %{{.*}} : memref<*xi8> return %res : memref<*xi8> }