diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -45,11 +45,20 @@ ConversionPatternRewriter &rewriter) const override; }; -/// Converts arith.remsi to SPIR-V ops. +/// Converts arith.remsi to GLSL SPIR-V ops. /// /// This cannot be merged into the template unary/binary pattern due to Vulkan /// restrictions over spv.SRem and spv.SMod. -struct RemSIOpPattern final : public OpConversionPattern { +struct RemSIOpGLSLPattern final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +/// Converts arith.remsi to OpenCL SPIR-V ops. +struct RemSIOpOCLPattern final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -396,7 +405,7 @@ } //===----------------------------------------------------------------------===// -// RemSIOpPattern +// RemSIOpGLSLPattern //===----------------------------------------------------------------------===// /// Returns signed remainder for `lhs` and `rhs` and lets the result follow @@ -406,6 +415,7 @@ /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod /// if either operand can be negative. Emulate it via spv.UMod. +template static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs, Value signOperand, OpBuilder &builder) { assert(lhs.getType() == rhs.getType()); @@ -414,8 +424,8 @@ Type type = lhs.getType(); // Calculate the remainder with spv.UMod. - Value lhsAbs = builder.create(loc, type, lhs); - Value rhsAbs = builder.create(loc, type, rhs); + Value lhsAbs = builder.create(loc, type, lhs); + Value rhsAbs = builder.create(loc, type, rhs); Value abs = builder.create(loc, lhsAbs, rhsAbs); // Fix the sign. @@ -429,11 +439,26 @@ } LogicalResult -RemSIOpPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - Value result = emulateSignedRemainder(op.getLoc(), adaptor.getOperands()[0], - adaptor.getOperands()[1], - adaptor.getOperands()[0], rewriter); +RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value result = emulateSignedRemainder( + op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], + adaptor.getOperands()[0], rewriter); + rewriter.replaceOp(op, result); + + return success(); +} + +//===----------------------------------------------------------------------===// +// RemSIOpOCLPattern +//===----------------------------------------------------------------------===// + +LogicalResult +RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + Value result = emulateSignedRemainder( + op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], + adaptor.getOperands()[0], rewriter); rewriter.replaceOp(op, result); return success(); @@ -762,7 +787,7 @@ spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, spirv::UnaryAndBinaryOpPattern, - RemSIOpPattern, + RemSIOpGLSLPattern, RemSIOpOCLPattern, BitwiseOpPattern, BitwiseOpPattern, XOrIOpLogicalPattern, XOrIOpBooleanPattern, diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -43,14 +43,8 @@ // Check float unary operation conversions. // CHECK-LABEL: @float32_unary_scalar func @float32_unary_scalar(%arg0: f32) { - // CHECK: spv.GLSL.FAbs %{{.*}}: f32 - %0 = math.abs %arg0 : f32 - // CHECK: spv.GLSL.Ceil %{{.*}}: f32 - %1 = math.ceil %arg0 : f32 // CHECK: spv.FNegate %{{.*}}: f32 - %5 = arith.negf %arg0 : f32 - // CHECK: spv.GLSL.Floor %{{.*}}: f32 - %10 = math.floor %arg0 : f32 + %0 = arith.negf %arg0 : f32 return } @@ -842,3 +836,39 @@ } } // end module + +// ----- + +// Check OpenCL lowering of arith.remsi +module attributes { + spv.target_env = #spv.target_env< + #spv.vce, {}> +} { + +// CHECK-LABEL: @scalar_srem +// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) +func @scalar_srem(%lhs: i32, %rhs: i32) { + // CHECK: %[[LABS:.+]] = spv.OCL.s_abs %[[LHS]] : i32 + // CHECK: %[[RABS:.+]] = spv.OCL.s_abs %[[RHS]] : i32 + // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : i32 + // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : i32 + // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : i32 + // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32 + %0 = arith.remsi %lhs, %rhs: i32 + return +} + +// CHECK-LABEL: @vector_srem +// CHECK-SAME: (%[[LHS:.+]]: vector<3xi16>, %[[RHS:.+]]: vector<3xi16>) +func @vector_srem(%arg0: vector<3xi16>, %arg1: vector<3xi16>) { + // CHECK: %[[LABS:.+]] = spv.OCL.s_abs %[[LHS]] : vector<3xi16> + // CHECK: %[[RABS:.+]] = spv.OCL.s_abs %[[RHS]] : vector<3xi16> + // CHECK: %[[ABS:.+]] = spv.UMod %[[LABS]], %[[RABS]] : vector<3xi16> + // CHECK: %[[POS:.+]] = spv.IEqual %[[LHS]], %[[LABS]] : vector<3xi16> + // CHECK: %[[NEG:.+]] = spv.SNegate %[[ABS]] : vector<3xi16> + // CHECK: %{{.+}} = spv.Select %[[POS]], %[[ABS]], %[[NEG]] : vector<3xi1>, vector<3xi16> + %0 = arith.remsi %arg0, %arg1: vector<3xi16> + return +} + +} // end module diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -22,6 +22,12 @@ %6 = math.tanh %arg0 : f32 // CHECK: spv.GLSL.Sin %{{.*}}: f32 %7 = math.sin %arg0 : f32 + // CHECK: spv.GLSL.FAbs %{{.*}}: f32 + %8 = math.abs %arg0 : f32 + // CHECK: spv.GLSL.Ceil %{{.*}}: f32 + %9 = math.ceil %arg0 : f32 + // CHECK: spv.GLSL.Floor %{{.*}}: f32 + %10 = math.floor %arg0 : f32 return } diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir --- a/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir @@ -22,6 +22,12 @@ %6 = math.tanh %arg0 : f32 // CHECK: spv.OCL.sin %{{.*}}: f32 %7 = math.sin %arg0 : f32 + // CHECK: spv.OCL.fabs %{{.*}}: f32 + %8 = math.abs %arg0 : f32 + // CHECK: spv.OCL.ceil %{{.*}}: f32 + %9 = math.ceil %arg0 : f32 + // CHECK: spv.OCL.floor %{{.*}}: f32 + %10 = math.floor %arg0 : f32 return }