Index: mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h =================================================================== --- mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h +++ mlir/include/mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h @@ -8,6 +8,7 @@ #ifndef MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ #define MLIR_CONVERSION_GPUTONVVM_GPUTONVVMPASS_H_ +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include namespace mlir { @@ -24,9 +25,11 @@ void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); -/// Creates a pass that lowers GPU dialect operations to NVVM counterparts. -std::unique_ptr> -createLowerGpuOpsToNVVMOpsPass(); +/// Creates a pass that lowers GPU dialect operations to NVVM counterparts. The +/// index bitwidth used for the lowering of the device side index computations +/// is configurable. +std::unique_ptr> createLowerGpuOpsToNVVMOpsPass( + unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout); } // namespace mlir Index: mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h =================================================================== --- mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h +++ mlir/include/mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h @@ -8,6 +8,7 @@ #ifndef MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ #define MLIR_CONVERSION_GPUTOROCDL_GPUTOROCDLPASS_H_ +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include namespace mlir { @@ -25,9 +26,12 @@ void populateGpuToROCDLConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns); -/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. +/// Creates a pass that lowers GPU dialect operations to ROCDL counterparts. The +/// index bitwidth used for the lowering of the device side index computations +/// is configurable. std::unique_ptr> -createLowerGpuOpsToROCDLOpsPass(); +createLowerGpuOpsToROCDLOpsPass( + unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout); } // namespace mlir Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -100,6 +100,11 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> { let summary = "Generate NVVM operations for gpu operations"; let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()"; + let options = [ + Option<"indexBitwidth", "index-bitwidth", "unsigned", + /*default=kDeriveIndexBitwidthFromDataLayout*/"0", + "Bitwidth of the index type, 0 to use size of machine word"> + ]; } //===----------------------------------------------------------------------===// @@ -109,6 +114,11 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> { let summary = "Generate ROCDL operations for gpu operations"; let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()"; + let options = [ + Option<"indexBitwidth", "index-bitwidth", "unsigned", + /*default=kDeriveIndexBitwidthFromDataLayout*/"0", + "Bitwidth of the index type, 0 to use size of machine word"> + ]; } //===----------------------------------------------------------------------===// Index: mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h =================================================================== --- mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -15,6 +15,7 @@ #ifndef MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H #define MLIR_CONVERSION_STANDARDTOLLVM_CONVERTSTANDARDTOLLVM_H +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Transforms/DialectConversion.h" namespace llvm { @@ -35,22 +36,6 @@ class LLVMType; } // namespace LLVM -/// Set of callbacks that allows the customization of LLVMTypeConverter. -struct LLVMTypeConverterCustomization { - using CustomCallback = std::function &)>; - - /// Customize the type conversion of function arguments. - CustomCallback funcArgConverter; - - /// Used to determine the bitwidth of the LLVM integer type that the index - /// type gets lowered to. Defaults to deriving the size from the data layout. - unsigned indexBitwidth; - - /// Initialize customization to default callbacks. - LLVMTypeConverterCustomization(); -}; - /// Callback to convert function argument types. It converts a MemRef function /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing @@ -75,13 +60,11 @@ public: using TypeConverter::convertType; - /// Create an LLVMTypeConverter using the default - /// LLVMTypeConverterCustomization. + /// Create an LLVMTypeConverter using the default LowerToLLVMOptions. LLVMTypeConverter(MLIRContext *ctx); - /// Create an LLVMTypeConverter using 'custom' customizations. - LLVMTypeConverter(MLIRContext *ctx, - const LLVMTypeConverterCustomization &custom); + /// Create an LLVMTypeConverter using custom LowerToLLVMOptions. + LLVMTypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options); /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` @@ -127,7 +110,7 @@ LLVM::LLVMType getIndexType(); /// Gets the bitwidth of the index type when converted to LLVM. - unsigned getIndexTypeBitwidth() { return customizations.indexBitwidth; } + unsigned getIndexTypeBitwidth() { return options.indexBitwidth; } protected: /// LLVM IR module used to parse/create types. @@ -193,8 +176,8 @@ // Convert a 1D vector type into an LLVM vector type. Type convertVectorType(VectorType type); - /// Callbacks for customizing the type conversion. - LLVMTypeConverterCustomization customizations; + /// Options for customizing the llvm lowering. + LowerToLLVMOptions options; }; /// Helper class to produce LLVM dialect operations extracting or inserting @@ -389,11 +372,17 @@ }; /// Base class for operation conversions targeting the LLVM IR dialect. Provides -/// conversion patterns with access to an LLVMTypeConverter. +/// conversion patterns with access to an LLVMTypeConverter and the +/// LowerToLLVMOptions. class ConvertToLLVMPattern : public ConversionPattern { public: ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, + const LowerToLLVMOptions &options = { + /*useBarePtrCallConv=*/false, + /*emitCWrappers=*/false, + /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout, + /*useAlignedAlloc=*/false}, PatternBenefit benefit = 1); /// Returns the LLVM dialect. @@ -445,6 +434,9 @@ protected: /// Reference to the type converter, with potential extensions. LLVMTypeConverter &typeConverter; + + /// Reference to the llvm lowering options. + const LowerToLLVMOptions &options; }; /// Utility class for operation conversions targeting the LLVM dialect that @@ -453,10 +445,11 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, + const LowerToLLVMOptions &options, PatternBenefit benefit = 1) : ConvertToLLVMPattern(OpTy::getOperationName(), &typeConverter.getContext(), typeConverter, - benefit) {} + options, benefit) {} }; namespace LLVM { Index: mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h =================================================================== --- mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -14,54 +14,50 @@ namespace mlir { class LLVMTypeConverter; class ModuleOp; -template class OperationPass; +template +class OperationPass; class OwningRewritePatternList; +/// Value to pass as bitwidth for the index type when the converter is expected +/// to derive the bitwidth from the LLVM data layout. +static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0; + +struct LowerToLLVMOptions { + bool useBarePtrCallConv = false; + bool emitCWrappers = false; + unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout; + /// Use aligned_alloc for heap allocations. + bool useAlignedAlloc = false; +}; + /// Collect a set of patterns to convert memory-related operations from the /// Standard dialect to the LLVM dialect, excluding non-memory-related /// operations and FuncOp. void populateStdToLLVMMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlignedAlloc); + const LowerToLLVMOptions &options); /// Collect a set of patterns to convert from the Standard dialect to the LLVM /// dialect, excluding the memory-related operations. void populateStdToLLVMNonMemoryConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + const LowerToLLVMOptions &options); /// Collect the default pattern to convert a FuncOp to the LLVM dialect. If /// `emitCWrappers` is set, the pattern will also produce functions /// that pass memref descriptors by pointer-to-structure in addition to the /// default unpacked form. -void populateStdToLLVMDefaultFuncOpConversionPattern( +void populateStdToLLVMFuncOpConversionPattern( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool emitCWrappers = false); + const LowerToLLVMOptions &options); -/// Collect a set of default patterns to convert from the Standard dialect to -/// LLVM. -void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter, - OwningRewritePatternList &patterns, - bool emitCWrappers = false, - bool useAlignedAlloc = false); - -/// Collect a set of patterns to convert from the Standard dialect to -/// LLVM using the bare pointer calling convention for MemRef function -/// arguments. -void populateStdToLLVMBarePtrConversionPatterns( +/// Collect the patterns to convert from the Standard dialect to LLVM. +void populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlignedAlloc); - -/// Value to pass as bitwidth for the index type when the converter is expected -/// to derive the bitwidth from the LLVM data layout. -static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0; - -struct LowerToLLVMOptions { - bool useBarePtrCallConv = false; - bool emitCWrappers = false; - unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout; - /// Use aligned_alloc for heap allocations. - bool useAlignedAlloc = false; -}; + const LowerToLLVMOptions &options = { + /*useBarePtrCallConv=*/false, /*emitCWrappers=*/false, + /*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout, + /*useAlignedAlloc=*/false}); /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// stdlib malloc/free is used by default for allocating memrefs allocated with Index: mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp =================================================================== --- mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -30,7 +30,6 @@ namespace { - struct GPUShuffleOpLowering : public ConvertToLLVMPattern { explicit GPUShuffleOpLowering(LLVMTypeConverter &lowering_) : ConvertToLLVMPattern(gpu::ShuffleOp::getOperationName(), @@ -97,17 +96,27 @@ /// /// This pass only handles device code and is not meant to be run on GPU host /// code. -class LowerGpuOpsToNVVMOpsPass +struct LowerGpuOpsToNVVMOpsPass : public ConvertGpuOpsToNVVMOpsBase { -public: + LowerGpuOpsToNVVMOpsPass() = default; + LowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) { + this->indexBitwidth = indexBitwidth; + } + void runOnOperation() override { gpu::GPUModuleOp m = getOperation(); + /// Customize the bitwidth used for the device side index computations. + LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false, + /*emitCWrappers = */ true, + /*indexBitwidth =*/indexBitwidth, + /*useAlignedAlloc =*/false}; + /// MemRef conversion for GPU to NVVM lowering. The GPU dialect uses memory /// space 5 for private memory attributions, but NVVM represents private /// memory allocations as local `alloca`s in the default address space. This /// converter drops the private memory space to support the use case above. - LLVMTypeConverter converter(m.getContext()); + LLVMTypeConverter converter(m.getContext(), options); converter.addConversion([&](MemRefType type) -> Optional { if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace()) return llvm::None; @@ -176,6 +185,6 @@ } std::unique_ptr> -mlir::createLowerGpuOpsToNVVMOpsPass() { - return std::make_unique(); +mlir::createLowerGpuOpsToNVVMOpsPass(unsigned indexBitwidth) { + return std::make_unique(indexBitwidth); } Index: mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp =================================================================== --- mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -41,13 +41,22 @@ // // This pass only handles device code and is not meant to be run on GPU host // code. -class LowerGpuOpsToROCDLOpsPass +struct LowerGpuOpsToROCDLOpsPass : public ConvertGpuOpsToROCDLOpsBase { -public: + LowerGpuOpsToROCDLOpsPass() = default; + LowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) { + this->indexBitwidth = indexBitwidth; + } + void runOnOperation() override { gpu::GPUModuleOp m = getOperation(); - LLVMTypeConverter converter(m.getContext()); + /// Customize the bitwidth used for the device side index computations. + LowerToLLVMOptions options = {/*useBarePtrCallConv =*/false, + /*emitCWrappers = */ true, + /*indexBitwidth =*/indexBitwidth, + /*useAlignedAlloc =*/false}; + LLVMTypeConverter converter(m.getContext(), options); OwningRewritePatternList patterns; @@ -106,6 +115,6 @@ } std::unique_ptr> -mlir::createLowerGpuOpsToROCDLOpsPass() { - return std::make_unique(); +mlir::createLowerGpuOpsToROCDLOpsPass(unsigned indexBitwidth) { + return std::make_unique(indexBitwidth); } Index: mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp =================================================================== --- mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -51,11 +51,6 @@ return wrappedLLVMType; } -/// Initialize customization to default callbacks. -LLVMTypeConverterCustomization::LLVMTypeConverterCustomization() - : funcArgConverter(structFuncArgTypeConverter), - indexBitwidth(kDeriveIndexBitwidthFromDataLayout) {} - /// Callback to convert function argument types. It converts a MemRef function /// argument to a list of non-aggregate types containing descriptor /// information, and an UnrankedmemRef function argument to a list containing @@ -122,20 +117,19 @@ return success(); } -/// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization. +/// Create an LLVMTypeConverter using default LowerToLLVMOptions. LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) - : LLVMTypeConverter(ctx, LLVMTypeConverterCustomization()) {} + : LLVMTypeConverter(ctx, LowerToLLVMOptions()) {} -/// Create an LLVMTypeConverter using 'custom' customizations. -LLVMTypeConverter::LLVMTypeConverter( - MLIRContext *ctx, const LLVMTypeConverterCustomization &customs) +/// Create an LLVMTypeConverter using custom LowerToLLVMOptions. +LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, + const LowerToLLVMOptions &options_) : llvmDialect(ctx->getRegisteredDialect()), - customizations(customs) { + options(options_) { assert(llvmDialect && "LLVM IR dialect is not registered"); module = &llvmDialect->getLLVMModule(); - if (customizations.indexBitwidth == kDeriveIndexBitwidthFromDataLayout) - customizations.indexBitwidth = - module->getDataLayout().getPointerSizeInBits(); + if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout) + options.indexBitwidth = module->getDataLayout().getPointerSizeInBits(); // Register conversions for the standard types. addConversion([&](ComplexType type) { return convertComplexType(type); }); @@ -262,11 +256,15 @@ LLVM::LLVMType LLVMTypeConverter::convertFunctionSignature( FunctionType type, bool isVariadic, LLVMTypeConverter::SignatureConversion &result) { + // Select the argument converter depending on the calling convetion. + auto funcArgConverter = options.useBarePtrCallConv + ? barePtrFuncArgTypeConverter + : structFuncArgTypeConverter; // Convert argument types one by one and check for errors. for (auto &en : llvm::enumerate(type.getInputs())) { Type type = en.value(); SmallVector converted; - if (failed(customizations.funcArgConverter(*this, type, converted))) + if (failed(funcArgConverter(*this, type, converted))) return {}; result.addInputs(en.index(), converted); } @@ -397,9 +395,10 @@ ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter_, + const LowerToLLVMOptions &options_, PatternBenefit benefit) : ConversionPattern(rootOpName, benefit, typeConverter_, context), - typeConverter(typeConverter_) {} + typeConverter(typeConverter_), options(options_) {} /*============================================================================*/ /* StructBuilder implementation */ @@ -1051,8 +1050,10 @@ /// information. static constexpr StringRef kEmitIfaceAttrName = "llvm.emit_c_interface"; struct FuncOpConversion : public FuncOpConversionBase { - FuncOpConversion(LLVMTypeConverter &converter, bool emitCWrappers) - : FuncOpConversionBase(converter), emitWrappers(emitCWrappers) {} + FuncOpConversion(LLVMTypeConverter &converter, + const LowerToLLVMOptions &options) + : FuncOpConversionBase(converter, options) {} + using ConvertOpToLLVMPattern::options; LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -1063,7 +1064,7 @@ if (!newFuncOp) return failure(); - if (emitWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { + if (options.emitCWrappers || funcOp.getAttrOfType(kEmitIfaceAttrName)) { if (newFuncOp.isExternal()) wrapExternalFunction(rewriter, op->getLoc(), typeConverter, funcOp, newFuncOp); @@ -1075,11 +1076,6 @@ rewriter.eraseOp(op); return success(); } - -private: - /// If true, also create the adaptor functions having signatures compatible - /// with those produced by clang. - const bool emitWrappers; }; /// FuncOp legalization pattern that converts MemRef arguments to bare pointers @@ -1506,11 +1502,11 @@ using ConvertOpToLLVMPattern::getIndexType; using ConvertOpToLLVMPattern::typeConverter; using ConvertOpToLLVMPattern::getVoidPtrType; + using ConvertOpToLLVMPattern::options; explicit AllocLikeOpLowering(LLVMTypeConverter &converter, - bool useAlignedAlloc = false) - : ConvertOpToLLVMPattern(converter), - useAlignedAlloc(useAlignedAlloc) {} + const LowerToLLVMOptions &options) + : ConvertOpToLLVMPattern(converter, options) {} LogicalResult match(Operation *op) const override { MemRefType memRefType = cast(op).getType(); @@ -1677,7 +1673,7 @@ /// allocation size to be a multiple of alignment, Optional getAllocationAlignment(AllocOp allocOp) const { // No alignment can be used for the 'malloc' call itself. - if (!useAlignedAlloc) + if (!options.useAlignedAlloc) return None; if (allocOp.alignment()) @@ -1849,16 +1845,14 @@ } protected: - /// Use aligned_alloc instead of malloc for all heap allocations. - bool useAlignedAlloc; /// The minimum alignment to use with aligned_alloc (has to be a power of 2). uint64_t kMinAlignedAllocAlignment = 16UL; }; struct AllocOpLowering : public AllocLikeOpLowering { explicit AllocOpLowering(LLVMTypeConverter &converter, - bool useAlignedAlloc = false) - : AllocLikeOpLowering(converter, useAlignedAlloc) {} + const LowerToLLVMOptions &options) + : AllocLikeOpLowering(converter, options) {} }; using AllocaOpLowering = AllocLikeOpLowering; @@ -1939,8 +1933,9 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - explicit DeallocOpLowering(LLVMTypeConverter &converter) - : ConvertOpToLLVMPattern(converter) {} + explicit DeallocOpLowering(LLVMTypeConverter &converter, + const LowerToLLVMOptions &options) + : ConvertOpToLLVMPattern(converter, options) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, @@ -2960,7 +2955,8 @@ /// Collect a set of patterns to convert from the Standard dialect to LLVM. void mlir::populateStdToLLVMNonMemoryConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, + const LowerToLLVMOptions &options) { // FIXME: this should be tablegen'ed // clang-format off patterns.insert< @@ -3023,13 +3019,13 @@ UnsignedRemIOpLowering, UnsignedShiftRightOpLowering, XOrOpLowering, - ZeroExtendIOpLowering>(converter); + ZeroExtendIOpLowering>(converter, options); // clang-format on } void mlir::populateStdToLLVMMemoryConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlignedAlloc) { + const LowerToLLVMOptions &options) { // clang-format off patterns.insert< AssumeAlignmentOpLowering, @@ -3039,41 +3035,26 @@ MemRefCastOpLowering, StoreOpLowering, SubViewOpLowering, - ViewOpLowering>(converter); - patterns.insert< - AllocOpLowering - >(converter, useAlignedAlloc); + ViewOpLowering, + AllocOpLowering>(converter, options); // clang-format on } -void mlir::populateStdToLLVMDefaultFuncOpConversionPattern( +void mlir::populateStdToLLVMFuncOpConversionPattern( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool emitCWrappers) { - patterns.insert(converter, emitCWrappers); + const LowerToLLVMOptions &options) { + if (options.useBarePtrCallConv) + patterns.insert(converter, options); + else + patterns.insert(converter, options); } void mlir::populateStdToLLVMConversionPatterns( LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool emitCWrappers, bool useAlignedAlloc) { - populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns, - emitCWrappers); - populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatterns(converter, patterns, - useAlignedAlloc); -} - -static void populateStdToLLVMBarePtrFuncOpConversionPattern( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { - patterns.insert(converter); -} - -void mlir::populateStdToLLVMBarePtrConversionPatterns( - LLVMTypeConverter &converter, OwningRewritePatternList &patterns, - bool useAlignedAlloc) { - populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns); - populateStdToLLVMNonMemoryConversionPatterns(converter, patterns); - populateStdToLLVMMemoryConversionPatterns(converter, patterns, - useAlignedAlloc); + const LowerToLLVMOptions &options) { + populateStdToLLVMFuncOpConversionPattern(converter, patterns, options); + populateStdToLLVMNonMemoryConversionPatterns(converter, patterns, options); + populateStdToLLVMMemoryConversionPatterns(converter, patterns, options); } // Create an LLVM IR structure type if there is more than one result. @@ -3163,19 +3144,12 @@ ModuleOp m = getOperation(); - LLVMTypeConverterCustomization customs; - customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter - : structFuncArgTypeConverter; - customs.indexBitwidth = indexBitwidth; - LLVMTypeConverter typeConverter(&getContext(), customs); + LowerToLLVMOptions options = {useBarePtrCallConv, emitCWrappers, + indexBitwidth, useAlignedAlloc}; + LLVMTypeConverter typeConverter(&getContext(), options); OwningRewritePatternList patterns; - if (useBarePtrCallConv) - populateStdToLLVMBarePtrConversionPatterns(typeConverter, patterns, - useAlignedAlloc); - else - populateStdToLLVMConversionPatterns(typeConverter, patterns, - emitCWrappers, useAlignedAlloc); + populateStdToLLVMConversionPatterns(typeConverter, patterns, options); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, patterns))) Index: mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir =================================================================== --- mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -1,36 +1,52 @@ // RUN: mlir-opt %s -convert-gpu-to-nvvm -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-gpu-to-nvvm='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s gpu.module @test_module { // CHECK-LABEL: func @gpu_index_ops() + // CHECK32-LABEL: func @gpu_index_ops() func @gpu_index_ops() -> (index, index, index, index, index, index, index, index, index, index, index, index) { + // CHECK32-NOT: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 + // CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index) std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ, @@ -42,6 +58,21 @@ // ----- +gpu.module @test_module { + // CHECK-LABEL: func @gpu_index_comp + // CHECK32-LABEL: func @gpu_index_comp + func @gpu_index_comp(%idx : index) -> index { + // CHECK: = llvm.add %{{.*}}, %{{.*}} : !llvm.i64 + // CHECK32: = llvm.add %{{.*}}, %{{.*}} : !llvm.i32 + %0 = addi %idx, %idx : index + // CHECK: llvm.return %{{.*}} : !llvm.i64 + // CHECK32: llvm.return %{{.*}} : !llvm.i32 + std.return %0 : index + } +} + +// ----- + gpu.module @test_module { // CHECK-LABEL: func @gpu_all_reduce_op() gpu.func @gpu_all_reduce_op() { Index: mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir =================================================================== --- mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -1,36 +1,52 @@ // RUN: mlir-opt %s -convert-gpu-to-rocdl -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-gpu-to-rocdl='index-bitwidth=32' -split-input-file | FileCheck --check-prefix=CHECK32 %s gpu.module @test_module { // CHECK-LABEL: func @gpu_index_ops() + // CHECK32-LABEL: func @gpu_index_ops() func @gpu_index_ops() -> (index, index, index, index, index, index, index, index, index, index, index, index) { + // CHECK32-NOT: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 + // CHECK: rocdl.workitem.id.x : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index) // CHECK: rocdl.workitem.id.y : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index) // CHECK: rocdl.workitem.id.z : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index) // CHECK: rocdl.workgroup.dim.x : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index) // CHECK: rocdl.workgroup.dim.y : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index) // CHECK: rocdl.workgroup.dim.z : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index) // CHECK: rocdl.workgroup.id.x : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index) // CHECK: rocdl.workgroup.id.y : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index) // CHECK: rocdl.workgroup.id.z : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index) // CHECK: rocdl.grid.dim.x : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index) // CHECK: rocdl.grid.dim.y : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index) // CHECK: rocdl.grid.dim.z : !llvm.i32 + // CHECK: = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index) std.return %tIdX, %tIdY, %tIdZ, %bDimX, %bDimY, %bDimZ, @@ -42,6 +58,21 @@ // ----- +gpu.module @test_module { + // CHECK-LABEL: func @gpu_index_comp + // CHECK32-LABEL: func @gpu_index_comp + func @gpu_index_comp(%idx : index) -> index { + // CHECK: = llvm.add %{{.*}}, %{{.*}} : !llvm.i64 + // CHECK32: = llvm.add %{{.*}}, %{{.*}} : !llvm.i32 + %0 = addi %idx, %idx : index + // CHECK: llvm.return %{{.*}} : !llvm.i64 + // CHECK32: llvm.return %{{.*}} : !llvm.i32 + std.return %0 : index + } +} + +// ----- + gpu.module @test_module { // CHECK-LABEL: func @gpu_sync() func @gpu_sync() {