diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -894,8 +894,7 @@ /// %2 = call @do(%iter_t0) : (tensor) -> tensor /// scf.yield %2 : tensor /// } -/// %2 = tensor.cast %1 : tensor to tensor<32x1024xf32> -/// use_of(%2) +/// use_of(%1) /// ``` /// /// folds into: @@ -908,7 +907,8 @@ /// %4 = tensor.cast %3 : tensor to tensor<32x1024xf32> /// scf.yield %4 : tensor<32x1024xf32> /// } -/// use_of(%0) +/// %1 = tensor.cast %0 : tensor<32x1024xf32> to tensor +/// use_of(%1) /// ``` struct ForOpTensorCastFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -920,17 +920,13 @@ auto incomingCast = iterOpOperand.get().getDefiningOp(); if (!incomingCast) continue; + // If the dest type of the cast does not preserve static information in + // the source type. + if (!tensor::preservesStaticInformation(incomingCast.getDest().getType(), + incomingCast.getSource().getType())) + continue; if (!std::get<1>(it).hasOneUse()) continue; - auto outgoingCastOp = - dyn_cast(*std::get<1>(it).user_begin()); - if (!outgoingCastOp) - continue; - - // Must be a tensor.cast op pair with matching types. - if (outgoingCastOp.getResult().getType() != - incomingCast.getSource().getType()) - continue; // Create a new ForOp with that iter operand replaced. auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,