diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -54,6 +54,10 @@ PatternRewriter &rewriter) const override { auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); + + if (sourceVectorType.isScalable() || resultVectorType.isScalable()) + return failure(); + if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) return failure(); @@ -87,6 +91,10 @@ PatternRewriter &rewriter) const override { auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); + + if (sourceVectorType.isScalable() || resultVectorType.isScalable()) + return failure(); + if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) return failure(); @@ -106,6 +114,20 @@ } }; +static void incIdx(llvm::MutableArrayRef idx, VectorType tp, + int dimIdx, int initialStep = 1) { + int step = initialStep; + for (int d = dimIdx; d >= 0; d--) { + idx[d] += step; + if (idx[d] >= tp.getDimSize(d)) { + idx[d] = 0; + step = 1; + } else { + break; + } + } +} + // We typically should not lower general shape cast operations into data // movement instructions, since the assumption is that these casts are // optimized away during progressive lowering. For completeness, however, @@ -121,6 +143,9 @@ auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); + if (sourceVectorType.isScalable() || resultVectorType.isScalable()) + return failure(); + // Special case 2D / 1D lowerings with better implementations. // TODO: make is ND / 1D to allow generic ND -> 1D -> MD. int64_t srcRank = sourceVectorType.getRank(); @@ -175,21 +200,159 @@ rewriter.replaceOp(op, result); return success(); } +}; -private: - static void incIdx(SmallVector &idx, VectorType tp, int64_t r) { - assert(0 <= r && r < tp.getRank()); - if (++idx[r] == tp.getDimSize(r)) { - idx[r] = 0; - incIdx(idx, tp, r - 1); +/// A shape_cast lowering for scalable vectors with a single trailing scalable +/// dimension. This is similar to the general shape_cast lowering but makes use +/// of vector.scalable.insert and vector.scalable.extract to move elements a +/// subvector at a time. +/// +/// E.g.: +/// ``` +/// // Flatten scalable vector +/// %0 = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> +/// ``` +/// is rewritten to: +/// ``` +/// // Flatten scalable vector +/// %c = arith.constant dense<0> : vector<[8]xi32> +/// %0 = vector.extract %arg0[0, 0] : vector<2x1x[4]xi32> +/// %1 = vector.scalable.insert %0, %c[0] : vector<[4]xi32> into vector<[8]xi32> +/// %2 = vector.extract %arg0[1, 0] : vector<2x1x[4]xi32> +/// %3 = vector.scalable.insert %2, %1[4] : vector<[4]xi32> into vector<[8]xi32> +/// ``` +/// or: +/// ``` +/// // Un-flatten scalable vector +/// %0 = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> +/// ``` +/// is rewritten to: +/// ``` +/// // Un-flatten scalable vector +/// %c = arith.constant dense<0> : vector<2x1x[4]xi32> +/// %0 = vector.scalable.extract %arg0[0] : vector<[4]xi32> from vector<[8]xi32> +/// %1 = vector.insert %0, %c [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> +/// %2 = vector.scalable.extract %arg0[4] : vector<[4]xi32> from vector<[8]xi32> +/// %3 = vector.insert %2, %1 [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> +/// ``` +class ScalableShapeCastOpRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + auto srcRank = sourceVectorType.getRank(); + auto resRank = resultVectorType.getRank(); + + // This can only lower shape_casts where both the source and result types + // have a single trailing scalable dimension. This is because there are no + // legal representation of other scalable types in LLVM (and likely won't be + // soon). There are also (currently) no operations that can index or extract + // from >= 2D scalable vectors or scalable vectors of fixed vectors. + if (!isTrailingDimScalable(sourceVectorType) || + !isTrailingDimScalable(resultVectorType)) { + return failure(); } + + // The sizes of the trailing dimension of the source and result vectors, the + // size of subvector to move, and the number of elements in the vectors. + // These are "min" sizes as they are the size when vscale == 1. + auto minSourceTrailingSize = sourceVectorType.getShape().back(); + auto minResultTrailingSize = resultVectorType.getShape().back(); + auto minExtractionSize = + std::min(minSourceTrailingSize, minResultTrailingSize); + int64_t minNumElts = 1; + for (auto size : sourceVectorType.getShape()) + minNumElts *= size; + + // The subvector type to move from the source to the result. Note that this + // is a scalable vector. This rewrite will generate code in terms of the + // "min" size (vscale == 1 case), that scales to any vscale. + auto extractionVectorType = VectorType::get( + {minExtractionSize}, sourceVectorType.getElementType(), {true}); + + Value result = rewriter.create( + loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); + + SmallVector srcIdx(srcRank); + SmallVector resIdx(resRank); + + Value currentResultScalableVector; + Value currentSourceScalableVector; + for (int64_t i = 0; i < minNumElts; i += minExtractionSize) { + // Extract a scalable subvector from the source vector. + if (!currentSourceScalableVector) { + if (srcRank != 1) { + currentSourceScalableVector = rewriter.create( + loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back()); + } else { + currentSourceScalableVector = op.getSource(); + } + } + Value sourceSubVector = currentSourceScalableVector; + if (minExtractionSize < minSourceTrailingSize) { + sourceSubVector = rewriter.create( + loc, extractionVectorType, sourceSubVector, srcIdx.back()); + } + + // Insert the scalable subvector into the result vector. + if (!currentResultScalableVector) { + if (minExtractionSize == minResultTrailingSize) { + currentResultScalableVector = sourceSubVector; + } else if (resRank != 1) { + currentResultScalableVector = rewriter.create( + loc, result, llvm::ArrayRef(resIdx).drop_back()); + } else { + currentResultScalableVector = result; + } + } + if (minExtractionSize < minResultTrailingSize) { + currentResultScalableVector = rewriter.create( + loc, sourceSubVector, currentResultScalableVector, resIdx.back()); + } + + if (resIdx.back() + minExtractionSize >= minResultTrailingSize && + currentResultScalableVector != result) { + // Finished row of result. Insert complete scalable vector into result + // (n-D) vector. + result = rewriter.create( + loc, currentResultScalableVector, result, + llvm::ArrayRef(resIdx).drop_back()); + currentResultScalableVector = {}; + } + + if (srcIdx.back() + minExtractionSize >= minSourceTrailingSize) { + // Finished row of source. + currentSourceScalableVector = {}; + } + + // Increment the insert/extract indices, stepping by minExtractionSize for + // the trailing dimensions. + incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize); + incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize); + } + + rewriter.replaceOp(op, result); + return success(); + } + + static bool isTrailingDimScalable(VectorType type) { + return type.getRank() >= 1 && type.getScalableDims().back() && + !llvm::is_contained(type.getScalableDims().drop_back(), true); } }; + } // namespace void mlir::vector::populateVectorShapeCastLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( - patterns.getContext(), benefit); + ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern, + ScalableShapeCastOpRewritePattern>(patterns.getContext(), + benefit); } diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir @@ -0,0 +1,214 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s + +/// This tests that shape casts of scalable vectors (with one trailing scalable dim) +/// can be correctly lowered to vector.scalable.insert/extract. + +// CHECK-LABEL: i32_3d_to_1d_last_dim_scalable +// CHECK-SAME: %[[arg0:.*]]: vector<2x1x[4]xi32> +func.func @i32_3d_to_1d_last_dim_scalable(%arg0: vector<2x1x[4]xi32>) -> vector<[8]xi32> +{ + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[8]xi32> + // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<2x1x[4]xi32> + // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[4]xi32> into vector<[8]xi32> + // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1, 0] : vector<2x1x[4]xi32> + // CHECK-NEXT: %[[res1:.*]] = vector.scalable.insert %[[subvec1]], %[[res0]][4] : vector<[4]xi32> into vector<[8]xi32> + %flat = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> + // CHECK-NEXT: return %[[res1]] : vector<[8]xi32> + return %flat : vector<[8]xi32> +} + +// ----- + +// CHECK-LABEL: i32_1d_to_3d_last_dim_scalable +// CHECK-SAME: %[[arg0:.*]]: vector<[8]xi32> +func.func @i32_1d_to_3d_last_dim_scalable(%arg0: vector<[8]xi32>) -> vector<2x1x[4]xi32> { + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<2x1x[4]xi32> + // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[4]xi32> from vector<[8]xi32> + // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> + // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][4] : vector<[4]xi32> from vector<[8]xi32> + // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> + %unflat = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> + // CHECK-NEXT: return %[[res1]] : vector<2x1x[4]xi32> + return %unflat : vector<2x1x[4]xi32> +} + +// ----- + +// CHECK-LABEL: i8_2d_to_1d_last_dim_scalable +// CHECK-SAME: %[[arg0:.*]]: vector<4x[8]xi8> +func.func @i8_2d_to_1d_last_dim_scalable(%arg0: vector<4x[8]xi8>) -> vector<[32]xi8> { + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[32]xi8> + // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<4x[8]xi8> + // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[8]xi8> into vector<[32]xi8> + // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<4x[8]xi8> + // CHECK-NEXT: %[[res1:.*]] = vector.scalable.insert %[[subvec1]], %[[res0]][8] : vector<[8]xi8> into vector<[32]xi8> + // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][2] : vector<4x[8]xi8> + // CHECK-NEXT: %[[res2:.*]] = vector.scalable.insert %[[subvec2]], %[[res1]][16] : vector<[8]xi8> into vector<[32]xi8> + // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][3] : vector<4x[8]xi8> + // CHECK-NEXT: %[[res3:.*]] = vector.scalable.insert %[[subvec3]], %[[res2]][24] : vector<[8]xi8> into vector<[32]xi8> + %flat = vector.shape_cast %arg0 : vector<4x[8]xi8> to vector<[32]xi8> + // CHECK-NEXT: return %[[res3]] : vector<[32]xi8> + return %flat : vector<[32]xi8> +} + +// ----- + +// CHECK-LABEL: i8_1d_to_2d_last_dim_scalable +// CHECK-SAME: %[[arg0:.*]]: vector<[32]xi8> +func.func @i8_1d_to_2d_last_dim_scalable(%arg0: vector<[32]xi8>) -> vector<4x[8]xi8> { + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<4x[8]xi8> + // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[8]xi8> from vector<[32]xi8> + // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[8]xi8> into vector<4x[8]xi8> + // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][8] : vector<[8]xi8> from vector<[32]xi8> + // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[8]xi8> into vector<4x[8]xi8> + // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[arg0]][16] : vector<[8]xi8> from vector<[32]xi8> + // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[8]xi8> into vector<4x[8]xi8> + // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[arg0]][24] : vector<[8]xi8> from vector<[32]xi8> + // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[8]xi8> into vector<4x[8]xi8> + %unflat = vector.shape_cast %arg0 : vector<[32]xi8> to vector<4x[8]xi8> + // CHECK-NEXT: return %[[res3]] : vector<4x[8]xi8> + return %unflat : vector<4x[8]xi8> +} + +// ----- + +// CHECK-LABEL: f32_permute_leading_non_scalable_dims +// CHECK-SAME: %[[arg0:.*]]: vector<2x3x[4]xf32> +func.func @f32_permute_leading_non_scalable_dims(%arg0: vector<2x3x[4]xf32>) -> vector<3x2x[4]xf32> { + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x[4]xf32> + // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<2x3x[4]xf32> + // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xf32> into vector<3x2x[4]xf32> + // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<2x3x[4]xf32> + // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [0, 1] : vector<[4]xf32> into vector<3x2x[4]xf32> + // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][0, 2] : vector<2x3x[4]xf32> + // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [1, 0] : vector<[4]xf32> into vector<3x2x[4]xf32> + // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 0] : vector<2x3x[4]xf32> + // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [1, 1] : vector<[4]xf32> into vector<3x2x[4]xf32> + // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][1, 1] : vector<2x3x[4]xf32> + // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [2, 0] : vector<[4]xf32> into vector<3x2x[4]xf32> + // CHECK-NEXT: %[[subvec5:.*]] = vector.extract %[[arg0]][1, 2] : vector<2x3x[4]xf32> + // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [2, 1] : vector<[4]xf32> into vector<3x2x[4]xf32> + %res = vector.shape_cast %arg0: vector<2x3x[4]xf32> to vector<3x2x[4]xf32> + // CHECK-NEXT: return %[[res5]] : vector<3x2x[4]xf32> + return %res : vector<3x2x[4]xf32> +} + +// ----- + +// CHECK-LABEL: f64_flatten_leading_non_scalable_dims +// CHECK-SAME: %[[arg0:.*]]: vector<2x2x[2]xf64> +func.func @f64_flatten_leading_non_scalable_dims(%arg0: vector<2x2x[2]xf64>) -> vector<4x[2]xf64> +{ + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<4x[2]xf64> + // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<2x2x[2]xf64> + // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf64> into vector<4x[2]xf64> + // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][0, 1] : vector<2x2x[2]xf64> + // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf64> into vector<4x[2]xf64> + // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][1, 0] : vector<2x2x[2]xf64> + // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf64> into vector<4x[2]xf64> + // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][1, 1] : vector<2x2x[2]xf64> + // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf64> into vector<4x[2]xf64> + %res = vector.shape_cast %arg0: vector<2x2x[2]xf64> to vector<4x[2]xf64> + // CHECK-NEXT: return %7 : vector<4x[2]xf64> + return %res : vector<4x[2]xf64> +} + +// ----- + +// CHECK-LABEL: f32_reduce_trailing_scalable_dim +// CHECK-SAME: %[[arg0:.*]]: vector<3x[4]xf32> +func.func @f32_reduce_trailing_scalable_dim(%arg0: vector<3x[4]xf32>) -> vector<6x[2]xf32> +{ + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<6x[2]xf32> + // CHECK-NEXT: %[[srcvec0:.*]] = vector.extract %[[arg0]][0] : vector<3x[4]xf32> + // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[srcvec0]][0] : vector<[2]xf32> from vector<[4]xf32> + // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[2]xf32> into vector<6x[2]xf32> + // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[srcvec0]][2] : vector<[2]xf32> from vector<[4]xf32> + // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[2]xf32> into vector<6x[2]xf32> + // CHECK-NEXT: %[[srcvec1:.*]] = vector.extract %[[arg0]][1] : vector<3x[4]xf32> + // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[srcvec1]][0] : vector<[2]xf32> from vector<[4]xf32> + // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[2]xf32> into vector<6x[2]xf32> + // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[srcvec1]][2] : vector<[2]xf32> from vector<[4]xf32> + // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[2]xf32> into vector<6x[2]xf32> + // CHECK-NEXT: %[[srcvec2:.*]] = vector.extract %[[arg0]][2] : vector<3x[4]xf32> + // CHECK-NEXT: %[[subvec4:.*]] = vector.scalable.extract %[[srcvec2]][0] : vector<[2]xf32> from vector<[4]xf32> + // CHECK-NEXT: %[[res4:.*]] = vector.insert %[[subvec4]], %[[res3]] [4] : vector<[2]xf32> into vector<6x[2]xf32> + // CHECK-NEXT: %[[subvec5:.*]] = vector.scalable.extract %[[srcvec2]][2] : vector<[2]xf32> from vector<[4]xf32> + // CHECK-NEXT: %[[res5:.*]] = vector.insert %[[subvec5]], %[[res4]] [5] : vector<[2]xf32> into vector<6x[2]xf32> + %res = vector.shape_cast %arg0: vector<3x[4]xf32> to vector<6x[2]xf32> + // CHECK-NEXT: return %[[res5]] : vector<6x[2]xf32> + return %res: vector<6x[2]xf32> +} + +// ----- + +// CHECK-LABEL: f32_increase_trailing_scalable_dim +// CHECK-SAME: %[[arg0:.*]]: vector<4x[2]xf32> +func.func @f32_increase_trailing_scalable_dim(%arg0: vector<4x[2]xf32>) -> vector<2x[4]xf32> +{ + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0.000000e+00> : vector<2x[4]xf32> + // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<4x[2]xf32> + // CHECK-NEXT: %[[resvec0:.*]] = vector.extract %[[cst]][0] : vector<2x[4]xf32> + // CHECK-NEXT: %[[resvec1:.*]] = vector.scalable.insert %[[subvec0]], %[[resvec0]][0] : vector<[2]xf32> into vector<[4]xf32> + // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<4x[2]xf32> + // CHECK-NEXT: %[[resvec2:.*]] = vector.scalable.insert %[[subvec1]], %[[resvec1]][2] : vector<[2]xf32> into vector<[4]xf32> + // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[resvec2]], %[[cst]] [0] : vector<[4]xf32> into vector<2x[4]xf32> + // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][2] : vector<4x[2]xf32> + // CHECK-NEXT: %[[resvec3:.*]] = vector.extract %[[cst]][1] : vector<2x[4]xf32> + // CHECK-NEXT: %[[resvec4:.*]] = vector.scalable.insert %[[subvec3]], %[[resvec3]][0] : vector<[2]xf32> into vector<[4]xf32> + // CHECK-NEXT: %[[subvec4:.*]] = vector.extract %[[arg0]][3] : vector<4x[2]xf32> + // CHECK-NEXT: %[[resvec5:.*]] = vector.scalable.insert %[[subvec4]], %[[resvec4]][2] : vector<[2]xf32> into vector<[4]xf32> + // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[resvec5]], %[[res0]] [1] : vector<[4]xf32> into vector<2x[4]xf32> + %res = vector.shape_cast %arg0: vector<4x[2]xf32> to vector<2x[4]xf32> + // CHECK-NEXT: return %[[res1]] : vector<2x[4]xf32> + return %res: vector<2x[4]xf32> +} + +// ----- + +/// The following shape_casts are not supported as the types cannot be +/// represented in LLVM (and likely won't be supported soon), and currently +/// there's no ops that could do the extracts/inserts required. + +// ----- + +// CHECK-LABEL: cannot_cast_to_non_trailing_scalable_dim +// CHECK-SAME: %[[arg0:.*]]: vector<[4]xf32> +func.func @cannot_cast_to_non_trailing_scalable_dim(%arg0: vector<[4]xf32>) -> vector<[2]x2xf32> { + // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[4]xf32> to vector<[2]x2xf32> + %res = vector.shape_cast %arg0 : vector<[4]xf32> to vector<[2]x2xf32> + // CHECK-NEXT: return %[[res]] : vector<[2]x2xf32> + return %res: vector<[2]x2xf32> +} + +// ----- + +// CHECK-LABEL: cannot_shape_cast_from_non_trailing_scalable_dim +// CHECK-SAME: %[[arg0:.*]]: vector<[2]x2xf32> +func.func @cannot_shape_cast_from_non_trailing_scalable_dim(%arg0: vector<[2]x2xf32>) -> vector<[4]xf32> { + // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[2]x2xf32> to vector<[4]xf32> + %res = vector.shape_cast %arg0 : vector<[2]x2xf32> to vector<[4]xf32> + // CHECK-NEXT: return %[[res]] : vector<[4]xf32> + return %res: vector<[4]xf32> +} + +// ----- + +// CHECK-LABEL: cannot_shape_cast_more_than_one_scalable_dim +// CHECK-SAME: %[[arg0:.*]]: vector<[4]x[4]xf32> +func.func @cannot_shape_cast_more_than_one_scalable_dim(%arg0: vector<[4]x[4]xf32>) -> vector<2x[2]x[4]xf32> { + // CHECK-NEXT: %[[res:.*]] = vector.shape_cast %[[arg0]] : vector<[4]x[4]xf32> to vector<2x[2]x[4]xf32> + %res = vector.shape_cast %arg0 : vector<[4]x[4]xf32> to vector<2x[2]x[4]xf32> + // CHECK-NEXT: return %[[res]] : vector<2x[2]x[4]xf32> + return %res: vector<2x[2]x[4]xf32> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.lower_shape_cast + } : !transform.any_op +}