diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -31,6 +31,15 @@ // Utility functions //===----------------------------------------------------------------------===// +/// Returns true if the given type is a signed integer or vector type. +static bool isSignedIntegerOrVector(Type type) { + if (type.isSignedInteger()) + return true; + if (auto vecType = type.dyn_cast()) + return vecType.getElementType().isSignedInteger(); + return false; +} + /// Returns true if the given type is an unsigned integer or vector type static bool isUnsignedIntegerOrVector(Type type) { if (type.isUnsignedInteger()) @@ -59,6 +68,66 @@ namespace { +/// Converts SPIR-V ConstantOp with scalar or vector type. +class ConstantScalarAndVectorPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::ConstantOp constOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto srcType = constOp.getType(); + if (!srcType.isa() && !srcType.isIntOrFloat()) + return failure(); + + auto dstType = this->typeConverter.convertType(srcType); + if (!dstType) + return failure(); + + // SPIR-V constant can be a signed/unsigned integer, which has to be + // casted to signless integer when converting to LLVM dialect. Removing the + // sign bit may have unexpected behaviour. However, it is better to handle + // it case-by-case, given that the purpose of the conversion is not to + // cover all possible corner cases. + if (isSignedIntegerOrVector(srcType) || + isUnsignedIntegerOrVector(srcType)) { + auto *context = rewriter.getContext(); + auto signlessType = IntegerType::get(getBitWidth(srcType), context); + + if (srcType.isa()) { + + auto dstElementsAttr = constOp.value().dyn_cast(); + SmallVector elements; + for (Attribute srcAttr : dstElementsAttr.getAttributeValues()) { + auto dstAttr = rewriter.getIntegerAttr( + signlessType, + srcAttr.cast().getValue().getSExtValue()); + if (!dstAttr) + return failure(); + elements.push_back(dstAttr); + } + + dstElementsAttr = DenseElementsAttr::get( + VectorType::get(dstElementsAttr.getType().getShape(), signlessType), + elements); + rewriter.replaceOpWithNewOp(constOp, dstType, + dstElementsAttr); + return success(); + } + + auto srcAttr = constOp.value().cast(); + auto dstAttr = rewriter.getIntegerAttr(signlessType, + srcAttr.getValue().getSExtValue()); + rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); + return success(); + } + + rewriter.replaceOpWithNewOp(constOp, dstType, operands, + constOp.getAttrs()); + return success(); + } +}; + /// Converts SPIR-V operations that have straightforward LLVM equivalent /// into LLVM dialect operations. template @@ -381,6 +450,9 @@ IComparePattern, IComparePattern, + // Constant op + ConstantScalarAndVectorPattern, + // Logical ops DirectConversionPattern, DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/constant-op-to-llvm.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.constant +//===----------------------------------------------------------------------===// + +func @bool_constant_scalar() { + // CHECK: {{.*}} = llvm.mlir.constant(true) : !llvm.i1 + %0 = spv.constant true + // CHECK: {{.*}} = llvm.mlir.constant(false) : !llvm.i1 + %1 = spv.constant false + return +} + +func @bool_constant_vector() { + // CHECK: {{.*}} = llvm.mlir.constant(dense<[true, false]> : vector<2xi1>) : !llvm<"<2 x i1>"> + %0 = constant dense<[true, false]> : vector<2xi1> + // CHECK: {{.*}} = llvm.mlir.constant(dense : vector<3xi1>) : !llvm<"<3 x i1>"> + %1 = constant dense : vector<3xi1> + return +} + +func @integer_constant_scalar() { + // CHECK: {{.*}} = llvm.mlir.constant(0 : i8) : !llvm.i8 + %0 = spv.constant 0 : i8 + // CHECK: {{.*}} = llvm.mlir.constant(-5 : i64) : !llvm.i64 + %1 = spv.constant -5 : si64 + // CHECK: {{.*}} = llvm.mlir.constant(10 : i16) : !llvm.i16 + %2 = spv.constant 10 : ui16 + return +} + +func @integer_constant_vector() { + // CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3]> : vector<2xi32>) : !llvm<"<2 x i32>"> + %0 = spv.constant dense<[2, 3]> : vector<2xi32> + // CHECK: {{.*}} = llvm.mlir.constant(dense<-4> : vector<2xi32>) : !llvm<"<2 x i32>"> + %1 = spv.constant dense<-4> : vector<2xsi32> + // CHECK: {{.*}} = llvm.mlir.constant(dense<[2, 3, 4]> : vector<3xi32>) : !llvm<"<3 x i32>"> + %2 = spv.constant dense<[2, 3, 4]> : vector<3xui32> + return +} + +func @float_constant_scalar() { + // CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f16) : !llvm.half + %0 = spv.constant 5.000000e+00 : f16 + // CHECK: {{.*}} = llvm.mlir.constant(5.000000e+00 : f64) : !llvm.double + %1 = spv.constant 5.000000e+00 : f64 + return +} + +func @float_constant_vector() { + // CHECK: {{.*}} = llvm.mlir.constant(dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32>) : !llvm<"<2 x float>"> + %0 = spv.constant dense<[2.000000e+00, 3.000000e+00]> : vector<2xf32> + return +} \ No newline at end of file