diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -306,4 +306,16 @@ }]; } +def ApplyFoldTensorSliceIntoTransferPatternsOp : Op]> { + let description = [{ + Indicates that tensor.extract_slice -> vector.transfer_read and + vector.transfer_write -> tensor.insert_slice op chains should be folded into + vector tranfer read and write ops + }]; + + let assemblyFormat = "attr-dict"; +} + #endif // VECTOR_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -138,6 +138,11 @@ populateVectorToSCFConversionPatterns(patterns, vectorTransferToSCFOptions); } +void transform::ApplyFoldTensorSliceIntoTransferPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + populateVectorTransferTensorSliceTransforms(patterns); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-tensor-slice-transforms.mlir @@ -1,4 +1,11 @@ -// RUN: mlir-opt -split-input-file -test-vector-transfer-tensor-slice-patterns %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-transform-dialect-interpreter %s | FileCheck %s + +transform.sequence failures(propagate) { +^bb1(%module_op: !transform.any_op): + transform.apply_patterns to %module_op { + transform.apply_patterns.vector.fold_tensor_slice_into_transfer + } : !transform.any_op +} // CHECK-LABEL: func @transfer_read_of_extract_slice( // CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index @@ -16,16 +23,14 @@ return %1 : vector<5x6xf32> } -// ----- - -// CHECK-LABEL: func @transfer_read_of_extract_slice( +// CHECK-LABEL: func @transfer_read_of_extract_slice_1d( // CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index // CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index // CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index // CHECK: %[[add:.*]] = arith.addi %[[s1]], %[[c4]] // CHECK: %[[r:.*]] = vector.transfer_read %[[t]][%[[c8]], %[[add]]], %{{.*}} {in_bounds = [true]} : tensor, vector<6xf32> // CHECK: return %[[r]] -func.func @transfer_read_of_extract_slice(%t : tensor, %s1 : index, %s2 : index) -> vector<6xf32> { +func.func @transfer_read_of_extract_slice_1d(%t : tensor, %s1 : index, %s2 : index) -> vector<6xf32> { %c3 = arith.constant 3 : index %c4 = arith.constant 4 : index %cst = arith.constant 0.0 : f32 @@ -34,8 +39,6 @@ return %1 : vector<6xf32> } -// ----- - // CHECK-LABEL: func @transfer_read_of_extract_slice_rank_reducing( // CHECK-SAME: %[[t:.*]]: tensor, %[[s1:.*]]: index, %[[s2:.*]]: index // CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index @@ -53,8 +56,6 @@ return %1 : vector<5x6xf32> } -// ----- - // CHECK-LABEL: func @transfer_read_of_extract_slice_illegal_rank_reducing( // CHECK: extract_slice // CHECK: vector.transfer_read @@ -67,8 +68,6 @@ return %1 : vector<5x6xf32> } -// ----- - // CHECK-LABEL: func @insert_slice_of_transfer_write( // CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index // CHECK: %[[c3:.*]] = arith.constant 3 : index @@ -81,8 +80,6 @@ return %1 : tensor } -// ----- - // CHECK-LABEL: func @insert_slice_of_transfer_write_illegal_rank_extending( // CHECK: vector.transfer_write // CHECK: insert_slice @@ -93,8 +90,6 @@ return %1 : tensor } -// ----- - // CHECK-LABEL: func @insert_slice_of_transfer_write_rank_extending( // CHECK-SAME: %[[t1:.*]]: tensor, %[[v:.*]]: vector<5x6xf32>, %[[s:.*]]: index // CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -690,26 +690,6 @@ } }; -struct TestVectorTransferTensorSlicePatterns - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestVectorTransferTensorSlicePatterns) - - StringRef getArgument() const final { - return "test-vector-transfer-tensor-slice-patterns"; - } - StringRef getDescription() const final { - return "Test patterns that fold vector transfer and tensor slice ops"; - } - - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateVectorTransferTensorSliceTransforms(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - struct TestFoldArithExtensionIntoVectorContractPatterns : public PassWrapper> { @@ -771,8 +751,6 @@ PassRegistration(); - PassRegistration(); - PassRegistration(); } } // namespace test