diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -53,15 +53,12 @@ return elementType.getIntOrFloatBitWidth(); } -/// Returns the bit width of integer or vector value of LLVM or SPIR-V type -static unsigned getValueBitWidth(Value value) { - if (auto llvmType = value.getType().dyn_cast()) - return llvmType.isVectorTy() - ? llvmType.getVectorElementType() - .getUnderlyingType() - ->getIntegerBitWidth() - : llvmType.getUnderlyingType()->getIntegerBitWidth(); - return getBitWidth(value.getType()); +/// Returns the bit width of LLVMType integer or vector. +static unsigned getLLVMTypeBitWidth(LLVM::LLVMType type) { + return type.isVectorTy() ? type.getVectorElementType() + .getUnderlyingType() + ->getIntegerBitWidth() + : type.getUnderlyingType()->getIntegerBitWidth(); } /// Creates `IntegerAttribute` with all bits set for given type @@ -74,7 +71,7 @@ return builder.getIntegerAttr(integerType, -1); } -/// Creates `llvm.mlir.constant` with all bits set for the given type +/// Creates `llvm.mlir.constant` with all bits set for the given type. static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, ConversionPatternRewriter &rewriter) { if (srcType.isa()) @@ -86,19 +83,17 @@ loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); } -/// This is a utility function for bit manipulations ops (`BitFieldInsert`) -/// and operates on their `Count` or `Offset` values. It casts the given -/// value to match the target type. -static Value optionallyCast(Location loc, Value value, Type dstType, - ConversionPatternRewriter &rewriter) { +/// Truncates or extends the value. If the bitwidth of the value is the same +/// as `dstType` bitwidth, the value remains unchanged. +static Value optionallyTruncateOrExtend(Location loc, Value value, Type dstType, + ConversionPatternRewriter &rewriter) { + auto srcType = value.getType(); auto llvmType = dstType.cast(); - unsigned targetBitWidth = - llvmType.isVectorTy() - ? llvmType.getVectorElementType() - .getUnderlyingType() - ->getIntegerBitWidth() - : llvmType.getUnderlyingType()->getIntegerBitWidth(); - unsigned valueBitWidth = getValueBitWidth(value); + unsigned targetBitWidth = getLLVMTypeBitWidth(llvmType); + unsigned valueBitWidth = + srcType.isa() + ? getLLVMTypeBitWidth(srcType.cast()) + : getBitWidth(srcType); if (valueBitWidth < targetBitWidth) return rewriter.create(loc, llvmType, value); @@ -113,15 +108,15 @@ /// Broadcasts the value to vector with `numElements` number of elements static void broadcast(Location loc, Value toBroadcast, Value &broadcasted, - int64_t numElements, LLVMTypeConverter &typeConverter, + unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { auto vectorType = VectorType::get(numElements, toBroadcast.getType()); auto llvmVectorType = typeConverter.convertType(vectorType); + auto llvmI32Type = typeConverter.convertType(rewriter.getIntegerType(32)); broadcasted = rewriter.create(loc, llvmVectorType); - for (int32_t i = 0; i < vectorType.getNumElements(); ++i) { + for (unsigned i = 0; i < numElements; ++i) { auto index = rewriter.create( - loc, typeConverter.convertType(rewriter.getIntegerType(32)), - rewriter.getI32IntegerAttr(i)); + loc, llvmI32Type, rewriter.getI32IntegerAttr(i)); broadcasted = rewriter.create( loc, llvmVectorType, broadcasted, toBroadcast, index); } @@ -156,7 +151,7 @@ Value offset; Value count; if (auto vectorType = srcType.dyn_cast()) { - int64_t numElements = static_cast(vectorType.getNumElements()); + unsigned numElements = vectorType.getNumElements(); broadcast(loc, op.offset(), offset, numElements, typeConverter, rewriter); broadcast(loc, op.count(), count, numElements, typeConverter, rewriter); } else { @@ -169,9 +164,10 @@ // Need to cast `Offset` and `Count` if their bit width is different // from `Base` bit width. - Value optionallyCastedCount = optionallyCast(loc, count, dstType, rewriter); + Value optionallyCastedCount = + optionallyTruncateOrExtend(loc, count, dstType, rewriter); Value optionallyCastedOffset = - optionallyCast(loc, offset, dstType, rewriter); + optionallyTruncateOrExtend(loc, offset, dstType, rewriter); // Create a mask with bits set outside [Offset, Offset + Count - 1]. Value maskShiftedByCount = rewriter.create( @@ -189,9 +185,8 @@ rewriter.create(loc, dstType, op.base(), mask); Value insertShiftedByOffset = rewriter.create( loc, dstType, op.insert(), optionallyCastedOffset); - rewriter.create(loc, dstType, baseAndMask, - insertShiftedByOffset); - rewriter.eraseOp(op); + rewriter.replaceOpWithNewOp(op, dstType, baseAndMask, + insertShiftedByOffset); return success(); } };