diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -38,10 +38,10 @@ void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) { if (containsVar(val)) return; - + // Caller is expected to fully compose map/operands if necessary. - assert((isTopLevelValue(val) || isAffineInductionVar(val)) && - "non-terminal symbol / loop IV expected"); + assert((isValidSymbol(val) || isAffineInductionVar(val)) && + "valid symbol / loop IV expected"); // Outer loop IVs could be used in forOp's bounds. if (auto loop = getForInductionVarOwner(val)) { appendDimVar(val); @@ -58,7 +58,7 @@ return; } - // Add top level symbol. + // Add valid symbol. appendSymbolVar(val); // Check if the symbol is a constant. if (std::optional constOp = getConstantIntValue(val)) diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -378,7 +378,7 @@ return false; } -/// A value can be used as a symbol for `region` iff it meets one of the +/// A value can be used as a symbol for `region` if it meets one of the /// following conditions: /// *) It is a constant. /// *) It is the result of an affine apply operation with symbol arguments. diff --git a/mlir/test/Dialect/Affine/parallelize.mlir b/mlir/test/Dialect/Affine/parallelize.mlir --- a/mlir/test/Dialect/Affine/parallelize.mlir +++ b/mlir/test/Dialect/Affine/parallelize.mlir @@ -323,3 +323,32 @@ } return } + +// CHECK-LABEL: @no_toplevel_block_parallelize +func.func @no_toplevel_block_parallelize(%arg0: i1) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %alloc_0 = memref.alloc() : memref<2x3x6xi1> + + scf.if %arg0 { + %alloc_1 = memref.alloc() : memref<2x3x6x9xi1> + %dim_0 = memref.dim %alloc_0, %c0 : memref<2x3x6xi1> + %dim_1 = memref.dim %alloc_0, %c1 : memref<2x3x6xi1> + affine.for %arg2 = 0 to %dim_0 { + affine.for %arg3 = 0 to %dim_1 { + affine.for %arg4 = 0 to 6 { + affine.for %arg5 = 0 to 9 { + %load_0 = affine.load %alloc_0[%arg2, %arg3, %arg4] : memref<2x3x6xi1> + affine.store %load_0, %alloc_1[%arg2, %arg3, %arg4, %arg5] : memref<2x3x6x9xi1> + } + } + } + } + } + // CHECK: affine.parallel (%{{.*}}) = (0) to (symbol(%{{.*}})) { + // CHECK: affine.parallel (%{{.*}}) = (0) to (symbol(%{{.*}})) { + // CHECK: affine.parallel (%{{.*}}) = (0) to (6) { + // CHECK: affine.parallel (%{{.*}}) = (0) to (9) { + return +}