diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1125,9 +1125,6 @@ b.getI64ArrayAttr(extractPos)); return extractOp.getResult(); } - // TODO: In case the rank of the broadcast source is greater than the rank of - // the extract result this can be combined into a new broadcast op. This needs - // to be added a canonicalization pattern if needed. return Value(); } @@ -1208,12 +1205,63 @@ namespace { +// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. +class ExtractOpFromBroadcast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractOp extractOp, + PatternRewriter &rewriter) const override { + Operation *defOp = extractOp.vector().getDefiningOp(); + if (!defOp || !isa(defOp)) + return failure(); + Value source = defOp->getOperand(0); + if (extractOp.getType() == source.getType()) + return failure(); + auto getRank = [](Type type) { + return type.isa() ? type.cast().getRank() : 0; + }; + unsigned broadcasrSrcRank = getRank(source.getType()); + unsigned extractResultRank = getRank(extractOp.getType()); + // We only consider the case where the rank of the source is smaller than + // the rank of the extract dst. The other cases are handled in the folding + // patterns. + if (extractResultRank <= broadcasrSrcRank) + return failure(); + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType(), source); + return success(); + } +}; + +// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. +class ExtractOpConstantFolder final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // Return if 'extractStridedSliceOp' operand is not defined by a + // ConstantOp. + auto constantOp = extractOp.vector().getDefiningOp(); + if (!constantOp) + return failure(); + auto dense = constantOp.getValue().dyn_cast(); + if (!dense) + return failure(); + Attribute newAttr = dense.getSplatValue(); + if (auto vecDstType = extractOp.getType().dyn_cast()) + newAttr = DenseElementsAttr::get(vecDstType, newAttr); + rewriter.replaceOpWithNewOp(extractOp, newAttr); + return success(); + } +}; + } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // ExtractToShapeCast is not a default canonicalization, it is opt-in by - // calling `populateCastAwayVectorLeadingOneDimPatterns` + results.add(context); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, @@ -1555,10 +1603,31 @@ return success(); } +namespace { + +// If insertOp is only inserting unit dimensions it can be transformed to a +// broadcast. +class InsertToBroadcast final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertOp insertOp, + PatternRewriter &rewriter) const override { + auto srcVecType = insertOp.getSourceType().dyn_cast(); + if (!srcVecType || insertOp.getDestVectorType().getNumElements() != + srcVecType.getNumElements()) + return failure(); + rewriter.replaceOpWithNewOp( + insertOp, insertOp.getDestVectorType(), insertOp.source()); + return success(); + } +}; + +} // namespace + void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - // InsertToShapeCast is not a default canonicalization, it is opt-in by - // calling `populateCastAwayVectorLeadingOneDimPatterns` + results.add(context); } // Eliminates insert operations that produce values identical to their source diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2943,6 +2943,11 @@ return VectorType::get(newShape, oldType.getElementType()); } +/// Return a smallVector of size `rank` containing all zeros. +static SmallVector splatZero(int64_t rank) { + return SmallVector(rank, 0); +} + // Casts away leading one dimensions in vector.extract_strided_slice's vector // input by inserting vector.shape_cast. struct CastAwayExtractStridedSliceLeadingOneDim @@ -2969,8 +2974,8 @@ Location loc = extractOp.getLoc(); - Value newSrcVector = rewriter.create( - loc, newSrcType, extractOp.vector()); + Value newSrcVector = rewriter.create( + loc, extractOp.vector(), splatZero(dropCount)); // The offsets/sizes/strides attribute can have a less number of elements // than the input vector's rank: it is meant for the leading dimensions. @@ -2984,7 +2989,7 @@ auto newExtractOp = rewriter.create( loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides); - rewriter.replaceOpWithNewOp(extractOp, oldDstType, + rewriter.replaceOpWithNewOp(extractOp, oldDstType, newExtractOp); return success(); @@ -3004,17 +3009,18 @@ VectorType oldDstType = insertOp.getDestVectorType(); VectorType newDstType = trimLeadingOneDims(oldDstType); - if (newSrcType.getRank() == oldSrcType.getRank() && - newDstType.getRank() == oldDstType.getRank()) + int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank(); + int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank(); + if (srcDropCount == 0 && dstDropCount == 0) return failure(); // Trim leading one dimensions from both operands. Location loc = insertOp.getLoc(); - Value newSrcVector = rewriter.create( - loc, newSrcType, insertOp.source()); - Value newDstVector = - rewriter.create(loc, newDstType, insertOp.dest()); + Value newSrcVector = rewriter.create( + loc, insertOp.source(), splatZero(srcDropCount)); + Value newDstVector = rewriter.create( + loc, insertOp.dest(), splatZero(dstDropCount)); auto newOffsets = rewriter.getArrayAttr( insertOp.offsets().getValue().take_back(newDstType.getRank())); @@ -3024,7 +3030,7 @@ auto newInsertOp = rewriter.create( loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides); - rewriter.replaceOpWithNewOp(insertOp, oldDstType, + rewriter.replaceOpWithNewOp(insertOp, oldDstType, newInsertOp); return success(); @@ -3068,7 +3074,7 @@ auto newRead = rewriter.create( read.getLoc(), newType, read.source(), read.indices(), newMap, read.padding(), inBounds); - rewriter.replaceOpWithNewOp(read, oldType, newRead); + rewriter.replaceOpWithNewOp(read, oldType, newRead); return success(); } @@ -3092,9 +3098,9 @@ VectorType oldType = write.getVectorType(); VectorType newType = trimLeadingOneDims(oldType); - if (newType == oldType) return failure(); + int64_t dropDim = oldType.getRank() - newType.getRank(); AffineMap oldMap = write.permutation_map(); ArrayRef newResults = @@ -3108,8 +3114,8 @@ inBounds = rewriter.getArrayAttr( write.in_boundsAttr().getValue().take_back(newType.getRank())); - auto newVector = rewriter.create( - write.getLoc(), newType, write.vector()); + auto newVector = rewriter.create( + write.getLoc(), write.vector(), splatZero(dropDim)); rewriter.replaceOpWithNewOp( write, newVector, write.source(), write.indices(), newMap, inBounds); @@ -3117,35 +3123,6 @@ } }; -template -struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BroadCastType broadcastOp, - PatternRewriter &rewriter) const override { - VectorType dstType = - broadcastOp.getResult().getType().template dyn_cast(); - if (!dstType) - return failure(); - VectorType newDstType = trimLeadingOneDims(dstType); - if (newDstType == dstType) - return failure(); - Location loc = broadcastOp.getLoc(); - Value source = broadcastOp->getOperand(0); - VectorType srcVecType = source.getType().template dyn_cast(); - if (srcVecType) - srcVecType = trimLeadingOneDims(srcVecType); - if (srcVecType && srcVecType != source.getType()) { - source = rewriter.create(loc, srcVecType, source); - } - Value newBroadcastOp = - rewriter.create(loc, newDstType, source); - rewriter.replaceOpWithNewOp(broadcastOp, dstType, - newBroadcastOp); - return success(); - } -}; - class CastAwayElementwiseLeadingOneDim : public RewritePattern { public: CastAwayElementwiseLeadingOneDim(MLIRContext *context) @@ -3161,14 +3138,12 @@ VectorType newVecType = trimLeadingOneDims(vecType); if (newVecType == vecType) return failure(); - + int64_t dropDim = vecType.getRank() - newVecType.getRank(); SmallVector newOperands; for (Value operand : op->getOperands()) { if (auto opVecType = operand.getType().dyn_cast()) { - auto newType = - VectorType::get(newVecType.getShape(), opVecType.getElementType()); - newOperands.push_back(rewriter.create( - op->getLoc(), newType, operand)); + newOperands.push_back(rewriter.create( + op->getLoc(), operand, splatZero(dropDim))); } else { newOperands.push_back(operand); } @@ -3178,69 +3153,12 @@ state.addOperands(newOperands); state.addTypes(newVecType); Operation *newOp = rewriter.createOperation(state); - rewriter.replaceOpWithNewOp(op, vecType, + rewriter.replaceOpWithNewOp(op, vecType, newOp->getResult(0)); return success(); } }; -// If extractOp is only removing unit dimensions it can be transformed to a -// shapecast. -class ExtractToShapeCast final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ExtractOp extractOp, - PatternRewriter &rewriter) const override { - auto dstVecType = extractOp.getResult().getType().dyn_cast(); - if (!dstVecType || extractOp.getVectorType().getNumElements() != - dstVecType.getNumElements()) - return failure(); - rewriter.replaceOpWithNewOp(extractOp, dstVecType, - extractOp.vector()); - return success(); - } -}; - -// If insertOp is only inserting unit dimensions it can be transformed to a -// shapecast. -class InsertToShapeCast final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(InsertOp insertOp, - PatternRewriter &rewriter) const override { - auto srcVecType = insertOp.getSourceType().dyn_cast(); - if (!srcVecType || insertOp.getDestVectorType().getNumElements() != - srcVecType.getNumElements()) - return failure(); - rewriter.replaceOpWithNewOp( - insertOp, insertOp.getDestVectorType(), insertOp.source()); - return success(); - } -}; - -// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In -// the degenerated case where the broadcast only adds dimensions of size 1 it -// can be replaced by a ShapeCastOp. This canonicalization checks if the total -// number of elements is the same before and after the broadcast to detect if -// the only change in the vector type are new dimensions of size 1. -class BroadcastToShapeCast final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BroadcastOp broadcastOp, - PatternRewriter &rewriter) const override { - auto srcVecType = broadcastOp.getSourceType().dyn_cast(); - if (!srcVecType || broadcastOp.getVectorType().getNumElements() != - srcVecType.getNumElements()) - return failure(); - rewriter.replaceOpWithNewOp( - broadcastOp, broadcastOp.getVectorType(), broadcastOp.source()); - return success(); - } -}; - // Returns the values in `arrayAttr` as an integer vector. static SmallVector getIntValueVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>( @@ -3722,13 +3640,11 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( RewritePatternSet &patterns) { - patterns.add< - BroadcastToShapeCast, CastAwayExtractStridedSliceLeadingOneDim, - CastAwayInsertStridedSliceLeadingOneDim, - CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim, - CastAwayBroadcastLeadingOneDim, - CastAwayBroadcastLeadingOneDim, CastAwayElementwiseLeadingOneDim, - ExtractToShapeCast, InsertToShapeCast>(patterns.getContext()); + patterns.add(patterns.getContext()); populateShapeCastFoldingPatterns(patterns); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -496,13 +496,10 @@ // ----- -// Negative test for extract_op folding when the type of broadcast source -// doesn't match the type of vector.extract. -// CHECK-LABEL: fold_extract_broadcast_negative -// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<1x2x4xf32> -// CHECK: %[[R:.*]] = vector.extract %[[B]][0, 1] : vector<1x2x4xf32> -// CHECK: return %[[R]] : vector<4xf32> -func @fold_extract_broadcast_negative(%a : f32) -> vector<4xf32> { +// CHECK-LABEL: fold_extract_broadcast +// CHECK: %[[B:.*]] = vector.broadcast %{{.*}} : f32 to vector<4xf32> +// CHECK: return %[[B]] : vector<4xf32> +func @fold_extract_broadcast(%a : f32) -> vector<4xf32> { %b = vector.broadcast %a : f32 to vector<1x2x4xf32> %r = vector.extract %b[0, 1] : vector<1x2x4xf32> return %r : vector<4xf32> @@ -1058,3 +1055,31 @@ vector<16x4xf16> to vector<2x4xf16> return %1 : vector<2x4xf16> } + +// ----- + +// CHECK-LABEL: func @insert_extract_to_broadcast +// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) +// CHECK: %[[V0:.*]] = vector.extract %[[ARG0]][0, 0] : vector<1x1x4xf32> +// CHECK: %[[V1:.*]] = vector.broadcast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> +// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> +func @insert_extract_to_broadcast(%arg0 : vector<1x1x4xf32>, + %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { + %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32> + %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> + return %0, %1 : vector<4xf32>, vector<1x1x4xf32> +} + +// ----- + +// CHECK-LABEL: extract_constant +// CHECK-DAG: %[[CST1:.*]] = arith.constant 1 : i32 +// CHECK-DAG: %[[CST0:.*]] = arith.constant dense<2.000000e+00> : vector<7xf32> +// CHECK: return %[[CST0]], %[[CST1]] : vector<7xf32>, i32 +func @extract_constant() -> (vector<7xf32>, i32) { + %cst = arith.constant dense<2.000000e+00> : vector<29x7xf32> + %cst_1 = arith.constant dense<1> : vector<4x37x9xi32> + %0 = vector.extract %cst[2] : vector<29x7xf32> + %1 = vector.extract %cst_1[1, 4, 5] : vector<4x37x9xi32> + return %0, %1 : vector<7xf32>, i32 +} diff --git a/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir b/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Vector/vector-dim-one-shape-cast.mlir +++ /dev/null @@ -1,23 +0,0 @@ -// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s - -// CHECK-LABEL: broadcast_to_shapecast -// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<4x4xf16> to vector<1x4x4xf16> -// CHECK-NEXT: return %[[C]] : vector<1x4x4xf16> -func @broadcast_to_shapecast(%arg0: vector<4x4xf16>) -> vector<1x4x4xf16> { - %0 = vector.broadcast %arg0 : vector<4x4xf16> to vector<1x4x4xf16> - return %0 : vector<1x4x4xf16> -} - -// ----- - -// CHECK-LABEL: func @insert_extract_to_shapecast -// CHECK-SAME: (%[[ARG0:.*]]: vector<1x1x4xf32>, %[[ARG1:.*]]: vector<4xf32>) -// CHECK: %[[V0:.*]] = vector.shape_cast %[[ARG0]] : vector<1x1x4xf32> to vector<4xf32> -// CHECK: %[[V1:.*]] = vector.shape_cast %[[ARG1]] : vector<4xf32> to vector<1x1x4xf32> -// CHECK: return %[[V0]], %[[V1]] : vector<4xf32>, vector<1x1x4xf32> -func @insert_extract_to_shapecast(%arg0 : vector<1x1x4xf32>, - %arg1 : vector<4xf32>) -> (vector<4xf32>, vector<1x1x4xf32>) { - %0 = vector.extract %arg0[0, 0] : vector<1x1x4xf32> - %1 = vector.insert %arg1, %arg0 [0, 0] : vector<4xf32> into vector<1x1x4xf32> - return %0, %1 : vector<4xf32>, vector<1x1x4xf32> -} diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -421,21 +421,21 @@ // CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8xf16>) -> vector<1x1x8xf16> { - // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16> + // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16> // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x8xf16> to vector<1x8xf16> %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x8xf16> to vector<1x1x8xf16> - // CHECK: %[[RET:.+]] = vector.shape_cast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16> + // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x8xf16> to vector<1x1x8xf16> // CHECK: return %[[RET]] return %0: vector<1x1x8xf16> } // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> { - // CHECK: %[[SRC:.+]] = vector.shape_cast %{{.*}} : vector<1x8xf16> to vector<8xf16> - // CHECK: %[[DST:.+]] = vector.shape_cast %{{.*}} : vector<1x8x8xf16> to vector<8x8xf16> + // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<1x8xf16> + // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<1x8x8xf16> // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<8xf16> into vector<8x8xf16> %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x8xf16> into vector<1x8x8xf16> - // CHECK: %[[RET:.+]] = vector.shape_cast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16> + // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x8xf16> to vector<1x8x8xf16> // CHECK: return %[[RET]] return %0: vector<1x8x8xf16> } @@ -443,9 +443,10 @@ // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element // CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16> func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> { - // CHECK: %[[CAST:.+]] = vector.shape_cast %[[ARG0]] : vector<1x1xf16> to vector<1x1x1xf16> + // CHECK: %[[EXT:.+]] = vector.extract %{{.*}}[0] : vector<1x1xf16> + // CHECK: %[[B:.+]] = vector.broadcast %[[EXT]] : vector<1xf16> to vector<1x1x1xf16> %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x1xf16> into vector<1x1x1xf16> - // CHECK: return %[[CAST]] + // CHECK: return %[[B]] return %0: vector<1x1x1xf16> } @@ -456,7 +457,7 @@ // CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f16 %f0 = arith.constant 0. : f16 // CHECK: %[[READ:.+]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[F0]] {in_bounds = [true]} : memref<1x4x8x16xf16>, vector<4xf16> - // CHECK: %[[CAST:.+]] = vector.shape_cast %[[READ]] : vector<4xf16> to vector<1x4xf16> + // CHECK: %[[CAST:.+]] = vector.broadcast %[[READ]] : vector<4xf16> to vector<1x4xf16> %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x4x8x16xf16>, vector<1x4xf16> // CHECK: return %[[CAST]] return %0: vector<1x4xf16> @@ -466,7 +467,7 @@ func @cast_away_transfer_read_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>) -> vector<1x1xf16> { %c0 = arith.constant 0 : index %f0 = arith.constant 0. : f16 - // CHECK: vector.shape_cast %{{.+}} : vector<1xf16> to vector<1x1xf16> + // CHECK: vector.broadcast %{{.+}} : vector<1xf16> to vector<1x1xf16> %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %f0 {in_bounds = [true, true]} : memref<1x1x1x1xf16>, vector<1x1xf16> return %0: vector<1x1xf16> } @@ -475,7 +476,7 @@ func @cast_away_transfer_write_leading_one_dims(%arg0: memref<1x4x8x16xf16>, %arg1: vector<1x4xf16>) { // CHECK: %[[C0:.+]] = arith.constant 0 : index %c0 = arith.constant 0 : index - // CHECK: %[[CAST:.+]] = vector.shape_cast %{{.*}} : vector<1x4xf16> to vector<4xf16> + // CHECK: %[[CAST:.+]] = vector.extract %{{.*}}[0] : vector<1x4xf16> // CHECK: vector.transfer_write %[[CAST]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xf16>, memref<1x4x8x16xf16> vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x4xf16>, memref<1x4x8x16xf16> @@ -485,54 +486,35 @@ // CHECK-LABEL: func @cast_away_transfer_write_leading_one_dims_one_element func @cast_away_transfer_write_leading_one_dims_one_element(%arg0: memref<1x1x1x1xf16>, %arg1: vector<1x1xf16>) { %c0 = arith.constant 0 : index - // CHECK: vector.shape_cast %{{.+}} : vector<1x1xf16> to vector<1xf16> + // CHECK: vector.extract %{{.+}}[0] : vector<1x1xf16> vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {in_bounds = [true, true]} : vector<1x1xf16>, memref<1x1x1x1xf16> return } -// CHECK-LABEL: func @cast_away_broadcast_leading_one_dims -func @cast_away_broadcast_leading_one_dims( - %arg0: vector<8xf32>, %arg1: f32, %arg2: vector<1x4xf32>) -> - (vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32>) { - // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> - %0 = vector.broadcast %arg0 : vector<8xf32> to vector<1x1x8xf32> - // CHECK: vector.broadcast %{{.*}} : f32 to vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32> - %1 = vector.broadcast %arg1 : f32 to vector<1x1x4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> - // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<3x4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<3x4xf32> to vector<1x3x4xf32> - %2 = vector.broadcast %arg2 : vector<1x4xf32> to vector<1x3x4xf32> - // CHECK: splat %{{.*}} : vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x1x4xf32> - %3 = splat %arg1 : vector<1x1x4xf32> - return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32>, vector<1x1x4xf32> -} - // CHECK-LABEL: func @cast_away_elementwise_leading_one_dims func @cast_away_elementwise_leading_one_dims( %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>, %arg3: vector<1x4xf32>, %arg4: i1) -> (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) { - // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32> + // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32> + // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x8xf32> // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<8xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> + // CHECK: vector.broadcast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> %0 = arith.addf %arg0, %arg0 : vector<1x1x8xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> // CHECK: arith.cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<4xi1> to vector<1x4xi1> + // CHECK: vector.broadcast %{{.*}} : vector<4xi1> to vector<1x4xi1> %1 = arith.cmpf ogt, %arg2, %arg3 : vector<1x4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> // CHECK: select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> %2 = select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> + // CHECK: vector.extract %{{.*}}[0] : vector<1x4xf32> // CHECK: select %arg4, %12, %{{.*}} : vector<4xf32> - // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32> + // CHECK: vector.broadcast %{{.*}} : vector<4xf32> to vector<1x4xf32> %3 = select %arg4, %arg3, %arg2 : vector<1x4xf32> return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32> }