diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -37,6 +37,19 @@ return false; } +/// Returns the bit width of integer, float or vector of float or integer values +static unsigned getBitWidth(Type type) { + assert((type.isIntOrFloat() || type.isa()) && + "bitwidth is not supported for this type"); + if (type.isIntOrFloat()) + return type.getIntOrFloatBitWidth(); + auto vecType = type.dyn_cast(); + auto elementType = vecType.getElementType(); + assert(elementType.isIntOrFloat() && + "only integers and floats have a bitwidth"); + return elementType.getIntOrFloatBitWidth(); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -61,6 +74,38 @@ } }; +/// Converts SPIR-V cast ops that do not have straightforward LLVM +/// equivalent in LLVM dialect. +template +class IndirectCastPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(SPIRVOp operation, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + Type fromType = operation.operand().getType(); + Type toType = operation.getType(); + + auto dstType = this->typeConverter.convertType(toType); + if (!dstType) + return failure(); + + if (getBitWidth(fromType) < getBitWidth(toType)) { + rewriter.template replaceOpWithNewOp(operation, dstType, + operands); + return success(); + } + if (getBitWidth(fromType) > getBitWidth(toType)) { + rewriter.template replaceOpWithNewOp(operation, dstType, + operands); + return success(); + } + return failure(); + } +}; + /// Converts SPIR-V floating-point comparisons to llvm.fcmp "predicate" template class FComparePattern : public SPIRVToLLVMConversion { @@ -168,12 +213,22 @@ DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, + DirectConversionPattern, // Bitwise ops DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, + // Cast ops + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + DirectConversionPattern, + IndirectCastPattern, + IndirectCastPattern, + IndirectCastPattern, + // Comparison ops IComparePattern, IComparePattern, @@ -199,6 +254,12 @@ IComparePattern, IComparePattern, + // Logical ops + DirectConversionPattern, + DirectConversionPattern, + IComparePattern, + IComparePattern, + // Shift ops ShiftPattern, ShiftPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/arithmetic-ops-to-llvm.mlir @@ -145,6 +145,22 @@ } //===----------------------------------------------------------------------===// +// spv.UMod +//===----------------------------------------------------------------------===// + +func @umod_scalar(%arg0: i32, %arg1: i32) { + // CHECK: %{{.*}} = llvm.urem %{{.*}}, %{{.*}} : !llvm.i32 + %0 = spv.UMod %arg0, %arg1 : i32 + return +} + +func @umod_vector(%arg0: vector<3xi64>, %arg1: vector<3xi64>) { + // CHECK: %{{.*}} = llvm.urem %{{.*}}, %{{.*}} : !llvm<"<3 x i64>"> + %0 = spv.UMod %arg0, %arg1 : vector<3xi64> + return +} + +//===----------------------------------------------------------------------===// // spv.SDiv //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/cast-ops-to-llvm.mlir @@ -0,0 +1,131 @@ +// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.ConvertFToS +//===----------------------------------------------------------------------===// + +func @convert_float_to_signed_scalar(%arg0: f32) { + // CHECK: %{{.*}} = llvm.fptosi %{{.*}} : !llvm.float to !llvm.i32 + %0 = spv.ConvertFToS %arg0: f32 to i32 + return +} + +func @convert_float_to_signed_vector(%arg0: vector<2xf32>) { + // CHECK: %{{.*}} = llvm.fptosi %{{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x i32>"> + %0 = spv.ConvertFToS %arg0: vector<2xf32> to vector<2xi32> + return +} + +//===----------------------------------------------------------------------===// +// spv.ConvertFToU +//===----------------------------------------------------------------------===// + +func @convert_float_to_unsigned_scalar(%arg0: f32) { + // CHECK: %{{.*}} = llvm.fptoui %{{.*}} : !llvm.float to !llvm.i32 + %0 = spv.ConvertFToU %arg0: f32 to i32 + return +} + +func @convert_float_to_unsigned_vector(%arg0: vector<2xf32>) { + // CHECK: %{{.*}} = llvm.fptoui %{{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x i32>"> + %0 = spv.ConvertFToU %arg0: vector<2xf32> to vector<2xi32> + return +} + +//===----------------------------------------------------------------------===// +// spv.ConvertSToF +//===----------------------------------------------------------------------===// + +func @convert_signed_to_float_scalar(%arg0: i32) { + // CHECK: %{{.*}} = llvm.sitofp %{{.*}} : !llvm.i32 to !llvm.float + %0 = spv.ConvertSToF %arg0: i32 to f32 + return +} + +func @convert_signed_to_float_vector(%arg0: vector<3xi32>) { + // CHECK: %{{.*}} = llvm.sitofp %{{.*}} : !llvm<"<3 x i32>"> to !llvm<"<3 x float>"> + %0 = spv.ConvertSToF %arg0: vector<3xi32> to vector<3xf32> + return +} + +//===----------------------------------------------------------------------===// +// spv.ConvertUToF +//===----------------------------------------------------------------------===// + +func @convert_unsigned_to_float_scalar(%arg0: i32) { + // CHECK: %{{.*}} = llvm.uitofp %{{.*}} : !llvm.i32 to !llvm.float + %0 = spv.ConvertUToF %arg0: i32 to f32 + return +} + +func @convert_unsigned_to_float_vector(%arg0: vector<3xi32>) { + // CHECK: %{{.*}} = llvm.uitofp %{{.*}} : !llvm<"<3 x i32>"> to !llvm<"<3 x float>"> + %0 = spv.ConvertUToF %arg0: vector<3xi32> to vector<3xf32> + return +} + +//===----------------------------------------------------------------------===// +// spv.FConvert +//===----------------------------------------------------------------------===// + +func @fconvert_scalar(%arg0: f32, %arg1: f64) { + // CHECK: %{{.*}} = llvm.fpext %{{.*}} : !llvm.float to !llvm.double + %0 = spv.FConvert %arg0: f32 to f64 + + // CHECK: %{{.*}} = llvm.fptrunc %{{.*}} : !llvm.double to !llvm.float + %1 = spv.FConvert %arg1: f64 to f32 + return +} + +func @fconvert_vector(%arg0: vector<2xf32>, %arg1: vector<2xf64>) { + // CHECK: %{{.*}} = llvm.fpext %{{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x double>"> + %0 = spv.FConvert %arg0: vector<2xf32> to vector<2xf64> + + // CHECK: %{{.*}} = llvm.fptrunc %{{.*}} : !llvm<"<2 x double>"> to !llvm<"<2 x float>"> + %1 = spv.FConvert %arg1: vector<2xf64> to vector<2xf32> + return +} + +//===----------------------------------------------------------------------===// +// spv.SConvert +//===----------------------------------------------------------------------===// + +func @sconvert_scalar(%arg0: i32, %arg1: i64) { + // CHECK: %{{.*}} = llvm.sext %{{.*}} : !llvm.i32 to !llvm.i64 + %0 = spv.SConvert %arg0: i32 to i64 + + // CHECK: %{{.*}} = llvm.trunc %{{.*}} : !llvm.i64 to !llvm.i32 + %1 = spv.SConvert %arg1: i64 to i32 + return +} + +func @sconvert_vector(%arg0: vector<3xi32>, %arg1: vector<3xi64>) { + // CHECK: %{{.*}} = llvm.sext %{{.*}} : !llvm<"<3 x i32>"> to !llvm<"<3 x i64>"> + %0 = spv.SConvert %arg0: vector<3xi32> to vector<3xi64> + + // CHECK: %{{.*}} = llvm.trunc %{{.*}} : !llvm<"<3 x i64>"> to !llvm<"<3 x i32>"> + %1 = spv.SConvert %arg1: vector<3xi64> to vector<3xi32> + return +} + +//===----------------------------------------------------------------------===// +// spv.UConvert +//===----------------------------------------------------------------------===// + +func @uconvert_scalar(%arg0: i32, %arg1: i64) { + // CHECK: %{{.*}} = llvm.zext %{{.*}} : !llvm.i32 to !llvm.i64 + %0 = spv.UConvert %arg0: i32 to i64 + + // CHECK: %{{.*}} = llvm.trunc %{{.*}} : !llvm.i64 to !llvm.i32 + %1 = spv.UConvert %arg1: i64 to i32 + return +} + +func @uconvert_vector(%arg0: vector<3xi32>, %arg1: vector<3xi64>) { + // CHECK: %{{.*}} = llvm.zext %{{.*}} : !llvm<"<3 x i32>"> to !llvm<"<3 x i64>"> + %0 = spv.UConvert %arg0: vector<3xi32> to vector<3xi64> + + // CHECK: %{{.*}} = llvm.trunc %{{.*}} : !llvm<"<3 x i64>"> to !llvm<"<3 x i32>"> + %1 = spv.UConvert %arg1: vector<3xi64> to vector<3xi32> + return +} diff --git a/mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir @@ -0,0 +1,65 @@ +// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// spv.LogicalEqual +//===----------------------------------------------------------------------===// + +func @logical_equal_scalar(%arg0: i1, %arg1: i1) { + // CHECK: %{{.*}} = llvm.icmp "eq" %{{.*}}, %{{.*}} : !llvm.i1 + %0 = spv.LogicalEqual %arg0, %arg0 : i1 + return +} + +func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) { + // CHECK: %{{.*}} = llvm.icmp "eq" %{{.*}}, %{{.*}} : !llvm<"<4 x i1>"> + %0 = spv.LogicalEqual %arg0, %arg0 : vector<4xi1> + return +} + +//===----------------------------------------------------------------------===// +// spv.LogicalNotEqual +//===----------------------------------------------------------------------===// + +func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) { + // CHECK: %{{.*}} = llvm.icmp "ne" %{{.*}}, %{{.*}} : !llvm.i1 + %0 = spv.LogicalNotEqual %arg0, %arg0 : i1 + return +} + +func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) { + // CHECK: %{{.*}} = llvm.icmp "ne" %{{.*}}, %{{.*}} : !llvm<"<4 x i1>"> + %0 = spv.LogicalNotEqual %arg0, %arg0 : vector<4xi1> + return +} + +//===----------------------------------------------------------------------===// +// spv.LogicalAnd +//===----------------------------------------------------------------------===// + +func @logical_and_scalar(%arg0: i1, %arg1: i1) { + // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm.i1 + %0 = spv.LogicalAnd %arg0, %arg0 : i1 + return +} + +func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) { + // CHECK: %{{.*}} = llvm.and %{{.*}}, %{{.*}} : !llvm<"<4 x i1>"> + %0 = spv.LogicalAnd %arg0, %arg0 : vector<4xi1> + return +} + +//===----------------------------------------------------------------------===// +// spv.LogicalOr +//===----------------------------------------------------------------------===// + +func @logical_or_scalar(%arg0: i1, %arg1: i1) { + // CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm.i1 + %0 = spv.LogicalOr %arg0, %arg0 : i1 + return +} + +func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) { + // CHECK: %{{.*}} = llvm.or %{{.*}}, %{{.*}} : !llvm<"<4 x i1>"> + %0 = spv.LogicalOr %arg0, %arg0 : vector<4xi1> + return +}