diff --git a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h --- a/mlir/include/mlir/Interfaces/SideEffectInterfaces.h +++ b/mlir/include/mlir/Interfaces/SideEffectInterfaces.h @@ -332,6 +332,17 @@ /// conditions are satisfied. bool isMemoryEffectFree(Operation *op); +/// Returns the side effects of an operation. If the operation has +/// RecursiveMemoryEffects, include all side effects of child operations. +/// +/// std::nullopt indicates that an option did not have a memory effect interface +/// and so no result could be obtained. An empty vector indicates that there +/// were no memory effects found (but every operation implemented the memory +/// effect interface or has RecursiveMemoryEffects). If the vector contains +/// multiple effects, these effects may be duplicates. +std::optional> +getEffectsRecursively(Operation *rootOp); + /// Returns true if the given operation is speculatable, i.e. has no undefined /// behavior or other side effects. /// diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp --- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp +++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp @@ -182,6 +182,37 @@ return true; } +// the returned vector may contain duplicate effects +std::optional> +mlir::getEffectsRecursively(Operation *rootOp) { + SmallVector effects; + SmallVector effectingOps(1, rootOp); + while (!effectingOps.empty()) { + Operation *op = effectingOps.pop_back_val(); + + // If the operation has recursive effects, push all of the nested + // operations on to the stack to consider. + bool hasRecursiveEffects = + op->hasTrait(); + if (hasRecursiveEffects) { + for (Region ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto &nestedOp : block) + effectingOps.push_back(&nestedOp); + } + } + } + + if (auto effectInterface = dyn_cast(op)) + effectInterface.getEffects(effects); + else if (!hasRecursiveEffects) + // the operation does not have recursive memory effects or implement + // the meemory effect op interface. Its effects are unknown. + return std::nullopt; + } + return effects; +} + bool mlir::isSpeculatable(Operation *op) { auto conditionallySpeculatable = dyn_cast(op); if (!conditionallySpeculatable) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -199,17 +199,24 @@ } } while (nextOp && nextOp != toOp) { - auto nextOpMemEffects = dyn_cast(nextOp); - // TODO: Do we need to handle other effects generically? - // If the operation does not implement the MemoryEffectOpInterface we - // conservatively assumes it writes. - if ((nextOpMemEffects && - nextOpMemEffects.hasEffect()) || - !nextOpMemEffects) { + std::optional> effects = + getEffectsRecursively(nextOp); + if (!effects) { + // TODO: Do we need to handle other effects generically? + // If the operation does not implement the MemoryEffectOpInterface we + // conservatively assumes it writes. result.first->second = std::make_pair(nextOp, MemoryEffects::Write::get()); return true; } + + for (const MemoryEffects::EffectInstance &effect : *effects) { + if (isa(effect.getEffect())) { + result.first->second = + std::make_pair(nextOp, MemoryEffects::Write::get()); + return true; + } + } nextOp = nextOp->getNextNode(); } result.first->second = std::make_pair(toOp, nullptr); diff --git a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir --- a/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir +++ b/mlir/test/Dialect/SparseTensor/buffer_rewriting.mlir @@ -332,8 +332,7 @@ // CHECK: scf.yield %[[VAL_145]] // CHECK: } // CHECK: %[[VAL_146:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_147:.*]]] -// CHECK: %[[VAL_148:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_127]]] -// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_148]] +// CHECK: %[[VAL_149:.*]] = arith.cmpi eq, %[[VAL_146]], %[[VAL_137]] // CHECK: %[[VAL_150:.*]] = arith.cmpi ult, %[[VAL_136]], %[[VAL_147]] // CHECK: %[[VAL_151:.*]]:3 = scf.if %[[VAL_150]] // CHECK: %[[VAL_152:.*]] = memref.load %[[VAL_2]]{{\[}}%[[VAL_136]]] @@ -529,4 +528,4 @@ func.func @sparse_sort_coo_heap(%arg0: index, %arg1: memref<100xindex>, %arg2: memref, %arg3: memref<10xi32>) -> (memref<100xindex>, memref, memref<10xi32>) { sparse_tensor.sort_coo heap_sort %arg0, %arg1 jointly %arg2, %arg3 {nx = 2 : index, ny = 1: index} : memref<100xindex> jointly memref, memref<10xi32> return %arg1, %arg2, %arg3 : memref<100xindex>, memref, memref<10xi32> -} \ No newline at end of file +} diff --git a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir --- a/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir +++ b/mlir/test/Dialect/SparseTensor/sparse_conv_2d_slice_based.mlir @@ -142,9 +142,7 @@ // CHECK: scf.yield %[[VAL_132]], %[[VAL_131]] : index, i32 // CHECK: } // CHECK: %[[VAL_133:.*]] = arith.addi %[[VAL_105]], %[[VAL_7]] : index -// CHECK: %[[VAL_134:.*]] = memref.load %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex> -// CHECK: %[[VAL_135:.*]] = arith.addi %[[VAL_134]], %[[VAL_5]] : index -// CHECK: memref.store %[[VAL_135]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex> +// CHECK: memref.store %[[VAL_112]], %[[VAL_18]]{{\[}}%[[VAL_7]]] : memref<11xindex> // CHECK: scf.yield %[[VAL_133]], %[[VAL_136:.*]]#1, %[[VAL_2]] : index, i32, i1 // CHECK: } // CHECK: %[[VAL_137:.*]] = scf.if %[[VAL_138:.*]]#2 -> (tensor<6x6xi32, #sparse_tensor.encoding<{{.*}}>>) { diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -459,3 +459,64 @@ // CHECK: } // CHECK-NOT: scf.if // CHECK: return %[[if]], %[[if]] + +// CHECK-LABEL: @cse_recursive_effects_success +func.func @cse_recursive_effects_success() -> (i32, i32, i32) { + // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32 + %0 = "test.op_with_memread"() : () -> (i32) + + // do something with recursive effects, containing no side effects + %true = arith.constant true + // CHECK-NEXT: %[[TRUE:.+]] = arith.constant true + // CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) { + %1 = scf.if %true -> (i32) { + %c42 = arith.constant 42 : i32 + scf.yield %c42 : i32 + // CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32 + // CHECK-NEXT: scf.yield %[[C42]] + // CHECK-NEXT: } else { + } else { + %c24 = arith.constant 24 : i32 + scf.yield %c24 : i32 + // CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32 + // CHECK-NEXT: scf.yield %[[C24]] + // CHECK-NEXT: } + } + + // %2 can be removed + // CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE]], %[[IF]] : i32, i32, i32 + %2 = "test.op_with_memread"() : () -> (i32) + return %0, %2, %1 : i32, i32, i32 +} + +// CHECK-LABEL: @cse_recursive_effects_failure +func.func @cse_recursive_effects_failure() -> (i32, i32, i32) { + // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32 + %0 = "test.op_with_memread"() : () -> (i32) + + // do something with recursive effects, containing a write effect + %true = arith.constant true + // CHECK-NEXT: %[[TRUE:.+]] = arith.constant true + // CHECK-NEXT: %[[IF:.+]] = scf.if %[[TRUE]] -> (i32) { + %1 = scf.if %true -> (i32) { + "test.op_with_memwrite"() : () -> () + // CHECK-NEXT: "test.op_with_memwrite"() : () -> () + %c42 = arith.constant 42 : i32 + scf.yield %c42 : i32 + // CHECK-NEXT: %[[C42:.+]] = arith.constant 42 : i32 + // CHECK-NEXT: scf.yield %[[C42]] + // CHECK-NEXT: } else { + } else { + %c24 = arith.constant 24 : i32 + scf.yield %c24 : i32 + // CHECK-NEXT: %[[C24:.+]] = arith.constant 24 : i32 + // CHECK-NEXT: scf.yield %[[C24]] + // CHECK-NEXT: } + } + + // %2 can not be be removed because of the write + // CHECK-NEXT: %[[READ_VALUE2:.*]] = "test.op_with_memread"() : () -> i32 + // CHECK-NEXT: return %[[READ_VALUE]], %[[READ_VALUE2]], %[[IF]] : i32, i32, i32 + %2 = "test.op_with_memread"() : () -> (i32) + return %0, %2, %1 : i32, i32, i32 +}