diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -63,6 +63,7 @@ std::vector opToErase; }; +} // namespace /// Return true if there is a path from start operation to dest operation, /// otherwise return false. The operations have to be in the same region. bool TransferOptimization::isReachable(Operation *start, Operation *dest) { @@ -288,14 +289,25 @@ return llvm::count_if(shape, [](int64_t dimSize) { return dimSize != 1; }); } +/// Returns a copy of `shape` without unit dims. +static SmallVector getReducedShape(ArrayRef shape) { + SmallVector reducedShape; + llvm::copy_if(shape, std::back_inserter(reducedShape), + [](int64_t dimSize) { return dimSize != 1; }); + return reducedShape; +} + /// Returns true if all values are `arith.constant 0 : index` static bool isZero(Value v) { auto cst = v.getDefiningOp(); return cst && cst.value() == 0; } -/// Rewrites vector.transfer_read ops where the source has unit dims, by -/// inserting a memref.subview dropping those unit dims. +namespace { + +/// Rewrites `vector.transfer_read` ops where the source has unit dims, by +/// inserting a memref.subview dropping those unit dims. The vector shapes are +/// also reduced accordingly. class TransferReadDropUnitDimsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -317,12 +329,15 @@ return failure(); if (!transferReadOp.getPermutationMap().isMinorIdentity()) return failure(); + // Check if the source shape can be further reduced. int reducedRank = getReducedRank(sourceType.getShape()); if (reducedRank == sourceType.getRank()) - return failure(); // The source shape can't be further reduced. - if (reducedRank != vectorType.getRank()) - return failure(); // This pattern requires the vector shape to match the - // reduced source shape. + return failure(); + // Check if the reduced vector shape matches the reduced source shape. + // Otherwise, this case is not supported yet. + int vectorReducedRank = getReducedRank(vectorType.getShape()); + if (reducedRank != vectorReducedRank) + return failure(); if (llvm::any_of(transferReadOp.getIndices(), [](Value v) { return !isZero(v); })) return failure(); @@ -331,14 +346,22 @@ Value c0 = rewriter.create(loc, 0); SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); - rewriter.replaceOpWithNewOp( - transferReadOp, vectorType, reducedShapeSource, zeros, identityMap); + auto reducedVectorType = VectorType::get( + getReducedShape(vectorType.getShape()), vectorType.getElementType()); + + auto newTransferReadOp = rewriter.create( + loc, reducedVectorType, reducedShapeSource, zeros, identityMap); + auto shapeCast = rewriter.createOrFold( + loc, vectorType, newTransferReadOp); + rewriter.replaceOp(transferReadOp, shapeCast); + return success(); } }; -/// Rewrites vector.transfer_write ops where the "source" (i.e. destination) has -/// unit dims, by inserting a memref.subview dropping those unit dims. +/// Rewrites `vector.transfer_write` ops where the "source" (i.e. destination) +/// has unit dims, by inserting a `memref.subview` dropping those unit dims. The +/// vector shapes are also reduced accordingly. class TransferWriteDropUnitDimsPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -360,12 +383,15 @@ return failure(); if (!transferWriteOp.getPermutationMap().isMinorIdentity()) return failure(); + // Check if the destination shape can be further reduced. int reducedRank = getReducedRank(sourceType.getShape()); if (reducedRank == sourceType.getRank()) - return failure(); // The source shape can't be further reduced. - if (reducedRank != vectorType.getRank()) - return failure(); // This pattern requires the vector shape to match the - // reduced source shape. + return failure(); + // Check if the reduced vector shape matches the reduced destination shape. + // Otherwise, this case is not supported yet. + int vectorReducedRank = getReducedRank(vectorType.getShape()); + if (reducedRank != vectorReducedRank) + return failure(); if (llvm::any_of(transferWriteOp.getIndices(), [](Value v) { return !isZero(v); })) return failure(); @@ -374,12 +400,20 @@ Value c0 = rewriter.create(loc, 0); SmallVector zeros(reducedRank, c0); auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); + VectorType reducedVectorType = VectorType::get( + getReducedShape(vectorType.getShape()), vectorType.getElementType()); + + auto shapeCast = rewriter.createOrFold( + loc, reducedVectorType, vector); rewriter.replaceOpWithNewOp( - transferWriteOp, vector, reducedShapeSource, zeros, identityMap); + transferWriteOp, shapeCast, reducedShapeSource, zeros, identityMap); + return success(); } }; +} // namespace + /// Return true if the memref type has its inner dimension matching the given /// shape. Otherwise return false. static int64_t hasMatchingInnerContigousShape(MemRefType memrefType, @@ -439,6 +473,8 @@ return success(); } +namespace { + /// Rewrites contiguous row-major vector.transfer_read ops by inserting /// memref.collapse_shape on the source so that the resulting /// vector.transfer_read has a 1D source. Requires the source shape to be @@ -732,6 +768,7 @@ return success(); } }; + } // namespace void mlir::vector::transferOpflowOpt(RewriterBase &rewriter, diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -15,6 +15,14 @@ // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> // CHECK: vector.transfer_read %[[SUBVIEW]] +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.apply_rank_reducing_subview_patterns %module_op + : (!pdl.operation) -> !pdl.operation +} + +// ----- + func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, 2, 1], offset: ?>>, %vec : vector<3x2xi8>) { %c0 = arith.constant 0 : index vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0] : @@ -28,6 +36,97 @@ // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}> // CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]] +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.apply_rank_reducing_subview_patterns %module_op + : (!pdl.operation) -> !pdl.operation +} + +// ----- + +func.func @transfer_read_and_vector_rank_reducing( + %arg : memref<1x1x3x2x1xf32>) -> vector<3x2x1xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst : + memref<1x1x3x2x1xf32>, vector<3x2x1xf32> + return %v : vector<3x2x1xf32> +} + +// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32> +// CHECK: vector.transfer_read %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : memref<3x2xf32>, vector<3x2xf32> + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.apply_rank_reducing_subview_patterns %module_op + : (!pdl.operation) -> !pdl.operation +} + +// ----- + +func.func @transfer_write_and_vector_rank_reducing( + %arg : memref<1x1x3x2x1xf32>, + %vec : vector<3x2x1xf32>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] : + vector<3x2x1xf32>, memref<1x1x3x2x1xf32> + return +} + +// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing +// CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2x1xf32> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0, 0] [1, 1, 3, 2, 1] [1, 1, 1, 1, 1] +// CHECK-SAME: memref<1x1x3x2x1xf32> to memref<3x2xf32> +// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]{{.*}} {in_bounds = [true, true]} : vector<3x2xf32>, memref<3x2xf32> + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + transform.vector.apply_rank_reducing_subview_patterns %module_op + : (!transform.any_op) -> !transform.any_op +} + +// ----- + +func.func @transfer_read_and_vector_rank_reducing_to_0d( + %arg : memref<1x1x1x1x1xf32>) -> vector<1x1x1xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.0 : f32 + %v = vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst : + memref<1x1x1x1x1xf32>, vector<1x1x1xf32> + return %v : vector<1x1x1xf32> +} + +// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d +// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref +// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref, vector +// CHECK: vector.shape_cast %[[READ]] : vector to vector<1x1x1xf32> + +transform.sequence failures(propagate) { +^bb1(%module_op: !pdl.operation): + transform.vector.apply_rank_reducing_subview_patterns %module_op + : (!pdl.operation) -> !pdl.operation +} + +// ----- + +func.func @transfer_write_and_vector_rank_reducing_to_0d( + %arg : memref<1x1x1x1x1xf32>, + %vec : vector<1x1x1xf32>) { + %c0 = arith.constant 0 : index + vector.transfer_write %vec, %arg [%c0, %c0, %c0, %c0, %c0] : + vector<1x1x1xf32>, memref<1x1x1x1x1xf32> + return +} + +// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d +// CHECK-SAME: %[[MEMREF:.+]]: memref<1x1x1x1x1xf32>, %[[VECTOR:.+]]: vector<1x1x1xf32> +// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[MEMREF]][0, 0, 0, 0, 0] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : memref<1x1x1x1x1xf32> to memref +// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector +// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector, memref transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op):