diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1829,7 +1829,7 @@ })) return failure(); - auto castOp = sliceOp.getSource().getDefiningOp(); + auto castOp = sliceOp.getSource().getDefiningOp(); if (!castOp) return failure(); @@ -1837,17 +1837,20 @@ return failure(); /// Deduce the type of the result to use for the canonicalized operation. + Location loc = sliceOp.getLoc(); + auto sliceOpType = sliceOp.getType(); RankedTensorType resultType = ExtractSliceOp::inferCanonicalRankReducedResultType( - sliceOp.getType().getRank(), sliceOp.getSourceType(), + sliceOpType.getRank(), sliceOp.getSourceType(), sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), sliceOp.getMixedStrides()); - Value newSlice = rewriter.create( - sliceOp.getLoc(), resultType, castOp.getSource(), sliceOp.getOffsets(), + Value newResult = rewriter.create( + loc, resultType, castOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); - rewriter.replaceOpWithNewOp(sliceOp, sliceOp.getType(), - newSlice); + if (newResult.getType() != sliceOpType) + newResult = rewriter.create(loc, sliceOpType, newResult); + rewriter.replaceOp(sliceOp, newResult); return success(); } };