diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -810,6 +810,8 @@ ``` }]; + + let hasCanonicalizer = 1; } #endif // MLIR_DIALECT_SPIRV_IR_ARITHMETIC_OPS diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -114,6 +114,58 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// spirv.UMod +//===----------------------------------------------------------------------===// + +// Input: +// %0 = spirv.UMod %arg0, %const32 : i32 +// %1 = spirv.UMod %0, %const4 : i32 +// Output: +// %0 = spirv.UMod %arg0, %const32 : i32 +// %1 = spirv.UMod %arg0, %const4 : i32 + +// The transformation is only applied if one divisor is a multiple of the other. + +// TODO(https://github.com/llvm/llvm-project/issues/63174): Add support for vector constants +struct UModSimplification final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spirv::UModOp umodOp, + PatternRewriter &rewriter) const override { + auto prevUMod = umodOp.getOperand(0).getDefiningOp(); + if (!prevUMod) + return failure(); + + IntegerAttr prevValue; + IntegerAttr currValue; + if (!matchPattern(prevUMod.getOperand(1), m_Constant(&prevValue)) || + !matchPattern(umodOp.getOperand(1), m_Constant(&currValue))) + return failure(); + + APInt prevConstValue = prevValue.getValue(); + APInt currConstValue = currValue.getValue(); + + // Ensure that one divisor is a multiple of the other. If not, fail the + // transformation. + if (prevConstValue.urem(currConstValue) != 0 && + currConstValue.urem(prevConstValue) != 0) + return failure(); + + // The transformation is safe. Replace the existing UMod operation with a + // new UMod operation, using the original dividend and the current divisor. + rewriter.replaceOpWithNewOp( + umodOp, umodOp.getType(), prevUMod.getOperand(0), umodOp.getOperand(1)); + + return success(); + } +}; + +void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // spirv.BitcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir --- a/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -453,6 +453,71 @@ // ----- +//===----------------------------------------------------------------------===// +// spirv.UMod +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @umod_fold +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @umod_fold(%arg0: i32) -> (i32, i32) { + // CHECK: %[[CONST4:.*]] = spirv.Constant 4 + // CHECK: %[[CONST32:.*]] = spirv.Constant 32 + %const1 = spirv.Constant 32 : i32 + %0 = spirv.UMod %arg0, %const1 : i32 + %const2 = spirv.Constant 4 : i32 + %1 = spirv.UMod %0, %const2 : i32 + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]] + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]] + // CHECK: return %[[UMOD0]], %[[UMOD1]] + return %0, %1: i32, i32 +} + +// CHECK-LABEL: @umod_fail_vector_fold +// CHECK-SAME: (%[[ARG:.*]]: vector<4xi32>) +func.func @umod_fail_vector_fold(%arg0: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) { + // CHECK: %[[CONST4:.*]] = spirv.Constant dense<4> : vector<4xi32> + // CHECK: %[[CONST32:.*]] = spirv.Constant dense<32> : vector<4xi32> + %const1 = spirv.Constant dense<32> : vector<4xi32> + %0 = spirv.UMod %arg0, %const1 : vector<4xi32> + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]] + %const2 = spirv.Constant dense<4> : vector<4xi32> + %1 = spirv.UMod %0, %const2 : vector<4xi32> + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST4]] + // CHECK: return %[[UMOD0]], %[[UMOD1]] + return %0, %1: vector<4xi32>, vector<4xi32> +} + +// CHECK-LABEL: @umod_fold_same_divisor +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) { + // CHECK: %[[CONST1:.*]] = spirv.Constant 32 + %const1 = spirv.Constant 32 : i32 + %0 = spirv.UMod %arg0, %const1 : i32 + %const2 = spirv.Constant 32 : i32 + %1 = spirv.UMod %0, %const2 : i32 + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST1]] + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST1]] + // CHECK: return %[[UMOD0]], %[[UMOD1]] + return %0, %1: i32, i32 +} + +// CHECK-LABEL: @umod_fail_fold +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @umod_fail_fold(%arg0: i32) -> (i32, i32) { + // CHECK: %[[CONST5:.*]] = spirv.Constant 5 + // CHECK: %[[CONST32:.*]] = spirv.Constant 32 + %const1 = spirv.Constant 32 : i32 + %0 = spirv.UMod %arg0, %const1 : i32 + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]] + %const2 = spirv.Constant 5 : i32 + %1 = spirv.UMod %0, %const2 : i32 + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[UMOD0]], %[[CONST5]] + // CHECK: return %[[UMOD0]], %[[UMOD1]] + return %0, %1: i32, i32 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.LogicalAnd //===----------------------------------------------------------------------===//