diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -108,6 +108,23 @@ return false; } +/// Creates a scalar/vector integer constant. +static Value getScalarOrVectorConstInt(Type type, uint64_t value, + OpBuilder &builder, Location loc) { + if (auto vectorType = dyn_cast(type)) { + Attribute element = IntegerAttr::get(vectorType.getElementType(), value); + SmallVector values(vectorType.getNumElements(), element); + auto attr = DenseElementsAttr::get(vectorType, values); + return builder.create(loc, vectorType, attr); + } + + if (auto intType = dyn_cast(type)) + return builder.create( + loc, type, builder.getIntegerAttr(type, value)); + + return nullptr; +} + /// Returns true if scalar/vector type `a` and `b` have the same number of /// bitwidth. static bool hasSameBitwidth(Type a, Type b) { @@ -525,6 +542,52 @@ } }; +/// Converts arith.extsi to spirv.Select if the type of source is neither i1 nor +/// vector of i1. +struct ExtSIPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = adaptor.getIn().getType(); + if (isBoolScalarOrVector(srcType)) + return failure(); + + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + if (dstType == srcType) { + // We can have the same source and destination type due to type emulation. + // Perform bit shifting to make sure we have the proper leading set bits. + + unsigned srcBW = + getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); + unsigned dstBW = + getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); + Value shiftSize = getScalarOrVectorConstInt(dstType, dstBW - srcBW, + rewriter, op.getLoc()); + + // First shift left to sequeeze out all leading bits beyond the original + // bitwidth. Here we need to use the original source and result type's + // bitwidth. + auto shiftLOp = rewriter.create( + op.getLoc(), dstType, adaptor.getIn(), shiftSize); + + // Then we perform arithmetic right shift to make sure we have the right + // leading set bits for negative values. + rewriter.replaceOpWithNewOp( + op, dstType, shiftLOp, shiftSize); + } else { + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + + return success(); + } +}; + //===----------------------------------------------------------------------===// // ExtUIOp //===----------------------------------------------------------------------===// @@ -554,6 +617,41 @@ } }; +/// Converts arith.extui to spirv.Select if the type of source is neither i1 nor +/// vector of i1. +struct ExtUIPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = adaptor.getIn().getType(); + if (isBoolScalarOrVector(srcType)) + return failure(); + + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + if (dstType == srcType) { + // We can have the same source and destination type due to type emulation. + // Perform bit masking to make sure we don't pollute downstream consumers + // with unwanted bits. Here we need to use the original source type's + // bitwidth. + unsigned bitwidth = + getElementTypeOrSelf(op.getIn().getType()).getIntOrFloatBitWidth(); + Value mask = getScalarOrVectorConstInt(dstType, (1u << bitwidth) - 1, + rewriter, op.getLoc()); + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getIn(), mask); + } else { + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + return success(); + } +}; + //===----------------------------------------------------------------------===// // TruncIOp //===----------------------------------------------------------------------===// @@ -588,6 +686,41 @@ } }; +/// Converts arith.trunci to spirv.Select if the type of result is neither i1 +/// nor vector of i1. +struct TruncIPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type srcType = adaptor.getIn().getType(); + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + if (isBoolScalarOrVector(dstType)) + return failure(); + + if (dstType == srcType) { + // We can have the same source and destination type due to type emulation. + // Perform bit masking to make sure we don't pollute downstream consumers + // with unwanted bits. Here we need to use the original result type's + // bitwidth. + unsigned bw = getElementTypeOrSelf(op.getType()).getIntOrFloatBitWidth(); + Value mask = getScalarOrVectorConstInt(dstType, (1u << bw) - 1, rewriter, + op.getLoc()); + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getIn(), mask); + } else { + // Given this is truncation, either SConvertOp or UConvertOp works. + rewriter.replaceOpWithNewOp(op, dstType, + adaptor.getOperands()); + } + return success(); + } +}; + //===----------------------------------------------------------------------===// // TypeCastingOp //===----------------------------------------------------------------------===// @@ -983,10 +1116,10 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, - TypeCastingOpPattern, ExtUII1Pattern, - TypeCastingOpPattern, ExtSII1Pattern, + ExtUIPattern, ExtUII1Pattern, + ExtSIPattern, ExtSII1Pattern, TypeCastingOpPattern, - TypeCastingOpPattern, TruncII1Pattern, + TruncIPattern, TruncII1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, UIToFPI1Pattern, TypeCastingOpPattern, diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -1000,6 +1000,38 @@ return %0: f64 } +// CHECK-LABEL: @trunci4 +// CHECK-SAME: %[[ARG:.*]]: i32 +func.func @trunci4(%arg0 : i32) -> i4 { + // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32 + // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[ARG]], %[[MASK]] : i32 + %0 = arith.trunci %arg0 : i32 to i4 + // CHECK: %[[RET:.+]] = builtin.unrealized_conversion_cast %[[AND]] : i32 to i4 + // CHECK: return %[[RET]] : i4 + return %0 : i4 +} + +// CHECK-LABEL: @zexti4 +func.func @zexti4(%arg0: i4) -> i32 { + // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %{{.+}} : i4 to i32 + // CHECK: %[[MASK:.+]] = spirv.Constant 15 : i32 + // CHECK: %[[AND:.+]] = spirv.BitwiseAnd %[[INPUT]], %[[MASK]] : i32 + %0 = arith.extui %arg0 : i4 to i32 + // CHECK: return %[[AND]] : i32 + return %0 : i32 +} + +// CHECK-LABEL: @sexti4 +func.func @sexti4(%arg0: i4) -> i32 { + // CHECK: %[[INPUT:.+]] = builtin.unrealized_conversion_cast %arg0 : i4 to i32 + // CHECK: %[[SIZE:.+]] = spirv.Constant 28 : i32 + // CHECK: %[[SL:.+]] = spirv.ShiftLeftLogical %[[INPUT]], %[[SIZE]] : i32, i32 + // CHECK: %[[SR:.+]] = spirv.ShiftRightArithmetic %[[SL]], %[[SIZE]] : i32, i32 + %0 = arith.extsi %arg0 : i4 to i32 + // CHECK: return %[[SR]] : i32 + return %0 : i32 +} + } // end module // -----