diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -724,35 +724,39 @@ /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> /// ``` /// -/// Note: A pass label is attached to new TransferReadOps, so that subsequent -/// applications of this pattern do not create an additional %v_init vector. +/// Note: As an optimization, if the result of the original TransferReadOp +/// was directly inserted into another vector, no new %v_init vector is created. +/// Instead, the new TransferReadOp results are inserted into that vector. struct UnrollTransferReadConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - /// Find the result vector %v_init or create a new vector if this the first - /// application of the pattern. + /// Return the vector into which the newly created TransferReadOp results + /// are inserted. Value getResultVector(TransferReadOp xferOp, PatternRewriter &rewriter) const { - if (xferOp->hasAttr(kPassLabel)) { - return getInsertOp(xferOp).dest(); - } + if (auto insertOp = getInsertOp(xferOp)) + return insertOp.dest(); return std_splat(xferOp.getVectorType(), xferOp.padding()).value; } - /// Assuming that this not the first application of the pattern, return the - /// vector.insert op in which the result of this transfer op is used. + /// If the result of the TransferReadOp has exactly one user, which is a + /// vector::InsertOp, return that operation. vector::InsertOp getInsertOp(TransferReadOp xferOp) const { - Operation *xferOpUser = *xferOp->getUsers().begin(); - return dyn_cast(xferOpUser); + if (xferOp->hasOneUse()) { + Operation *xferOpUser = *xferOp->getUsers().begin(); + if (auto insertOp = dyn_cast(xferOpUser)) + return insertOp; + } + + return vector::InsertOp(); } - /// Assuming that this not the first application of the pattern, return the - /// indices of the vector.insert op in which the result of this transfer op - /// is used. + /// If the result of the TransferReadOp has exactly one user, which is a + /// vector::InsertOp, return that operation's indices. void getInsertionIndices(TransferReadOp xferOp, SmallVector &indices) const { - if (xferOp->hasAttr(kPassLabel)) { - llvm::for_each(getInsertOp(xferOp).position(), [&](Attribute attr) { + if (auto insertOp = getInsertOp(xferOp)) { + llvm::for_each(insertOp.position(), [&](Attribute attr) { indices.push_back(attr.dyn_cast().getInt()); }); } @@ -766,6 +770,7 @@ return failure(); ScopedContext scope(rewriter, xferOp.getLoc()); + auto insertOp = getInsertOp(xferOp); auto vec = getResultVector(xferOp, rewriter); auto vecType = vec.getType().dyn_cast(); auto xferVecType = xferOp.getVectorType(); @@ -803,7 +808,6 @@ dyn_cast(newXferOpVal.getDefiningOp()); maybeAssignMask(b, xferOp, newXferOp, i); - maybeApplyPassLabel(b, newXferOp); return vector_insert(newXferOp, vec, insertionIndices).value; }, @@ -814,8 +818,9 @@ }); } - if (xferOp->hasAttr(kPassLabel)) { - rewriter.replaceOp(getInsertOp(xferOp), vec); + if (insertOp) { + // Rewrite single user of the old TransferReadOp, which was an InsertOp. + rewriter.replaceOp(insertOp, vec); rewriter.eraseOp(xferOp); } else { rewriter.replaceOp(xferOp, vec); @@ -846,32 +851,33 @@ /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> /// ``` /// -/// Note: A pass label is attached to new TransferWriteOps, so that subsequent -/// applications of this pattern can read the indices of previously generated -/// vector.extract ops. +/// Note: As an optimization, if the vector of the original TransferWriteOp +/// was directly extracted from another vector via an ExtractOp `a`, extract +/// the vectors for the newly generated TransferWriteOps from `a`'s input. By +/// doing so, `a` may become dead, and the number of ExtractOps generated during +/// recursive application of this pattern will be minimal. struct UnrollTransferWriteConversion : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - /// If this is not the first application of the pattern, find the original - /// vector %vec that is written by this transfer op. Otherwise, return the - /// vector of this transfer op. + /// Return the vector from which newly generated ExtracOps will extract. Value getDataVector(TransferWriteOp xferOp) const { - if (xferOp->hasAttr(kPassLabel)) - return getExtractOp(xferOp).vector(); + if (auto extractOp = getExtractOp(xferOp)) + return extractOp.vector(); return xferOp.vector(); } - /// Assuming that this is not the first application of the pattern, find the - /// vector.extract op whose result is written by this transfer op. + /// If the input of the given TransferWriteOp is an ExtractOp, return it. vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { return dyn_cast(xferOp.vector().getDefiningOp()); } + /// If the input of the given TransferWriteOp is an ExtractOp, return its + /// indices. void getExtractionIndices(TransferWriteOp xferOp, SmallVector &indices) const { - if (xferOp->hasAttr(kPassLabel)) { - llvm::for_each(getExtractOp(xferOp).position(), [&](Attribute attr) { + if (auto extractOp = getExtractOp(xferOp)) { + llvm::for_each(extractOp.position(), [&](Attribute attr) { indices.push_back(attr.dyn_cast().getInt()); }); } @@ -918,7 +924,6 @@ .op; maybeAssignMask(b, xferOp, newXferOp, i); - maybeApplyPassLabel(b, newXferOp); }); } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir @@ -4,7 +4,7 @@ // RUN: FileCheck %s // RUN: mlir-opt %s -test-unrolled-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ -// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s @@ -17,6 +17,17 @@ return } +func @transfer_read_3d_and_extract(%A : memref, + %o: index, %a: index, %b: index, %c: index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42 + {in_bounds = [true, true, true]} + : memref, vector<2x5x3xf32> + %sub = vector.extract %f[0] : vector<2x5x3xf32> + vector.print %sub: vector<5x3xf32> + return +} + func @transfer_read_3d_broadcast(%A : memref, %o: index, %a: index, %b: index, %c: index) { %fm42 = constant -42.0: f32 @@ -94,26 +105,31 @@ : (memref, index, index, index, index) -> () // CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) ) - // 2. Write 3D vector to 4D memref. + // 2. Read 3D vector from 4D memref and extract subvector from result. + call @transfer_read_3d_and_extract(%A, %c0, %c0, %c0, %c0) + : (memref, index, index, index, index) -> () + // CHECK: ( ( 0, 0, 2 ), ( 2, 3, 4 ), ( 4, 6, 6 ), ( 6, 9, 20 ), ( 20, 30, 22 ) ) + + // 3. Write 3D vector to 4D memref. call @transfer_write_3d(%A, %c0, %c0, %c1, %c1) : (memref, index, index, index, index) -> () - // 3. Read memref to verify step 2. + // 4. Read memref to verify step 2. call @transfer_read_3d(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () // CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) ) - // 4. Read 3D vector from 4D memref and transpose vector. + // 5. Read 3D vector from 4D memref and transpose vector. call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () // CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) ) - // 5. Read 1D vector from 4D memref and broadcast vector to 3D. + // 6. Read 1D vector from 4D memref and broadcast vector to 3D. call @transfer_read_3d_broadcast(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () // CHECK: ( ( ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ), ( 0, 0, -42 ) ), ( ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ), ( 20, 30, -42 ) ) ) - // 6. Read 1D vector from 4D memref with mask and broadcast vector to 3D. + // 7. Read 1D vector from 4D memref with mask and broadcast vector to 3D. call @transfer_read_3d_mask_broadcast(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () // CHECK: ( ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ), ( ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ), ( 20, 20, 20 ) ) )