diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" +#include "mlir/IR/BuiltinTypes.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "arith-to-spirv-pattern" @@ -665,23 +666,44 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite( arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Type operandType = op.getLhs().getType(); - if (!isBoolScalarOrVector(operandType)) + Type srcType = op.getLhs().getType(); + if (!isBoolScalarOrVector(srcType)) + return failure(); + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) return failure(); switch (op.getPredicate()) { -#define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: { \ - rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), \ - adaptor.getRhs()); \ - return success(); \ + case arith::CmpIPredicate::eq: { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); } - - DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); - DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); - -#undef DISPATCH - default:; + case arith::CmpIPredicate::ne: { + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), + adaptor.getRhs()); + return success(); + } + case arith::CmpIPredicate::uge: + case arith::CmpIPredicate::ugt: + case arith::CmpIPredicate::ule: + case arith::CmpIPredicate::ult: { + // There are no direct corresponding instructions in SPIR-V for such cases. + // Extend them to 32-bit and do comparision then. + Type type = rewriter.getI32Type(); + if (auto vectorType = dstType.dyn_cast()) + type = VectorType::get(vectorType.getShape(), type); + auto extLhs = + rewriter.create(op.getLoc(), type, adaptor.getLhs()); + auto extRhs = + rewriter.create(op.getLoc(), type, adaptor.getRhs()); + + rewriter.replaceOpWithNewOp(op, op.getPredicate(), extLhs, + extRhs); + return success(); + } + default: + break; } return failure(); } diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -401,8 +401,8 @@ return } -// CHECK-LABEL: @boolcmpi -func.func @boolcmpi(%arg0 : i1, %arg1 : i1) { +// CHECK-LABEL: @boolcmpi_equality +func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) { // CHECK: spv.LogicalEqual %0 = arith.cmpi eq, %arg0, %arg1 : i1 // CHECK: spv.LogicalNotEqual @@ -410,8 +410,19 @@ return } -// CHECK-LABEL: @vec1boolcmpi -func.func @vec1boolcmpi(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { +// CHECK-LABEL: @boolcmpi_unsigned +func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) { + // CHECK-COUNT-2: spv.Select + // CHECK: spv.UGreaterThanEqual + %0 = arith.cmpi uge, %arg0, %arg1 : i1 + // CHECK-COUNT-2: spv.Select + // CHECK: spv.ULessThan + %1 = arith.cmpi ult, %arg0, %arg1 : i1 + return +} + +// CHECK-LABEL: @vec1boolcmpi_equality +func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { // CHECK: spv.LogicalEqual %0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1> // CHECK: spv.LogicalNotEqual @@ -419,8 +430,19 @@ return } -// CHECK-LABEL: @vecboolcmpi -func.func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { +// CHECK-LABEL: @vec1boolcmpi_unsigned +func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { + // CHECK-COUNT-2: spv.Select + // CHECK: spv.UGreaterThanEqual + %0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1> + // CHECK-COUNT-2: spv.Select + // CHECK: spv.ULessThan + %1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1> + return +} + +// CHECK-LABEL: @vecboolcmpi_equality +func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { // CHECK: spv.LogicalEqual %0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1> // CHECK: spv.LogicalNotEqual @@ -428,6 +450,18 @@ return } +// CHECK-LABEL: @vecboolcmpi_unsigned +func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) { + // CHECK-COUNT-2: spv.Select + // CHECK: spv.UGreaterThanEqual + %0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1> + // CHECK-COUNT-2: spv.Select + // CHECK: spv.ULessThan + %1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1> + return +} + + } // end module // -----