diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -54,6 +54,10 @@ PatternRewriter &rewriter) const override { auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); + + if (sourceVectorType.isScalable() || resultVectorType.isScalable()) + return failure(); + if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) return failure(); @@ -87,6 +91,10 @@ PatternRewriter &rewriter) const override { auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); + + if (sourceVectorType.isScalable() || resultVectorType.isScalable()) + return failure(); + if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) return failure(); @@ -106,6 +114,20 @@ } }; +static void incIdx(llvm::MutableArrayRef idx, VectorType tp, + int dimIdx, int initialStep = 1) { + int step = initialStep; + for (int d = dimIdx; d >= 0; d--) { + idx[d] += step; + if (idx[d] >= tp.getDimSize(d)) { + idx[d] = 0; + step = 1; + } else { + break; + } + } +} + // We typically should not lower general shape cast operations into data // movement instructions, since the assumption is that these casts are // optimized away during progressive lowering. For completeness, however, @@ -121,6 +143,9 @@ auto sourceVectorType = op.getSourceVectorType(); auto resultVectorType = op.getResultVectorType(); + if (sourceVectorType.isScalable() || resultVectorType.isScalable()) + return failure(); + // Special case 2D / 1D lowerings with better implementations. // TODO: make is ND / 1D to allow generic ND -> 1D -> MD. int64_t srcRank = sourceVectorType.getRank(); @@ -175,21 +200,111 @@ rewriter.replaceOp(op, result); return success(); } +}; -private: - static void incIdx(SmallVector &idx, VectorType tp, int64_t r) { - assert(0 <= r && r < tp.getRank()); - if (++idx[r] == tp.getDimSize(r)) { - idx[r] = 0; - incIdx(idx, tp, r - 1); +// A shape_cast lowering for scalable vectors with a single trailing scalable +// dimension. This is similar to the general shape_cast lowering but makes use +// of vector.scalable.insert and vector.scalable.extract to move elements a +// subvector at a time. +class ScalableShapeCastOpRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + auto srcRank = sourceVectorType.getRank(); + auto resRank = resultVectorType.getRank(); + + // vector.scalable.insert/extract only accept 1D vectors, so only a trailing + // scalable dim is supported. + if (!isVectorTypeWithtrailingScalableDim(sourceVectorType) || + !isVectorTypeWithtrailingScalableDim(resultVectorType)) { + return failure(); } + + auto sourceMinScalableSize = sourceVectorType.getShape().back(); + auto resultMinScalableSize = resultVectorType.getShape().back(); + + auto extractionSize = + std::min(sourceMinScalableSize, resultMinScalableSize); + + int64_t minNumElts = 1; + for (auto size : sourceVectorType.getShape()) + minNumElts *= size; + + SmallVector srcIdx(srcRank); + SmallVector resIdx(resRank); + + auto extractionVectorType = VectorType::get( + {extractionSize}, sourceVectorType.getElementType(), {true}); + + Value result = rewriter.create( + loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); + + Value currentResultScalableVector; + for (int64_t i = 0; i < minNumElts; i += extractionSize) { + + Value extractedSubVector = op.getSource(); + if (srcRank != 1) { + extractedSubVector = rewriter.create( + loc, op.getSource(), llvm::ArrayRef(srcIdx).drop_back()); + } + if (extractionSize < sourceMinScalableSize) { + extractedSubVector = rewriter.create( + loc, extractionVectorType, extractedSubVector, srcIdx.back()); + } + + if (!currentResultScalableVector) { + if (extractionSize == resultMinScalableSize) { + currentResultScalableVector = extractedSubVector; + } else if (resRank != 1) { + currentResultScalableVector = rewriter.create( + loc, result, llvm::ArrayRef(resIdx).drop_back()); + } else { + currentResultScalableVector = result; + } + } + + if (extractionSize < resultMinScalableSize) { + currentResultScalableVector = rewriter.create( + loc, extractedSubVector, currentResultScalableVector, + resIdx.back()); + } + + if (resIdx.back() + extractionSize >= resultMinScalableSize) { + // Will wrap to next scalable vector, insert complete 1D slice into + // result. + result = rewriter.create( + loc, currentResultScalableVector, result, + llvm::ArrayRef(resIdx).drop_back()); + currentResultScalableVector = {}; + } + + incIdx(srcIdx, sourceVectorType, srcRank - 1, extractionSize); + incIdx(resIdx, resultVectorType, resRank - 1, extractionSize); + } + + rewriter.replaceOp(op, result); + return success(); + } + + static bool isVectorTypeWithtrailingScalableDim(VectorType type) { + return type.getRank() >= 1 && type.getScalableDims().back() && + !llvm::is_contained(type.getScalableDims().drop_back(), true); } }; + } // namespace void mlir::vector::populateVectorShapeCastLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( - patterns.getContext(), benefit); + ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern, + ScalableShapeCastOpRewritePattern>(patterns.getContext(), + benefit); } diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-scalable-vectors.mlir @@ -0,0 +1,81 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter | FileCheck %s + +/// This tests that shape casts of scalable vectors (with one trailing scalable dim) +/// can be correctly lowered to vector.scalable.insert/extract. + +// CHECK-LABEL: i32_vector_shape_cast_trailing_3d_scalable_to_1d +// CHECK-SAME: %[[arg0:.*]]: vector<2x1x[4]xi32> +func.func @i32_vector_shape_cast_trailing_3d_scalable_to_1d(%arg0: vector<2x1x[4]xi32>) -> vector<[8]xi32> +{ + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[8]xi32> + // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0, 0] : vector<2x1x[4]xi32> + // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[4]xi32> into vector<[8]xi32> + // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1, 0] : vector<2x1x[4]xi32> + // CHECK-NEXT: %[[res1:.*]] = vector.scalable.insert %[[subvec1]], %[[res0]][4] : vector<[4]xi32> into vector<[8]xi32> + %flat = vector.shape_cast %arg0 : vector<2x1x[4]xi32> to vector<[8]xi32> + // CHECK-NEXT: return %[[res1]] : vector<[8]xi32> + return %flat : vector<[8]xi32> +} + +// ----- + +// CHECK-LABEL: i32_vector_shape_cast_1d_scalable_to_3d +// CHECK-SAME: %[[arg0:.*]]: vector<[8]xi32> +func.func @i32_vector_shape_cast_1d_scalable_to_3d(%arg0: vector<[8]xi32>) -> vector<2x1x[4]xi32> { + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<2x1x[4]xi32> + // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %[[arg0]][0] : vector<[4]xi32> from vector<[8]xi32> + // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> + // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][4] : vector<[4]xi32> from vector<[8]xi32> + // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1, 0] : vector<[4]xi32> into vector<2x1x[4]xi32> + %unflat = vector.shape_cast %arg0 : vector<[8]xi32> to vector<2x1x[4]xi32> + // CHECK-NEXT: return %[[res1]] : vector<2x1x[4]xi32> + return %unflat : vector<2x1x[4]xi32> +} + +// ----- + +// CHECK-LABEL: i8_vector_shape_cast_trailing_2d_scalable_to_1d +// CHECK-SAME: %[[arg0:.*]]: vector<4x[8]xi8> +func.func @i8_vector_shape_cast_trailing_2d_scalable_to_1d(%arg0: vector<4x[8]xi8>) -> vector<[32]xi8> { + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<[32]xi8> + // CHECK-NEXT: %[[subvec0:.*]] = vector.extract %[[arg0]][0] : vector<4x[8]xi8> + // CHECK-NEXT: %[[res0:.*]] = vector.scalable.insert %[[subvec0]], %[[cst]][0] : vector<[8]xi8> into vector<[32]xi8> + // CHECK-NEXT: %[[subvec1:.*]] = vector.extract %[[arg0]][1] : vector<4x[8]xi8> + // CHECK-NEXT: %[[res1:.*]] = vector.scalable.insert %[[subvec1]], %[[res0]][8] : vector<[8]xi8> into vector<[32]xi8> + // CHECK-NEXT: %[[subvec2:.*]] = vector.extract %[[arg0]][2] : vector<4x[8]xi8> + // CHECK-NEXT: %[[res2:.*]] = vector.scalable.insert %[[subvec2]], %[[res1]][16] : vector<[8]xi8> into vector<[32]xi8> + // CHECK-NEXT: %[[subvec3:.*]] = vector.extract %[[arg0]][3] : vector<4x[8]xi8> + // CHECK-NEXT: %[[res3:.*]] = vector.scalable.insert %[[subvec3]], %[[res2]][24] : vector<[8]xi8> into vector<[32]xi8> + %flat = vector.shape_cast %arg0 : vector<4x[8]xi8> to vector<[32]xi8> + // CHECK-NEXT: return %[[res3]] : vector<[32]xi8> + return %flat : vector<[32]xi8> +} + +// ----- + +// CHECK-LABEL: i8_vector_shape_cast_1d_scalable_to_2d +// CHECK-SAME: %[[arg0:.*]]: vector<[32]xi8> +func.func @i8_vector_shape_cast_1d_scalable_to_2d(%arg0: vector<[32]xi8>) -> vector<4x[8]xi8> { + // CHECK-NEXT: %[[cst:.*]] = arith.constant dense<0> : vector<4x[8]xi8> + // CHECK-NEXT: %[[subvec0:.*]] = vector.scalable.extract %arg0[0] : vector<[8]xi8> from vector<[32]xi8> + // CHECK-NEXT: %[[res0:.*]] = vector.insert %[[subvec0]], %[[cst]] [0] : vector<[8]xi8> into vector<4x[8]xi8> + // CHECK-NEXT: %[[subvec1:.*]] = vector.scalable.extract %[[arg0]][8] : vector<[8]xi8> from vector<[32]xi8> + // CHECK-NEXT: %[[res1:.*]] = vector.insert %[[subvec1]], %[[res0]] [1] : vector<[8]xi8> into vector<4x[8]xi8> + // CHECK-NEXT: %[[subvec2:.*]] = vector.scalable.extract %[[arg0]][16] : vector<[8]xi8> from vector<[32]xi8> + // CHECK-NEXT: %[[res2:.*]] = vector.insert %[[subvec2]], %[[res1]] [2] : vector<[8]xi8> into vector<4x[8]xi8> + // CHECK-NEXT: %[[subvec3:.*]] = vector.scalable.extract %[[arg0]][24] : vector<[8]xi8> from vector<[32]xi8> + // CHECK-NEXT: %[[res3:.*]] = vector.insert %[[subvec3]], %[[res2]] [3] : vector<[8]xi8> into vector<4x[8]xi8> + %unflat = vector.shape_cast %arg0 : vector<[32]xi8> to vector<4x[8]xi8> + // CHECK-NEXT: return %[[res3]] : vector<4x[8]xi8> + return %unflat : vector<4x[8]xi8> +} + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + %f = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + + transform.apply_patterns to %f { + transform.apply_patterns.vector.lower_shape_cast + } : !transform.any_op +}