diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -65,10 +65,6 @@ static Value createIndexAttrConstant(OpBuilder &builder, Location loc, Type resultType, int64_t value); - /// Create an LLVM dialect operation defining the given index constant. - Value createIndexConstant(ConversionPatternRewriter &builder, Location loc, - uint64_t value) const; - // This is a strided getElementPtr variant that linearizes subscripts as: // `base_offset + index_0 * stride_0 + ... + index_n * stride_n`. Value getStridedElementPtr(Location loc, MemRefType type, Value memRefDesc, @@ -155,9 +151,9 @@ ConversionPatternRewriter &rewriter) const final { if constexpr (SourceOp::hasProperties()) return rewrite(cast(op), - OpAdaptor(operands, op->getDiscardableAttrDictionary(), - cast(op).getProperties()), - rewriter); + OpAdaptor(operands, op->getDiscardableAttrDictionary(), + cast(op).getProperties()), + rewriter); rewrite(cast(op), OpAdaptor(operands, op->getDiscardableAttrDictionary()), rewriter); } diff --git a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h --- a/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h +++ b/mlir/include/mlir/Conversion/MemRefToLLVM/AllocLikeConversion.h @@ -15,7 +15,7 @@ /// Lowering for memory allocation ops. struct AllocationOpLLVMLowering : public ConvertToLLVMPattern { - using ConvertToLLVMPattern::createIndexConstant; + using ConvertToLLVMPattern::createIndexAttrConstant; using ConvertToLLVMPattern::getIndexType; using ConvertToLLVMPattern::getVoidPtrType; @@ -43,7 +43,9 @@ MemRefType memRefType = op.getType(); Value alignment; if (auto alignmentAttr = op.getAlignment()) { - alignment = createIndexConstant(rewriter, loc, *alignmentAttr); + Type indexType = getIndexType(); + alignment = + createIndexAttrConstant(rewriter, loc, indexType, *alignmentAttr); } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { // In the case where no alignment is specified, we may want to override // `malloc's` behavior. `malloc` typically aligns at the size of the diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -168,7 +168,7 @@ Value lowHalf = rewriter.create(loc, llvmI32, ptrAsInt); resource = rewriter.create( loc, llvm4xI32, resource, lowHalf, - this->createIndexConstant(rewriter, loc, 0)); + this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 0)); // Bits 48-63 are used both for the stride of the buffer and (on gfx10) for // enabling swizzling. Prevent the high bits of pointers from accidentally @@ -180,7 +180,7 @@ createI32Constant(rewriter, loc, 0x0000ffff)); resource = rewriter.create( loc, llvm4xI32, resource, highHalfTruncated, - this->createIndexConstant(rewriter, loc, 1)); + this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 1)); Value numRecords; if (memrefType.hasStaticShape()) { @@ -202,7 +202,7 @@ } resource = rewriter.create( loc, llvm4xI32, resource, numRecords, - this->createIndexConstant(rewriter, loc, 2)); + this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 2)); // Final word: // bits 0-11: dst sel, ignored by these intrinsics @@ -227,7 +227,7 @@ Value word3Const = createI32Constant(rewriter, loc, word3); resource = rewriter.create( loc, llvm4xI32, resource, word3Const, - this->createIndexConstant(rewriter, loc, 3)); + this->createIndexAttrConstant(rewriter, loc, this->getIndexType(), 3)); args.push_back(resource); // Indexing (voffset) diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -67,9 +67,10 @@ protected: Value getNumElements(ConversionPatternRewriter &rewriter, Location loc, MemRefType type, MemRefDescriptor desc) const { + Type indexType = ConvertToLLVMPattern::getIndexType(); return type.hasStaticShape() - ? ConvertToLLVMPattern::createIndexConstant( - rewriter, loc, type.getNumElements()) + ? ConvertToLLVMPattern::createIndexAttrConstant( + rewriter, loc, indexType, type.getNumElements()) // For identity maps (verified by caller), the number of // elements is stride[0] * size[0]. : rewriter.create(loc, diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -60,11 +60,6 @@ builder.getIndexAttr(value)); } -Value ConvertToLLVMPattern::createIndexConstant( - ConversionPatternRewriter &builder, Location loc, uint64_t value) const { - return createIndexAttrConstant(builder, loc, getIndexType(), value); -} - Value ConvertToLLVMPattern::getStridedElementPtr( Location loc, MemRefType type, Value memRefDesc, ValueRange indices, ConversionPatternRewriter &rewriter) const { @@ -79,13 +74,15 @@ Value base = memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type); + Type indexType = getIndexType(); Value index; for (int i = 0, e = indices.size(); i < e; ++i) { Value increment = indices[i]; if (strides[i] != 1) { // Skip if stride is 1. - Value stride = ShapedType::isDynamic(strides[i]) - ? memRefDescriptor.stride(rewriter, loc, i) - : createIndexConstant(rewriter, loc, strides[i]); + Value stride = + ShapedType::isDynamic(strides[i]) + ? memRefDescriptor.stride(rewriter, loc, i) + : createIndexAttrConstant(rewriter, loc, indexType, strides[i]); increment = rewriter.create(loc, increment, stride); } index = @@ -130,15 +127,17 @@ sizes.reserve(memRefType.getRank()); unsigned dynamicIndex = 0; + Type indexType = getIndexType(); for (int64_t size : memRefType.getShape()) { - sizes.push_back(size == ShapedType::kDynamic - ? dynamicSizes[dynamicIndex++] - : createIndexConstant(rewriter, loc, size)); + sizes.push_back( + size == ShapedType::kDynamic + ? dynamicSizes[dynamicIndex++] + : createIndexAttrConstant(rewriter, loc, indexType, size)); } // Strides: iterate sizes in reverse order and multiply. int64_t stride = 1; - Value runningStride = createIndexConstant(rewriter, loc, 1); + Value runningStride = createIndexAttrConstant(rewriter, loc, indexType, 1); strides.resize(memRefType.getRank()); for (auto i = memRefType.getRank(); i-- > 0;) { strides[i] = runningStride; @@ -158,7 +157,7 @@ runningStride = rewriter.create(loc, runningStride, sizes[i]); else - runningStride = createIndexConstant(rewriter, loc, stride); + runningStride = createIndexAttrConstant(rewriter, loc, indexType, stride); } if (sizeInBytes) { // Buffer size in bytes. @@ -195,22 +194,25 @@ static_cast(dynamicSizes.size()) && "dynamicSizes size doesn't match dynamic sizes count in memref shape"); + Type indexType = getIndexType(); Value numElements = memRefType.getRank() == 0 - ? createIndexConstant(rewriter, loc, 1) + ? createIndexAttrConstant(rewriter, loc, indexType, 1) : nullptr; unsigned dynamicIndex = 0; // Compute the total number of memref elements. for (int64_t staticSize : memRefType.getShape()) { if (numElements) { - Value size = staticSize == ShapedType::kDynamic - ? dynamicSizes[dynamicIndex++] - : createIndexConstant(rewriter, loc, staticSize); + Value size = + staticSize == ShapedType::kDynamic + ? dynamicSizes[dynamicIndex++] + : createIndexAttrConstant(rewriter, loc, indexType, staticSize); numElements = rewriter.create(loc, numElements, size); } else { - numElements = staticSize == ShapedType::kDynamic - ? dynamicSizes[dynamicIndex++] - : createIndexConstant(rewriter, loc, staticSize); + numElements = + staticSize == ShapedType::kDynamic + ? dynamicSizes[dynamicIndex++] + : createIndexAttrConstant(rewriter, loc, indexType, staticSize); } } return numElements; @@ -231,8 +233,9 @@ memRefDescriptor.setAlignedPtr(rewriter, loc, alignedPtr); // Field 3: Offset in aligned pointer. - memRefDescriptor.setOffset(rewriter, loc, - createIndexConstant(rewriter, loc, 0)); + Type indexType = getIndexType(); + memRefDescriptor.setOffset( + rewriter, loc, createIndexAttrConstant(rewriter, loc, indexType, 0)); // Fields 4: Sizes. for (const auto &en : llvm::enumerate(sizes)) diff --git a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/AllocLikeConversion.cpp @@ -138,7 +138,8 @@ Value AllocationOpLLVMLowering::allocateBufferAutoAlign( ConversionPatternRewriter &rewriter, Location loc, Value sizeBytes, Operation *op, const DataLayout *defaultLayout, int64_t alignment) const { - Value allocAlignment = createIndexConstant(rewriter, loc, alignment); + Value allocAlignment = + createIndexAttrConstant(rewriter, loc, getIndexType(), alignment); MemRefType memRefType = getMemRefResultType(op); // Function aligned_alloc requires size to be a multiple of alignment; we pad diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -160,11 +160,12 @@ auto computeNumElements = [&](MemRefType type, function_ref getDynamicSize) -> Value { // Compute number of elements. + Type indexType = ConvertToLLVMPattern::getIndexType(); Value numElements = type.isDynamicDim(0) ? getDynamicSize() - : createIndexConstant(rewriter, loc, type.getDimSize(0)); - Type indexType = getIndexType(); + : createIndexAttrConstant(rewriter, loc, indexType, + type.getDimSize(0)); if (numElements.getType() != indexType) numElements = typeConverter->materializeTargetConversion( rewriter, loc, indexType, numElements); @@ -482,7 +483,8 @@ // The size value that we have to extract can be obtained using GEPop with // `dimOp.index() + 1` index argument. Value idxPlusOne = rewriter.create( - loc, createIndexConstant(rewriter, loc, 1), adaptor.getIndex()); + loc, createIndexAttrConstant(rewriter, loc, getIndexType(), 1), + adaptor.getIndex()); Value sizePtr = rewriter.create( loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, idxPlusOne); @@ -508,6 +510,7 @@ // Take advantage if index is constant. MemRefType memRefType = cast(operandType); + Type indexType = getIndexType(); if (std::optional index = getConstantDimIndex(dimOp)) { int64_t i = *index; if (i >= 0 && i < memRefType.getRank()) { @@ -518,7 +521,7 @@ } // Use constant for static size. int64_t dimSize = memRefType.getDimSize(i); - return createIndexConstant(rewriter, loc, dimSize); + return createIndexAttrConstant(rewriter, loc, indexType, dimSize); } } Value index = adaptor.getIndex(); @@ -717,7 +720,11 @@ // This is called after a type conversion, which would have failed if this // call fails. - unsigned memSpace = *getTypeConverter()->getMemRefAddressSpace(type); + std::optional maybeAddressSpace = + getTypeConverter()->getMemRefAddressSpace(type); + if (!maybeAddressSpace) + return std::make_tuple(Value(), Value()); + unsigned memSpace = *maybeAddressSpace; Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter()); Type resTy = getTypeConverter()->getPointerType(arrayTy, memSpace); @@ -826,8 +833,10 @@ return success(); } if (auto rankedMemRefType = dyn_cast(operandType)) { - rewriter.replaceOp( - op, {createIndexConstant(rewriter, loc, rankedMemRefType.getRank())}); + Type indexType = getIndexType(); + rewriter.replaceOp(op, + {createIndexAttrConstant(rewriter, loc, indexType, + rankedMemRefType.getRank())}); return success(); } return failure(); @@ -1351,29 +1360,31 @@ assert(targetMemRefType.getLayout().isIdentity() && "Identity layout map is a precondition of a valid reshape op"); + Type indexType = getIndexType(); Value stride = nullptr; int64_t targetRank = targetMemRefType.getRank(); for (auto i : llvm::reverse(llvm::seq(0, targetRank))) { if (!ShapedType::isDynamic(strides[i])) { // If the stride for this dimension is dynamic, then use the product // of the sizes of the inner dimensions. - stride = createIndexConstant(rewriter, loc, strides[i]); + stride = + createIndexAttrConstant(rewriter, loc, indexType, strides[i]); } else if (!stride) { // `stride` is null only in the first iteration of the loop. However, // since the target memref has an identity layout, we can safely set // the innermost stride to 1. - stride = createIndexConstant(rewriter, loc, 1); + stride = createIndexAttrConstant(rewriter, loc, indexType, 1); } Value dimSize; // If the size of this dimension is dynamic, then load it at runtime // from the shape operand. if (!targetMemRefType.isDynamicDim(i)) { - dimSize = createIndexConstant(rewriter, loc, - targetMemRefType.getDimSize(i)); + dimSize = createIndexAttrConstant(rewriter, loc, indexType, + targetMemRefType.getDimSize(i)); } else { Value shapeOp = reshapeOp.getShape(); - Value index = createIndexConstant(rewriter, loc, i); + Value index = createIndexAttrConstant(rewriter, loc, indexType, i); dimSize = rewriter.create(loc, shapeOp, index); Type indexType = getIndexType(); if (dimSize.getType() != indexType) @@ -1444,7 +1455,7 @@ Value targetStridesBase = UnrankedMemRefDescriptor::strideBasePtr( rewriter, loc, *getTypeConverter(), targetSizesBase, resultRank); Value shapeOperandPtr = shapeDesc.alignedPtr(rewriter, loc); - Value oneIndex = createIndexConstant(rewriter, loc, 1); + Value oneIndex = createIndexAttrConstant(rewriter, loc, getIndexType(), 1); Value resultRankMinusOne = rewriter.create(loc, resultRank, oneIndex); @@ -1466,7 +1477,7 @@ Value indexArg = condBlock->getArgument(0); Value strideArg = condBlock->getArgument(1); - Value zeroIndex = createIndexConstant(rewriter, loc, 0); + Value zeroIndex = createIndexAttrConstant(rewriter, loc, indexType, 0); Value pred = rewriter.create( loc, IntegerType::get(rewriter.getContext(), 1), LLVM::ICmpPredicate::sge, indexArg, zeroIndex); @@ -1604,11 +1615,11 @@ // Build and return the value for the idx^th shape dimension, either by // returning the constant shape dimension or counting the proper dynamic size. Value getSize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef shape, ValueRange dynamicSizes, - unsigned idx) const { + ArrayRef shape, ValueRange dynamicSizes, unsigned idx, + Type indexType) const { assert(idx < shape.size()); if (!ShapedType::isDynamic(shape[idx])) - return createIndexConstant(rewriter, loc, shape[idx]); + return createIndexAttrConstant(rewriter, loc, indexType, shape[idx]); // Count the number of dynamic dims in range [0, idx] unsigned nDynamic = llvm::count_if(shape.take_front(idx), ShapedType::isDynamic); @@ -1621,16 +1632,16 @@ // result returned by this function. Value getStride(ConversionPatternRewriter &rewriter, Location loc, ArrayRef strides, Value nextSize, - Value runningStride, unsigned idx) const { + Value runningStride, unsigned idx, Type indexType) const { assert(idx < strides.size()); if (!ShapedType::isDynamic(strides[idx])) - return createIndexConstant(rewriter, loc, strides[idx]); + return createIndexAttrConstant(rewriter, loc, indexType, strides[idx]); if (nextSize) return runningStride ? rewriter.create(loc, runningStride, nextSize) : nextSize; assert(!runningStride); - return createIndexConstant(rewriter, loc, 1); + return createIndexAttrConstant(rewriter, loc, indexType, 1); } LogicalResult @@ -1697,11 +1708,13 @@ targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr); - // Field 3: The offset in the resulting type must be 0. This is because of - // the type change: an offset on srcType* may not be expressible as an - // offset on dstType*. - targetMemRef.setOffset(rewriter, loc, - createIndexConstant(rewriter, loc, offset)); + Type indexType = getIndexType(); + // Field 3: The offset in the resulting type must be 0. This is + // because of the type change: an offset on srcType* may not be + // expressible as an offset on dstType*. + targetMemRef.setOffset( + rewriter, loc, + createIndexAttrConstant(rewriter, loc, indexType, offset)); // Early exit for 0-D corner case. if (viewMemRefType.getRank() == 0) @@ -1712,10 +1725,11 @@ for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { // Update size. Value size = getSize(rewriter, loc, viewMemRefType.getShape(), - adaptor.getSizes(), i); + adaptor.getSizes(), i, indexType); targetMemRef.setSize(rewriter, loc, i, size); // Update stride. - stride = getStride(rewriter, loc, strides, nextSize, stride, i); + stride = + getStride(rewriter, loc, strides, nextSize, stride, i, indexType); targetMemRef.setStride(rewriter, loc, i, stride); nextSize = size; }