diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -481,16 +481,32 @@ auto dstType = this->getTypeConverter()->convertType(op.getResult().getType()); Location loc = op.getLoc(); - Attribute zeroAttr, oneAttr; - if (auto vectorType = dstType.dyn_cast()) { - zeroAttr = DenseElementsAttr::get(vectorType, 0); - oneAttr = DenseElementsAttr::get(vectorType, 1); - } else { - zeroAttr = IntegerAttr::get(dstType, 0); - oneAttr = IntegerAttr::get(dstType, 1); - } - Value zero = rewriter.create(loc, zeroAttr); - Value one = rewriter.create(loc, oneAttr); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + rewriter.template replaceOpWithNewOp( + op, dstType, operands.front(), one, zero); + return success(); + } +}; + +/// Converts std.uitofp to spv.Select if the type of source is i1 or vector of +/// i1. +class UIToFPI1Pattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(UIToFPOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto srcType = operands.front().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + + auto dstType = + this->getTypeConverter()->convertType(op.getResult().getType()); + Location loc = op.getLoc(); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.template replaceOpWithNewOp( op, dstType, operands.front(), one, zero); return success(); @@ -1096,8 +1112,10 @@ ReturnOpPattern, SelectOpPattern, // Type cast patterns - ZeroExtendI1Pattern, TypeCastingOpPattern, + UIToFPI1Pattern, ZeroExtendI1Pattern, + TypeCastingOpPattern, TypeCastingOpPattern, + TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -25,6 +25,8 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CallInterfaces.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" @@ -1580,6 +1582,22 @@ builder.getBoolAttr(false)); return builder.create( loc, type, builder.getIntegerAttr(type, APInt(width, 0))); + } else if (auto floatType = type.dyn_cast()) { + return builder.create( + loc, type, builder.getFloatAttr(floatType, 0.0)); + } else if (auto vectorType = type.dyn_cast()) { + Type elemType = vectorType.getElementType(); + if (elemType.isa()) { + return builder.create( + loc, type, + DenseElementsAttr::get(vectorType, + IntegerAttr::get(elemType, 0.0).getValue())); + } else if (elemType.isa()) { + return builder.create( + loc, type, + DenseFPElementsAttr::get(vectorType, + FloatAttr::get(elemType, 0.0).getValue())); + } } llvm_unreachable("unimplemented types for ConstantOp::getZero()"); @@ -1594,6 +1612,22 @@ builder.getBoolAttr(true)); return builder.create( loc, type, builder.getIntegerAttr(type, APInt(width, 1))); + } else if (auto floatType = type.dyn_cast()) { + return builder.create( + loc, type, builder.getFloatAttr(floatType, 1.0)); + } else if (auto vectorType = type.dyn_cast()) { + Type elemType = vectorType.getElementType(); + if (elemType.isa()) { + return builder.create( + loc, type, + DenseElementsAttr::get(vectorType, + IntegerAttr::get(elemType, 1.0).getValue())); + } else if (elemType.isa()) { + return builder.create( + loc, type, + DenseFPElementsAttr::get(vectorType, + FloatAttr::get(elemType, 1.0).getValue())); + } } llvm_unreachable("unimplemented types for ConstantOp::getOne()"); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -564,6 +564,58 @@ return %0 : f64 } +// CHECK-LABEL: @uitofp_i16_f32 +func @uitofp_i16_f32(%arg0: i16) -> f32 { + // CHECK: spv.ConvertUToF %{{.*}} : i16 to f32 + %0 = std.uitofp %arg0 : i16 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp_i32_f32 +func @uitofp_i32_f32(%arg0 : i32) -> f32 { + // CHECK: spv.ConvertUToF %{{.*}} : i32 to f32 + %0 = std.uitofp %arg0 : i32 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp_i1_f32 +func @uitofp_i1_f32(%arg0 : i1) -> f32 { + // CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f32 + // CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f32 + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f32 + %0 = std.uitofp %arg0 : i1 to f32 + return %0 : f32 +} + +// CHECK-LABEL: @uitofp_i1_f64 +func @uitofp_i1_f64(%arg0 : i1) -> f64 { + // CHECK: %[[ZERO:.+]] = spv.constant 0.000000e+00 : f64 + // CHECK: %[[ONE:.+]] = spv.constant 1.000000e+00 : f64 + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, f64 + %0 = std.uitofp %arg0 : i1 to f64 + return %0 : f64 +} + +// CHECK-LABEL: @uitofp_vec_i1_f32 +func @uitofp_vec_i1_f32(%arg0 : vector<4xi1>) -> vector<4xf32> { + // CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf32> + // CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf32> + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf32> + %0 = std.uitofp %arg0 : vector<4xi1> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: @uitofp_vec_i1_f64 +spv.func @uitofp_vec_i1_f64(%arg0: vector<4xi1>) -> vector<4xf64> "None" { + // CHECK: %[[ZERO:.+]] = spv.constant dense<0.000000e+00> : vector<4xf64> + // CHECK: %[[ONE:.+]] = spv.constant dense<1.000000e+00> : vector<4xf64> + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xf64> + %0 = spv.constant dense<0.000000e+00> : vector<4xf64> + %1 = spv.constant dense<1.000000e+00> : vector<4xf64> + %2 = spv.Select %arg0, %1, %0 : vector<4xi1>, vector<4xf64> + spv.ReturnValue %2 : vector<4xf64> +} + // CHECK-LABEL: @zexti1 func @zexti1(%arg0: i16) -> i64 { // CHECK: spv.UConvert %{{.*}} : i16 to i64 @@ -596,6 +648,15 @@ return %0 : vector<4xi32> } +// CHECK-LABEL: @zexti5 +func @zexti5(%arg0 : vector<4xi1>) -> vector<4xi64> { + // CHECK: %[[ZERO:.+]] = spv.constant dense<0> : vector<4xi64> + // CHECK: %[[ONE:.+]] = spv.constant dense<1> : vector<4xi64> + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xi64> + %0 = std.zexti %arg0 : vector<4xi1> to vector<4xi64> + return %0 : vector<4xi64> +} + // CHECK-LABEL: @trunci1 func @trunci1(%arg0 : i64) -> i16 { // CHECK: spv.SConvert %{{.*}} : i64 to i16