Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -810,6 +810,8 @@ ``` }]; + + let hasCanonicalizer = 1; } #endif // MLIR_DIALECT_SPIRV_IR_ARITHMETIC_OPS Index: mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -114,6 +114,47 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// spirv.UMod +//===----------------------------------------------------------------------===// + +// Input: +// %0 = spirv.UMod %arg0, %const32 : i32 +// %1 = spirv.UMod %0, %const32 : i32 +// Output: +// %0 = spirv.UMod %arg0, %const32 : i32 +// %1 = spirv.UMod %arg0, %const32 : i32 + +struct UModSimplification : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spirv::UModOp umodOp, PatternRewriter &rewriter) const override { + auto *defOp = umodOp.getOperand(0).getDefiningOp(); + spirv::UModOp prevUMod = llvm::dyn_cast_or_null(defOp); + if (!prevUMod) + return failure(); + + // Check if the divisors of the two UMod operations are the same. + if (prevUMod.getOperand(1) != umodOp.getOperand(1)) + return failure(); + + // Build a new UMod operation with the same divisor as the original UModOp, + // but with the dividend of the previous UModOp. + auto newUModOp = rewriter.create(umodOp.getLoc(), + prevUMod.getOperand(0), + umodOp.getOperand(1)); + + rewriter.replaceOp(umodOp, newUModOp.getResult()); + + return success(); + } +}; + +void spirv::UModOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // spirv.BitcastOp //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir =================================================================== --- mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -453,6 +453,24 @@ // ----- +//===----------------------------------------------------------------------===// +// spirv.UMod +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @umod_fold +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @umod_fold(%arg0: i32) -> (i32, i32) { + // CHECK: %[[CONST:.*]] = spirv.Constant 32 + %const32 = spirv.Constant 32 : i32 + %0 = spirv.UMod %arg0, %const32 : i32 + %1 = spirv.UMod %0, %const32 : i32 + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST]] + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST]] + return %0, %1: i32, i32 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.LogicalAnd //===----------------------------------------------------------------------===//