diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2385,7 +2385,19 @@ // Clone op. Operation *newOp = linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands); - rewriter.replaceOp(op, newOp->getResults()); + SmallVector replacements; + replacements.reserve(newOp->getNumResults()); + for (auto result : enumerate(newOp->getResults())) { + Value newResult = result.value(); + Value oldResult = op->getResult(result.index()); + if (newResult.getType() != oldResult.getType()) { + replacements.push_back(rewriter.create( + op->getLoc(), oldResult.getType(), newResult)); + } else { + replacements.push_back(newResult); + } + } + rewriter.replaceOp(op, replacements); return success(); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -3303,7 +3303,11 @@ static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op, SubTensorOp newOp) { - rewriter.replaceOpWithNewOp(op, op.getType(), newOp); + Value replacement = newOp.getResult(); + if (replacement.getType() != op.getType()) + replacement = + rewriter.create(op.getLoc(), op.getType(), replacement); + rewriter.replaceOp(op, replacement); } /// Pattern to rewrite a subview op with constant arguments. @@ -3787,11 +3791,10 @@ } OpFoldResult SubTensorInsertOp::fold(ArrayRef) { - if (getSourceType() == getType() && + if (getSourceType().hasStaticShape() && getType().hasStaticShape() && + getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); - if (succeeded(tensor::foldTensorCast(*this))) - return this->source(); return OpFoldResult(); } @@ -3848,9 +3851,9 @@ : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(SubTensorInsertOp subTensorOp, + LogicalResult matchAndRewrite(SubTensorInsertOp subTensorInsertOp, PatternRewriter &rewriter) const override { - if (llvm::any_of(subTensorOp.getOperands(), [](Value operand) { + if (llvm::any_of(subTensorInsertOp.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return failure(); @@ -3861,21 +3864,25 @@ return llvm::None; return castOp.source(); }; - Optional sourceCastSource = getSourceOfCastOp(subTensorOp.source()); - Optional destCastSource = getSourceOfCastOp(subTensorOp.dest()); - if (!sourceCastSource && !destCastSource && - subTensorOp.dest().getType() == subTensorOp.getResult().getType()) + Optional sourceCastSource = + getSourceOfCastOp(subTensorInsertOp.source()); + Optional destCastSource = + getSourceOfCastOp(subTensorInsertOp.dest()); + if (!sourceCastSource && !destCastSource) return failure(); - auto newOp = rewriter.create( - subTensorOp.getLoc(), - (sourceCastSource ? *sourceCastSource : subTensorOp.source()), - (destCastSource ? *destCastSource : subTensorOp.dest()), - subTensorOp.getMixedOffsets(), subTensorOp.getMixedSizes(), - subTensorOp.getMixedStrides()); + Value replacement = rewriter.create( + subTensorInsertOp.getLoc(), + (sourceCastSource ? *sourceCastSource : subTensorInsertOp.source()), + (destCastSource ? *destCastSource : subTensorInsertOp.dest()), + subTensorInsertOp.getMixedOffsets(), subTensorInsertOp.getMixedSizes(), + subTensorInsertOp.getMixedStrides()); - rewriter.replaceOpWithNewOp(subTensorOp, - subTensorOp.getType(), newOp); + if (replacement.getType() != subTensorInsertOp.getType()) { + replacement = rewriter.create( + subTensorInsertOp.getLoc(), subTensorInsertOp.getType(), replacement); + } + rewriter.replaceOp(subTensorInsertOp, replacement); return success(); } };