diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -243,6 +243,36 @@ fun(resultIdx, indicesIdx); }] >, + InterfaceMethod< + /*desc=*/[{ + Return an upper-bound shape accessed by the transfer op within the + tensor/memref operand. + For example: + ``` + vector.transfer %w0[%i, %j] { + permutation_map = affine_map<(d0, d1) -> (d1, d0, 0)>} : + tensor, vector<4x2x6xf32> + ``` + returns a shape [2, 4]. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getTransferChunkAccessed", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + SmallVector dimSizes($_op.getPermutationMap().getNumDims(), 0); + for (auto vecDims : llvm::zip($_op.getPermutationMap().getResults(), + $_op.getVectorType().getShape())) { + AffineExpr dim = std::get<0>(vecDims); + int64_t size = std::get<1>(vecDims); + // Skip broadcast. + if (dim.isa()) + continue; + dimSizes[dim.cast().getPosition()] = size; + } + return dimSizes; + }] + >, ]; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -3284,11 +3284,91 @@ return success(); } }; + +/// Store to load forwarding for transfer operations with permuation maps. +/// Even if the permutation maps are different we can still propagate the store +/// into the load if the size of the dimensions read and written match. Then we +/// can replace the transfer_read + transfer_write by vector.broadcast and +/// vector.transpose. +/// Example: +/// ``` +/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0] +/// {in_bounds = [true, true], +/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} : +/// vector<4x1xf32>, tensor<4x4x4xf32> +/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0 +/// {in_bounds = [true, true, true, true], +/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} : +/// tensor<4x4x4xf32>, vector<1x100x4x5xf32> +/// ``` +/// To: +/// ``` +/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32> +/// %r = vector.transpose %0, [3, 0, 2, 1] : +/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32> +/// ``` +struct TransferReadAfterWriteToBroadcast + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransferReadOp readOp, + PatternRewriter &rewriter) const override { + if (readOp.hasOutOfBoundsDim() || + !readOp.getShapedType().isa()) + return failure(); + auto defWrite = readOp.getSource().getDefiningOp(); + if (!defWrite) + return failure(); + + SmallVector readDims = readOp.getTransferChunkAccessed(); + Value vec; + if (readOp.getIndices() == defWrite.getIndices() && + readOp.getMask() == defWrite.getMask()) { + SmallVector writeDims = defWrite.getTransferChunkAccessed(); + // TODO: If the writeDim is a superset of the read dims we could do an + // extract_strided_slice. + if (writeDims == readDims) + vec = defWrite.getVector(); + } + // TODO: loop through the chain of transfer_write if we can prove that they + // don't overlap with the transfer_read. This requires improving + // `isDisjointTransferIndices` helper. + if (!vec) + return failure(); + SmallVector permutation; + AffineMap readMap = compressUnusedDims(readOp.getPermutationMap()); + AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap()); + AffineMap map = readMap.compose(writeMap); + if (map.getNumResults() == 0) + return failure(); + // Calculate the permuation to apply to go from the vector stored to the + // vector read. + if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) + return failure(); + + Location loc = readOp.getLoc(); + // Calculate the broadcast shape by applying the reverse permuation to the + // final shape we want. + ArrayRef destShape = readOp.getVectorType().getShape(); + SmallVector broadcastShape(destShape.size()); + for (const auto &pos : llvm::enumerate(permutation)) + broadcastShape[pos.value()] = destShape[pos.index()]; + VectorType broadcastedType = VectorType::get( + broadcastShape, defWrite.getVectorType().getElementType()); + vec = rewriter.create(loc, broadcastedType, vec); + SmallVector transposePerm(permutation.begin(), permutation.end()); + rewriter.replaceOpWithNewOp(readOp, vec, + transposePerm); + return success(); + } +}; } // namespace void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1058,6 +1058,45 @@ // ----- +// CHECK-LABEL: func @store_to_load_tensor_broadcast +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4xf32>, %[[V0:.*]]: vector<4x2xf32>) +// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x2xf32> to vector<6x4x2xf32> +// CHECK: %[[T:.*]] = vector.transpose %[[B]], [1, 2, 0] : vector<6x4x2xf32> to vector<4x2x6xf32> +// CHECK: return %[[T]] : vector<4x2x6xf32> +func.func @store_to_load_tensor_broadcast(%arg0 : tensor<4x4xf32>, + %v0 : vector<4x2xf32>) -> vector<4x2x6xf32> { + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg0[%c0, %c0] {in_bounds = [true, true]} : + vector<4x2xf32>, tensor<4x4xf32> + %0 = vector.transfer_read %w0[%c0, %c0], %cf0 {in_bounds = [true, true, true], + permutation_map = affine_map<(d0, d1) -> (d0, d1, 0)>} : + tensor<4x4xf32>, vector<4x2x6xf32> + return %0 : vector<4x2x6xf32> +} + +// ----- + +// CHECK-LABEL: func @store_to_load_tensor_perm_broadcast +// CHECK-SAME: (%[[ARG:.*]]: tensor<4x4x4xf32>, %[[V0:.*]]: vector<4x1xf32>) +// CHECK: %[[B:.*]] = vector.broadcast %[[V0]] : vector<4x1xf32> to vector<100x5x4x1xf32> +// CHECK: %[[T:.*]] = vector.transpose %[[B]], [3, 0, 2, 1] : vector<100x5x4x1xf32> to vector<1x100x4x5xf32> +// CHECK: return %[[T]] : vector<1x100x4x5xf32> +func.func @store_to_load_tensor_perm_broadcast(%arg0 : tensor<4x4x4xf32>, + %v0 : vector<4x1xf32>) -> vector<1x100x4x5xf32> { + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 + %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0] {in_bounds = [true, true], + permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} : + vector<4x1xf32>, tensor<4x4x4xf32> + %0 = vector.transfer_read %w0[%c0, %c0, %c0], %cf0 {in_bounds = [true, true, true, true], + permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} : + tensor<4x4x4xf32>, vector<1x100x4x5xf32> + return %0 : vector<1x100x4x5xf32> +} + +// ----- + // CHECK-LABEL: func @dead_store_tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index