diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2711,11 +2711,82 @@ return success(); } }; + +// Splits vector.transfer_read into vector.transfer_read + vector.broadcast + +// vector.transpose if vector.transfer_read reads into a vector of higher rank +// then the source memref type. +// +// +// %vread = vector.transfer_read %t[%c0, %c0], %cst { +// in_bounds = [true, true, true, true], +// permutation_map = affine_map<(d0, d1) -> (0, d1, d0, 0)> +// } : tensor<4x5xf32>, vector<8x5x4x2xf32> +// +// %0 = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} +// : tensor<4x5xf32>, vector<4x5xf32> +// %1 = vector.broadcast %0 +// : vector<4x5xf32> to vector<8x2x4x5xf32> +// %2 = vector.transpose %1, [0, 3, 2, 1] +// : vector<8x2x4x5xf32> to vector<8x5x4x2xf32> +class SplitTransferRead final : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TransferReadOp readOp, + PatternRewriter &rewriter) const override { + auto srcType = readOp.getShapedType(); + auto vectorType = readOp.getVectorType(); + if (vectorType.getRank() <= srcType.getRank()) + return failure(); + + int64_t srcRank = srcType.getRank(); + int64_t vectorRank = vectorType.getRank(); + SmallVector newInbounds(srcRank, false); + SmallVector targetBroadcastedShape(vectorRank, 0); + SmallVector permutation(vectorRank, 0); + + size_t zeroExprsCount = 0; + ArrayAttr inbounds = readOp.in_boundsAttr(); + for (auto &en : llvm::enumerate(readOp.permutation_map().getResults())) { + size_t index = en.index(); + auto &expr = en.value(); + if (auto zero = expr.dyn_cast()) { + permutation[zeroExprsCount] = index; + targetBroadcastedShape[zeroExprsCount++] = vectorType.getDimSize(index); + continue; + } + unsigned int dimPos = expr.cast().getPosition(); + permutation[srcRank + dimPos] = index; + targetBroadcastedShape[srcRank + dimPos] = srcType.getDimSize(dimPos); + + if (!inbounds.empty()) + newInbounds[dimPos] = inbounds[index].cast().getValue(); + } + + Location loc = readOp.getLoc(); + Value read = rewriter.create( + loc, VectorType::get(srcType.getShape(), srcType.getElementType()), + readOp.source(), readOp.indices(), + AffineMap::getMultiDimIdentityMap(srcRank, rewriter.getContext()), + readOp.padding(), rewriter.getBoolArrayAttr(newInbounds)); + + Value result = rewriter.create( + loc, VectorType::get(targetBroadcastedShape, srcType.getElementType()), + read); + + // Insert TransposeOp if necessary. + if (!llvm::is_sorted(permutation)) + result = rewriter.create(loc, result, permutation); + + rewriter.replaceOp(readOp, result); + return success(); + } +}; + } // namespace void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1026,3 +1026,84 @@ %1 = tensor.insert_slice %0 into %t1[4, 3, %s] [1, 5, 6] [1, 1, 1] : tensor<5x6xf32> into tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: func @split_transfer_read_no_transpose +func @split_transfer_read_no_transpose(%t: tensor<4xf32>) -> vector<2x4xf32> { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + %vread = vector.transfer_read %t[%c0], %cst { + in_bounds = [true, true], + permutation_map = affine_map<(d0) -> (0, d0)> + } : tensor<4xf32>, vector<2x4xf32> + return %vread : vector<2x4xf32> +} +// CHECK-SAME: %[[IN:.*]]: tensor<4xf32>) -> vector<2x4xf32> +// CHECK-NEXT: %[[C0_F32:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[C0:.*]] = constant 0 : index + +// CHECK-NEXT: %[[READ:.*]] = vector.transfer_read %[[IN]][%[[C0]]], %[[C0_F32]] +// CHECK-SAME: {in_bounds = [true]} : tensor<4xf32>, vector<4xf32> + +// CHECK-NEXT: %[[BCAST:.*]] = vector.broadcast %[[READ]] +// CHECK-SAME: : vector<4xf32> to vector<2x4xf32> + +// CHECK-NEXT: return %[[BCAST]] : vector<2x4xf32> + +// ----- + +// CHECK-LABEL: func @split_transfer_read_with_transpose +func @split_transfer_read_with_transpose(%t: tensor<4xf32>) + -> vector<4x8xf32> { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + %vread = vector.transfer_read %t[%c0], %cst { + in_bounds = [true, true], + permutation_map = affine_map<(d0) -> (d0, 0)> + } : tensor<4xf32>, vector<4x8xf32> + return %vread : vector<4x8xf32> +} +// CHECK-SAME: %[[IN:.*]]: tensor<4xf32>) -> vector<4x8xf32> +// CHECK-NEXT: %[[C0_F32:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[C0:.*]] = constant 0 : index + +// CHECK-NEXT: %[[READ:.*]] = vector.transfer_read %[[IN]][%[[C0]]], %[[C0_F32]] +// CHECK-SAME: {in_bounds = [true]} : tensor<4xf32>, vector<4xf32> + +// CHECK-NEXT: %[[BCAST:.*]] = vector.broadcast %[[READ]] +// CHECK-SAME: : vector<4xf32> to vector<8x4xf32> + +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [1, 0] +// CHECK-SAME: : vector<8x4xf32> to vector<4x8xf32> + +// CHECK-NEXT: return %[[TRANSPOSE]] : vector<4x8xf32> + +// ----- + +// CHECK-LABEL: func @split_transfer_read_with_transpose_4D +func @split_transfer_read_with_transpose_4D(%t: tensor<4x5xf32>) + -> vector<8x5x4x2xf32> { + %c0 = constant 0 : index + %cst = constant 0.0 : f32 + %vread = vector.transfer_read %t[%c0, %c0], %cst { + in_bounds = [true, true, true, true], + permutation_map = affine_map<(d0, d1) -> (0, d1, d0, 0)> + } : tensor<4x5xf32>, vector<8x5x4x2xf32> + return %vread : vector<8x5x4x2xf32> +} +// CHECK-SAME: %[[IN:.*]]: tensor<4x5xf32>) -> vector<8x5x4x2xf32> +// CHECK-NEXT: %[[C0_F32:.*]] = constant 0.000000e+00 : f32 +// CHECK-NEXT: %[[C0:.*]] = constant 0 : index + +// CHECK-NEXT: %[[READ:.*]] = vector.transfer_read +// CHECK-SAME: %[[IN]][%[[C0]], %[[C0]]], %[[C0_F32]] +// CHECK-SAME: {in_bounds = [true, true]} : tensor<4x5xf32>, vector<4x5xf32> + +// CHECK-NEXT: %[[BCAST:.*]] = vector.broadcast %[[READ]] +// CHECK-SAME: : vector<4x5xf32> to vector<8x2x4x5xf32> + +// CHECK-NEXT: %[[TRANSPOSE:.*]] = vector.transpose %[[BCAST]], [0, 3, 2, 1] +// CHECK-SAME: : vector<8x2x4x5xf32> to vector<8x5x4x2xf32> + +// CHECK-NEXT: return %[[TRANSPOSE]] : vector<8x5x4x2xf32>