diff --git a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h --- a/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/LoweringOptions.h @@ -14,6 +14,7 @@ #ifndef MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H #define MLIR_CONVERSION_LLVMCOMMON_LOWERINGOPTIONS_H +#include "mlir/IR/BuiltinTypes.h" #include "llvm/IR/DataLayout.h" namespace mlir { @@ -66,6 +67,9 @@ /// Get the index bitwidth. unsigned getIndexBitwidth() const { return indexBitwidth; } + /// Hook to customize the conversion of MemRefType to LLVMType. + llvm::function_ref memrefIndexTypeConverter = nullptr; + private: unsigned indexBitwidth; }; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -50,6 +50,11 @@ /// defined by the used type converter. Type getIndexType() const; + /// Gets the MLIR type wrapping the LLVM integer type whose bit width is + /// defined by the used type converter and matching the index type needed for + /// MemRefType `t`. + Type getIndexTypeMatchingMemRef(MemRefType t) const; + /// Gets the MLIR type wrapping the LLVM integer type whose bit width /// corresponds to that of a LLVM pointer type. Type getIntPtrType(unsigned addressSpace = 0) const; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h --- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h @@ -132,6 +132,11 @@ /// integer type with the size configured for this type converter. Type getIndexType(); + /// Gets the LLVM representation of the index type that matches the MemRefType + /// `t`. The returned type is an integer type with the size configured for + /// this type converter. + Type getIndexTypeMatchingMemRef(MemRefType t); + /// Returns true if using opaque pointers was enabled in the lowering options. bool useOpaquePointers() const { return getOptions().useOpaquePointers; } diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h --- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h +++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h @@ -43,7 +43,8 @@ MemRefType memRefType = op.getType(); Value alignment; if (auto alignmentAttr = op.getAlignment()) { - Type indexType = getIndexType(); + Type indexType = + ConvertToLLVMPattern::getIndexTypeMatchingMemRef(memRefType); alignment = createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr); } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td @@ -52,6 +52,14 @@ /// Returns the numeric value used to identify the private memory address /// space. static AddressSpace getPrivateAddressSpace() { return AddressSpace::Private; } + + /// Return true if the given MemRefType has an address space that is a + /// gpu::AddressSpaceAttr attribute with value 'workgroup`. + static bool hasSharedMemoryAddressSpace(MemRefType type); + + /// Return true if the given Attribute has matches is a gpu::AddressSpaceAttr + /// attribute with value 'workgroup`. + static bool isSharedMemoryAddressSpace(Attribute type); }]; let dependentDialects = ["arith::ArithDialect"]; diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -67,7 +67,7 @@ protected: Value getNumElements(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, MemRefDescriptor desc) const { - Type indexType = ConvertToLLVMPattern::getIndexType(); + Type indexType = ConvertToLLVMPattern::getIndexTypeMatchingMemRef(type); return type.hasStaticShape() ? ConvertToLLVMPattern::createIndexAttrConstant( rewriter, loc, indexType, type.getNumElements()) @@ -654,10 +654,16 @@ } // namespace +static IntegerType getIndexTypeForMemRef(MemRefType t) { + int64_t numBits = gpu::GPUDialect::hasSharedMemoryAddressSpace(t) ? 32 : 64; + return IntegerType::get(t.getContext(), numBits); +} + void GpuToLLVMConversionPass::runOnOperation() { LowerToLLVMOptions options(&getContext()); options.useOpaquePointers = useOpaquePointers; options.useBarePtrCallConv = hostBarePtrCallConv; + options.memrefIndexTypeConverter = getIndexTypeForMemRef; LLVMTypeConverter converter(&getContext(), options); RewritePatternSet patterns(&getContext()); diff --git a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt --- a/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt +++ b/mlir/lib/Conversion/GPUToNVVM/CMakeLists.txt @@ -18,7 +18,9 @@ MLIRLLVMCommonConversion MLIRLLVMDialect MLIRMemRefToLLVM + MLIRNVGPUDialect MLIRNVVMDialect MLIRPass MLIRTransformUtils + MLIRVectorToLLVM ) diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -16,10 +16,12 @@ #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -27,6 +29,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -202,6 +205,14 @@ /// Import the GPU Ops to NVVM Patterns. #include "GPUToNVVM.cpp.inc" +static IntegerType getIndexTypeForMemRef(MemRefType t) { + int64_t numBits = (gpu::GPUDialect::hasSharedMemoryAddressSpace(t) || + nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(t)) + ? 32 + : 64; + return IntegerType::get(t.getContext(), numBits); +} + /// A pass that replaces all occurrences of GPU device operations with their /// corresponding NVVM equivalent. /// @@ -232,6 +243,7 @@ options.overrideIndexBitwidth(indexBitwidth); options.useOpaquePointers = useOpaquePointers; options.useBarePtrCallConv = useBarePtrCallConv; + options.memrefIndexTypeConverter = getIndexTypeForMemRef; // Apply in-dialect lowering. In-dialect lowering will replace // ops which need to be lowered further, which is not supported by a @@ -271,6 +283,7 @@ arith::populateArithToLLVMConversionPatterns(converter, llvmPatterns); cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns); + populateVectorToLLVMConversionPatterns(converter, llvmPatterns); populateFuncToLLVMConversionPatterns(converter, llvmPatterns); populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns); populateGpuToNVVMConversionPatterns(converter, llvmPatterns); diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -19,6 +19,20 @@ // ConvertToLLVMPattern //===----------------------------------------------------------------------===// +static Value convertToDesiredIndexType(OpBuilder &b, Location loc, Value src, + Type desiredIndexType) { + assert(src.getType().isIntOrIndex() && !src.getType().isIndex() && + "expected int type"); + assert(desiredIndexType.isIntOrIndex() && !desiredIndexType.isIndex() && + "expected int type"); + if (src.getType() == desiredIndexType) + return src; + if (src.getType().getIntOrFloatBitWidth() < + desiredIndexType.getIntOrFloatBitWidth()) + return b.create(loc, desiredIndexType, src); + return b.create(loc, desiredIndexType, src); +} + ConvertToLLVMPattern::ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &typeConverter, @@ -38,6 +52,10 @@ return getTypeConverter()->getIndexType(); } +Type ConvertToLLVMPattern::getIndexTypeMatchingMemRef(MemRefType t) const { + return getTypeConverter()->getIndexTypeMatchingMemRef(t); +} + Type ConvertToLLVMPattern::getIntPtrType(unsigned addressSpace) const { return IntegerType::get(&getTypeConverter()->getContext(), getTypeConverter()->getPointerBitwidth(addressSpace)); @@ -74,7 +92,7 @@ Value base = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type); - Type indexType = getIndexType(); + Type indexType = getIndexTypeMatchingMemRef(type); Value index; for (int i = 0, e = indices.size(); i < e; ++i) { Value increment = indices[i]; @@ -83,8 +101,11 @@ ShapedType::isDynamic(strides[i]) ? memRefDescriptor.stride(rewriter, loc, i) : createIndexAttrConstant(rewriter, loc, indexType, strides[i]); + increment = + convertToDesiredIndexType(rewriter, loc, increment, indexType); increment = rewriter.create(loc, increment, stride); } + increment = convertToDesiredIndexType(rewriter, loc, increment, indexType); index = index ? rewriter.create(loc, index, increment) : increment; } @@ -127,7 +148,7 @@ sizes.reserve(memRefType.getRank()); unsigned dynamicIndex = 0; - Type indexType = getIndexType(); + Type indexType = getIndexTypeMatchingMemRef(memRefType); for (int64_t size : memRefType.getShape()) { sizes.push_back( size == ShapedType::kDynamic @@ -194,7 +215,7 @@ static_cast(dynamicSizes.size()) && "dynamicSizes size doesn't match dynamic sizes count in memref shape"); - Type indexType = getIndexType(); + Type indexType = getIndexTypeMatchingMemRef(memRefType); Value numElements = memRefType.getRank() == 0 ? createIndexAttrConstant(rewriter, loc, indexType, 1) : nullptr; @@ -233,7 +254,7 @@ memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); // Field 3: Offset in aligned pointer. - Type indexType = getIndexType(); + Type indexType = getIndexTypeMatchingMemRef(memRefType); memRefDescriptor.setOffset( rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0)); diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -174,6 +174,11 @@ return IntegerType::get(&getContext(), getIndexTypeBitwidth()); } +Type LLVMTypeConverter::getIndexTypeMatchingMemRef(MemRefType t) { + return options.memrefIndexTypeConverter ? options.memrefIndexTypeConverter(t) + : getIndexType(); +} + LLVM::LLVMPointerType LLVMTypeConverter::getPointerType(Type elementType, unsigned int addressSpace) { if (useOpaquePointers()) @@ -339,7 +344,7 @@ } auto ptrTy = getPointerType(elementType, *addressSpace); - auto indexTy = getIndexType(); + Type indexTy = getIndexTypeMatchingMemRef(type); SmallVector results = {ptrTy, ptrTy, indexTy}; auto rank = type.getRank(); @@ -358,7 +363,8 @@ // Compute the descriptor size given that of its components indicated above. unsigned space = *getMemRefAddressSpace(type); return 2 * llvm::divideCeil(getPointerBitwidth(space), 8) + - (1 + 2 * type.getRank()) * layout.getTypeSize(getIndexType()); + (1 + 2 * type.getRank()) * + layout.getTypeSize(getIndexTypeMatchingMemRef(type)); } /// Converts MemRefType to LLVMType. A MemRefType is converted to a struct that diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -160,7 +160,7 @@ auto computeNumElements = [&](MemRefType type, function_ref getDynamicSize) -> Value { // Compute number of elements. - Type indexType = ConvertToLLVMPattern::getIndexType(); + Type indexType = ConvertToLLVMPattern::getIndexTypeMatchingMemRef(type); Value numElements = type.isDynamicDim(0) ? getDynamicSize() @@ -483,7 +483,8 @@ // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. Value idxPlusOne = rewriter.create( - loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1), + loc, + createIndexAttrConstant(rewriter, loc, adaptor.getIndex().getType(), 1), adaptor.getIndex()); Value sizePtr = rewriter.create( loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, @@ -510,7 +511,7 @@ // Take advantage if index is constant. MemRefType memRefType = cast(operandType); - Type indexType = getIndexType(); + Type indexType = getIndexTypeMatchingMemRef(memRefType); if (std::optional index = getConstantDimIndex(dimOp)) { int64_t i = *index; if (i >= 0 && i < memRefType.getRank()) { @@ -1360,7 +1361,7 @@ assert(targetMemRefType.getLayout().isIdentity() && "Identity layout map is a precondition of a valid reshape op"); - Type indexType = getIndexType(); + Type indexType = getIndexTypeMatchingMemRef(targetMemRefType); Value stride = nullptr; int64_t targetRank = targetMemRefType.getRank(); for (auto i : llvm::reverse(llvm::seq(0, targetRank))) { @@ -1455,7 +1456,8 @@ Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); - Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1); + Value oneIndex = + createIndexAttrConstant(rewriter, loc, resultRank.getType(), 1); Value resultRankMinusOne = rewriter.create(loc, resultRank, oneIndex); @@ -1708,7 +1710,7 @@ targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); - Type indexType = getIndexType(); + auto indexType = targetMemRef.getIndexType(); // Field 3: The offset in the resulting type must be 0. This is // because of the type change: an offset on srcType* may not be // expressible as an offset on dstType*. diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -35,6 +35,22 @@ #include "mlir/Dialect/GPU/IR/GPUOpsDialect.cpp.inc" +/// Return true if the given MemRefType has an address space that is a +/// gpu::AddressSpaceAttr attribute with value 'workgroup`. +bool gpu::GPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { + return isSharedMemoryAddressSpace(type.getMemorySpace()); +} + +/// Return true if the given Attribute has matches is a gpu::AddressSpaceAttr +/// attribute with value 'workgroup`. +bool gpu::GPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { + if (!memorySpace) + return false; + if (auto gpuAttr = llvm::dyn_cast(memorySpace)) + return gpuAttr.getValue() == gpu::AddressSpace::Workgroup; + return false; +} + //===----------------------------------------------------------------------===// // GPU Device Mapping Attributes //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir b/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir --- a/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir +++ b/mlir/test/Conversion/GPUCommon/memory-attrbution.mlir @@ -76,14 +76,14 @@ // ROCDL-SAME: !llvm.ptr<3> // Populate the memref descriptor. - // NVVM: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> + // NVVM: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i32, array<1 x i32>, array<1 x i32>)> // NVVM: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] // NVVM: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] - // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64 + // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i32 // NVVM: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] - // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : index) : i64 + // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : index) : i32 // NVVM: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] - // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64 + // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i32 // NVVM: %[[descr6:.*]] = llvm.insertvalue %[[c1]], %[[descr5]][4, 0] // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<1 x i64>, array<1 x i64>)> @@ -137,22 +137,22 @@ // ROCDL-SAME: !llvm.ptr<3> // Populate the memref descriptor. - // NVVM: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)> + // NVVM: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i32, array<3 x i32>, array<3 x i32>)> // NVVM: %[[descr2:.*]] = llvm.insertvalue %[[raw]], %[[descr1]][0] // NVVM: %[[descr3:.*]] = llvm.insertvalue %[[raw]], %[[descr2]][1] - // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i64 + // NVVM: %[[c0:.*]] = llvm.mlir.constant(0 : index) : i32 // NVVM: %[[descr4:.*]] = llvm.insertvalue %[[c0]], %[[descr3]][2] - // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : index) : i64 + // NVVM: %[[c4:.*]] = llvm.mlir.constant(4 : index) : i32 // NVVM: %[[descr5:.*]] = llvm.insertvalue %[[c4]], %[[descr4]][3, 0] - // NVVM: %[[c12:.*]] = llvm.mlir.constant(12 : index) : i64 + // NVVM: %[[c12:.*]] = llvm.mlir.constant(12 : index) : i32 // NVVM: %[[descr6:.*]] = llvm.insertvalue %[[c12]], %[[descr5]][4, 0] - // NVVM: %[[c2:.*]] = llvm.mlir.constant(2 : index) : i64 + // NVVM: %[[c2:.*]] = llvm.mlir.constant(2 : index) : i32 // NVVM: %[[descr7:.*]] = llvm.insertvalue %[[c2]], %[[descr6]][3, 1] - // NVVM: %[[c6:.*]] = llvm.mlir.constant(6 : index) : i64 + // NVVM: %[[c6:.*]] = llvm.mlir.constant(6 : index) : i32 // NVVM: %[[descr8:.*]] = llvm.insertvalue %[[c6]], %[[descr7]][4, 1] - // NVVM: %[[c6:.*]] = llvm.mlir.constant(6 : index) : i64 + // NVVM: %[[c6:.*]] = llvm.mlir.constant(6 : index) : i32 // NVVM: %[[descr9:.*]] = llvm.insertvalue %[[c6]], %[[descr8]][3, 2] - // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i64 + // NVVM: %[[c1:.*]] = llvm.mlir.constant(1 : index) : i32 // NVVM: %[[descr10:.*]] = llvm.insertvalue %[[c1]], %[[descr9]][4, 2] // ROCDL: %[[descr1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<3>, ptr<3>, i64, array<3 x i64>, array<3 x i64>)> diff --git a/mlir/test/Conversion/GPUToNVVM/typed-pointers.mlir b/mlir/test/Conversion/GPUToNVVM/typed-pointers.mlir --- a/mlir/test/Conversion/GPUToNVVM/typed-pointers.mlir +++ b/mlir/test/Conversion/GPUToNVVM/typed-pointers.mlir @@ -11,13 +11,15 @@ %i = arith.constant 16 : index %j = arith.constant 16 : index %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> - // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 + // CHECK: %[[INX64:.*]] = llvm.mlir.constant(16 : index) : i64 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] - // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 - // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 - // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 - // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[INX:.*]] = llvm.trunc %[[INX64]] : i64 to i32 + // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 + // CHECK: %[[INX2:.*]] = llvm.trunc %[[INX64]] : i64 to i32 + // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX2]] : i32 + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr, i32) -> !llvm.ptr // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] // CHECK-SAME: {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : (!llvm.ptr) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> diff --git a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir --- a/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/wmma-ops-to-nvvm.mlir @@ -11,13 +11,15 @@ %i = arith.constant 16 : index %j = arith.constant 16 : index %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xf16, 3> -> !gpu.mma_matrix<16x16xf16, "AOp"> - // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 + // CHECK: %[[INX64:.*]] = llvm.mlir.constant(16 : index) : i64 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] - // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 - // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 - // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 - // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16 + // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)> + // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[INX:.*]] = llvm.trunc %[[INX64]] : i64 to i32 + // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 + // CHECK: %[[INX2:.*]] = llvm.trunc %[[INX64]] : i64 to i32 + // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX2]] : i32 + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] // CHECK-SAME: {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> @@ -50,13 +52,15 @@ %i = arith.constant 16 : index %j = arith.constant 16 : index %0 = gpu.subgroup_mma_load_matrix %wg[%i, %j] {leadDimension = 32 : index, transpose} : memref<32x32xi8, 3> -> !gpu.mma_matrix<16x16xsi8, "AOp"> - // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 + // CHECK: %[[INX64:.*]] = llvm.mlir.constant(16 : index) : i64 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] - // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 - // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 - // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 - // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, i8 + // CHECK: %[[BASE:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)> + // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[INX:.*]] = llvm.trunc %[[INX64]] : i64 to i32 + // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 + // CHECK: %[[INX2:.*]] = llvm.trunc %[[INX64]] : i64 to i32 + // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX2]] : i32 + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 // CHECK: %[[FRAG:.*]] = nvvm.wmma.load %[[ADDRESS]], %[[LDM32]] // CHECK-SAME: {eltype = #nvvm.mma_type, frag = #nvvm.mma_frag, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : (!llvm.ptr<3>) -> !llvm.struct<(i32, i32)> @@ -90,7 +94,7 @@ %i = arith.constant 16 : index %j = arith.constant 16 : index gpu.subgroup_mma_store_matrix %arg0, %sg[%i,%j] {leadDimension= 32 : index, transpose} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x32xf16, 3> - // CHECK: %[[INX:.*]] = llvm.mlir.constant(16 : index) : i64 + // CHECK: %[[INX64:.*]] = llvm.mlir.constant(16 : index) : i64 // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] // CHECK: %{{.*}} = llvm.insertvalue %{{.*}}, %{{.*}}[{{.*}}, {{.*}}] @@ -99,11 +103,13 @@ // CHECK: %[[EL2:.*]] = llvm.extractvalue %[[D]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL3:.*]] = llvm.extractvalue %[[D]][2] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> // CHECK: %[[EL4:.*]] = llvm.extractvalue %[[D]][3] : !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)> - // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i64 - // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i64 - // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX]] : i64 - // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i64) -> !llvm.ptr<3>, f16 + // CHECK: %[[BASE:.*]] = llvm.extractvalue %[[MEMREF]][1] : !llvm.struct<(ptr<3>, ptr<3>, i32, array<2 x i32>, array<2 x i32>)> + // CHECK: %[[LDM:.*]] = llvm.mlir.constant(32 : index) : i32 + // CHECK: %[[INX:.*]] = llvm.trunc %[[INX64]] : i64 to i32 + // CHECK: %[[LI:.*]] = llvm.mul %[[INX]], %[[LDM]] : i32 + // CHECK: %[[INX2:.*]] = llvm.trunc %[[INX64]] : i64 to i32 + // CHECK: %[[LIJ:.*]] = llvm.add %[[LI]], %[[INX2]] : i32 + // CHECK: %[[ADDRESS:.*]] = llvm.getelementptr %[[BASE]][%[[LIJ]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 // CHECK: %[[LDM32:.*]] = llvm.mlir.constant(32 : index) : i32 // CHECK: nvvm.wmma.store %[[ADDRESS]], %[[LDM32]], %[[EL1]], %[[EL2]], %[[EL3]], %[[EL4]] // CHECK-SAME: {eltype = #nvvm.mma_type, k = 16 : i32, layout = #nvvm.mma_layout, m = 16 : i32, n = 16 : i32} : !llvm.ptr<3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16> diff --git a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir --- a/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir +++ b/mlir/test/Integration/Dialect/Vector/GPU/CUDA/test-warp-distribute.mlir @@ -1,7 +1,7 @@ // Run the test cases without distributing ops to test default lowering. Run // everything on the same thread. // RUN: mlir-opt %s -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \ -// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \ +// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-arith-to-llvm \ // RUN: -gpu-kernel-outlining |\ // RUN: mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin))' |\ // RUN: mlir-opt -gpu-to-llvm -reconcile-unrealized-casts |\ @@ -14,7 +14,7 @@ // Run the same test cases with distribution and propagation. // RUN: mlir-opt %s -test-vector-warp-distribute="hoist-uniform distribute-transfer-write" \ // RUN: -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \ -// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \ +// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-arith-to-llvm \ // RUN: -gpu-kernel-outlining |\ // RUN: mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin))' |\ // RUN: mlir-opt -gpu-to-llvm -reconcile-unrealized-casts |\ @@ -26,7 +26,7 @@ // RUN: mlir-opt %s -test-vector-warp-distribute="hoist-uniform distribute-transfer-write propagate-distribution" \ // RUN: -test-vector-warp-distribute=rewrite-warp-ops-to-scf-if -canonicalize | \ -// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-vector-to-llvm -convert-arith-to-llvm \ +// RUN: mlir-opt -convert-scf-to-cf -convert-cf-to-llvm -convert-arith-to-llvm \ // RUN: -gpu-kernel-outlining |\ // RUN: mlir-opt -pass-pipeline='builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,reconcile-unrealized-casts,gpu-to-cubin))' |\ // RUN: mlir-opt -gpu-to-llvm -reconcile-unrealized-casts |\ diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -4954,6 +4954,7 @@ ":MathDialect", ":MemRefDialect", ":MemRefToLLVM", + ":NVGPUDialect", ":NVVMDialect", ":Pass", ":Transforms",