diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -445,6 +445,28 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts std.zexti to spv.Select if the type of source is i1. +class ZeroExtendI1Pattern final : public SPIRVOpLowering { +public: + using SPIRVOpLowering::SPIRVOpLowering; + + LogicalResult + matchAndRewrite(ZeroExtendIOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto srcType = operands.front().getType(); + if (!srcType.isSignlessInteger() || srcType.getIntOrFloatBitWidth() != 1) + return failure(); + + auto dstType = this->typeConverter.convertType(op.getResult().getType()); + Location loc = op.getLoc(); + Value zero = rewriter.create(loc, 0, dstType); + Value one = rewriter.create(loc, 1, dstType); + rewriter.template replaceOpWithNewOp( + op, dstType, operands.front(), one, zero); + return success(); + } +}; + /// Converts type-casting standard operations to SPIR-V operations. template class TypeCastingOpPattern final : public SPIRVOpLowering { @@ -455,9 +477,12 @@ matchAndRewrite(StdOp operation, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1); + auto srcType = operands.front().getType(); + if (srcType.isSignlessInteger() && srcType.getIntOrFloatBitWidth() == 1) + return failure(); auto dstType = this->typeConverter.convertType(operation.getResult().getType()); - if (dstType == operands.front().getType()) { + if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. rewriter.replaceOp(operation, operands.front()); @@ -1012,7 +1037,7 @@ BoolCmpIOpPattern, ConstantCompositeOpPattern, ConstantScalarOpPattern, CmpFOpPattern, CmpIOpPattern, IntLoadOpPattern, LoadOpPattern, ReturnOpPattern, SelectOpPattern, IntStoreOpPattern, StoreOpPattern, - TypeCastingOpPattern, + ZeroExtendI1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, TypeCastingOpPattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -564,6 +564,15 @@ return %0 : i64 } +// CHECK-LABEL: @zexti3 +func @zexti3(%arg0 : i1) -> i32 { + // CHECK: %[[ZERO:.+]] = spv.constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.constant 1 : i32 + // CHECK: spv.Select %{{.*}}, %[[ONE]], %[[ZERO]] : i1, i32 + %0 = std.zexti %arg0 : i1 to i32 + return %0 : i32 +} + // CHECK-LABEL: @trunci1 func @trunci1(%arg0 : i64) -> i16 { // CHECK: spv.SConvert %{{.*}} : i64 to i16