diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -493,7 +493,8 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts std.zexti to spv.Select if the type of source is i1. +/// Converts std.zexti to spv.Select if the type of source is i1 or vector of +/// i1. class ZeroExtendI1Pattern final : public SPIRVOpLowering { public: using SPIRVOpLowering::SPIRVOpLowering; @@ -502,13 +503,21 @@ matchAndRewrite(ZeroExtendIOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto srcType = operands.front().getType(); - if (!srcType.isSignlessInteger() || srcType.getIntOrFloatBitWidth() != 1) + if (!isBoolScalarOrVector(srcType)) return failure(); auto dstType = this->typeConverter.convertType(op.getResult().getType()); Location loc = op.getLoc(); - Value zero = rewriter.create(loc, 0, dstType); - Value one = rewriter.create(loc, 1, dstType); + Attribute zeroAttr, oneAttr; + if (auto vectorType = dstType.dyn_cast()) { + zeroAttr = DenseElementsAttr::get(vectorType, 0); + oneAttr = DenseElementsAttr::get(vectorType, 1); + } else { + zeroAttr = IntegerAttr::get(dstType, 0); + oneAttr = IntegerAttr::get(dstType, 1); + } + Value zero = rewriter.create(loc, zeroAttr); + Value one = rewriter.create(loc, oneAttr); rewriter.template replaceOpWithNewOp( op, dstType, operands.front(), one, zero); return success(); @@ -526,7 +535,7 @@ ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1); auto srcType = operands.front().getType(); - if (srcType.isSignlessInteger() && srcType.getIntOrFloatBitWidth() == 1) + if (isBoolScalarOrVector(srcType)) return failure(); auto dstType = this->typeConverter.convertType(operation.getResult().getType()); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -578,6 +578,15 @@ return %0 : i32 } +// CHECK-LABEL: @zexti4 +func @zexti4(%arg0 : vector<4xi1>) -> vector<4xi32> { + // CHECK: %[[ZERO:.+]] = spv.constant dense<0> : vector<4xi32> + // CHECK: %[[ONE:.+]] = spv.constant dense<1> : vector<4xi32> + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : vector<4xi1>, vector<4xi32> + %0 = std.zexti %arg0 : vector<4xi1> to vector<4xi32> + return %0 : vector<4xi32> +} + // CHECK-LABEL: @trunci1 func @trunci1(%arg0 : i64) -> i16 { // CHECK: spv.SConvert %{{.*}} : i64 to i16