diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -224,9 +224,9 @@ // Returns true if the above mod or floordiv are detected, updating 'memo' with // these new expressions. Returns false otherwise. static bool detectAsMod(const FlatLinearConstraints &cst, unsigned pos, - int64_t lbConst, int64_t ubConst, - SmallVectorImpl &memo, - MLIRContext *context) { + unsigned offset, unsigned num, int64_t lbConst, + int64_t ubConst, MLIRContext *context, + SmallVectorImpl &memo) { assert(pos < cst.getNumVars() && "invalid position"); // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can @@ -308,7 +308,17 @@ // Express `var_r` as `var_n % divisor` and store the expression in `memo`. if (quotientCount >= 1) { - auto ub = cst.getConstantBound64(BoundType::UB, dimExpr.getPosition()); + // Find the column corresponding to `dimExpr`. `num` columns starting at + // `offset` correspond to previously unknown variables. The column + // corresponding to the trivially known `dimExpr` can be on either side + // of these. + unsigned dimExprCol; + unsigned dimExprPos = dimExpr.getPosition(); + if (dimExprPos < offset) + dimExprCol = dimExprPos; + else + dimExprCol = dimExprPos + num; + auto ub = cst.getConstantBound64(BoundType::UB, dimExprCol); // If `var_n` has an upperbound that is less than the divisor, mod can be // eliminated altogether. if (ub && *ub < divisor) @@ -499,7 +509,8 @@ // Detect a variable as modulo of another variable w.r.t a // constant. - if (detectAsMod(*this, pos, *lbConst, *ubConst, memo, context)) { + if (detectAsMod(*this, pos, offset, num, *lbConst, *ubConst, context, + memo)) { changed = true; continue; } diff --git a/mlir/test/Transforms/loop-fusion-4.mlir b/mlir/test/Transforms/loop-fusion-4.mlir --- a/mlir/test/Transforms/loop-fusion-4.mlir +++ b/mlir/test/Transforms/loop-fusion-4.mlir @@ -190,3 +190,39 @@ // PRODUCER-CONSUMER-NEXT: } return } + +// ----- + +// PRODUCER-CONSUMER-LABEL: @fuse_higher_dim_nest_into_lower_dim_nest +func.func @fuse_higher_dim_nest_into_lower_dim_nest() { + %A = memref.alloc() : memref<8x12x128x64xf32> + %B = memref.alloc() : memref<8x128x12x64xf32> + affine.for %arg205 = 0 to 8 { + affine.for %arg206 = 0 to 128 { + affine.for %arg207 = 0 to 12 { + affine.for %arg208 = 0 to 64 { + %a = affine.load %A[%arg205, %arg207, %arg206, %arg208] : memref<8x12x128x64xf32> + affine.store %a, %B[%arg205, %arg206, %arg207, %arg208] : memref<8x128x12x64xf32> + } + } + } + } + %C = memref.alloc() : memref<8x128x768xf16> + affine.for %arg205 = 0 to 8 { + affine.for %arg206 = 0 to 128 { + affine.for %arg207 = 0 to 768 { + %b = affine.load %B[%arg205, %arg206, %arg207 floordiv 64, %arg207 mod 64] : memref<8x128x12x64xf32> + %c = arith.truncf %b : f32 to f16 + affine.store %c, %C[%arg205, %arg206, %arg207] : memref<8x128x768xf16> + } + } + } + + // Check that fusion happens into the innermost loop of the consumer. + // PRODUCER-CONSUMER: affine.for + // PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 128 + // PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 768 + // PRODUCER-CONSUMER-NOT: affine.for + // PRODUCER-CONSUMER: return + return +}