diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -17,8 +17,10 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Debug.h" +#include namespace mlir { #define GEN_PASS_DEF_CONVERTARITHTOSPIRV @@ -118,6 +120,16 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts arith.extsi to spirv.Select if the type of source is i1 or vector +/// of i1. +struct ExtSII1Pattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + /// Converts arith.extui to spirv.Select if the type of source is i1 or vector /// of i1. struct ExtUII1Pattern final : public OpConversionPattern { @@ -615,6 +627,42 @@ return success(); } +//===----------------------------------------------------------------------===// +// ExtSII1Pattern +//===----------------------------------------------------------------------===// + +LogicalResult +ExtSII1Pattern::matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value operand = adaptor.getIn(); + if (!isBoolScalarOrVector(operand.getType())) + return failure(); + + Location loc = op.getLoc(); + Type dstType = getTypeConverter()->convertType(op.getResult().getType()); + + Value allOnes; + if (auto intTy = dstType.dyn_cast()) { + unsigned componentBitwidth = intTy.getWidth(); + allOnes = rewriter.create( + loc, intTy, + rewriter.getIntegerAttr(intTy, APInt::getAllOnes(componentBitwidth))); + } else if (auto vectorTy = dstType.dyn_cast()) { + unsigned componentBitwidth = vectorTy.getElementTypeBitWidth(); + allOnes = rewriter.create( + loc, vectorTy, + SplatElementsAttr::get(vectorTy, APInt::getAllOnes(componentBitwidth))); + } else { + return rewriter.notifyMatchFailure( + loc, llvm::formatv("unhandled type: {0}", dstType)); + } + + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + rewriter.replaceOpWithNewOp(op, dstType, operand, allOnes, + zero); + return success(); +} + //===----------------------------------------------------------------------===// // ExtUII1Pattern //===----------------------------------------------------------------------===// @@ -982,7 +1030,7 @@ spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, TypeCastingOpPattern, ExtUII1Pattern, - TypeCastingOpPattern, + TypeCastingOpPattern, ExtSII1Pattern, TypeCastingOpPattern, TypeCastingOpPattern, TruncII1Pattern, TypeCastingOpPattern, diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -2032,7 +2032,7 @@ return builder.create( loc, type, DenseElementsAttr::get(vectorType, - IntegerAttr::get(elemType, 0.0).getValue())); + IntegerAttr::get(elemType, 0).getValue())); } if (elemType.isa()) { return builder.create( @@ -2065,7 +2065,7 @@ return builder.create( loc, type, DenseElementsAttr::get(vectorType, - IntegerAttr::get(elemType, 1.0).getValue())); + IntegerAttr::get(elemType, 1).getValue())); } if (elemType.isa()) { return builder.create( diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -756,6 +756,28 @@ return %0 : i64 } +// CHECK-LABEL: @sext_bool_scalar +// CHECK-SAME: ([[ARG:%.+]]: i1) -> i32 +func.func @sext_bool_scalar(%arg0 : i1) -> i32 { + // CHECK-DAG: [[ONES:%.+]] = spirv.Constant -1 : i32 + // CHECK-DAG: [[ZERO:%.+]] = spirv.Constant 0 : i32 + // CHECK: [[SEL:%.+]] = spirv.Select [[ARG]], [[ONES]], [[ZERO]] : i1, i32 + // CHECK-NEXT: return [[SEL]] : i32 + %0 = arith.extsi %arg0 : i1 to i32 + return %0 : i32 +} + +// CHECK-LABEL: @sext_bool_vector +// CHECK-SAME: ([[ARG:%.+]]: vector<3xi1>) -> vector<3xi32> +func.func @sext_bool_vector(%arg0 : vector<3xi1>) -> vector<3xi32> { + // CHECK-DAG: [[ONES:%.+]] = spirv.Constant dense<-1> : vector<3xi32> + // CHECK-DAG: [[ZERO:%.+]] = spirv.Constant dense<0> : vector<3xi32> + // CHECK: [[SEL:%.+]] = spirv.Select [[ARG]], [[ONES]], [[ZERO]] : vector<3xi1>, vector<3xi32> + // CHECK-NEXT: return [[SEL]] : vector<3xi32> + %0 = arith.extsi %arg0 : vector<3xi1> to vector<3xi32> + return %0 : vector<3xi32> +} + // CHECK-LABEL: @zexti1 func.func @zexti1(%arg0: i16) -> i64 { // CHECK: spirv.UConvert %{{.*}} : i16 to i64