diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -871,12 +871,12 @@ if (!hasSameTensorSize(padOp.source(), trimPadding)) return failure(); SmallVector inBounds(xferOp.getVectorType().getRank(), false); + rewriter.setInsertionPoint(xferOp); auto newXferOp = rewriter.replaceOpWithNewOp( xferOp, padOp.source().getType(), xferOp.vector(), padOp.source(), xferOp.indices(), xferOp.permutation_mapAttr(), xferOp.mask(), rewriter.getBoolArrayAttr(inBounds)); rewriter.replaceOp(trimPadding, newXferOp->getResult(0)); - return success(); } @@ -1017,6 +1017,7 @@ // Generate TransferReadOp: Read entire source tensor and add high padding. SmallVector readIndices( vecRank, rewriter.create(padOp.getLoc(), 0)); + rewriter.setInsertionPoint(insertOp); auto read = rewriter.create( padOp.getLoc(), vecType, padOp.source(), readIndices, padValue); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -621,38 +621,44 @@ // ----- +func private @make_vector() -> vector<7x9xf32> // CHECK-LABEL: func @pad_and_transfer_write_static -// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: vector<7x9xf32> +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK-NOT: linalg.pad_tensor // CHECK: %[[C0:.*]] = constant 0 : index -// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32> +// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32> +// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[ARG0]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor<5x6xf32> // CHECK: return %[[RESULT]] func @pad_and_transfer_write_static( - %arg0: tensor<5x6xf32>, %arg1: vector<7x9xf32>) -> tensor<5x6xf32> { + %arg0: tensor<5x6xf32>) -> tensor<5x6xf32> { %c0 = constant 0 : index %c5 = constant 5.0 : f32 %0 = linalg.pad_tensor %arg0 low[0, 0] high[5, 7] { ^bb0(%arg2: index, %arg3: index): linalg.yield %c5 : f32 } : tensor<5x6xf32> to tensor<10x13xf32> - %1 = vector.transfer_write %arg1, %0[%c0, %c0] + %1 = call @make_vector() : () -> vector<7x9xf32> + %2 = vector.transfer_write %1, %0[%c0, %c0] : vector<7x9xf32>, tensor<10x13xf32> - %2 = tensor.extract_slice %1[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32> - return %2 : tensor<5x6xf32> + %3 = tensor.extract_slice %2[0, 0] [5, 6] [1, 1] : tensor<10x13xf32> to tensor<5x6xf32> + return %3 : tensor<5x6xf32> } // ----- +func private @make_vector() -> vector<7x9xf32> + // CHECK-LABEL: func @pad_and_transfer_write_dynamic_static -// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: vector<7x9xf32>, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[SIZE:.*]]: index, %[[PADDING:.*]]: index // CHECK-NOT: linalg.pad_tensor // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[SUB:.*]] = tensor.extract_slice %[[ARG0]][0, 0] [%[[SIZE]], 6] [1, 1] : tensor to tensor -// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[ARG1]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor +// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> vector<7x9xf32> +// CHECK: %[[RESULT:.*]] = vector.transfer_write %[[VEC0]], %[[SUB]][%[[C0]], %[[C0]]] : vector<7x9xf32>, tensor // CHECK: return %[[RESULT]] func @pad_and_transfer_write_dynamic_static( - %arg0: tensor, %arg1: vector<7x9xf32>, %size: index, %padding: index) -> tensor { + %arg0: tensor, %size: index, %padding: index) -> tensor { %c0 = constant 0 : index %c5 = constant 5.0 : f32 %s = tensor.extract_slice %arg0[0, 0] [%size, 6] [1, 1] @@ -661,31 +667,36 @@ ^bb0(%arg2: index, %arg3: index): linalg.yield %c5 : f32 } : tensor to tensor - %1 = vector.transfer_write %arg1, %0[%c0, %c0] + %1 = call @make_vector() : () -> vector<7x9xf32> + %2 = vector.transfer_write %1, %0[%c0, %c0] : vector<7x9xf32>, tensor - %2 = tensor.extract_slice %1[0, 0] [%size, 6] [1, 1] : tensor to tensor - return %2 : tensor + %3 = tensor.extract_slice %2[0, 0] [%size, 6] [1, 1] : tensor to tensor + return %3 : tensor } // ----- +func private @make_vector() -> tensor<12x13xf32> + // CHECK-LABEL: func @pad_and_insert_slice -// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32> +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32> // CHECK-NOT: linalg.pad_tensor // CHECK-DAG: %[[C0:.*]] = constant 0 : index // CHECK-DAG: %[[C5:.*]] = constant 5.0 +// CHECK: %[[VEC0:.*]] = call @make_vector() : () -> tensor<12x13xf32> // CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32> -// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32> +// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[VEC0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32> // CHECK: return %[[WRITE]] func @pad_and_insert_slice( - %arg0: tensor<5x6xf32>, %arg1: tensor<12x13xf32>) -> tensor<12x13xf32> { + %arg0: tensor<5x6xf32>) -> tensor<12x13xf32> { %c0 = constant 0 : index %c5 = constant 5.0 : f32 %0 = linalg.pad_tensor %arg0 low[0, 0] high[2, 3] { ^bb0(%arg2: index, %arg3: index): linalg.yield %c5 : f32 } : tensor<5x6xf32> to tensor<7x9xf32> - %r = tensor.insert_slice %0 into %arg1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32> + %1 = call @make_vector() : () -> tensor<12x13xf32> + %r = tensor.insert_slice %0 into %1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32> return %r : tensor<12x13xf32> }