diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -7,11 +7,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/Support/Debug.h" #define DEBUG_TYPE "vector-drop-unit-dim" @@ -37,7 +39,7 @@ namespace { // Casts away leading one dimensions in vector.extract_strided_slice's vector -// input by inserting vector.shape_cast. +// input by inserting vector.broadcast. struct CastAwayExtractStridedSliceLeadingOneDim : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -84,8 +86,8 @@ } }; -// Casts away leading one dimensions in vector.extract_strided_slice's vector -// inputs by inserting vector.shape_cast. +// Casts away leading one dimensions in vector.insert_strided_slice's vector +// inputs by inserting vector.broadcast. struct CastAwayInsertStridedSliceLeadingOneDim : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -125,6 +127,61 @@ } }; +// Casts away leading one dimensions in vector.insert's vector inputs by +// inserting vector.broadcast. +struct CastAwayInsertLeadingOneDim : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::InsertOp insertOp, + PatternRewriter &rewriter) const override { + Type oldSrcType = insertOp.getSourceType(); + Type newSrcType = oldSrcType; + int64_t oldSrcRank = 0, newSrcRank = 0; + if (auto type = oldSrcType.dyn_cast()) { + newSrcType = trimLeadingOneDims(type); + oldSrcRank = type.getRank(); + newSrcRank = newSrcType.cast().getRank(); + } + + VectorType oldDstType = insertOp.getDestVectorType(); + VectorType newDstType = trimLeadingOneDims(oldDstType); + + int64_t srcDropCount = oldSrcRank - newSrcRank; + int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); + if (srcDropCount == 0 && dstDropCount == 0) + return failure(); + + // Trim leading one dimensions from both operands. + Location loc = insertOp.getLoc(); + + Value newSrcVector = insertOp.getSource(); + if (oldSrcRank != 0) { + newSrcVector = rewriter.create( + loc, insertOp.getSource(), splatZero(srcDropCount)); + } + Value newDstVector = rewriter.create( + loc, insertOp.getDest(), splatZero(dstDropCount)); + + unsigned oldPosRank = insertOp.getPosition().getValue().size(); + unsigned newPosRank = newDstType.getRank() - newSrcRank; + SmallVector newPositions = llvm::to_vector( + insertOp.getPosition().getValue().take_back(newPosRank)); + if (newPosRank > oldPosRank) { + auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type()); + newPositions.resize(newPosRank, zeroAttr); + } + + auto newInsertOp = rewriter.create( + loc, newDstType, newSrcVector, newDstVector, + rewriter.getArrayAttr(newPositions)); + + rewriter.replaceOpWithNewOp(insertOp, oldDstType, + newInsertOp); + + return success(); + } +}; + // Turns vector.transfer_read on vector with leading 1 dimensions into // vector.shape_cast followed by vector.transfer_read on vector without leading // 1 dimensions. @@ -383,7 +440,7 @@ RewritePatternSet &patterns) { patterns .add(patterns.getContext()); diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -265,3 +265,42 @@ return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32> } +// CHECK-LABEL: func @cast_away_insert_leading_one_dims_scalar +// CHECK-SAME: (%[[S:.+]]: f32, %[[V:.+]]: vector<1x1x4xf32>) +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[V]][0, 0] : vector<1x1x4xf32> +// CHECK: %[[INSERT:.+]] = vector.insert %[[S]], %[[EXTRACT]] [0] : f32 into vector<4xf32> +// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[BCAST]] +func @cast_away_insert_leading_one_dims_scalar(%s: f32, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> { + %0 = vector.insert %s, %v [0, 0, 0] : f32 into vector<1x1x4xf32> + return %0: vector<1x1x4xf32> +} + +// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank1 +// CHECK-SAME: (%[[S:.+]]: vector<4xf32>, %[[V:.+]]: vector<1x1x4xf32>) +// CHECK: %[[BCAST:.+]] = vector.broadcast %[[S]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[BCAST]] +func @cast_away_insert_leading_one_dims_rank1(%s: vector<4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> { + %0 = vector.insert %s, %v [0, 0] : vector<4xf32> into vector<1x1x4xf32> + return %0: vector<1x1x4xf32> +} + +// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2 +// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x1x4xf32>) +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32> +// CHECK: %[[BCAST:.+]] = vector.broadcast %[[EXTRACT]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[BCAST]] +func @cast_away_insert_leading_one_dims_rank2(%s: vector<1x4xf32>, %v: vector<1x1x4xf32>) -> vector<1x1x4xf32> { + %0 = vector.insert %s, %v [0] : vector<1x4xf32> into vector<1x1x4xf32> + return %0: vector<1x1x4xf32> +} + +// CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest +// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>) +// CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32> +// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACT]], %[[V]] [5, 0] : vector<4xf32> into vector<8x1x4xf32> +// CHECK: return %[[INSERT]] +func @cast_away_insert_leading_one_dims_non_one_dest(%s: vector<1x4xf32>, %v: vector<8x1x4xf32>) -> vector<8x1x4xf32> { + %0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32> + return %0: vector<8x1x4xf32> +}