diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -592,7 +592,7 @@ def AffineParallelOp : Affine_Op<"parallel", [ImplicitAffineTerminator, RecursiveSideEffects, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, MemRefsNormalizable]> { let summary = "multi-index parallel band operation"; let description = [{ The "affine.parallel" operation represents a hyper-rectangular affine @@ -842,7 +842,8 @@ let hasFolder = 1; } -def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator, ReturnLike]> { +def AffineYieldOp : Affine_Op<"yield", [NoSideEffect, Terminator, ReturnLike, + MemRefsNormalizable]> { let summary = "Yield values to parent operation"; let description = [{ "affine.yield" yields zero or more SSA values from an affine op region and diff --git a/mlir/lib/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Transforms/NormalizeMemRefs.cpp --- a/mlir/lib/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Transforms/NormalizeMemRefs.cpp @@ -512,6 +512,10 @@ // affine map, `oldOp` is returned without modification. if (resultTypeNormalized) { OpBuilder bb(oldOp); + for (auto &oldRegion : oldOp->getRegions()) { + Region *newRegion = result.addRegion(); + newRegion->takeBody(oldRegion); + } return bb.createOperation(result); } else return oldOp; diff --git a/mlir/test/Transforms/normalize-memrefs.mlir b/mlir/test/Transforms/normalize-memrefs.mlir --- a/mlir/test/Transforms/normalize-memrefs.mlir +++ b/mlir/test/Transforms/normalize-memrefs.mlir @@ -319,3 +319,16 @@ } // CHECK: %[[res:[0-9]+]] = call @external_func_B(%[[A]], %[[B]]) : (memref<4x4xf64>, f64) -> memref<2x4xf64> // CHECK: return %{{.*}} : memref<2x4xf64> + +// CHECK-LABEL: func @affine_parallel_norm +func @affine_parallel_norm() -> memref<8xf32, #tile> { + %c = constant 23.0 : f32 + %a = alloc() : memref<8xf32, #tile> + // CHECK: affine.parallel (%{{.*}}) = (0) to (8) reduce ("assign") -> (memref<2x4xf32>) + %1 = affine.parallel (%i) = (0) to (8) reduce ("assign") -> memref<8xf32, #tile> { + affine.store %c, %a[%i] : memref<8xf32, #tile> + // CHECK: affine.yield %{{.*}} : memref<2x4xf32> + affine.yield %a : memref<8xf32, #tile> + } + return %1 : memref<8xf32, #tile> +}