diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp @@ -84,8 +84,17 @@ if (!areAllOpsInTheBlockListInvariant(forOp.getLoopBody(), indVar, iterArgs, opsWithUsers, opsToHoist)) return false; + } else if (auto parOp = dyn_cast(op)) { + if (!areAllOpsInTheBlockListInvariant(parOp.getLoopBody(), indVar, iterArgs, + opsWithUsers, opsToHoist)) + return false; } else if (isa(op)) { // TODO: Support DMA ops. + // FIXME: This should be fixed to not special-case these affine DMA ops but + // instead rely on side effects. + return false; + } else if (op.getNumRegions() > 0) { + // We can't handle region-holding ops we don't know about. return false; } else if (!matchPattern(&op, m_Constant())) { // Register op in the set of ops that have users. diff --git a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir --- a/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir +++ b/mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir @@ -752,3 +752,59 @@ // CHECK-NEXT: affine.for // CHECK-NEXT: arith.addi // CHECK-NEXT: affine.yield + +#map = affine_map<(d0) -> (64, d0 * -64 + 1020)> +// CHECK-LABEL: func.func @affine_parallel +func.func @affine_parallel(%memref_8: memref<4090x2040xf32>, %x: index) { + %cst = arith.constant 0.000000e+00 : f32 + affine.parallel (%arg3) = (0) to (32) { + affine.for %arg4 = 0 to 16 { + affine.parallel (%arg5, %arg6) = (0, 0) to (min(128, 122), min(64, %arg3 * -64 + 2040)) { + affine.for %arg7 = 0 to min #map(%arg4) { + affine.store %cst, %memref_8[%arg5 + 3968, %arg6 + %arg3 * 64] : memref<4090x2040xf32> + } + } + } + } + // CHECK: affine.parallel + // CHECK-NEXT: affine.for + // CHECK-NEXT: affine.parallel + // CHECK-NEXT: affine.store + // CHECK-NEXT: affine.for + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + scf.parallel (%arg3) = (%c0) to (%c32) step (%c1) { + affine.for %arg4 = 0 to 16 { + affine.parallel (%arg5, %arg6) = (0, 0) to (min(128, 122), min(64, %x * -64 + 2040)) { + affine.for %arg7 = 0 to min #map(%arg4) { + affine.store %cst, %memref_8[%arg5 + 3968, %arg6] : memref<4090x2040xf32> + } + } + } + } + // CHECK: scf.parallel + // CHECK-NEXT: affine.for + // CHECK-NEXT: affine.parallel + // CHECK-NEXT: affine.store + // CHECK-NEXT: affine.for + + affine.for %arg3 = 0 to 32 { + affine.for %arg4 = 0 to 16 { + affine.parallel (%arg5, %arg6) = (0, 0) to (min(128, 122), min(64, %arg3 * -64 + 2040)) { + // Unknown region-holding op for this pass. + scf.for %arg7 = %c0 to %x step %c1 { + affine.store %cst, %memref_8[%arg5 + 3968, %arg6 + %arg3 * 64] : memref<4090x2040xf32> + } + } + } + } + // CHECK: affine.for + // CHECK-NEXT: affine.for + // CHECK-NEXT: affine.parallel + // CHECK-NEXT: scf.for + // CHECK-NEXT: affine.store + + return +}