diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -219,12 +219,25 @@ int64_t AffineExpr::getLargestKnownDivisor() const { AffineBinaryOpExpr binExpr(nullptr); switch (getKind()) { - case AffineExprKind::CeilDiv: - [[fallthrough]]; case AffineExprKind::DimId: - case AffineExprKind::FloorDiv: + [[fallthrough]]; case AffineExprKind::SymbolId: return 1; + case AffineExprKind::CeilDiv: + [[fallthrough]]; + case AffineExprKind::FloorDiv: { + // If the RHS is a constant and divides the known divisor on the LHS, the + // quotient is a known divisor of the expression. + binExpr = this->cast(); + auto rhs = binExpr.getRHS().dyn_cast(); + // Leave alone undefined expressions. + if (rhs && rhs.getValue() != 0) { + int64_t lhsDiv = binExpr.getLHS().getLargestKnownDivisor(); + if (lhsDiv % rhs.getValue() == 0) + return lhsDiv / rhs.getValue(); + } + return 1; + } case AffineExprKind::Constant: return std::abs(this->cast().getValue()); case AffineExprKind::Mul: { diff --git a/mlir/test/Dialect/Affine/unroll.mlir b/mlir/test/Dialect/Affine/unroll.mlir --- a/mlir/test/Dialect/Affine/unroll.mlir +++ b/mlir/test/Dialect/Affine/unroll.mlir @@ -746,4 +746,25 @@ // UNROLL-CLEANUP-LOOP-NEXT: %[[V4:.*]] = affine.apply {{.*}} // UNROLL-CLEANUP-LOOP-NEXT: {{.*}} = "foo"(%[[V4]]) : (index) -> i32 // UNROLL-CLEANUP-LOOP-NEXT: return -} \ No newline at end of file +} + +// UNROLL-BY-4-LABEL: func @known_multiple_ceildiv +func.func @known_multiple_ceildiv(%N: index, %S: index) { + %cst = arith.constant 0.0 : f32 + %m = memref.alloc(%S) : memref + // This exercises affine expr getLargestKnownDivisor for the ceildiv case. + affine.for %i = 0 to affine_map<(d0) -> (32 * d0 + 64)>(%N) step 8 { + affine.store %cst, %m[%i] : memref + } + // UNROLL-BY-4: affine.for %{{.*}} = 0 to {{.*}} step 32 + // UNROLL-BY-4-NOT: affine.for + + // This exercises affine expr getLargestKnownDivisor for floordiv. + affine.for %i = 0 to affine_map<(d0) -> ((32 * d0 + 64) floordiv 8)>(%N) { + affine.store %cst, %m[%i] : memref + } + // UNROLL-BY-4: affine.for %{{.*}} = 0 to {{.*}} step 4 + // UNROLL-BY-4-NOT: affine.for + // UNROLL-BY-4: return + return +}