diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -271,20 +271,41 @@ builder); } +// Emit instructions that correspond to computing the minimum value amoung the +// values of a (potentially) multi-output affine map applied to `operands`. +static Value lowerAffineMapMin(OpBuilder &builder, Location loc, AffineMap map, + ValueRange operands) { + if (auto values = + expandAffineMap(builder, loc, map, llvm::to_vector<4>(operands))) + return buildMinMaxReductionSeq(loc, CmpIPredicate::slt, *values, builder); + return nullptr; +} + // Emit instructions that correspond to the affine map in the upper bound // applied to the respective operands, and compute the minimum value across // the results. Value mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { - SmallVector boundOperands(op.getUpperBoundOperands()); - auto ubValues = expandAffineMap(builder, op.getLoc(), op.getUpperBoundMap(), - boundOperands); - if (!ubValues) - return nullptr; - return buildMinMaxReductionSeq(op.getLoc(), CmpIPredicate::slt, *ubValues, - builder); + return lowerAffineMapMin(builder, op.getLoc(), op.getUpperBoundMap(), + op.getUpperBoundOperands()); } namespace { +class AffineMinLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineMinOp op, + PatternRewriter &rewriter) const override { + Value reduced = + lowerAffineMapMin(rewriter, op.getLoc(), op.map(), op.operands()); + if (!reduced) + return matchFailure(); + + rewriter.replaceOp(op, reduced); + return matchSuccess(); + } +}; + // Affine terminators are removed. class AffineTerminatorLowering : public OpRewritePattern { public: @@ -520,10 +541,19 @@ void mlir::populateAffineToStdConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { + // clang-format off patterns.insert< - AffineApplyLowering, AffineDmaStartLowering, AffineDmaWaitLowering, - AffineLoadLowering, AffinePrefetchLowering, AffineStoreLowering, - AffineForLowering, AffineIfLowering, AffineTerminatorLowering>(ctx); + AffineApplyLowering, + AffineDmaStartLowering, + AffineDmaWaitLowering, + AffineLoadLowering, + AffineMinLowering, + AffinePrefetchLowering, + AffineStoreLowering, + AffineForLowering, + AffineIfLowering, + AffineTerminatorLowering>(ctx); + // clang-format on } namespace { diff --git a/mlir/test/Transforms/lower-affine.mlir b/mlir/test/Transforms/lower-affine.mlir --- a/mlir/test/Transforms/lower-affine.mlir +++ b/mlir/test/Transforms/lower-affine.mlir @@ -590,3 +590,18 @@ // CHECK-NEXT: dma_wait %0[%[[b]]], %c64 : memref<1xi32> return } + +// CHECK-LABEL: func @affine_min +// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index +func @affine_min(%arg0: index, %arg1: index) -> index{ + // CHECK: %[[Cm1:.*]] = constant -1 + // CHECK: %[[neg1:.*]] = muli %[[ARG1]], %[[Cm1:.*]] + // CHECK: %[[first:.*]] = addi %[[ARG0]], %[[neg1]] + // CHECK: %[[Cm2:.*]] = constant -1 + // CHECK: %[[neg2:.*]] = muli %[[ARG0]], %[[Cm2:.*]] + // CHECK: %[[second:.*]] = addi %[[ARG1]], %[[neg2]] + // CHECK: %[[cmp:.*]] = cmpi "slt", %[[first]], %[[second]] + // CHECK: select %[[cmp]], %[[first]], %[[second]] + %0 = affine.min affine_map<(d0,d1) -> (d0 - d1, d1 - d0)>(%arg0, %arg1) + return %0 : index +}