diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1094,21 +1094,20 @@ } bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { - // Check to see if this operation was replaced or its parent ignored. - return replacements.count(op) || ignoredOps.count(op->getParentOp()); + // Check to see if this operation or a parent was replaced or ignored. + do { + if (replacements.count(op)) + return true; + if (ignoredOps.count(op)) + return true; + } while ((op = op->getParentOp())); + return false; } void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) { - // Walk this operation and collect nested operations that define non-empty - // regions. We mark such operations as 'ignored' so that we know we don't have - // to convert them, or their nested ops. - if (op->getNumRegions() == 0) - return; - op->walk([&](Operation *op) { - if (llvm::any_of(op->getRegions(), - [](Region ®ion) { return !region.empty(); })) - ignoredOps.insert(op); - }); + // Mark operation as ignored and all children will be recursively considered + // ignored also. + ignoredOps.insert(op); } //===----------------------------------------------------------------------===// @@ -1184,10 +1183,6 @@ if (currentConversionPattern) converter = currentConversionPattern->getTypeConverter(); replacements.insert(std::make_pair(op, OpReplacement(converter))); - - // Mark this operation as recursively ignored so that we don't need to - // convert any nested operations. - markNestedOpsIgnored(op); } void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {