diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2538,6 +2538,20 @@ } }; +template +struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(T affineOp, + PatternRewriter &rewriter) const override { + if (affineOp.map().getNumResults() != 1) + return failure(); + rewriter.replaceOpWithNewOp(affineOp, affineOp.map(), + affineOp.getOperands()); + return success(); + } +}; + //===----------------------------------------------------------------------===// // AffineMinOp //===----------------------------------------------------------------------===// @@ -2551,7 +2565,8 @@ void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add, + patterns.add, + DeduplicateAffineMinMaxExpressions, MergeAffineMinMaxOp, SimplifyAffineOp>( context); } @@ -2569,7 +2584,8 @@ void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add, + patterns.add, + DeduplicateAffineMinMaxExpressions, MergeAffineMinMaxOp, SimplifyAffineOp>( context); } diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -870,7 +870,6 @@ return %1: index } - // ----- // CHECK-LABEL: func @dont_merge_affine_max_if_not_single_sym @@ -936,3 +935,21 @@ return } +// ----- + +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 + 16)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 * 4)> + +// CHECK: func @canonicalize_single_min_max +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index) +func @canonicalize_single_min_max(%i0: index, %i1: index) -> (index, index) { + // CHECK-NOT: affine.min + // CHECK-NEXT: affine.apply #[[$MAP0]]()[%[[I0]]] + %0 = affine.min affine_map<()[s0] -> (s0 + 16)> ()[%i0] + + // CHECK-NOT: affine.max + // CHECK-NEXT: affine.apply #[[$MAP1]]()[%[[I1]]] + %1 = affine.min affine_map<()[s0] -> (s0 * 4)> ()[%i1] + + return %0, %1: index, index +}