diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp @@ -33,6 +33,19 @@ return builder.getBoolArrayAttr(newInBoundsValues); } +/// Extend the rank of a vector Value by `addedRanks` by adding outer unit +/// dimensions. +static Value extendVectorRank(OpBuilder &builder, Location loc, Value vec, + int64_t addedRank) { + auto originalVecType = vec.getType().cast(); + SmallVector newShape(addedRank, 1); + newShape.append(originalVecType.getShape().begin(), + originalVecType.getShape().end()); + VectorType newVecType = + VectorType::get(newShape, originalVecType.getElementType()); + return builder.create(loc, newVecType, vec); +} + /// Lower transfer_read op with permutation into a transfer_read with a /// permutation map composed of leading zeros followed by a minor identiy + /// vector.transpose op. @@ -170,6 +183,77 @@ } }; +/// Convert a transfer.write op with a map which isn't the permutation of a +/// minor identity into a vector.broadcast + transfer_write with permutation of +/// minor identity map by adding unit dim on inner dimension. Ex: +/// ``` +/// vector.transfer_write %v +/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : +/// vector<8x16xf32> +/// ``` +/// into: +/// ``` +/// %v1 = vector.broadcast %v : vector<8x16xf32> to vector<1x8x16xf32> +/// vector.transfer_write %v1 +/// {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1, d2)>} : +/// vector<1x8x16xf32> +/// ``` +struct TransferWriteNonPermutationLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp op, + PatternRewriter &rewriter) const override { + if (op.getTransferRank() == 0) + return failure(); + SmallVector permutation; + AffineMap map = op.getPermutationMap(); + if (map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) + return failure(); + + // Missing outer dimensions are allowed, find the most outer existing + // dimension then deduce the missing inner dimensions. + SmallVector foundDim(map.getNumDims(), false); + for (AffineExpr exp : map.getResults()) { + foundDim[exp.cast().getPosition()] = true; + } + SmallVector exprs; + bool foundFirstDim = false; + SmallVector missingInnerDim; + for (size_t i = 0; i < foundDim.size(); i++) { + if (foundDim[i]) { + foundFirstDim = true; + continue; + } + if (!foundFirstDim) + continue; + // Once we found one outer dimension existing in the map keep track of all + // the missing dimensions after that. + missingInnerDim.push_back(i); + exprs.push_back(rewriter.getAffineDimExpr(i)); + } + // Add unit dims at the beginning of the shape. + Value newVec = extendVectorRank(rewriter, op.getLoc(), op.getVector(), + missingInnerDim.size()); + exprs.append(map.getResults().begin(), map.getResults().end()); + AffineMap newMap = + AffineMap::get(map.getNumDims(), 0, exprs, op.getContext()); + ArrayAttr newInBoundsAttr; + if (op.getInBounds()) { + // All the new dimensions added are inbound. + SmallVector newInBoundsValues(missingInnerDim.size(), true); + for (Attribute attr : op.getInBounds().value().getValue()) { + newInBoundsValues.push_back(attr.cast().getValue()); + } + newInBoundsAttr = rewriter.getBoolArrayAttr(newInBoundsValues); + } + rewriter.replaceOpWithNewOp( + op, newVec, op.getSource(), op.getIndices(), AffineMapAttr::get(newMap), + op.getMask(), newInBoundsAttr); + return success(); + } +}; + /// Lower transfer_read op with broadcast in the leading dimensions into /// transfer_read of lower rank + vector.broadcast. /// Ex: vector.transfer_read ... @@ -250,7 +334,8 @@ void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add( - patterns.getContext(), benefit); + patterns + .add( + patterns.getContext(), benefit); } diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir --- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir +++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir @@ -149,34 +149,32 @@ // CHECK-LABEL:func @materialize_write(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { func.func @materialize_write(%M: index, %N: index, %O: index, %P: index) { - // CHECK-DAG: %{{.*}} = arith.constant dense<1.000000e+00> : vector<5x4x3xf32> + // CHECK-DAG: %{{.*}} = arith.constant dense<1.000000e+00> : vector<3x4x1x5xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index // CHECK: %{{.*}} = memref.alloc(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : memref // CHECK-NEXT: affine.for %[[I0:.*]] = 0 to %{{.*}} step 3 { // CHECK-NEXT: affine.for %[[I1:.*]] = 0 to %{{.*}} step 4 { // CHECK-NEXT: affine.for %[[I2:.*]] = 0 to %{{.*}} { // CHECK-NEXT: affine.for %[[I3:.*]] = 0 to %{{.*}} step 5 { - // CHECK: %[[ALLOC:.*]] = memref.alloca() : memref> - // CHECK: memref.store %{{.*}}, %[[ALLOC]][] : memref> - // CHECK: %[[VECTOR_VIEW1:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<5xvector<4x3xf32>> - // CHECK: scf.for %[[I4:.*]] = %[[C0]] to %[[C5]] step %[[C1]] { + // CHECK: %[[ALLOC:.*]] = memref.alloca() : memref> + // CHECK: memref.store %{{.*}}, %[[ALLOC]][] : memref> + // CHECK: %[[VECTOR_VIEW1:.*]] = vector.type_cast %[[ALLOC]] : memref> to memref<3xvector<4x1x5xf32>> + // CHECK: scf.for %[[I4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] { // CHECK: scf.if - // CHECK: %[[S3:.*]] = affine.apply #[[$ADD]](%[[I3]], %[[I4]]) - // CHECK: %[[VECTOR_VIEW2:.*]] = vector.type_cast %[[VECTOR_VIEW1]] : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> + // CHECK: %[[S3:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I4]]) + // CHECK: %[[VECTOR_VIEW2:.*]] = vector.type_cast %[[VECTOR_VIEW1]] : memref<3xvector<4x1x5xf32>> to memref<3x4xvector<1x5xf32>> // CHECK: scf.for %[[I5:.*]] = %[[C0]] to %[[C4]] step %[[C1]] { // CHECK: scf.if // CHECK: %[[S1:.*]] = affine.apply #[[$ADD]](%[[I1]], %[[I5]]) - // CHECK: %[[VEC:.*]] = memref.load %[[VECTOR_VIEW2]][%[[I4]], %[[I5]]] : memref<5x4xvector<3xf32>> - // CHECK: scf.for %[[I6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] { - // CHECK: %[[S0:.*]] = affine.apply #[[$ADD]](%[[I0]], %[[I6]]) + // CHECK: %[[VECTOR_VIEW3:.*]] = vector.type_cast %[[VECTOR_VIEW2]] : memref<3x4xvector<1x5xf32>> to memref<3x4x1xvector<5xf32>> + // CHECK: scf.for %[[I6:.*]] = %[[C0]] to %[[C1]] step %[[C1]] { // CHECK: scf.if - // CHECK: %[[SCAL:.*]] = vector.extractelement %[[VEC]][%[[I6]] : index] : vector<3xf32> - // CHECK: memref.store %[[SCAL]], {{.*}}[%[[S0]], %[[S1]], %[[I2]], %[[S3]]] : memref - // CHECK: } + // CHECK: %[[S0:.*]] = affine.apply #[[$ADD]](%[[I2]], %[[I6]]) + // CHECK: %[[VEC:.*]] = memref.load %[[VECTOR_VIEW3]][%[[I4]], %[[I5]], %[[I6]]] : memref<3x4x1xvector<5xf32>> + // CHECK: vector.transfer_write %[[VEC]], %{{.*}}[%[[S3]], %[[S1]], %[[S0]], %[[I3]]] : vector<5xf32>, memref // CHECK: } // CHECK: } // CHECK: } diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -178,14 +178,12 @@ // CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { // CHECK-NEXT: %[[CF0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-NEXT: %[[RES:.*]] = vector.transfer_read %[[MEM]][%[[IDX]], %[[IDX]]], %[[CF0]] {in_bounds = [true], permutation_map = #{{.*}}} : memref<8x8xf32>, vector<4xf32> -// CHECK-NEXT: vector.transfer_write %[[RES]], %[[MEM]][%[[IDX]], %[[IDX]]] {in_bounds = [true], permutation_map = #{{.*}}} : vector<4xf32>, memref<8x8xf32> // CHECK-NEXT: return %[[RES]] : vector<4xf32> // CHECK-NEXT: } func.func @transfer_perm_map(%mem : memref<8x8xf32>, %i : index) -> vector<4xf32> { %cf0 = arith.constant 0.0 : f32 %res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : memref<8x8xf32>, vector<4xf32> - vector.transfer_write %res, %mem[%i, %i] {in_bounds = [true], permutation_map = affine_map<(d0, d1) -> (d0)>} : vector<4xf32>, memref<8x8xf32> return %res : vector<4xf32> } @@ -349,3 +347,30 @@ return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @transfer_write_broadcast_unit_dim +// CHECK-SAME: %[[ARG0:.*]]: memref +// CHECK-SAME: %[[ARG1:.*]]: tensor +// CHECK-SAME: %[[ARG2:.*]]: vector<14x8x16xf32> +// CHECK-SAME: %[[ARG3:.*]]: vector<8x16xf32> +// CHECK-SAME: %[[M:.*]]: i1 +func.func @transfer_write_broadcast_unit_dim( + %arg0 : memref, %arg1 : tensor, + %v1 : vector<14x8x16xf32>, %v2 : vector<8x16xf32>, %m: i1) -> tensor { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + %c0 = arith.constant 0 : index + + %0 = vector.transfer_write %v1, %arg1[%c0, %c0, %c0, %c0] {in_bounds = [false, false, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>} : vector<14x8x16xf32>, tensor + // CHECK: %[[NEW_VEC0:.*]] = vector.broadcast %{{.*}} : vector<14x8x16xf32> to vector<1x14x8x16xf32> + // CHECK: %[[NEW_VEC1:.*]] = vector.transpose %[[NEW_VEC0]], [1, 2, 0, 3] : vector<1x14x8x16xf32> to vector<14x8x1x16xf32> + // CHECK: %[[NEW_RES0:.*]] = vector.transfer_write %[[NEW_VEC1]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [false, false, true, true]} : vector<14x8x1x16xf32>, tensor + + vector.transfer_write %v2, %arg0[%c0, %c0, %c0, %c0] {permutation_map = affine_map<(d0, d1, d2, d3) -> (d1, d2)>} : vector<8x16xf32>, memref + // CHECK: %[[NEW_VEC2:.*]] = vector.broadcast %{{.*}} : vector<8x16xf32> to vector<1x8x16xf32> + // CHECK: %[[NEW_VEC3:.*]] = vector.transpose %[[NEW_VEC2]], [1, 2, 0] : vector<1x8x16xf32> to vector<8x16x1xf32> + // CHECK: vector.transfer_write %[[NEW_VEC3]], %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] : vector<8x16x1xf32>, memref + + return %0 : tensor +}