diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -53,6 +53,17 @@ return elementType.getIntOrFloatBitWidth(); } +/// Returns the bit width of integer or vector value of LLVM or SPIR-V type +static unsigned getValueBitWidth(Value value) { + if (auto llvmType = value.getType().dyn_cast()) + return llvmType.isVectorTy() + ? llvmType.getVectorElementType() + .getUnderlyingType() + ->getIntegerBitWidth() + : llvmType.getUnderlyingType()->getIntegerBitWidth(); + return getBitWidth(value.getType()); +} + /// Creates `IntegerAttribute` with all bits set for given type IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { if (auto vecType = type.dyn_cast()) { @@ -63,12 +74,128 @@ return builder.getIntegerAttr(integerType, -1); } +/// Creates `llvm.mlir.constant` with all bits set for the given type +static Value createConstantAllBitsSet(Location loc, Type srcType, Type dstType, + ConversionPatternRewriter &rewriter) { + if (srcType.isa()) + return rewriter.create( + loc, dstType, + SplatElementsAttr::get(srcType.cast(), + minusOneIntegerAttribute(srcType, rewriter))); + return rewriter.create( + loc, dstType, minusOneIntegerAttribute(srcType, rewriter)); +} + +/// This is a utility function for bit manipulations ops (`BitFieldInsert`) +/// and operates on their `Count` or `Offset` values. It casts the given +/// value to match the target type. +static Value optionallyCast(Location loc, Value value, Type dstType, + ConversionPatternRewriter &rewriter) { + auto llvmType = dstType.cast(); + unsigned targetBitWidth = + llvmType.isVectorTy() + ? llvmType.getVectorElementType() + .getUnderlyingType() + ->getIntegerBitWidth() + : llvmType.getUnderlyingType()->getIntegerBitWidth(); + unsigned valueBitWidth = getValueBitWidth(value); + + if (valueBitWidth < targetBitWidth) + return rewriter.create(loc, llvmType, value); + // If the bit widths of `Count` and `Offset` are greater than the bit width + // of the target type, they are truncated. Truncation is safe since `Count` + // and `Offset` must be no more than 64 for op behaviour to be defined. Hence, + // both values can be expressed in 8 bits. + if (valueBitWidth > targetBitWidth) + return rewriter.create(loc, llvmType, value); + return value; +} + +/// Broadcasts the value to vector with `numElements` number of elements +static void broadcast(Location loc, Value toBroadcast, Value &broadcasted, + int64_t numElements, LLVMTypeConverter &typeConverter, + ConversionPatternRewriter &rewriter) { + auto vectorType = VectorType::get(numElements, toBroadcast.getType()); + auto llvmVectorType = typeConverter.convertType(vectorType); + broadcasted = rewriter.create(loc, llvmVectorType); + for (int32_t i = 0; i < vectorType.getNumElements(); ++i) { + auto index = rewriter.create( + loc, typeConverter.convertType(rewriter.getIntegerType(32)), + rewriter.getI32IntegerAttr(i)); + broadcasted = rewriter.create( + loc, llvmVectorType, broadcasted, toBroadcast, index); + } +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// namespace { +class BitFieldInsertPattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto srcType = op.getType(); + auto dstType = this->typeConverter.convertType(srcType); + if (!dstType) + return failure(); + Location loc = op.getLoc(); + + // Broadcast `Offset` and `Count` to match the type of `Base` and `Insert`. + // If `Base` is of a vector type, construct a vector that has: + // - same number of elements as `Base` + // - each element has the type that is the same as the type of `Offset` or + // `Count` + // - each element has the same value as `Offset` or `Count` + Value offset; + Value count; + if (auto vectorType = srcType.dyn_cast()) { + int64_t numElements = static_cast(vectorType.getNumElements()); + broadcast(loc, op.offset(), offset, numElements, typeConverter, rewriter); + broadcast(loc, op.count(), count, numElements, typeConverter, rewriter); + } else { + offset = op.offset(); + count = op.count(); + } + + // Create a mask with all bits set of the same type as `srcType` + Value minusOne = createConstantAllBitsSet(loc, srcType, dstType, rewriter); + + // Need to cast `Offset` and `Count` if their bit width is different + // from `Base` bit width. + Value optionallyCastedCount = optionallyCast(loc, count, dstType, rewriter); + Value optionallyCastedOffset = + optionallyCast(loc, offset, dstType, rewriter); + + // Create a mask with bits set outside [Offset, Offset + Count - 1]. + Value maskShiftedByCount = rewriter.create( + loc, dstType, minusOne, optionallyCastedCount); + Value negated = rewriter.create(loc, dstType, + maskShiftedByCount, minusOne); + Value maskShiftedByCountAndOffset = rewriter.create( + loc, dstType, negated, optionallyCastedOffset); + Value mask = rewriter.create( + loc, dstType, maskShiftedByCountAndOffset, minusOne); + + // Extract unchanged bits from the `Base` that are outside of + // [Offset, Offset + Count - 1]. Then `or` with shifted `Insert`. + Value baseAndMask = + rewriter.create(loc, dstType, op.base(), mask); + Value insertShiftedByOffset = rewriter.create( + loc, dstType, op.insert(), optionallyCastedOffset); + rewriter.create(loc, dstType, baseAndMask, + insertShiftedByOffset); + rewriter.eraseOp(op); + return success(); + } +}; + /// Converts SPIR-V operations that have straightforward LLVM equivalent /// into LLVM dialect operations. template @@ -379,6 +506,7 @@ DirectConversionPattern, // Bitwise ops + BitFieldInsertPattern, DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir @@ -1,5 +1,137 @@ // RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spv.BitFieldInsert +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @bitfield_insert_scalar_same_bit_width +// CHECK-SAME: %[[BASE:.*]]: !llvm.i32, %[[INSERT:.*]]: !llvm.i32, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i32 +func @bitfield_insert_scalar_same_bit_width(%base: i32, %insert: i32, %offset: i32, %count: i32) { + // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32 + // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[COUNT]] : !llvm.i32 + // CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i32 + // CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[OFFSET]] : !llvm.i32 + // CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm.i32 + // CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm.i32 + // CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[OFFSET]] : !llvm.i32 + // CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm.i32 + %0 = spv.BitFieldInsert %base, %insert, %offset, %count : i32, i32, i32 + return +} + +// CHECK-LABEL: func @bitfield_insert_scalar_smaller_bit_width +// CHECK-SAME: %[[BASE:.*]]: !llvm.i64, %[[INSERT:.*]]: !llvm.i64, %[[OFFSET:.*]]: !llvm.i8, %[[COUNT:.*]]: !llvm.i8 +func @bitfield_insert_scalar_smaller_bit_width(%base: i64, %insert: i64, %offset: i8, %count: i8) { + // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i64) : !llvm.i64 + // CHECK: %[[EXT_COUNT:.*]] = llvm.zext %[[COUNT]] : !llvm.i8 to !llvm.i64 + // CHECK: %[[EXT_OFFSET:.*]] = llvm.zext %[[OFFSET]] : !llvm.i8 to !llvm.i64 + // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[EXT_COUNT]] : !llvm.i64 + // CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i64 + // CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[EXT_OFFSET]] : !llvm.i64 + // CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm.i64 + // CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm.i64 + // CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[EXT_OFFSET]] : !llvm.i64 + // CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm.i64 + %0 = spv.BitFieldInsert %base, %insert, %offset, %count : i64, i8, i8 + return +} + +// CHECK-LABEL: func @bitfield_insert_scalar_greater_bit_width +// CHECK-SAME: %[[BASE:.*]]: !llvm.i16, %[[INSERT:.*]]: !llvm.i16, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i64 +func @bitfield_insert_scalar_greater_bit_width(%base: i16, %insert: i16, %offset: i32, %count: i64) { + // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(-1 : i16) : !llvm.i16 + // CHECK: %[[TRUNC_COUNT:.*]] = llvm.trunc %[[COUNT]] : !llvm.i64 to !llvm.i16 + // CHECK: %[[TRUNC_OFFSET:.*]] = llvm.trunc %[[OFFSET]] : !llvm.i32 to !llvm.i16 + // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[TRUNC_COUNT]] : !llvm.i16 + // CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm.i16 + // CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[TRUNC_OFFSET]] : !llvm.i16 + // CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm.i16 + // CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm.i16 + // CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[TRUNC_OFFSET]] : !llvm.i16 + // CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm.i16 + %0 = spv.BitFieldInsert %base, %insert, %offset, %count : i16, i32, i64 + return +} + +// CHECK-LABEL: func @bitfield_insert_vector_same_bit_width +// CHECK-SAME: %[[BASE:.*]]: !llvm<"<2 x i32>">, %[[INSERT:.*]]: !llvm<"<2 x i32>">, %[[OFFSET:.*]]: !llvm.i32, %[[COUNT:.*]]: !llvm.i32 +func @bitfield_insert_vector_same_bit_width(%base: vector<2xi32>, %insert: vector<2xi32>, %offset: i32, %count: i32) { + // CHECK: %[[OFFSET_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i32>"> + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: %[[OFFSET_V1:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i32>"> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 + // CHECK: %[[OFFSET_V2:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i32>"> + // CHECK: %[[COUNT_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i32>"> + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: %[[COUNT_V1:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i32>"> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 + // CHECK: %[[COUNT_V2:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i32>"> + // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(dense<-1> : vector<2xi32>) : !llvm<"<2 x i32>"> + // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[COUNT_V2]] : !llvm<"<2 x i32>"> + // CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm<"<2 x i32>"> + // CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[OFFSET_V2]] : !llvm<"<2 x i32>"> + // CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm<"<2 x i32>"> + // CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm<"<2 x i32>"> + // CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[OFFSET_V2]] : !llvm<"<2 x i32>"> + // CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm<"<2 x i32>"> + %0 = spv.BitFieldInsert %base, %insert, %offset, %count : vector<2xi32>, i32, i32 + return +} + +// CHECK-LABEL: func @bitfield_insert_vector_smaller_bit_width +// CHECK-SAME: %[[BASE:.*]]: !llvm<"<2 x i32>">, %[[INSERT:.*]]: !llvm<"<2 x i32>">, %[[OFFSET:.*]]: !llvm.i8, %[[COUNT:.*]]: !llvm.i8 +func @bitfield_insert_vector_smaller_bit_width(%base: vector<2xi32>, %insert: vector<2xi32>, %offset: i8, %count: i8) { + // CHECK: %[[OFFSET_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i8>"> + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: %[[OFFSET_V1:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i8>"> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 + // CHECK: %[[OFFSET_V2:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i8>"> + // CHECK: %[[COUNT_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i8>"> + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: %[[COUNT_V1:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i8>"> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 + // CHECK: %[[COUNT_V2:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i8>"> + // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(dense<-1> : vector<2xi32>) : !llvm<"<2 x i32>"> + // CHECK: %[[EXT_COUNT:.*]] = llvm.zext %[[COUNT_V2]] : !llvm<"<2 x i8>"> to !llvm<"<2 x i32>"> + // CHECK: %[[EXT_OFFSET:.*]] = llvm.zext %[[OFFSET_V2]] : !llvm<"<2 x i8>"> to !llvm<"<2 x i32>"> + // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[EXT_COUNT]] : !llvm<"<2 x i32>"> + // CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm<"<2 x i32>"> + // CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[EXT_OFFSET]] : !llvm<"<2 x i32>"> + // CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm<"<2 x i32>"> + // CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm<"<2 x i32>"> + // CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[EXT_OFFSET]] : !llvm<"<2 x i32>"> + // CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm<"<2 x i32>"> + %0 = spv.BitFieldInsert %base, %insert, %offset, %count : vector<2xi32>, i8, i8 + return +} + +// CHECK-LABEL: func @bitfield_insert_vector_greater_bit_width +// CHECK-SAME: %[[BASE:.*]]: !llvm<"<2 x i32>">, %[[INSERT:.*]]: !llvm<"<2 x i32>">, %[[OFFSET:.*]]: !llvm.i64, %[[COUNT:.*]]: !llvm.i64 +func @bitfield_insert_vector_greater_bit_width(%base: vector<2xi32>, %insert: vector<2xi32>, %offset: i64, %count: i64) { + // CHECK: %[[OFFSET_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i64>"> + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: %[[OFFSET_V1:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i64>"> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 + // CHECK: %[[OFFSET_V2:.*]] = llvm.insertelement %[[OFFSET]], %[[OFFSET_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i64>"> + // CHECK: %[[COUNT_V0:.*]] = llvm.mlir.undef : !llvm<"<2 x i64>"> + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: %[[COUNT_V1:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V0]][%[[ZERO]] : !llvm.i32] : !llvm<"<2 x i64>"> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 + // CHECK: %[[COUNT_V2:.*]] = llvm.insertelement %[[COUNT]], %[[COUNT_V1]][%[[ONE]] : !llvm.i32] : !llvm<"<2 x i64>"> + // CHECK: %[[MINUS_ONE:.*]] = llvm.mlir.constant(dense<-1> : vector<2xi32>) : !llvm<"<2 x i32>"> + // CHECK: %[[TRUNC_COUNT:.*]] = llvm.trunc %[[COUNT_V2]] : !llvm<"<2 x i64>"> to !llvm<"<2 x i32>"> + // CHECK: %[[TRUNC_OFFSET:.*]] = llvm.trunc %[[OFFSET_V2]] : !llvm<"<2 x i64>"> to !llvm<"<2 x i32>"> + // CHECK: %[[T0:.*]] = llvm.shl %[[MINUS_ONE]], %[[TRUNC_COUNT]] : !llvm<"<2 x i32>"> + // CHECK: %[[T1:.*]] = llvm.xor %[[T0]], %[[MINUS_ONE]] : !llvm<"<2 x i32>"> + // CHECK: %[[T2:.*]] = llvm.shl %[[T1]], %[[TRUNC_OFFSET]] : !llvm<"<2 x i32>"> + // CHECK: %[[MASK:.*]] = llvm.xor %[[T2]], %[[MINUS_ONE]] : !llvm<"<2 x i32>"> + // CHECK: %[[NEW_BASE:.*]] = llvm.and %[[BASE]], %[[MASK]] : !llvm<"<2 x i32>"> + // CHECK: %[[SHIFTED_INSERT:.*]] = llvm.shl %[[INSERT]], %[[TRUNC_OFFSET]] : !llvm<"<2 x i32>"> + // CHECK: %{{.*}} = llvm.or %[[NEW_BASE]], %[[SHIFTED_INSERT]] : !llvm<"<2 x i32>"> + %0 = spv.BitFieldInsert %base, %insert, %offset, %count : vector<2xi32>, i64, i64 + return +} + //===----------------------------------------------------------------------===// // spv.BitwiseAnd //===----------------------------------------------------------------------===//