diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -44,6 +44,14 @@ void populateCastAwayVectorLeadingOneDimPatterns( OwningRewritePatternList &patterns, MLIRContext *context); +/// Collect a set of patterns that bubble up/down bitcast ops. +/// +/// These patterns move vector.bitcast ops to be before insert ops or after +/// extract ops where suitable. With them, bitcast will happen on smaller +/// vectors and there are more chances to share extract/insert ops. +void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns, + MLIRContext *context); + /// Collect a set of vector slices transformation patterns: /// ExtractSlicesOpLowering, InsertSlicesOpLowering /// Useful for clients that want to express all vector "slices" diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2787,6 +2787,244 @@ } }; +// Returns the values in `arrayAttr` as an integer vector. +static SmallVector getIntValueVector(ArrayAttr arrayAttr) { + return llvm::to_vector<4>( + llvm::map_range(arrayAttr.getAsRange(), + [](IntegerAttr attr) { return attr.getInt(); })); +}; + +// Shuffles vector.bitcast op after vector.extract op. +// +// This transforms IR like: +// %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> +// %1 = vector.extract %0[3] : vector<8xf16> +// Into: +// %0 = vector.extract %src[1] : vector<4xf32> +// %1 = vector.bitcast %0: vector<1xf32> to vector<2xf16> +// %2 = vector.extract %1[1] : vector<2xf16> +struct BubbleDownVectorBitCastForExtract + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // Only support extracting scalars for now. + if (extractOp.getVectorType().getRank() != 1) + return failure(); + + auto castOp = extractOp.vector().getDefiningOp(); + if (!castOp) + return failure(); + + VectorType castSrcType = castOp.getSourceVectorType(); + VectorType castDstType = castOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + + // Fail to match if we only have one element in the cast op source. + // This is to avoid infinite loop given that this pattern can generate + // such cases. + if (castSrcType.getNumElements() == 1) + return failure(); + + // Only support casting to a larger number of elements or now. + // E.g., vector<4xf32> -> vector<8xf16>. + if (castSrcType.getNumElements() > castDstType.getNumElements()) + return failure(); + + unsigned expandRatio = + castDstType.getNumElements() / castSrcType.getNumElements(); + + auto getFirstIntValue = [](ArrayAttr attr) -> uint64_t { + return attr.getAsValueRange().begin()->getZExtValue(); + }; + + uint64_t index = getFirstIntValue(extractOp.position()); + + // Get the single scalar (as a vector) in the source value that packs the + // desired scalar. E.g. extract vector<1xf32> from vector<4xf32> + VectorType oneScalarType = + VectorType::get({1}, castSrcType.getElementType()); + Value packedValue = rewriter.create( + extractOp.getLoc(), oneScalarType, castOp.source(), + rewriter.getI64ArrayAttr(index / expandRatio)); + + // Cast it to a vector with the desired scalar's type. + // E.g. f32 -> vector<2xf16> + VectorType packedType = + VectorType::get({expandRatio}, castDstType.getElementType()); + Value castedValue = rewriter.create( + extractOp.getLoc(), packedType, packedValue); + + // Finally extract the desired scalar. + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType(), castedValue, + rewriter.getI64ArrayAttr(index % expandRatio)); + + return success(); + } +}; + +// Shuffles vector.bitcast op after vector.extract_strided_slice op. +// +// This transforms IR like: +// %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> +// %0 = vector.extract_strided_slice %cast { +// offsets = [4], sizes = [4], strides = [1] +// } : vector<8xf16> to vector<4xf16> +// Into: +// %0 = vector.extract_strided_slice %src { +// offsets = [2], sizes = [2], strides = [1] +// } : vector<4xf32> to vector<2xf32> +// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> +struct BubbleDownBitCastForStridedSliceExtract + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, + PatternRewriter &rewriter) const override { + auto castOp = extractOp.vector().getDefiningOp(); + if (!castOp) + return failure(); + + VectorType castSrcType = castOp.getSourceVectorType(); + VectorType castDstType = castOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + + int64_t castSrcLastDim = castSrcType.getShape().back(); + int64_t castDstLastDim = castDstType.getShape().back(); + // Require casting to more elements for now; other cases to be implemented. + if (castSrcLastDim > castDstLastDim) + return failure(); + + // Only accept all one strides for now. + if (llvm::any_of(extractOp.strides().getAsValueRange(), + [](const APInt &val) { return !val.isOneValue(); })) + return failure(); + + unsigned rank = extractOp.getVectorType().getRank(); + assert(castDstLastDim % castSrcLastDim == 0); + int64_t expandRatio = castDstLastDim / castSrcLastDim; + + // If we have a less number of offsets than the rank, then implicitly we + // are selecting the full range for the last bitcasted dimension; other + // dimensions aren't affected. Otherwise, we need to scale down the last + // dimension's offset given we are extracting from less elements now. + ArrayAttr newOffsets = extractOp.offsets(); + if (newOffsets.size() == rank) { + SmallVector offsets = getIntValueVector(newOffsets); + if (offsets.back() % expandRatio != 0) + return failure(); + offsets.back() = offsets.back() / expandRatio; + newOffsets = rewriter.getI64ArrayAttr(offsets); + } + + // Similarly for sizes. + ArrayAttr newSizes = extractOp.sizes(); + if (newSizes.size() == rank) { + SmallVector sizes = getIntValueVector(newSizes); + if (sizes.back() % expandRatio != 0) + return failure(); + sizes.back() = sizes.back() / expandRatio; + newSizes = rewriter.getI64ArrayAttr(sizes); + } + + SmallVector dims = + llvm::to_vector<4>(extractOp.getType().cast().getShape()); + dims.back() = dims.back() / expandRatio; + VectorType newExtractType = + VectorType::get(dims, castSrcType.getElementType()); + + auto newExtractOp = rewriter.create( + extractOp.getLoc(), newExtractType, castOp.source(), newOffsets, + newSizes, extractOp.strides()); + + rewriter.replaceOpWithNewOp( + extractOp, extractOp.getType(), newExtractOp); + + return success(); + } +}; + +// Shuffles vector.bitcast op before vector.insert_strided_slice op. +// +// This transforms IR like: +// %0 = vector.insert_strided_slice %src, %dst { +// offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> +// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> +// Into: +// %0 = vector.bitcast %src : vector<4xf16> to vector<2xf32> +// %1 = vector.bitcast %dst : vector<8xf16> to vector<4xf32> +// %2 = vector.insert_strided_slice %src, %dst { +// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> +struct BubbleUpBitCastForStridedSliceInsert + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, + PatternRewriter &rewriter) const override { + VectorType castSrcType = bitcastOp.getSourceVectorType(); + VectorType castDstType = bitcastOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + + int64_t castSrcLastDim = castSrcType.getShape().back(); + int64_t castDstLastDim = castDstType.getShape().back(); + // Require casting to less elements for now; other cases to be implemented. + if (castSrcLastDim < castDstLastDim) + return failure(); + + assert(castSrcLastDim % castDstLastDim == 0); + int64_t shrinkRatio = castSrcLastDim / castDstLastDim; + + auto insertOp = + bitcastOp.source().getDefiningOp(); + if (!insertOp) + return failure(); + + // Only accept all one strides for now. + if (llvm::any_of(insertOp.strides().getAsValueRange(), + [](const APInt &val) { return !val.isOneValue(); })) + return failure(); + + unsigned rank = insertOp.getSourceVectorType().getRank(); + // Require insert op to have the same rank for the source and destination + // vector; other cases to be implemented. + if (rank != insertOp.getDestVectorType().getRank()) + return failure(); + + ArrayAttr newOffsets = insertOp.offsets(); + assert(newOffsets.size() == rank); + SmallVector offsets = getIntValueVector(newOffsets); + if (offsets.back() % shrinkRatio != 0) + return failure(); + offsets.back() = offsets.back() / shrinkRatio; + newOffsets = rewriter.getI64ArrayAttr(offsets); + + SmallVector srcDims = + llvm::to_vector<4>(insertOp.getSourceVectorType().getShape()); + srcDims.back() = srcDims.back() / shrinkRatio; + VectorType newCastSrcType = + VectorType::get(srcDims, castDstType.getElementType()); + + auto newCastSrcOp = rewriter.create( + bitcastOp.getLoc(), newCastSrcType, insertOp.source()); + + SmallVector dstDims = + llvm::to_vector<4>(insertOp.getDestVectorType().getShape()); + dstDims.back() = dstDims.back() / shrinkRatio; + VectorType newCastDstType = + VectorType::get(dstDims, castDstType.getElementType()); + + auto newCastDstOp = rewriter.create( + bitcastOp.getLoc(), newCastDstType, insertOp.dest()); + + rewriter.replaceOpWithNewOp( + bitcastOp, bitcastOp.getType(), newCastSrcOp, newCastDstOp, newOffsets, + insertOp.strides()); + + return success(); + } +}; + // TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp). // TODO: Add this as DRR pattern. void mlir::vector::populateVectorToVectorTransformationPatterns( @@ -2811,6 +3049,13 @@ context); } +void mlir::vector::populateBubbleVectorBitCastOpPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); +} + void mlir::vector::populateVectorSlicesLoweringPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -671,3 +671,92 @@ vector.transfer_write %arg1, %arg0[%c0, %c0, %c0, %c0] {masked = [false, false]} : vector<1x1xf16>, memref<1x1x1x1xf16> return } + +// CHECK-LABEL: func @bubble_down_bitcast_in_extract +// CHECK-SAME: %[[SRC:.+]]: vector<4xf32> +func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) { + %0 = vector.bitcast %src : vector<4xf32> to vector<8xf16> + // CHECK: %[[EXTRACT1:.+]] = vector.extract %[[SRC]][1] : vector<4xf32> + // CHECK: %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<1xf32> to vector<2xf16> + // CHECK: %[[EXTRACT2:.+]] = vector.extract %[[CAST1]][1] : vector<2xf16> + %1 = vector.extract %0[3] : vector<8xf16> + // CHECK: %[[EXTRACT3:.+]] = vector.extract %[[SRC]][2] : vector<4xf32> + // CHECK: %[[CAST2:.+]] = vector.bitcast %[[EXTRACT3]] : vector<1xf32> to vector<2xf16> + // CHECK: %[[EXTRACT4:.+]] = vector.extract %[[CAST2]][0] : vector<2xf16> + %2 = vector.extract %0[4] : vector<8xf16> + // CHECK: return %[[EXTRACT2]], %[[EXTRACT4]] + return %1, %2: f16, f16 +} + +// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract +// CHECK-SAME: %[[SRC:.+]]: vector<4xf32> +func @bubble_down_bitcast_in_strided_slice_extract(%arg0: vector<4xf32>) -> vector<4xf16> { + // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2xf32> to vector<4xf16> + %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> + %0 = vector.extract_strided_slice %cast {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> + // CHECK: return %[[CAST]] + return %0: vector<4xf16> +} + +// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim +// CHECK-SAME: %[[SRC:.+]]: vector<4x2xf32> +func @bubble_down_bitcast_in_strided_slice_extract_full_last_dim(%arg0: vector<4x2xf32>) -> vector<2x4xf16> { + // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [1], sizes = [2], strides = [1]} : vector<4x2xf32> to vector<2x2xf32> + // CHECK: %[[CAST:.+]] = vector.bitcast %[[EXTRACT]] : vector<2x2xf32> to vector<2x4xf16> + %cast = vector.bitcast %arg0: vector<4x2xf32> to vector<4x4xf16> + %0 = vector.extract_strided_slice %cast {offsets = [1], sizes = [2], strides = [1]} : vector<4x4xf16> to vector<2x4xf16> + // CHECK: return %[[CAST]] + return %0: vector<2x4xf16> +} + +// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_offset +func @bubble_down_bitcast_in_strided_slice_extract_odd_offset(%arg0: vector<4xf32>) -> vector<4xf16> { + // CHECK: vector.bitcast + // CHECK-NEXT: vector.extract_strided_slice + %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> + %0 = vector.extract_strided_slice %cast {offsets = [3], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> + return %0: vector<4xf16> +} + +// CHECK-LABEL: func @bubble_down_bitcast_in_strided_slice_extract_odd_size +func @bubble_down_bitcast_in_strided_slice_extract_odd_size(%arg0: vector<4xf32>) -> vector<3xf16> { + // CHECK: vector.bitcast + // CHECK-NEXT: vector.extract_strided_slice + %cast = vector.bitcast %arg0: vector<4xf32> to vector<8xf16> + %0 = vector.extract_strided_slice %cast {offsets = [0], sizes = [3], strides = [1]} : vector<8xf16> to vector<3xf16> + return %0: vector<3xf16> +} + +// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert +// CHECK-SAME: (%[[DST:.+]]: vector<8xf16>, %[[SRC1:.+]]: vector<4xf16>, %[[SRC2:.+]]: vector<4xf16>) +func @bubble_up_bitcast_in_strided_slice_insert(%dst: vector<8xf16>, %src1: vector<4xf16>, %src2: vector<4xf16>) -> vector<4xf32> { + // CHECK-DAG: %[[CAST_SRC1:.+]] = vector.bitcast %[[SRC1]] : vector<4xf16> to vector<2xf32> + // CHECK-DAG: %[[CAST_SRC2:.+]] = vector.bitcast %[[SRC2]] : vector<4xf16> to vector<2xf32> + // CHECK-DAG: %[[CAST_DST:.+]] = vector.bitcast %[[DST]] : vector<8xf16> to vector<4xf32> + // CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST_SRC1]], %[[CAST_DST]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> + // CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[CAST_SRC2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> + %0 = vector.insert_strided_slice %src1, %dst {offsets = [0], strides = [1]} : vector<4xf16> into vector<8xf16> + %1 = vector.insert_strided_slice %src2, %0 {offsets = [4], strides = [1]} : vector<4xf16> into vector<8xf16> + %cast = vector.bitcast %1: vector<8xf16> to vector<4xf32> + // CHECK: return %[[INSERT2]] + return %cast: vector<4xf32> +} + +// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_odd_offset +func @bubble_up_bitcast_in_strided_slice_insert_odd_offset(%dst: vector<8xf16>, %src: vector<4xf16>) -> vector<4xf32> { + // CHECK: vector.insert_strided_slice + // CHECK-NEXT: vector.bitcast + %0 = vector.insert_strided_slice %src, %dst {offsets = [3], strides = [1]} : vector<4xf16> into vector<8xf16> + %cast = vector.bitcast %0: vector<8xf16> to vector<4xf32> + return %cast: vector<4xf32> +} + +// CHECK-LABEL: func @bubble_up_bitcast_in_strided_slice_insert_different_rank +func @bubble_up_bitcast_in_strided_slice_insert_different_rank(%dst: vector<16x4x8xf16>, %src: vector<2x4xf16>) -> vector<16x4x4xf32> { + // CHECK: vector.insert_strided_slice + // CHECK-NEXT: vector.bitcast + %0 = vector.insert_strided_slice %src, %dst {offsets = [0, 0, 2], strides = [1, 1]} : vector<2x4xf16> into vector<16x4x8xf16> + %cast = vector.bitcast %0: vector<16x4x8xf16> to vector<16x4x4xf32> + return %cast: vector<16x4x4xf32> +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -46,6 +46,7 @@ populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx); + populateBubbleVectorBitCastOpPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); }