diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -181,6 +181,22 @@ std::function controlFn = nullptr, PatternBenefit benefit = 1); +/// Populate `patterns` with a pattern to break down 1-D vector.bitcast ops +/// based on the destination vector shape. Bitcasts from a lower bitwidth +/// element type to a higher bitwidth one are extracted from the lower bitwidth +/// based on the native destination vector shape and inserted based on the ratio +/// of the bitwidths. +/// +/// This acts as a last resort way to break down vector.bitcast ops to smaller +/// vector sizes. Because this pattern composes until it is bitcasting to a +/// single element of the higher bitwidth, the is an optional control function. +/// If `controlFn` is not nullptr, the pattern will only apply to ops where +/// `controlFn` returns true, otherwise applies to all bitcast ops. +void populateBreakDownVectorBitCastOpPatterns( + RewritePatternSet &patterns, + std::function controlFn = nullptr, + PatternBenefit benefit = 1); + /// Populate `patterns` with the following patterns. /// /// Patterns in populateVectorInsertExtractStridedSliceDecompositionPatterns(); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -800,6 +800,91 @@ } }; +// Breaks down vector.bitcast op +// +// This transforms IR like: +// %1 = vector.bitcast %0: vector<8xf16> to vector<4xf32> +// Into: +// %cst = vector.splat %c0_f32 : vector<4xf32> +// %1 = vector.extract_strided_slice %0 { +// offsets = [0], sizes = [4], strides = [1] +// } : vector<8xf16> to vector<4xf16> +// %2 = vector.bitcast %1 : vector<4xf16> to vector<2xf32> +// %4 = vector.insert_strided_slice %2, %cst { +// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> +// %5 = vector.extract_strided_slice %0 { +// offsets = [4], sizes = [4], strides = [1] +// } : vector<8xf16> to vector<4xf16> +// %6 = vector.bitcast %5 : vector<4xf16> to vector<2xf32> +// %7 = vector.insert_strided_slice %6, %cst { +// offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> +struct BreakDownVectorBitCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + +public: + BreakDownVectorBitCast(MLIRContext *context, + std::function controlFn, + PatternBenefit benefit) + : OpRewritePattern(context, benefit), controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, + PatternRewriter &rewriter) const override { + + if (controlFn && !controlFn(bitcastOp)) + return failure(); + + VectorType castSrcType = bitcastOp.getSourceVectorType(); + VectorType castDstType = bitcastOp.getResultVectorType(); + assert(castSrcType.getRank() == castDstType.getRank()); + + // Only support rank 1 case for now. + if (castSrcType.getRank() != 1) + return failure(); + + int64_t castSrcLastDim = castSrcType.getShape().back(); + int64_t castDstLastDim = castDstType.getShape().back(); + // Require casting to less elements for now; other cases to be implemented. + if (castSrcLastDim < castDstLastDim) + return failure(); + + assert(castSrcLastDim % castDstLastDim == 0); + int64_t shrinkRatio = castSrcLastDim / castDstLastDim; + // Nothing to do if it is already bitcasting to a single element. + if (castSrcLastDim == shrinkRatio) + return failure(); + + Location loc = bitcastOp.getLoc(); + Type elemType = castDstType.getElementType(); + assert(elemType.isSignlessIntOrIndexOrFloat()); + + Value zero = rewriter.create( + loc, elemType, rewriter.getZeroAttr(elemType)); + Value res = rewriter.create(loc, castDstType, zero); + + SmallVector sliceShape{castDstLastDim}; + SmallVector strides{1}; + VectorType newCastDstType = + VectorType::get(SmallVector{castDstLastDim / shrinkRatio}, + castDstType.getElementType()); + + for (int i = 0, e = shrinkRatio; i < e; ++i) { + Value extracted = rewriter.create( + loc, bitcastOp.getSource(), ArrayRef{i * castDstLastDim}, + sliceShape, strides); + Value bitcast = + rewriter.create(loc, newCastDstType, extracted); + res = rewriter.create( + loc, bitcast, res, + ArrayRef{i * castDstLastDim / shrinkRatio}, strides); + } + rewriter.replaceOp(bitcastOp, res); + return success(); + } + +private: + std::function controlFn; +}; + // Helper that returns a vector comparison that constructs a mask: // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] // @@ -1151,6 +1236,13 @@ benefit); } +void mlir::vector::populateBreakDownVectorBitCastOpPatterns( + RewritePatternSet &patterns, + std::function controlFn, PatternBenefit benefit) { + patterns.add(patterns.getContext(), + std::move(controlFn), benefit); +} + void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT( RewritePatternSet &patterns, std::function constraint, diff --git a/mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir b/mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-break-down-bitcast.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-opt -split-input-file -test-vector-break-down-bitcast %s | FileCheck %s + +// CHECK-LABEL: func.func @bitcast_f16_to_f32 +// CHECK-SAME: (%[[INPUT:.+]]: vector<8xf16>) +func.func @bitcast_f16_to_f32(%input: vector<8xf16>) -> vector<4xf32> { + %0 = vector.bitcast %input : vector<8xf16> to vector<4xf32> + return %0: vector<4xf32> +} + +// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> +// CHECK: %[[EXTRACT0:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: %[[CAST0:.+]] = vector.bitcast %[[EXTRACT0]] : vector<4xf16> to vector<2xf32> +// CHECK: %[[INSERT0:.+]] = vector.insert_strided_slice %[[CAST0]], %[[INIT]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xf16> to vector<4xf16> +// CHECK: %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<4xf16> to vector<2xf32> +// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST1]], %[[INSERT0]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK: return %[[INSERT1]] + +// ----- + +// CHECK-LABEL: func.func @bitcast_i8_to_i32 +// CHECK-SAME: (%[[INPUT:.+]]: vector<16xi8>) +func.func @bitcast_i8_to_i32(%input: vector<16xi8>) -> vector<4xi32> { + %0 = vector.bitcast %input : vector<16xi8> to vector<4xi32> + return %0: vector<4xi32> +} + +// CHECK: %[[INIT:.+]] = arith.constant dense<0> : vector<4xi32> +// CHECK: %[[EXTRACT0:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [0], sizes = [4], strides = [1]} : vector<16xi8> to vector<4xi8> +// CHECK: %[[CAST0:.+]] = vector.bitcast %[[EXTRACT0]] : vector<4xi8> to vector<1xi32> +// CHECK: %[[INSERT0:.+]] = vector.insert_strided_slice %[[CAST0]], %[[INIT]] {offsets = [0], strides = [1]} : vector<1xi32> into vector<4xi32> +// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [4], sizes = [4], strides = [1]} : vector<16xi8> to vector<4xi8> +// CHECK: %[[CAST1:.+]] = vector.bitcast %[[EXTRACT1]] : vector<4xi8> to vector<1xi32> +// CHECK: %[[INSERT1:.+]] = vector.insert_strided_slice %[[CAST1]], %[[INSERT0]] {offsets = [1], strides = [1]} : vector<1xi32> into vector<4xi32> +// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [8], sizes = [4], strides = [1]} : vector<16xi8> to vector<4xi8> +// CHECK: %[[CAST2:.+]] = vector.bitcast %[[EXTRACT2]] : vector<4xi8> to vector<1xi32> +// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[CAST2]], %[[INSERT1]] {offsets = [2], strides = [1]} : vector<1xi32> into vector<4xi32> +// CHECK: %[[EXTRACT3:.+]] = vector.extract_strided_slice %[[INPUT]] {offsets = [12], sizes = [4], strides = [1]} : vector<16xi8> to vector<4xi8> +// CHECK: %[[CAST3:.+]] = vector.bitcast %[[EXTRACT3]] : vector<4xi8> to vector<1xi32> +// CHECK: %[[INSERT3:.+]] = vector.insert_strided_slice %[[CAST3]], %[[INSERT2]] {offsets = [3], strides = [1]} : vector<1xi32> into vector<4xi32> +// CHECK: return %[[INSERT3]] diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -604,6 +604,26 @@ } }; +struct TestVectorBreakDownBitCast + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBreakDownBitCast) + + StringRef getArgument() const final { + return "test-vector-break-down-bitcast"; + } + StringRef getDescription() const final { + return "Test pattern that breaks down vector.bitcast ops "; + } + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateBreakDownVectorBitCastOpPatterns(patterns, [](BitCastOp op) { + return op.getSourceVectorType().getShape().back() > 4; + }); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestCreateVectorBroadcast : public PassWrapper> { @@ -688,6 +708,8 @@ PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration();