diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2049,10 +2049,10 @@ // Must have some masked dimension to be a candidate for splitting. if (!xferOp.hasMaskedDim()) return failure(); - // Don't split transfer operations under IfOp, this avoids applying the - // pattern recursively. - // TODO: improve the condition to make it more applicable. - if (xferOp.getParentOfType()) + // Don't split transfer operations directly under IfOp, this avoids applying + // the pattern recursively. + // TODO: improve the filtering condition to make it more applicable. + if (isa(xferOp.getOperation()->getParentOp())) return failure(); return success(); }