diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -53,6 +53,16 @@ return elementType.getIntOrFloatBitWidth(); } +/// Creates `IntegerAttribute` with all bits set for given type. +IntegerAttr minusOneIntegerAttribute(Type type, Builder builder) { + if (auto vecType = type.dyn_cast()) { + auto integerType = vecType.getElementType().cast(); + return builder.getIntegerAttr(integerType, -1); + } + auto integerType = type.cast(); + return builder.getIntegerAttr(integerType, -1); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -154,6 +164,35 @@ } }; +/// Converts `spv.Not` and `spv.LogicalNot` into LLVM dialect. +template +class NotPattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(SPIRVOp notOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + auto srcType = notOp.getType(); + auto dstType = this->typeConverter.convertType(srcType); + if (!dstType) + return failure(); + + Location loc = notOp.getLoc(); + IntegerAttr minusOne = minusOneIntegerAttribute(srcType, rewriter); + auto mask = srcType.template isa() + ? rewriter.create( + loc, dstType, + SplatElementsAttr::get( + srcType.template cast(), minusOne)) + : rewriter.create(loc, dstType, minusOne); + rewriter.template replaceOpWithNewOp(notOp, dstType, + notOp.operand(), mask); + return success(); + } +}; + class ReturnPattern : public SPIRVToLLVMConversion { public: using SPIRVToLLVMConversion::SPIRVToLLVMConversion; @@ -346,6 +385,7 @@ DirectConversionPattern, DirectConversionPattern, DirectConversionPattern, + NotPattern, // Cast ops DirectConversionPattern, @@ -386,6 +426,7 @@ DirectConversionPattern, IComparePattern, IComparePattern, + NotPattern, // Shift ops ShiftPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/bitwise-ops-to-llvm.mlir @@ -79,3 +79,21 @@ %0 = spv.BitwiseXor %arg0, %arg1 : vector<2xi16> return } + +//===----------------------------------------------------------------------===// +// spv.Not +//===----------------------------------------------------------------------===// + +func @not__scalar(%arg0: i32) { + // CHECK: %[[CONST:.*]] = llvm.mlir.constant(-1 : i32) : !llvm.i32 + // CHECK: %{{.*}} = llvm.xor %{{.*}}, %[[CONST]] : !llvm.i32 + %0 = spv.Not %arg0 : i32 + return +} + +func @not_vector(%arg0: vector<2xi16>) { + // CHECK: %[[CONST:.*]] = llvm.mlir.constant(dense<-1> : vector<2xi16>) : !llvm<"<2 x i16>"> + // CHECK: %{{.*}} = llvm.xor %{{.*}}, %[[CONST]] : !llvm<"<2 x i16>"> + %0 = spv.Not %arg0 : vector<2xi16> + return +} diff --git a/mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/logical-to-llvm.mlir @@ -33,6 +33,24 @@ } //===----------------------------------------------------------------------===// +// spv.LogicalNot +//===----------------------------------------------------------------------===// + +func @logical_not__scalar(%arg0: i1) { + // CHECK: %[[CONST:.*]] = llvm.mlir.constant(true) : !llvm.i1 + // CHECK: %{{.*}} = llvm.xor %{{.*}}, %[[CONST]] : !llvm.i1 + %0 = spv.LogicalNot %arg0 : i1 + return +} + +func @logical_not_vector(%arg0: vector<4xi1>) { + // CHECK: %[[CONST:.*]] = llvm.mlir.constant(dense : vector<4xi1>) : !llvm<"<4 x i1>"> + // CHECK: %{{.*}} = llvm.xor %{{.*}}, %[[CONST]] : !llvm<"<4 x i1>"> + %0 = spv.LogicalNot %arg0 : vector<4xi1> + return +} + +//===----------------------------------------------------------------------===// // spv.LogicalAnd //===----------------------------------------------------------------------===//