Index: mlir/lib/IR/SubElementInterfaces.cpp =================================================================== --- mlir/lib/IR/SubElementInterfaces.cpp +++ mlir/lib/IR/SubElementInterfaces.cpp @@ -8,12 +8,16 @@ #include "mlir/IR/SubElementInterfaces.h" +#include "llvm/ADT/DenseSet.h" + using namespace mlir; template static void walkSubElementsImpl(InterfaceT interface, function_ref walkAttrsFn, - function_ref walkTypesFn) { + function_ref walkTypesFn, + llvm::DenseSet &visitedAttrs, + llvm::DenseSet &visitedTypes) { interface.walkImmediateSubElements( [&](Attribute attr) { // Guard against potentially null inputs. This removes the need for the @@ -21,9 +25,14 @@ if (!attr) return; + // Avoid infinite recursion when visiting sub attributes later. + if (!visitedAttrs.insert(attr).second) + return; + // Walk any sub elements first. if (auto interface = attr.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this attribute. walkAttrsFn(attr); @@ -34,9 +43,14 @@ if (!type) return; + // Avoid infinite recursion when visiting sub types later. + if (!visitedTypes.insert(type).second) + return; + // Walk any sub elements first. if (auto interface = type.dyn_cast()) - walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn); + walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); // Walk this type. walkTypesFn(type); @@ -47,14 +61,20 @@ function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + llvm::DenseSet visitedAttrs; + llvm::DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } void SubElementTypeInterface::walkSubElements( function_ref walkAttrsFn, function_ref walkTypesFn) { assert(walkAttrsFn && walkTypesFn && "expected valid walk functions"); - walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn); + llvm::DenseSet visitedAttrs; + llvm::DenseSet visitedTypes; + walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn, visitedAttrs, + visitedTypes); } //===----------------------------------------------------------------------===//