Index: mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td +++ mlir/include/mlir/Dialect/SPIRV/IR/SPIRVArithmeticOps.td @@ -810,6 +810,8 @@ ``` }]; + + let hasCanonicalizer = 1; } #endif // MLIR_DIALECT_SPIRV_IR_ARITHMETIC_OPS Index: mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp +++ mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp @@ -114,6 +114,64 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// spirv.UMod +//===----------------------------------------------------------------------===// + +// Input: +// %0 = spirv.UMod %arg0, %const32 : i32 +// %1 = spirv.UMod %0, %const4 : i32 +// Output: +// %0 = spirv.UMod %arg0, %const32 : i32 +// %1 = spirv.UMod %arg0, %const4 : i32 + +// The transformation works only if one divisor is a multiple of the other. + +//===----------------------------------------------------------------------===// +// spirv.UMod +//===----------------------------------------------------------------------===// + +struct UModSimplification : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(spirv::UModOp umodOp, + PatternRewriter &rewriter) const override { + auto *defOp = umodOp.getOperand(0).getDefiningOp(); + spirv::UModOp prevUMod = llvm::dyn_cast_or_null(defOp); + if (!prevUMod) + return failure(); + + // Check if the divisors of the two UMod operations are constants. + auto prevDivisor = + prevUMod.getOperand(1).getDefiningOp(); + auto currDivisor = umodOp.getOperand(1).getDefiningOp(); + if (!prevDivisor || !currDivisor) + return failure(); + + // Get the constant values of the divisors. + APInt prevValue = prevDivisor.getValue().cast().getValue(); + APInt currValue = currDivisor.getValue().cast().getValue(); + + // Ensure that one divisor is a multiple of the other. If not, fail the + // transformation. + if (prevValue.urem(currValue) != 0 && currValue.urem(prevValue) != 0) + return failure(); + + // The transformation is safe. Create a new UMod operation with the original + // dividend and the current divisor. + auto newUModOp = rewriter.create( + umodOp.getLoc(), prevUMod.getOperand(0), umodOp.getOperand(1)); + rewriter.replaceOp(umodOp, newUModOp.getResult()); + + return success(); + } +}; + +void spirv::UModOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.insert(context); +} + //===----------------------------------------------------------------------===// // spirv.BitcastOp //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir =================================================================== --- mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir +++ mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir @@ -453,6 +453,39 @@ // ----- +//===----------------------------------------------------------------------===// +// spirv.UMod +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @umod_fold +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @umod_fold(%arg0: i32) -> (i32, i32) { + // CHECK: %[[CONST4:.*]] = spirv.Constant 4 + // CHECK: %[[CONST32:.*]] = spirv.Constant 32 + %const1 = spirv.Constant 32 : i32 + %0 = spirv.UMod %arg0, %const1 : i32 + %const2 = spirv.Constant 4 : i32 + %1 = spirv.UMod %0, %const2 : i32 + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST32]] + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST4]] + return %0, %1: i32, i32 +} + +// CHECK-LABEL: @umod_fold_same_divisor +// CHECK-SAME: (%[[ARG:.*]]: i32) +func.func @umod_fold_same_divisor(%arg0: i32) -> (i32, i32) { + // CHECK: %[[CONST1:.*]] = spirv.Constant 32 + %const1 = spirv.Constant 32 : i32 + %0 = spirv.UMod %arg0, %const1 : i32 + %const2 = spirv.Constant 32 : i32 + %1 = spirv.UMod %0, %const2 : i32 + // CHECK: %[[UMOD0:.*]] = spirv.UMod %[[ARG]], %[[CONST1]] + // CHECK: %[[UMOD1:.*]] = spirv.UMod %[[ARG]], %[[CONST1]] + return %0, %1: i32, i32 +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.LogicalAnd //===----------------------------------------------------------------------===//