diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -663,6 +663,17 @@ ConversionPatternRewriter &rewriter) const override; }; +/// Converts std.xor to SPIR-V operations if the type of source is i1 or vector +/// of i1. +class BoolXOrOpPattern final : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(XOrOp xorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1250,6 +1261,22 @@ return success(); } +LogicalResult +BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + assert(operands.size() == 2); + + if (!isBoolScalarOrVector(operands.front().getType())) + return failure(); + + auto dstType = getTypeConverter()->convertType(xorOp.getType()); + if (!dstType) + return failure(); + rewriter.replaceOpWithNewOp(xorOp, dstType, + operands); + return success(); +} + //===----------------------------------------------------------------------===// // Pattern population //===----------------------------------------------------------------------===// @@ -1293,7 +1320,7 @@ UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, - SignedRemIOpPattern, XOrOpPattern, + SignedRemIOpPattern, XOrOpPattern, BoolXOrOpPattern, // Comparison patterns BoolCmpIOpPattern, CmpFOpPattern, CmpFOpNanNonePattern, CmpIOpPattern, diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -224,6 +224,8 @@ %0 = and %arg0, %arg1 : i1 // CHECK: spv.LogicalOr %1 = or %arg0, %arg1 : i1 + // CHECK: spv.LogicalNotEqual + %2 = xor %arg0, %arg1 : i1 return } @@ -233,6 +235,8 @@ %0 = and %arg0, %arg1 : vector<4xi1> // CHECK: spv.LogicalOr %1 = or %arg0, %arg1 : vector<4xi1> + // CHECK: spv.LogicalNotEqual + %2 = xor %arg0, %arg1 : vector<4xi1> return }