diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1920,20 +1920,14 @@ if (!canFoldIntoConsumerOp(castOp)) return failure(); - /// Deduce the type of the result to use for the canonicalized operation. + // Create folded extract. Location loc = sliceOp.getLoc(); - auto sliceOpType = sliceOp.getType(); - RankedTensorType resultType = - ExtractSliceOp::inferCanonicalRankReducedResultType( - sliceOpType.getRank(), sliceOp.getSourceType(), - sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(), - sliceOp.getMixedStrides()); Value newResult = rewriter.create( - loc, resultType, castOp.getSource(), sliceOp.getOffsets(), + loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); - if (newResult.getType() != sliceOpType) - newResult = rewriter.create(loc, sliceOpType, newResult); + if (newResult.getType() != sliceOp.getType()) + newResult = rewriter.create(loc, sliceOp.getType(), newResult); rewriter.replaceOp(sliceOp, newResult); return success(); }