diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -36,9 +36,9 @@ ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - auto sourceType = cast(adaptor.getBase().getType()); + auto convertedType = cast(adaptor.getBase().getType()); Type oldElementType = op.getType().getElementType(); - Type newElementType = sourceType.getElementType(); + Type newElementType = convertedType.getElementType(); int srcBits = oldElementType.getIntOrFloatBitWidth(); int dstBits = newElementType.getIntOrFloatBitWidth(); @@ -81,16 +81,73 @@ stridedMetadata.getConstifiedMixedStrides(), getAsOpFoldResult(adaptor.getIndices())); - auto srcElementType = sourceType.getElementType(); - auto numElements = - static_cast(std::ceil(static_cast(origElements) / scale)); + auto numElements = (origElements + scale - 1) / scale; auto newLoad = rewriter.create( - loc, VectorType::get(numElements, srcElementType), adaptor.getBase(), + loc, VectorType::get(numElements, newElementType), adaptor.getBase(), getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); - numElements *= scale; - auto castType = VectorType::get(numElements, oldElementType); - auto bitCast = rewriter.create(loc, castType, newLoad); + auto bitCast = + rewriter.create(loc, op.getType(), newLoad); + + rewriter.replaceOp(op, bitCast->getResult(0)); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// ConvertVectorTransferRead +//===----------------------------------------------------------------------===// + +struct ConvertVectorTransferRead final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto loc = op.getLoc(); + auto convertedType = cast(adaptor.getSource().getType()); + Type oldElementType = op.getType().getElementType(); + Type newElementType = convertedType.getElementType(); + int srcBits = oldElementType.getIntOrFloatBitWidth(); + int dstBits = newElementType.getIntOrFloatBitWidth(); + + if (dstBits % srcBits != 0) { + return rewriter.notifyMatchFailure( + op, "only dstBits % srcBits == 0 supported"); + } + int scale = dstBits / srcBits; + + auto origElements = op.getVectorType().getNumElements(); + if (origElements % scale != 0) + return failure(); + + auto newPadding = rewriter.create(loc, newElementType, + adaptor.getPadding()); + + auto stridedMetadata = + rewriter.create(loc, op.getSource()); + + OpFoldResult linearizedIndices; + std::tie(std::ignore, linearizedIndices) = + memref::getLinearizedMemRefOffsetAndSize( + rewriter, loc, srcBits, dstBits, + stridedMetadata.getConstifiedMixedOffset(), + stridedMetadata.getConstifiedMixedSizes(), + stridedMetadata.getConstifiedMixedStrides(), + getAsOpFoldResult(adaptor.getIndices())); + + auto numElements = (origElements + scale - 1) / scale; + auto newReadType = VectorType::get(numElements, newElementType); + + auto newRead = rewriter.create( + loc, newReadType, adaptor.getSource(), + getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices), + newPadding); + + auto bitCast = + rewriter.create(loc, op.getType(), newRead); rewriter.replaceOp(op, bitCast->getResult(0)); return success(); @@ -107,5 +164,6 @@ RewritePatternSet &patterns) { // Populate `vector.*` conversion patterns. - patterns.add(typeConverter, patterns.getContext()); + patterns.add( + typeConverter, patterns.getContext()); } diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir --- a/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir +++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir @@ -79,3 +79,32 @@ // CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]] // CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref, vector<1xi32> // CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4> + +// ----- + +func.func @vector_transfer_read_i4(%arg1: index, %arg2: index) -> vector<8xi4> { + %c0 = arith.constant 0 : i4 + %0 = memref.alloc() : memref<3x8xi4> + %1 = vector.transfer_read %0[%arg1, %arg2], %c0 {in_bounds = [true]} : + memref<3x8xi4>, vector<8xi4> + return %1 : vector<8xi4> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)> +// CHECK: func @vector_transfer_read_i4 +// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK: %[[CONST:.+]] = arith.constant 0 : i4 +// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<12xi8> +// CHECK: %[[PAD:.+]] = arith.extui %[[CONST]] : i4 to i8 +// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]] +// CHECK: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<12xi8>, vector<4xi8> +// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4> + +// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)> +// CHECK32: func @vector_transfer_read_i4 +// CHECK32-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: index, %[[ARG1:[a-zA-Z0-9]+]]: index) +// CHECK32: %[[CONST:.+]] = arith.constant 0 : i4 +// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<3xi32> +// CHECK32: %[[PAD:.+]] = arith.extui %[[CONST]] : i4 to i32 +// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]], %[[ARG1]]] +// CHECK32: %[[VEC:.+]] = vector.transfer_read %[[ALLOC]][%[[INDEX]]], %[[PAD]] : memref<3xi32>, vector<1xi32> +// CHECK32: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<1xi32> to vector<8xi4>