diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -84,6 +84,11 @@ // TODO: Support DMA ops. return false; } else if (!isa(op)) { + // Register op in the set of ops defined inside the loop. This set is used + // to prevent hoisting ops that depend on other ops defined inside the loop + // which are themselves not being hoisted. + definedOps.insert(&op); + if (isMemRefDereferencingOp(op)) { Value memref = isa(op) ? cast(op).getMemRef() @@ -111,9 +116,6 @@ } } - // Insert this op in the defined ops list. - definedOps.insert(&op); - if (op.getNumOperands() == 0 && !isa(op)) { LLVM_DEBUG(llvm::dbgs() << "\nNon-constant op with 0 operands\n"); return false; diff --git a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir --- a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir +++ b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir @@ -22,6 +22,8 @@ return } +// ----- + // The store-load forwarding can see through affine apply's since it relies on // dependence information. // CHECK-LABEL: func @store_affine_apply @@ -36,12 +38,14 @@ // CHECK: %cst = constant 7.000000e+00 : f32 // CHECK-NEXT: %0 = alloc() : memref<10xf32> // CHECK-NEXT: affine.for %arg0 = 0 to 10 { -// CHECK-NEXT: %1 = affine.apply #map3(%arg0) +// CHECK-NEXT: %1 = affine.apply #map{{[0-9]+}}(%arg0) // CHECK-NEXT: affine.store %cst, %0[%1] : memref<10xf32> // CHECK-NEXT: } // CHECK-NEXT: return %0 : memref<10xf32> } +// ----- + func @nested_loops_code_invariant_to_both() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -61,6 +65,8 @@ return } +// ----- + func @single_loop_nothing_invariant() { %m1 = alloc() : memref<10xf32> %m2 = alloc() : memref<10xf32> @@ -82,6 +88,8 @@ return } +// ----- + func @invariant_code_inside_affine_if() { %m = alloc() : memref<10xf32> %cf8 = constant 8.0 : f32 @@ -98,7 +106,7 @@ // CHECK: %0 = alloc() : memref<10xf32> // CHECK-NEXT: %cst = constant 8.000000e+00 : f32 // CHECK-NEXT: affine.for %arg0 = 0 to 10 { - // CHECK-NEXT: %1 = affine.apply #map3(%arg0) + // CHECK-NEXT: %1 = affine.apply #map{{[0-9]+}}(%arg0) // CHECK-NEXT: affine.if #set0(%arg0, %1) { // CHECK-NEXT: %2 = addf %cst, %cst : f32 // CHECK-NEXT: affine.store %2, %0[%arg0] : memref<10xf32> @@ -108,6 +116,7 @@ return } +// ----- func @dependent_stores() { %m = alloc() : memref<10xf32> @@ -137,6 +146,8 @@ return } +// ----- + func @independent_stores() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -165,6 +176,8 @@ return } +// ----- + func @load_dependent_store() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -192,6 +205,8 @@ return } +// ----- + func @load_after_load() { %m = alloc() : memref<10xf32> %cf7 = constant 7.0 : f32 @@ -219,6 +234,8 @@ return } +// ----- + func @invariant_affine_if() { %m = alloc() : memref<10xf32> %cf8 = constant 8.0 : f32 @@ -244,6 +261,8 @@ return } +// ----- + func @invariant_affine_if2() { %m = alloc() : memref<10xf32> %cf8 = constant 8.0 : f32 @@ -271,6 +290,8 @@ return } +// ----- + func @invariant_affine_nested_if() { %m = alloc() : memref<10xf32> %cf8 = constant 8.0 : f32 @@ -303,6 +324,8 @@ return } +// ----- + func @invariant_affine_nested_if_else() { %m = alloc() : memref<10xf32> %cf8 = constant 8.0 : f32 @@ -339,6 +362,8 @@ return } +// ----- + func @invariant_affine_nested_if_else2() { %m = alloc() : memref<10xf32> %m2 = alloc() : memref<10xf32> @@ -375,6 +400,7 @@ return } +// ----- func @invariant_affine_nested_if2() { %m = alloc() : memref<10xf32> @@ -406,6 +432,8 @@ return } +// ----- + func @invariant_affine_for_inside_affine_if() { %m = alloc() : memref<10xf32> %cf8 = constant 8.0 : f32 @@ -438,6 +466,7 @@ return } +// ----- func @invariant_constant_and_load() { %m = alloc() : memref<100xf32> @@ -459,6 +488,7 @@ return } +// ----- func @nested_load_store_same_memref() { %m = alloc() : memref<10xf32> @@ -483,6 +513,7 @@ return } +// ----- func @nested_load_store_same_memref2() { %m = alloc() : memref<10xf32> @@ -505,3 +536,33 @@ return } + +// ----- + +// CHECK-LABEL: func @do_not_hoist_dependent_side_effect_free_op +func @do_not_hoist_dependent_side_effect_free_op(%arg0: memref<10x512xf32>) { + %0 = alloca() : memref<1xf32> + %cst = constant 8.0 : f32 + affine.for %i = 0 to 512 { + affine.for %j = 0 to 10 { + %5 = affine.load %arg0[%i, %j] : memref<10x512xf32> + %6 = affine.load %0[0] : memref<1xf32> + %add = addf %5, %6 : f32 + affine.store %add, %0[0] : memref<1xf32> + } + %3 = affine.load %0[0] : memref<1xf32> + %4 = mulf %3, %cst : f32 // It shouldn't be hoisted. + } + return +} + +// CHECK: affine.for +// CHECK-NEXT: affine.for +// CHECK-NEXT: affine.load +// CHECK-NEXT: affine.load +// CHECK-NEXT: addf +// CHECK-NEXT: affine.store +// CHECK-NEXT: } +// CHECK-NEXT: affine.load +// CHECK-NEXT: mulf +// CHECK-NEXT: }