diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -71,10 +71,11 @@ if (!checkInvarianceOfNestedIfOps(&op, indVar, opsWithUsers, opsToHoist)) { return false; } - } else if (isa(op)) { - // If the body of a predicated region has a for loop, we don't hoist the - // 'affine.if'. - return false; + } else if (auto forOp = dyn_cast(op)) { + if (!areAllOpsInTheBlockListInvariant(forOp.getLoopBody(), indVar, + opsWithUsers, opsToHoist)) { + return false; + } } else if (isa(op)) { // TODO: Support DMA ops. return false; @@ -113,29 +114,29 @@ LLVM_DEBUG(llvm::dbgs() << "\nNon-constant op with 0 operands\n"); return false; } - for (unsigned int i = 0; i < op.getNumOperands(); ++i) { - auto *operandSrc = op.getOperand(i).getDefiningOp(); + } - LLVM_DEBUG( - op.getOperand(i).print(llvm::dbgs() << "\nIterating on operand\n")); + // Check operands. + for (unsigned int i = 0; i < op.getNumOperands(); ++i) { + auto *operandSrc = op.getOperand(i).getDefiningOp(); - // If the loop IV is the operand, this op isn't loop invariant. - if (indVar == op.getOperand(i)) { - LLVM_DEBUG(llvm::dbgs() << "\nLoop IV is the operand\n"); - return false; - } + LLVM_DEBUG( + op.getOperand(i).print(llvm::dbgs() << "\nIterating on operand\n")); - if (operandSrc != nullptr) { - LLVM_DEBUG(llvm::dbgs() - << *operandSrc << "\nIterating on operand src\n"); + // If the loop IV is the operand, this op isn't loop invariant. + if (indVar == op.getOperand(i)) { + LLVM_DEBUG(llvm::dbgs() << "\nLoop IV is the operand\n"); + return false; + } - // If the value was defined in the loop (outside of the - // if/else region), and that operation itself wasn't meant to - // be hoisted, then mark this operation loop dependent. - if (opsWithUsers.count(operandSrc) && - opsToHoist.count(operandSrc) == 0) { - return false; - } + if (operandSrc != nullptr) { + LLVM_DEBUG(llvm::dbgs() << *operandSrc << "\nIterating on operand src\n"); + + // If the value was defined in the loop (outside of the + // if/else region), and that operation itself wasn't meant to + // be hoisted, then mark this operation loop dependent. + if (opsWithUsers.count(operandSrc) && opsToHoist.count(operandSrc) == 0) { + return false; } } } @@ -198,12 +199,9 @@ // not being hoisted. if (!op.use_empty()) opsWithUsers.insert(&op); - // We don't hoist for loops. - if (!isa(op)) { - if (!isa(op)) { - if (isOpLoopInvariant(op, indVar, opsWithUsers, opsToHoist)) { - opsToMove.push_back(&op); - } + if (!isa(op)) { + if (isOpLoopInvariant(op, indVar, opsWithUsers, opsToHoist)) { + opsToMove.push_back(&op); } } } diff --git a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir --- a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir +++ b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir @@ -17,6 +17,8 @@ // CHECK-NEXT: %cst_0 = constant 8.000000e+00 : f32 // CHECK-NEXT: %1 = addf %cst, %cst_0 : f32 // CHECK-NEXT: affine.for %arg0 = 0 to 10 { + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.store %1, %0[%arg0] : memref<10xf32> return @@ -67,6 +69,33 @@ // ----- +// CHECK-LABEL: func @nested_loops_inner_loops_invariant_to_outermost_loop +func @nested_loops_inner_loops_invariant_to_outermost_loop(%m : memref<10xindex>) { + affine.for %arg0 = 0 to 20 { + affine.for %arg1 = 0 to 30 { + %v0 = affine.for %arg2 = 0 to 10 iter_args (%prevAccum = %arg1) -> index { + %v1 = affine.load %m[%arg2] : memref<10xindex> + %newAccum = addi %prevAccum, %v1 : index + affine.yield %newAccum : index + } + } + } + + // CHECK: affine.for %{{.*}} = 0 to 30 { + // CHECK-NEXT: %{{.*}} = affine.for %{{.*}} = 0 to 10 iter_args(%{{.*}} = %{{.*}}) -> (index) { + // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}} : memref<10xindex> + // CHECK-NEXT: %{{.*}} = addi %{{.*}}, %{{.*}} : index + // CHECK-NEXT: affine.yield %{{.*}} : index + // CHECK-NEXT: } + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 20 { + // CHECK-NEXT: } + + return +} + +// ----- + func @single_loop_nothing_invariant() { %m1 = memref.alloc() : memref<10xf32> %m2 = memref.alloc() : memref<10xf32> @@ -228,8 +257,9 @@ // CHECK-NEXT: %2 = addf %cst, %cst : f32 // CHECK-NEXT: affine.for %arg0 = 0 to 10 { // CHECK-NEXT: %3 = affine.load %0[%arg0] : memref<10xf32> - // CHECK-NEXT: affine.for %arg1 = 0 to 10 { - // CHECK-NEXT: %4 = affine.load %0[%arg1] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { + // CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}}] : memref<10xf32> return } @@ -252,6 +282,8 @@ // CHECK: %0 = memref.alloc() : memref<10xf32> // CHECK-NEXT: %cst = constant 8.000000e+00 : f32 // CHECK-NEXT: affine.for %arg0 = 0 to 10 { + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.if #set(%arg0, %arg0) { // CHECK-NEXT: %1 = addf %cst, %cst : f32 // CHECK-NEXT: affine.store %1, %0[%arg0] : memref<10xf32> @@ -386,6 +418,8 @@ // CHECK-NEXT: %1 = memref.alloc() : memref<10xf32> // CHECK-NEXT: %cst = constant 8.000000e+00 : f32 // CHECK-NEXT: affine.for %arg0 = 0 to 10 { + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.if #set(%arg0, %arg0) { // CHECK-NEXT: %2 = addf %cst, %cst : f32 // CHECK-NEXT: %3 = affine.load %0[%arg0] : memref<10xf32> @@ -420,6 +454,8 @@ // CHECK: %0 = memref.alloc() : memref<10xf32> // CHECK-NEXT: %cst = constant 8.000000e+00 : f32 // CHECK-NEXT: affine.for %arg0 = 0 to 10 { + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.if #set(%arg0, %arg0) { // CHECK-NEXT: %1 = addf %cst, %cst : f32 // CHECK-NEXT: %2 = affine.load %0[%arg0] : memref<10xf32> @@ -530,6 +566,8 @@ // CHECK-NEXT: %cst = constant 8.000000e+00 : f32 // CHECK-NEXT: %c0 = constant 0 : index // CHECK-NEXT: affine.for %arg0 = 0 to 10 { + // CHECK-NEXT: } + // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.store %cst, %0[%c0] : memref<10xf32> // CHECK-NEXT: %1 = affine.load %0[%arg0] : memref<10xf32>