diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -2327,6 +2327,34 @@ return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt); } +/// Remove duplicated expressions in affine min/max ops. +template +struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(T affineOp, + PatternRewriter &rewriter) const override { + AffineMap oldMap = affineOp.getAffineMap(); + + SmallVector newExprs; + for (AffineExpr expr : oldMap.getResults()) { + // This is a linear scan over newExprs, but it should be fine given that + // we typically just have a few expressions per op. + if (!llvm::is_contained(newExprs, expr)) + newExprs.push_back(expr); + } + + if (newExprs.size() == oldMap.getNumResults()) + return failure(); + + auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), + newExprs, rewriter.getContext()); + rewriter.replaceOpWithNewOp(affineOp, newMap, affineOp.getMapOperands()); + + return success(); + } +}; + //===----------------------------------------------------------------------===// // AffineMinOp //===----------------------------------------------------------------------===// @@ -2340,7 +2368,8 @@ void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add, + SimplifyAffineOp>(context); } //===----------------------------------------------------------------------===// @@ -2356,7 +2385,8 @@ void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add>(context); + patterns.add, + SimplifyAffineOp>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Affine/canonicalize.mlir b/mlir/test/Dialect/Affine/canonicalize.mlir --- a/mlir/test/Dialect/Affine/canonicalize.mlir +++ b/mlir/test/Dialect/Affine/canonicalize.mlir @@ -694,3 +694,27 @@ } return } + +// ----- + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1, s0 * s1)> + +// CHECK: func @deduplicate_affine_min_expressions +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index) +func @deduplicate_affine_min_expressions(%i0: index, %i1: index) -> index { + // CHECK: affine.min #[[MAP]]()[%[[I0]], %[[I1]]] + %0 = affine.min affine_map<()[s0, s1] -> (s0 + s1, s0 * s1, s1 + s0, s0 * s1)> ()[%i0, %i1] + return %0: index +} + +// ----- + +// CHECK: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1, s0 * s1)> + +// CHECK: func @deduplicate_affine_max_expressions +// CHECK-SAME: (%[[I0:.+]]: index, %[[I1:.+]]: index) +func @deduplicate_affine_max_expressions(%i0: index, %i1: index) -> index { + // CHECK: affine.max #[[MAP]]()[%[[I0]], %[[I1]]] + %0 = affine.max affine_map<()[s0, s1] -> (s0 + s1, s0 * s1, s1 + s0, s0 * s1)> ()[%i0, %i1] + return %0: index +}