diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -307,27 +307,25 @@ void CSE::simplifyBlock(ScopedMapTy &knownValues, Block *bb, bool hasSSADominance) { for (auto &op : *bb) { - // If the operation is simplified, we don't process any held regions. - if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance))) - continue; - // Most operations don't have regions, so fast path that case. - if (op.getNumRegions() == 0) - continue; - - // If this operation is isolated above, we can't process nested regions with - // the given 'knownValues' map. This would cause the insertion of implicit - // captures in explicit capture only regions. - if (op.mightHaveTrait()) { - ScopedMapTy nestedKnownValues; - for (auto ®ion : op.getRegions()) - simplifyRegion(nestedKnownValues, region); - continue; + if (op.getNumRegions() != 0) { + // If this operation is isolated above, we can't process nested regions + // with the given 'knownValues' map. This would cause the insertion of + // implicit captures in explicit capture only regions. + if (op.mightHaveTrait()) { + ScopedMapTy nestedKnownValues; + for (auto ®ion : op.getRegions()) + simplifyRegion(nestedKnownValues, region); + } else { + // Otherwise, process nested regions normally. + for (auto ®ion : op.getRegions()) + simplifyRegion(knownValues, region); + } } - // Otherwise, process nested regions normally. - for (auto ®ion : op.getRegions()) - simplifyRegion(knownValues, region); + // If the operation is simplified, we don't process any held regions. + if (succeeded(simplifyOperation(knownValues, &op, hasSSADominance))) + continue; } // Clear the MemoryEffects cache since its usage is by block only. memEffectsCache.clear(); diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -446,3 +446,25 @@ // CHECK: %[[OP:.+]] = test.cse_of_single_block_op // CHECK-NOT: test.cse_of_single_block_op // CHECK: return %[[OP]], %[[OP]] + +func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) { + %false_2 = arith.constant false + %true_5 = arith.constant true + %9 = test.cse_of_single_block_op inputs(%arg2) { + ^bb0(%out: i1): + %true_144 = arith.constant true + test.region_yield %true_144 : i1 + } : tensor<2xi1> -> tensor<2xi1> + %15 = test.cse_of_single_block_op inputs(%arg2) { + ^bb0(%out: i1): + %true_144 = arith.constant true + test.region_yield %true_144 : i1 + } : tensor<2xi1> -> tensor<2xi1> + %93 = arith.maxsi %false_2, %true_5 : i1 + return %9, %15 : tensor<2xi1>, tensor<2xi1> +} +// CHECK-LABEL: func @failing_issue_59135 +// CHECK: %[[TRUE:.+]] = arith.constant true +// CHECK: %[[OP:.+]] = test.cse_of_single_block_op +// CHECK: test.region_yield %[[TRUE]] +// CHECK: return %[[OP]], %[[OP]]