diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -710,17 +710,28 @@ } /// A memref escapes the function if either: -/// 1. it is a function argument, or -/// 2. it is used by a non-affine op (e.g., std load/store, std -/// call, etc.) -/// FIXME: Support alias creating ops like memref view ops. +/// 1. it (or its alias) is a block argument, or +/// 2. created by an op not known to guarantee alias freedom, +/// 3. it (or its alias) is used by a non-affine op (e.g., call op, memref +/// load/store ops, alias creating ops, unknown ops, etc.); such ops +/// do not deference the memref in an affine way. static bool isEscapingMemref(Value memref) { - // Check if 'memref' escapes because it's a block argument. - if (memref.isa()) + Operation *defOp = memref.getDefiningOp(); + // Check if 'memref' is a block argument. + if (!defOp) return true; - // Check if 'memref' escapes through a non-affine op (e.g., std load/store, - // call op, etc.). This already covers aliases created from this. + // Check if this is defined to be an alias of another memref. + if (auto viewOp = dyn_cast(defOp)) + if (isEscapingMemref(viewOp.getViewSource())) + return true; + + // Any op besides allocating ops wouldn't guarantee alias freedom + if (!hasSingleEffect(defOp, memref)) + return true; + + // Check if 'memref' is used by a non-deferencing op (including unknown ones) + // (e.g., call ops, alias creating ops, etc.). for (Operation *user : memref.getUsers()) if (!isa(*user)) return true; @@ -728,7 +739,7 @@ } /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' -/// that escape the function. +/// that escape the function or are accessed by non-affine ops. void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet &escapingMemRefs) { auto *node = mdg->getNode(id); diff --git a/mlir/test/Transforms/loop-fusion-3.mlir b/mlir/test/Transforms/loop-fusion-3.mlir --- a/mlir/test/Transforms/loop-fusion-3.mlir +++ b/mlir/test/Transforms/loop-fusion-3.mlir @@ -1076,4 +1076,52 @@ // CHECK: affine.for // CHECK-NOT: affine.for +// CHECK-LABEL: func @alias_escaping_memref +func @alias_escaping_memref(%a : memref<2x5xf32>) { + %cst = arith.constant 0.000000e+00 : f32 + %alias = memref.reinterpret_cast %a to offset: [0], sizes: [10], strides: [1] : memref<2x5xf32> to memref<10xf32> + affine.for %i0 = 0 to 10 { + affine.store %cst, %alias[%i0] : memref<10xf32> + } + + affine.for %i1 = 0 to 10 { + %0 = affine.load %alias[%i1] : memref<10xf32> + } + // Fusion happens, but memref isn't privatized since %alias is an alias of a + // function argument. + // CHECK: memref.reinterpret_cast + // CHECK-NEXT: affine.for + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NOT: affine.for + + return +} + +// CHECK-LABEL: func @unknown_memref_def_op +func @unknown_memref_def_op() { + %cst = arith.constant 0.000000e+00 : f32 + %may_alias = call @bar() : () -> memref<10xf32> + affine.for %i0 = 0 to 10 { + affine.store %cst, %may_alias[%i0] : memref<10xf32> + } + + affine.for %i1 = 0 to 10 { + %0 = affine.load %may_alias[%i1] : memref<10xf32> + } + // Fusion happens, but memref isn't privatized since %may_alias's origin is + // unknown. + // CHECK: call + // CHECK-NEXT: affine.for + // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32> + // CHECK-NEXT: } + // CHECK-NOT: affine.for + + return +} +func private @bar() -> memref<10xf32> + + // Add further tests in mlir/test/Transforms/loop-fusion-4.mlir