diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -497,6 +497,41 @@ } }; +/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of +/// i1. +class UIToFPI1Pattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UIToFPOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto srcType = operands.front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Location loc = op.getLoc(); + Attribute zeroAttr, oneAttr; + if (auto vectorType = dstType.dyn_cast()) { + Type elemType = vectorType.getElementType(); + zeroAttr = DenseFPElementsAttr::get( + vectorType, FloatAttr::get(elemType, 0.0).getValue()); + oneAttr = DenseFPElementsAttr::get( + vectorType, FloatAttr::get(elemType, 1.0).getValue()); + } else { + zeroAttr = FloatAttr::get(dstType, 0.0); + oneAttr = FloatAttr::get(dstType, 1.0); + } + Value zero = rewriter.create(loc, zeroAttr); + Value one = rewriter.create(loc, oneAttr); + rewriter.template replaceOpWithNewOp( + op, dstType, operands.front(), one, zero); + return success(); + } +}; + /// Converts type-casting standard operations to SPIR-V operations. template class TypeCastingOpPattern final : public OpConversionPattern { @@ -1096,8 +1131,10 @@ ReturnOpPattern, SelectOpPattern, // Type cast patterns - ZeroExtendI1Pattern, TypeCastingOpPattern, + UIToFPI1Pattern, ZeroExtendI1Pattern, + TypeCastingOpPattern, TypeCastingOpPattern, + TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -564,6 +564,38 @@ return %0 : f64 } +// CHECK-LABEL: @uitofp1 +func @uitofp1(%arg0: i16) -> f32 { + // CHECK: spv.ConvertUToF %{{.*}} : i16 to f32 + %0 = std.uitofp %arg0 : i16 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp2 +func @uitofp2(%arg0 : i32) -> f32 { + // CHECK: spv.ConvertUToF %{{.*}} : i32 to f32 + %0 = std.uitofp %arg0 : i32 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp3 +func @uitofp3(%arg0 : i1) -> f32 { + // CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f32 + // CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f32 + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f32 + %0 = std.uitofp %arg0 : i1 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp4 +func @uitofp4(%arg0 : vector<4xi1>) -> vector<4xf32> { + // CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf32> + %0 = std.uitofp %arg0 : vector<4xi1> to vector<4xf32> + return %0 : vector<4xf32> +} + // CHECK-LABEL: @zexti1 func @zexti1(%arg0: i16) -> i64 { // CHECK: spv.UConvert %{{.*}} : i16 to i64