diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h copy from mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h copy to mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -1,47 +1,31 @@ -//===- ConvertStandardToLLVM.h - Convert to the LLVM dialect ----*- C++ -*-===// +//===- Pattern.h - Pattern for conversion to the LLVM dialect ---*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// -// -// Provides a dialect conversion targeting the LLVM IR dialect. By default, it -// converts Standard ops and types and provides hooks for dialect-specific -// extensions to the conversion. -// -//===----------------------------------------------------------------------===// -#ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H -#define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H +#ifndef MLIR_CONVERSION_LLVMCOMMON_PATTERN_H +#define MLIR_CONVERSION_LLVMCOMMON_PATTERN_H #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Transforms/DialectConversion.h" -namespace llvm { -class IntegerType; -class LLVMContext; -class Module; -class Type; -} // namespace llvm - namespace mlir { -class BaseMemRefType; -class ComplexType; -class DataLayoutAnalysis; -class LLVMTypeConverter; -class UnrankedMemRefType; - namespace LLVM { -class LLVMDialect; -class LLVMPointerType; +namespace detail { +/// Replaces the given operation "op" with a new operation of type "targetOp" +/// and given operands. +LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); +} // namespace detail } // namespace LLVM -// ------------------ - /// Base class for operation conversions targeting the LLVM IR dialect. It /// provides the conversion patterns with access to the LLVMTypeConverter and /// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the @@ -73,6 +57,11 @@ /// Get the MLIR type wrapping the LLVM i8* type. Type getVoidPtrType() const; + /// Create a constant Op producing a value of `resultType` from an index-typed + /// integer attribute. + static Value createIndexAttrConstant(OpBuilder &builder, Location loc, + Type resultType, int64_t value); + /// Create an LLVM dialect operation defining the given index constant. Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const; @@ -177,71 +166,6 @@ using ConvertToLLVMPattern::matchAndRewrite; }; -/// Lowering for AllocOp and AllocaOp. -struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern { - using ConvertToLLVMPattern::createIndexConstant; - using ConvertToLLVMPattern::getIndexType; - using ConvertToLLVMPattern::getVoidPtrType; - - explicit AllocLikeOpLLVMLowering(StringRef opName, - LLVMTypeConverter &converter) - : ConvertToLLVMPattern(opName, &converter.getContext(), converter) {} - -protected: - // Returns 'input' aligned up to 'alignment'. Computes - // bumped = input + alignement - 1 - // aligned = bumped - bumped % alignment - static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, - Value input, Value alignment); - - /// Allocates the underlying buffer. Returns the allocated pointer and the - /// aligned pointer. - virtual std::tuple - allocateBuffer(ConversionPatternRewriter &rewriter, Location loc, - Value sizeBytes, Operation *op) const = 0; - -private: - static MemRefType getMemRefResultType(Operation *op) { - return op->getResult(0).getType().cast(); - } - - // An `alloc` is converted into a definition of a memref descriptor value and - // a call to `malloc` to allocate the underlying data buffer. The memref - // descriptor is of the LLVM structure type where: - // 1. the first element is a pointer to the allocated (typed) data buffer, - // 2. the second element is a pointer to the (typed) payload, aligned to the - // specified alignment, - // 3. the remaining elements serve to store all the sizes and strides of the - // memref using LLVM-converted `index` type. - // - // Alignment is performed by allocating `alignment` more bytes than - // requested and shifting the aligned pointer relative to the allocated - // memory. Note: `alignment - ` would actually be - // sufficient. If alignment is unspecified, the two pointers are equal. - - // An `alloca` is converted into a definition of a memref descriptor value and - // an llvm.alloca to allocate the underlying data buffer. - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override; -}; - -namespace LLVM { -namespace detail { -/// Replaces the given operation "op" with a new operation of type "targetOp" -/// and given operands. -LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, - ValueRange operands, - LLVMTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter); - -LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, - ValueRange operands, - LLVMTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter); -} // namespace detail -} // namespace LLVM - /// Generic implementation of one-to-one conversion from "SourceOp" to /// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent. /// Upholds a convention that multi-result operations get converted into an @@ -264,33 +188,6 @@ } }; -/// Basic lowering implementation to rewrite Ops with just one result to the -/// LLVM Dialect. This supports higher-dimensional vector types. -template -class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Super = VectorConvertToLLVMPattern; - - LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - static_assert( - std::is_base_of, SourceOp>::value, - "expected single result op"); - return LLVM::detail::vectorOneToOneRewrite( - op, TargetOp::getOperationName(), operands, *this->getTypeConverter(), - rewriter); - } -}; - -/// Derived class that automatically populates legalization information for -/// different LLVM ops. -class LLVMConversionTarget : public ConversionTarget { -public: - explicit LLVMConversionTarget(MLIRContext &ctx); -}; - } // namespace mlir -#endif // MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H +#endif // MLIR_CONVERSION_LLVMCOMMON_PATTERN_H diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -0,0 +1,85 @@ +//===- VectorPattern.h - Conversion pattern to the LLVM dialect -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H +#define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { + +namespace LLVM { +namespace detail { +// Helper struct to "unroll" operations on n-D vectors in terms of operations on +// 1-D LLVM vectors. +struct NDVectorTypeInfo { + // LLVM array struct which encodes n-D vectors. + Type llvmNDVectorTy; + // LLVM vector type which encodes the inner 1-D vector type. + Type llvm1DVectorTy; + // Multiplicity of llvmNDVectorTy to llvm1DVectorTy. + SmallVector arraySizes; +}; + +// For >1-D vector types, extracts the necessary information to iterate over all +// 1-D subvectors in the underlying llrepresentation of the n-D vector +// Iterates on the llvm array type until we hit a non-array type (which is +// asserted to be an llvm vector type). +NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, + LLVMTypeConverter &converter); + +// Express `linearIndex` in terms of coordinates of `basis`. +// Returns the empty vector when linearIndex is out of the range [0, P] where +// P is the product of all the basis coordinates. +// +// Prerequisites: +// Basis is an array of nonnegative integers (signed type inherited from +// vector shape type). +SmallVector getCoordinates(ArrayRef basis, + unsigned linearIndex); + +// Iterate of linear index, convert to coords space and insert splatted 1-D +// vector in each position. +void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, + function_ref fun); + +LogicalResult handleMultidimensionalVectors( + Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, + std::function createOperand, + ConversionPatternRewriter &rewriter); + +LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, + ValueRange operands, + LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter); +} // namespace detail +} // namespace LLVM + +/// Basic lowering implementation to rewrite Ops with just one result to the +/// LLVM Dialect. This supports higher-dimensional vector types. +template +class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using Super = VectorConvertToLLVMPattern; + + LogicalResult + matchAndRewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + static_assert( + std::is_base_of, SourceOp>::value, + "expected single result op"); + return LLVM::detail::vectorOneToOneRewrite( + op, TargetOp::getOperationName(), operands, *this->getTypeConverter(), + rewriter); + } +}; +} // namespace mlir + +#endif // MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -15,167 +15,37 @@ #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H -#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" -#include "mlir/Transforms/DialectConversion.h" - -namespace llvm { -class IntegerType; -class LLVMContext; -class Module; -class Type; -} // namespace llvm +#include "mlir/Conversion/LLVMCommon/Pattern.h" namespace mlir { -class BaseMemRefType; -class ComplexType; -class DataLayoutAnalysis; class LLVMTypeConverter; -class UnrankedMemRefType; - -namespace LLVM { -class LLVMDialect; -class LLVMPointerType; -} // namespace LLVM - -// ------------------ - -/// Base class for operation conversions targeting the LLVM IR dialect. It -/// provides the conversion patterns with access to the LLVMTypeConverter and -/// the LowerToLLVMOptions. The class captures the LLVMTypeConverter and the -/// LowerToLLVMOptions by reference meaning the references have to remain alive -/// during the entire pattern lifetime. -class ConvertToLLVMPattern : public ConversionPattern { -public: - ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, - LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1); - -protected: - /// Returns the LLVM dialect. - LLVM::LLVMDialect &getDialect() const; - - LLVMTypeConverter *getTypeConverter() const; - - /// Gets the MLIR type wrapping the LLVM integer type whose bit width is - /// defined by the used type converter. - Type getIndexType() const; - - /// Gets the MLIR type wrapping the LLVM integer type whose bit width - /// corresponds to that of a LLVM pointer type. - Type getIntPtrType(unsigned addressSpace = 0) const; - - /// Gets the MLIR type wrapping the LLVM void type. - Type getVoidType() const; - - /// Get the MLIR type wrapping the LLVM i8* type. - Type getVoidPtrType() const; - - /// Create an LLVM dialect operation defining the given index constant. - Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, - uint64_t value) const; - - // This is a strided getElementPtr variant that linearizes subscripts as: - // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. - Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, - ValueRange indices, - ConversionPatternRewriter &rewriter) const; - - /// Returns if the given memref has identity maps and the element type is - /// convertible to LLVM. - bool isConvertibleAndHasIdentityMaps(MemRefType type) const; - - /// Returns the type of a pointer to an element of the memref. - Type getElementPtrType(MemRefType type) const; - - /// Computes sizes, strides and buffer size in bytes of `memRefType` with - /// identity layout. Emits constant ops for the static sizes of `memRefType`, - /// and uses `dynamicSizes` for the others. Emits instructions to compute - /// strides and buffer size from these sizes. - /// - /// For example, memref<4x?xf32> emits: - /// `sizes[0]` = llvm.mlir.constant(4 : index) : i64 - /// `sizes[1]` = `dynamicSizes[0]` - /// `strides[1]` = llvm.mlir.constant(1 : index) : i64 - /// `strides[0]` = `sizes[0]` - /// %size = llvm.mul `sizes[0]`, `sizes[1]` : i64 - /// %nullptr = llvm.mlir.null : !llvm.ptr - /// %gep = llvm.getelementptr %nullptr[%size] - /// : (!llvm.ptr, i64) -> !llvm.ptr - /// `sizeBytes` = llvm.ptrtoint %gep : !llvm.ptr to i64 - void getMemRefDescriptorSizes(Location loc, MemRefType memRefType, - ValueRange dynamicSizes, - ConversionPatternRewriter &rewriter, - SmallVectorImpl &sizes, - SmallVectorImpl &strides, - Value &sizeBytes) const; - - /// Computes the size of type in bytes. - Value getSizeInBytes(Location loc, Type type, - ConversionPatternRewriter &rewriter) const; - - /// Computes total number of elements for the given shape. - Value getNumElements(Location loc, ArrayRef shape, - ConversionPatternRewriter &rewriter) const; - - /// Creates and populates a canonical memref descriptor struct. - MemRefDescriptor - createMemRefDescriptor(Location loc, MemRefType memRefType, - Value allocatedPtr, Value alignedPtr, - ArrayRef sizes, ArrayRef strides, - ConversionPatternRewriter &rewriter) const; -}; - -/// Utility class for operation conversions targeting the LLVM dialect that -/// match exactly one source operation. -template -class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { -public: - explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : ConvertToLLVMPattern(SourceOp::getOperationName(), - &typeConverter.getContext(), typeConverter, - benefit) {} - - /// Wrappers around the RewritePattern methods that pass the derived op type. - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), operands, rewriter); - } - LogicalResult match(Operation *op) const final { - return match(cast(op)); - } - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast(op), operands, rewriter); - } - - /// Rewrite and Match methods that operate on the SourceOp type. These must be - /// overridden by the derived pattern class. - virtual void rewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - llvm_unreachable("must override rewrite or matchAndRewrite"); - } - virtual LogicalResult match(SourceOp op) const { - llvm_unreachable("must override match or matchAndRewrite"); - } - virtual LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const { - if (succeeded(match(op))) { - rewrite(op, operands, rewriter); - return success(); - } - return failure(); - } - -private: - using ConvertToLLVMPattern::match; - using ConvertToLLVMPattern::matchAndRewrite; -}; +class RewritePatternSet; + +/// Collect a set of patterns to convert memory-related operations from the +/// Standard dialect to the LLVM dialect, excluding non-memory-related +/// operations and FuncOp. +void populateStdToLLVMMemoryConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Collect a set of patterns to convert from the Standard dialect to the LLVM +/// dialect, excluding the memory-related operations. +void populateStdToLLVMNonMemoryConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If +/// `emitCWrappers` is set, the pattern will also produce functions +/// that pass memref descriptors by pointer-to-structure in addition to the +/// default unpacked form. +void populateStdToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Collect the patterns to convert from the Standard dialect to LLVM. The +/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions +/// by reference meaning the references have to remain alive during the entire +/// pattern lifetime. +void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); /// Lowering for AllocOp and AllocaOp. struct AllocLikeOpLLVMLowering : public ConvertToLLVMPattern { @@ -226,64 +96,6 @@ ConversionPatternRewriter &rewriter) const override; }; -namespace LLVM { -namespace detail { -/// Replaces the given operation "op" with a new operation of type "targetOp" -/// and given operands. -LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, - ValueRange operands, - LLVMTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter); - -LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, - ValueRange operands, - LLVMTypeConverter &typeConverter, - ConversionPatternRewriter &rewriter); -} // namespace detail -} // namespace LLVM - -/// Generic implementation of one-to-one conversion from "SourceOp" to -/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent. -/// Upholds a convention that multi-result operations get converted into an -/// operation returning the LLVM IR structure type, in which case individual -/// values must be extracted from using LLVM::ExtractValueOp before being used. -template -class OneToOneConvertToLLVMPattern : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Super = OneToOneConvertToLLVMPattern; - - /// Converts the type of the result to an LLVM type, pass operands as is, - /// preserve attributes. - LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - return LLVM::detail::oneToOneRewrite(op, TargetOp::getOperationName(), - operands, *this->getTypeConverter(), - rewriter); - } -}; - -/// Basic lowering implementation to rewrite Ops with just one result to the -/// LLVM Dialect. This supports higher-dimensional vector types. -template -class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { -public: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - using Super = VectorConvertToLLVMPattern; - - LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - static_assert( - std::is_base_of, SourceOp>::value, - "expected single result op"); - return LLVM::detail::vectorOneToOneRewrite( - op, TargetOp::getOperationName(), operands, *this->getTypeConverter(), - rewriter); - } -}; - /// Derived class that automatically populates legalization information for /// different LLVM ops. class LLVMConversionTarget : public ConversionTarget { diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -12,38 +12,10 @@ #include namespace mlir { -class LLVMTypeConverter; class LowerToLLVMOptions; class ModuleOp; template class OperationPass; -class RewritePatternSet; -using OwningRewritePatternList = RewritePatternSet; - -/// Collect a set of patterns to convert memory-related operations from the -/// Standard dialect to the LLVM dialect, excluding non-memory-related -/// operations and FuncOp. -void populateStdToLLVMMemoryConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); - -/// Collect a set of patterns to convert from the Standard dialect to the LLVM -/// dialect, excluding the memory-related operations. -void populateStdToLLVMNonMemoryConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); - -/// Collect the default pattern to convert a FuncOp to the LLVM dialect. If -/// `emitCWrappers` is set, the pattern will also produce functions -/// that pass memref descriptors by pointer-to-structure in addition to the -/// default unpacked form. -void populateStdToLLVMFuncOpConversionPattern(LLVMTypeConverter &converter, - RewritePatternSet &patterns); - -/// Collect the patterns to convert from the Standard dialect to LLVM. The -/// conversion patterns capture the LLVMTypeConverter and the LowerToLLVMOptions -/// by reference meaning the references have to remain alive during the entire -/// pattern lifetime. -void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// stdlib malloc/free is used by default for allocating memrefs allocated with diff --git a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt --- a/mlir/lib/Conversion/GPUCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUCommon/CMakeLists.txt @@ -32,6 +32,7 @@ MLIRAsyncToLLVM MLIRGPUTransforms MLIRIR + MLIRLLVMCommonConversion MLIRLLVMIR MLIRPass MLIRSupport diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -17,7 +17,9 @@ #include "../PassDetail.h" #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/GPU/GPUDialect.h" diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt --- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt +++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt @@ -1,8 +1,10 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion LoweringOptions.cpp MemRefBuilder.cpp + Pattern.cpp StructBuilder.cpp TypeConverter.cpp + VectorPattern.cpp LINK_COMPONENTS Core diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -0,0 +1,269 @@ +//===- Pattern.cpp - Conversion pattern to the LLVM dialect ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/AffineMap.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPattern +//===----------------------------------------------------------------------===// + +ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, + MLIRContext *context, + LLVMTypeConverter &typeConverter, + PatternBenefit benefit) + : ConversionPattern(typeConverter, rootOpName, benefit, context) {} + +LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { + return static_cast( + ConversionPattern::getTypeConverter()); +} + +LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { + return *getTypeConverter()->getDialect(); +} + +Type ConvertToLLVMPattern::getIndexType() const { + return getTypeConverter()->getIndexType(); +} + +Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { + return IntegerType::get(&getTypeConverter()->getContext(), + getTypeConverter()->getPointerBitwidth(addressSpace)); +} + +Type ConvertToLLVMPattern::getVoidType() const { + return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); +} + +Type ConvertToLLVMPattern::getVoidPtrType() const { + return LLVM::LLVMPointerType::get( + IntegerType::get(&getTypeConverter()->getContext(), 8)); +} + +Value ConvertToLLVMPattern::createIndexAttrConstant(OpBuilder &builder, + Location loc, + Type resultType, + int64_t value) { + return builder.create( + loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); +} + +Value ConvertToLLVMPattern::createIndexConstant( + ConversionPatternRewriter &builder, Location loc, uint64_t value) const { + return createIndexAttrConstant(builder, loc, getIndexType(), value); +} + +Value ConvertToLLVMPattern::getStridedElementPtr( + Location loc, MemRefType type, Value memRefDesc, ValueRange indices, + ConversionPatternRewriter &rewriter) const { + + int64_t offset; + SmallVector strides; + auto successStrides = getStridesAndOffset(type, strides, offset); + assert(succeeded(successStrides) && "unexpected non-strided memref"); + (void)successStrides; + + MemRefDescriptor memRefDescriptor(memRefDesc); + Value base = memRefDescriptor.alignedPtr(rewriter, loc); + + Value index; + if (offset != 0) // Skip if offset is zero. + index = MemRefType::isDynamicStrideOrOffset(offset) + ? memRefDescriptor.offset(rewriter, loc) + : createIndexConstant(rewriter, loc, offset); + + for (int i = 0, e = indices.size(); i < e; ++i) { + Value increment = indices[i]; + if (strides[i] != 1) { // Skip if stride is 1. + Value stride = MemRefType::isDynamicStrideOrOffset(strides[i]) + ? memRefDescriptor.stride(rewriter, loc, i) + : createIndexConstant(rewriter, loc, strides[i]); + increment = rewriter.create(loc, increment, stride); + } + index = + index ? rewriter.create(loc, index, increment) : increment; + } + + Type elementPtrType = memRefDescriptor.getElementPtrType(); + return index ? rewriter.create(loc, elementPtrType, base, index) + : base; +} + +// Check if the MemRefType `type` is supported by the lowering. We currently +// only support memrefs with identity maps. +bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( + MemRefType type) const { + if (!typeConverter->convertType(type.getElementType())) + return false; + return type.getAffineMaps().empty() || + llvm::all_of(type.getAffineMaps(), + [](AffineMap map) { return map.isIdentity(); }); +} + +Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { + auto elementType = type.getElementType(); + auto structElementType = typeConverter->convertType(elementType); + return LLVM::LLVMPointerType::get(structElementType, + type.getMemorySpaceAsInt()); +} + +void ConvertToLLVMPattern::getMemRefDescriptorSizes( + Location loc, MemRefType memRefType, ValueRange dynamicSizes, + ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, + SmallVectorImpl &strides, Value &sizeBytes) const { + assert(isConvertibleAndHasIdentityMaps(memRefType) && + "layout maps must have been normalized away"); + assert(count(memRefType.getShape(), ShapedType::kDynamicSize) == + static_cast(dynamicSizes.size()) && + "dynamicSizes size doesn't match dynamic sizes count in memref shape"); + + sizes.reserve(memRefType.getRank()); + unsigned dynamicIndex = 0; + for (int64_t size : memRefType.getShape()) { + sizes.push_back(size == ShapedType::kDynamicSize + ? dynamicSizes[dynamicIndex++] + : createIndexConstant(rewriter, loc, size)); + } + + // Strides: iterate sizes in reverse order and multiply. + int64_t stride = 1; + Value runningStride = createIndexConstant(rewriter, loc, 1); + strides.resize(memRefType.getRank()); + for (auto i = memRefType.getRank(); i-- > 0;) { + strides[i] = runningStride; + + int64_t size = memRefType.getShape()[i]; + if (size == 0) + continue; + bool useSizeAsStride = stride == 1; + if (size == ShapedType::kDynamicSize) + stride = ShapedType::kDynamicSize; + if (stride != ShapedType::kDynamicSize) + stride *= size; + + if (useSizeAsStride) + runningStride = sizes[i]; + else if (stride == ShapedType::kDynamicSize) + runningStride = + rewriter.create(loc, runningStride, sizes[i]); + else + runningStride = createIndexConstant(rewriter, loc, stride); + } + + // Buffer size in bytes. + Type elementPtrType = getElementPtrType(memRefType); + Value nullPtr = rewriter.create(loc, elementPtrType); + Value gepPtr = rewriter.create( + loc, elementPtrType, ArrayRef{nullPtr, runningStride}); + sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); +} + +Value ConvertToLLVMPattern::getSizeInBytes( + Location loc, Type type, ConversionPatternRewriter &rewriter) const { + // Compute the size of an individual element. This emits the MLIR equivalent + // of the following sizeof(...) implementation in LLVM IR: + // %0 = getelementptr %elementType* null, %indexType 1 + // %1 = ptrtoint %elementType* %0 to %indexType + // which is a common pattern of getting the size of a type in bytes. + auto convertedPtrType = + LLVM::LLVMPointerType::get(typeConverter->convertType(type)); + auto nullPtr = rewriter.create(loc, convertedPtrType); + auto gep = rewriter.create( + loc, convertedPtrType, + ArrayRef{nullPtr, createIndexConstant(rewriter, loc, 1)}); + return rewriter.create(loc, getIndexType(), gep); +} + +Value ConvertToLLVMPattern::getNumElements( + Location loc, ArrayRef shape, + ConversionPatternRewriter &rewriter) const { + // Compute the total number of memref elements. + Value numElements = + shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); + for (unsigned i = 1, e = shape.size(); i < e; ++i) + numElements = rewriter.create(loc, numElements, shape[i]); + return numElements; +} + +/// Creates and populates the memref descriptor struct given all its fields. +MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( + Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, + ArrayRef sizes, ArrayRef strides, + ConversionPatternRewriter &rewriter) const { + auto structType = typeConverter->convertType(memRefType); + auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); + + // Field 1: Allocated pointer, used for malloc/free. + memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); + + // Field 2: Actual aligned pointer to payload. + memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); + + // Field 3: Offset in aligned pointer. + memRefDescriptor.setOffset(rewriter, loc, + createIndexConstant(rewriter, loc, 0)); + + // Fields 4: Sizes. + for (auto en : llvm::enumerate(sizes)) + memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); + + // Field 5: Strides. + for (auto en : llvm::enumerate(strides)) + memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); + + return memRefDescriptor; +} + +//===----------------------------------------------------------------------===// +// Detail methods +//===----------------------------------------------------------------------===// + +/// Replaces the given operation "op" with a new operation of type "targetOp" +/// and given operands. +LogicalResult LLVM::detail::oneToOneRewrite( + Operation *op, StringRef targetOp, ValueRange operands, + LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + unsigned numResults = op->getNumResults(); + + Type packedType; + if (numResults != 0) { + packedType = typeConverter.packFunctionResults(op->getResultTypes()); + if (!packedType) + return failure(); + } + + // Create the operation through state since we don't know its C++ type. + OperationState state(op->getLoc(), targetOp); + state.addTypes(packedType); + state.addOperands(operands); + state.addAttributes(op->getAttrs()); + Operation *newOp = rewriter.createOperation(state); + + // If the operation produced 0 or 1 result, return them immediately. + if (numResults == 0) + return rewriter.eraseOp(op), success(); + if (numResults == 1) + return rewriter.replaceOp(op, newOp->getResult(0)), success(); + + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + SmallVector results; + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + auto type = typeConverter.convertType(op->getResult(i).getType()); + results.push_back(rewriter.create( + op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); + } + rewriter.replaceOp(op, results); + return success(); +} diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp @@ -0,0 +1,142 @@ +//===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" + +using namespace mlir; + +// For >1-D vector types, extracts the necessary information to iterate over all +// 1-D subvectors in the underlying llrepresentation of the n-D vector +// Iterates on the llvm array type until we hit a non-array type (which is +// asserted to be an llvm vector type). +LLVM::detail::NDVectorTypeInfo +LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType, + LLVMTypeConverter &converter) { + assert(vectorType.getRank() > 1 && "expected >1D vector type"); + NDVectorTypeInfo info; + info.llvmNDVectorTy = converter.convertType(vectorType); + if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) { + info.llvmNDVectorTy = nullptr; + return info; + } + info.arraySizes.reserve(vectorType.getRank() - 1); + auto llvmTy = info.llvmNDVectorTy; + while (llvmTy.isa()) { + info.arraySizes.push_back( + llvmTy.cast().getNumElements()); + llvmTy = llvmTy.cast().getElementType(); + } + if (!LLVM::isCompatibleVectorType(llvmTy)) + return info; + info.llvm1DVectorTy = llvmTy; + return info; +} + +// Express `linearIndex` in terms of coordinates of `basis`. +// Returns the empty vector when linearIndex is out of the range [0, P] where +// P is the product of all the basis coordinates. +// +// Prerequisites: +// Basis is an array of nonnegative integers (signed type inherited from +// vector shape type). +SmallVector LLVM::detail::getCoordinates(ArrayRef basis, + unsigned linearIndex) { + SmallVector res; + res.reserve(basis.size()); + for (unsigned basisElement : llvm::reverse(basis)) { + res.push_back(linearIndex % basisElement); + linearIndex = linearIndex / basisElement; + } + if (linearIndex > 0) + return {}; + std::reverse(res.begin(), res.end()); + return res; +} + +// Iterate of linear index, convert to coords space and insert splatted 1-D +// vector in each position. +void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info, + OpBuilder &builder, + function_ref fun) { + unsigned ub = 1; + for (auto s : info.arraySizes) + ub *= s; + for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { + auto coords = getCoordinates(info.arraySizes, linearIndex); + // Linear index is out of bounds, we are done. + if (coords.empty()) + break; + assert(coords.size() == info.arraySizes.size()); + auto position = builder.getI64ArrayAttr(coords); + fun(position); + } +} + +LogicalResult LLVM::detail::handleMultidimensionalVectors( + Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, + std::function createOperand, + ConversionPatternRewriter &rewriter) { + auto resultNDVectorType = op->getResult(0).getType().cast(); + + SmallVector operand1DVectorTypes; + for (Value operand : op->getOperands()) { + auto operandNDVectorType = operand.getType().cast(); + auto operandTypeInfo = + extractNDVectorTypeInfo(operandNDVectorType, typeConverter); + operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy); + } + auto resultTypeInfo = + extractNDVectorTypeInfo(resultNDVectorType, typeConverter); + auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; + auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; + auto loc = op->getLoc(); + Value desc = rewriter.create(loc, resultNDVectoryTy); + nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) { + // For this unrolled `position` corresponding to the `linearIndex`^th + // element, extract operand vectors + SmallVector extractedOperands; + for (auto operand : llvm::enumerate(operands)) { + extractedOperands.push_back(rewriter.create( + loc, operand1DVectorTypes[operand.index()], operand.value(), + position)); + } + Value newVal = createOperand(result1DVectorTy, extractedOperands); + desc = rewriter.create(loc, resultNDVectoryTy, desc, + newVal, position); + }); + rewriter.replaceOp(op, desc); + return success(); +} + +LogicalResult LLVM::detail::vectorOneToOneRewrite( + Operation *op, StringRef targetOp, ValueRange operands, + LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { + assert(!operands.empty()); + + // Cannot convert ops if their operands are not of LLVM type. + if (!llvm::all_of(operands.getTypes(), + [](Type t) { return isCompatibleType(t); })) + return failure(); + + auto llvmNDVectorTy = operands[0].getType(); + if (!llvmNDVectorTy.isa()) + return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); + + auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy, + ValueRange operands) { + OperationState state(op->getLoc(), targetOp); + state.addTypes(llvm1DVectorTy); + state.addOperands(operands); + state.addAttributes(op->getAttrs()); + return rewriter.createOperation(state)->getResult(0); + }; + + return handleMultidimensionalVectors(op, operands, typeConverter, callback, + rewriter); +} diff --git a/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt b/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/OpenMPToLLVM/CMakeLists.txt @@ -13,6 +13,7 @@ LINK_LIBS PUBLIC MLIRIR + MLIRLLVMCommonConversion MLIRLLVMIR MLIROpenMP MLIRStandardToLLVM diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -9,7 +9,9 @@ #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" #include "../PassDetail.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -13,6 +13,8 @@ #include "../PassDetail.h" #include "mlir/Analysis/DataLayoutAnalysis.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/VectorPattern.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" @@ -46,214 +48,6 @@ #define PASS_NAME "convert-std-to-llvm" -ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, - MLIRContext *context, - LLVMTypeConverter &typeConverter, - PatternBenefit benefit) - : ConversionPattern(typeConverter, rootOpName, benefit, context) {} - - -LLVMTypeConverter *ConvertToLLVMPattern::getTypeConverter() const { - return static_cast( - ConversionPattern::getTypeConverter()); -} - -LLVM::LLVMDialect &ConvertToLLVMPattern::getDialect() const { - return *getTypeConverter()->getDialect(); -} - -Type ConvertToLLVMPattern::getIndexType() const { - return getTypeConverter()->getIndexType(); -} - -Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { - return IntegerType::get(&getTypeConverter()->getContext(), - getTypeConverter()->getPointerBitwidth(addressSpace)); -} - -Type ConvertToLLVMPattern::getVoidType() const { - return LLVM::LLVMVoidType::get(&getTypeConverter()->getContext()); -} - -Type ConvertToLLVMPattern::getVoidPtrType() const { - return LLVM::LLVMPointerType::get( - IntegerType::get(&getTypeConverter()->getContext(), 8)); -} - -// Creates a constant Op producing a value of `resultType` from an index-typed -// integer attribute. -static Value createIndexAttrConstant(OpBuilder &builder, Location loc, - Type resultType, int64_t value) { - return builder.create( - loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); -} - -Value ConvertToLLVMPattern::createIndexConstant( - ConversionPatternRewriter &builder, Location loc, uint64_t value) const { - return createIndexAttrConstant(builder, loc, getIndexType(), value); -} - -Value ConvertToLLVMPattern::getStridedElementPtr( - Location loc, MemRefType type, Value memRefDesc, ValueRange indices, - ConversionPatternRewriter &rewriter) const { - - int64_t offset; - SmallVector strides; - auto successStrides = getStridesAndOffset(type, strides, offset); - assert(succeeded(successStrides) && "unexpected non-strided memref"); - (void)successStrides; - - MemRefDescriptor memRefDescriptor(memRefDesc); - Value base = memRefDescriptor.alignedPtr(rewriter, loc); - - Value index; - if (offset != 0) // Skip if offset is zero. - index = MemRefType::isDynamicStrideOrOffset(offset) - ? memRefDescriptor.offset(rewriter, loc) - : createIndexConstant(rewriter, loc, offset); - - for (int i = 0, e = indices.size(); i < e; ++i) { - Value increment = indices[i]; - if (strides[i] != 1) { // Skip if stride is 1. - Value stride = MemRefType::isDynamicStrideOrOffset(strides[i]) - ? memRefDescriptor.stride(rewriter, loc, i) - : createIndexConstant(rewriter, loc, strides[i]); - increment = rewriter.create(loc, increment, stride); - } - index = - index ? rewriter.create(loc, index, increment) : increment; - } - - Type elementPtrType = memRefDescriptor.getElementPtrType(); - return index ? rewriter.create(loc, elementPtrType, base, index) - : base; -} - -// Check if the MemRefType `type` is supported by the lowering. We currently -// only support memrefs with identity maps. -bool ConvertToLLVMPattern::isConvertibleAndHasIdentityMaps( - MemRefType type) const { - if (!typeConverter->convertType(type.getElementType())) - return false; - return type.getAffineMaps().empty() || - llvm::all_of(type.getAffineMaps(), - [](AffineMap map) { return map.isIdentity(); }); -} - -Type ConvertToLLVMPattern::getElementPtrType(MemRefType type) const { - auto elementType = type.getElementType(); - auto structElementType = typeConverter->convertType(elementType); - return LLVM::LLVMPointerType::get(structElementType, - type.getMemorySpaceAsInt()); -} - -void ConvertToLLVMPattern::getMemRefDescriptorSizes( - Location loc, MemRefType memRefType, ValueRange dynamicSizes, - ConversionPatternRewriter &rewriter, SmallVectorImpl &sizes, - SmallVectorImpl &strides, Value &sizeBytes) const { - assert(isConvertibleAndHasIdentityMaps(memRefType) && - "layout maps must have been normalized away"); - assert(count(memRefType.getShape(), ShapedType::kDynamicSize) == - static_cast(dynamicSizes.size()) && - "dynamicSizes size doesn't match dynamic sizes count in memref shape"); - - sizes.reserve(memRefType.getRank()); - unsigned dynamicIndex = 0; - for (int64_t size : memRefType.getShape()) { - sizes.push_back(size == ShapedType::kDynamicSize - ? dynamicSizes[dynamicIndex++] - : createIndexConstant(rewriter, loc, size)); - } - - // Strides: iterate sizes in reverse order and multiply. - int64_t stride = 1; - Value runningStride = createIndexConstant(rewriter, loc, 1); - strides.resize(memRefType.getRank()); - for (auto i = memRefType.getRank(); i-- > 0;) { - strides[i] = runningStride; - - int64_t size = memRefType.getShape()[i]; - if (size == 0) - continue; - bool useSizeAsStride = stride == 1; - if (size == ShapedType::kDynamicSize) - stride = ShapedType::kDynamicSize; - if (stride != ShapedType::kDynamicSize) - stride *= size; - - if (useSizeAsStride) - runningStride = sizes[i]; - else if (stride == ShapedType::kDynamicSize) - runningStride = - rewriter.create(loc, runningStride, sizes[i]); - else - runningStride = createIndexConstant(rewriter, loc, stride); - } - - // Buffer size in bytes. - Type elementPtrType = getElementPtrType(memRefType); - Value nullPtr = rewriter.create(loc, elementPtrType); - Value gepPtr = rewriter.create( - loc, elementPtrType, ArrayRef{nullPtr, runningStride}); - sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); -} - -Value ConvertToLLVMPattern::getSizeInBytes( - Location loc, Type type, ConversionPatternRewriter &rewriter) const { - // Compute the size of an individual element. This emits the MLIR equivalent - // of the following sizeof(...) implementation in LLVM IR: - // %0 = getelementptr %elementType* null, %indexType 1 - // %1 = ptrtoint %elementType* %0 to %indexType - // which is a common pattern of getting the size of a type in bytes. - auto convertedPtrType = - LLVM::LLVMPointerType::get(typeConverter->convertType(type)); - auto nullPtr = rewriter.create(loc, convertedPtrType); - auto gep = rewriter.create( - loc, convertedPtrType, - ArrayRef{nullPtr, createIndexConstant(rewriter, loc, 1)}); - return rewriter.create(loc, getIndexType(), gep); -} - -Value ConvertToLLVMPattern::getNumElements( - Location loc, ArrayRef shape, - ConversionPatternRewriter &rewriter) const { - // Compute the total number of memref elements. - Value numElements = - shape.empty() ? createIndexConstant(rewriter, loc, 1) : shape.front(); - for (unsigned i = 1, e = shape.size(); i < e; ++i) - numElements = rewriter.create(loc, numElements, shape[i]); - return numElements; -} - -/// Creates and populates the memref descriptor struct given all its fields. -MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( - Location loc, MemRefType memRefType, Value allocatedPtr, Value alignedPtr, - ArrayRef sizes, ArrayRef strides, - ConversionPatternRewriter &rewriter) const { - auto structType = typeConverter->convertType(memRefType); - auto memRefDescriptor = MemRefDescriptor::undef(rewriter, loc, structType); - - // Field 1: Allocated pointer, used for malloc/free. - memRefDescriptor.setAllocatedPtr(rewriter, loc, allocatedPtr); - - // Field 2: Actual aligned pointer to payload. - memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); - - // Field 3: Offset in aligned pointer. - memRefDescriptor.setOffset(rewriter, loc, - createIndexConstant(rewriter, loc, 0)); - - // Fields 4: Sizes. - for (auto en : llvm::enumerate(sizes)) - memRefDescriptor.setSize(rewriter, loc, en.index(), en.value()); - - // Field 5: Strides. - for (auto en : llvm::enumerate(strides)) - memRefDescriptor.setStride(rewriter, loc, en.index(), en.value()); - - return memRefDescriptor; -} - /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. @@ -572,190 +366,6 @@ } }; -//////////////// Support for Lowering operations on n-D vectors //////////////// -// Helper struct to "unroll" operations on n-D vectors in terms of operations on -// 1-D LLVM vectors. -struct NDVectorTypeInfo { - // LLVM array struct which encodes n-D vectors. - Type llvmNDVectorTy; - // LLVM vector type which encodes the inner 1-D vector type. - Type llvm1DVectorTy; - // Multiplicity of llvmNDVectorTy to llvm1DVectorTy. - SmallVector arraySizes; -}; -} // namespace - -// For >1-D vector types, extracts the necessary information to iterate over all -// 1-D subvectors in the underlying llrepresentation of the n-D vector -// Iterates on the llvm array type until we hit a non-array type (which is -// asserted to be an llvm vector type). -static NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, - LLVMTypeConverter &converter) { - assert(vectorType.getRank() > 1 && "expected >1D vector type"); - NDVectorTypeInfo info; - info.llvmNDVectorTy = converter.convertType(vectorType); - if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) { - info.llvmNDVectorTy = nullptr; - return info; - } - info.arraySizes.reserve(vectorType.getRank() - 1); - auto llvmTy = info.llvmNDVectorTy; - while (llvmTy.isa()) { - info.arraySizes.push_back( - llvmTy.cast().getNumElements()); - llvmTy = llvmTy.cast().getElementType(); - } - if (!LLVM::isCompatibleVectorType(llvmTy)) - return info; - info.llvm1DVectorTy = llvmTy; - return info; -} - -// Express `linearIndex` in terms of coordinates of `basis`. -// Returns the empty vector when linearIndex is out of the range [0, P] where -// P is the product of all the basis coordinates. -// -// Prerequisites: -// Basis is an array of nonnegative integers (signed type inherited from -// vector shape type). -static SmallVector getCoordinates(ArrayRef basis, - unsigned linearIndex) { - SmallVector res; - res.reserve(basis.size()); - for (unsigned basisElement : llvm::reverse(basis)) { - res.push_back(linearIndex % basisElement); - linearIndex = linearIndex / basisElement; - } - if (linearIndex > 0) - return {}; - std::reverse(res.begin(), res.end()); - return res; -} - -// Iterate of linear index, convert to coords space and insert splatted 1-D -// vector in each position. -template -void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, - Lambda fun) { - unsigned ub = 1; - for (auto s : info.arraySizes) - ub *= s; - for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { - auto coords = getCoordinates(info.arraySizes, linearIndex); - // Linear index is out of bounds, we are done. - if (coords.empty()) - break; - assert(coords.size() == info.arraySizes.size()); - auto position = builder.getI64ArrayAttr(coords); - fun(position); - } -} -////////////// End Support for Lowering operations on n-D vectors ////////////// - -/// Replaces the given operation "op" with a new operation of type "targetOp" -/// and given operands. -LogicalResult LLVM::detail::oneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { - unsigned numResults = op->getNumResults(); - - Type packedType; - if (numResults != 0) { - packedType = typeConverter.packFunctionResults(op->getResultTypes()); - if (!packedType) - return failure(); - } - - // Create the operation through state since we don't know its C++ type. - OperationState state(op->getLoc(), targetOp); - state.addTypes(packedType); - state.addOperands(operands); - state.addAttributes(op->getAttrs()); - Operation *newOp = rewriter.createOperation(state); - - // If the operation produced 0 or 1 result, return them immediately. - if (numResults == 0) - return rewriter.eraseOp(op), success(); - if (numResults == 1) - return rewriter.replaceOp(op, newOp->getResult(0)), success(); - - // Otherwise, it had been converted to an operation producing a structure. - // Extract individual results from the structure and return them as list. - SmallVector results; - results.reserve(numResults); - for (unsigned i = 0; i < numResults; ++i) { - auto type = typeConverter.convertType(op->getResult(i).getType()); - results.push_back(rewriter.create( - op->getLoc(), type, newOp->getResult(0), rewriter.getI64ArrayAttr(i))); - } - rewriter.replaceOp(op, results); - return success(); -} - -static LogicalResult handleMultidimensionalVectors( - Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, - std::function createOperand, - ConversionPatternRewriter &rewriter) { - auto resultNDVectorType = op->getResult(0).getType().cast(); - - SmallVector operand1DVectorTypes; - for (Value operand : op->getOperands()) { - auto operandNDVectorType = operand.getType().cast(); - auto operandTypeInfo = - extractNDVectorTypeInfo(operandNDVectorType, typeConverter); - operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy); - } - auto resultTypeInfo = - extractNDVectorTypeInfo(resultNDVectorType, typeConverter); - auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; - auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; - auto loc = op->getLoc(); - Value desc = rewriter.create(loc, resultNDVectoryTy); - nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) { - // For this unrolled `position` corresponding to the `linearIndex`^th - // element, extract operand vectors - SmallVector extractedOperands; - for (auto operand : llvm::enumerate(operands)) { - extractedOperands.push_back(rewriter.create( - loc, operand1DVectorTypes[operand.index()], operand.value(), - position)); - } - Value newVal = createOperand(result1DVectorTy, extractedOperands); - desc = rewriter.create(loc, resultNDVectoryTy, desc, - newVal, position); - }); - rewriter.replaceOp(op, desc); - return success(); -} - -LogicalResult LLVM::detail::vectorOneToOneRewrite( - Operation *op, StringRef targetOp, ValueRange operands, - LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { - assert(!operands.empty()); - - // Cannot convert ops if their operands are not of LLVM type. - if (!llvm::all_of(operands.getTypes(), - [](Type t) { return isCompatibleType(t); })) - return failure(); - - auto llvmNDVectorTy = operands[0].getType(); - if (!llvmNDVectorTy.isa()) - return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); - - auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy, - ValueRange operands) { - OperationState state(op->getLoc(), targetOp); - state.addTypes(llvm1DVectorTy); - state.addOperands(operands); - state.addAttributes(op->getAttrs()); - return rewriter.createOperation(state)->getResult(0); - }; - - return handleMultidimensionalVectors(op, operands, typeConverter, callback, - rewriter); -} - -namespace { // Straightforward lowerings. using AbsFOpLowering = VectorConvertToLLVMPattern; using AddFOpLowering = VectorConvertToLLVMPattern; @@ -1427,7 +1037,7 @@ if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); - return handleMultidimensionalVectors( + return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), operands, *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( @@ -1482,7 +1092,7 @@ if (!vectorType) return rewriter.notifyMatchFailure(op, "expected vector result type"); - return handleMultidimensionalVectors( + return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), operands, *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( @@ -1536,7 +1146,7 @@ if (!vectorType) return failure(); - return handleMultidimensionalVectors( + return LLVM::detail::handleMultidimensionalVectors( op.getOperation(), operands, *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { auto splatAttr = SplatElementsAttr::get( @@ -2244,7 +1854,7 @@ if (!vectorType) return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type"); - return handleMultidimensionalVectors( + return LLVM::detail::handleMultidimensionalVectors( cmpiOp.getOperation(), operands, *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { CmpIOpAdaptor transformed(operands); @@ -2282,7 +1892,7 @@ if (!vectorType) return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type"); - return handleMultidimensionalVectors( + return LLVM::detail::handleMultidimensionalVectors( cmpfOp.getOperation(), operands, *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { CmpFOpAdaptor transformed(operands); @@ -2441,7 +2051,7 @@ // First insert it into an undef vector so we can shuffle it. auto loc = splatOp.getLoc(); auto vectorTypeInfo = - extractNDVectorTypeInfo(resultType, *getTypeConverter()); + LLVM::detail::extractNDVectorTypeInfo(resultType, *getTypeConverter()); auto llvmNDVectorTy = vectorTypeInfo.llvmNDVectorTy; auto llvm1DVectorTy = vectorTypeInfo.llvm1DVectorTy; if (!llvmNDVectorTy || !llvm1DVectorTy)