diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1808,12 +1808,34 @@ } }; +// Folds extract(shape_cast(..)) into shape_cast when the total element count +// does not change. +LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp, + PatternRewriter &rewriter) { + auto castOp = extractOp.getVector().getDefiningOp(); + if (!castOp) + return failure(); + + VectorType sourceType = castOp.getSourceVectorType(); + auto targetType = dyn_cast(extractOp.getResult().getType()); + if (!targetType) + return failure(); + + if (sourceType.getNumElements() != targetType.getNumElements()) + return failure(); + + rewriter.replaceOpWithNewOp(extractOp, targetType, + castOp.getSource()); + return success(); +} + } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + results.add(foldExtractFromShapeCastToShapeCast); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -669,6 +669,18 @@ // ----- +// CHECK-LABEL: fold_extract_shapecast_to_shapecast +// CHECK-SAME: (%[[ARG:.+]]: vector<3x4xf32>) +// CHECK: %[[R:.+]] = vector.shape_cast %[[ARG]] : vector<3x4xf32> to vector<12xf32> +// CHECK: return %[[R]] +func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vector<12xf32> { + %0 = vector.shape_cast %arg0 : vector<3x4xf32> to vector<1x12xf32> + %r = vector.extract %0[0] : vector<1x12xf32> + return %r : vector<12xf32> +} + +// ----- + // CHECK-LABEL: dont_fold_expand_collapse // CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32> // CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>