diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3040,6 +3040,18 @@ } }; +/// Transpose a vector transfer op's `in_bounds` attribute according to given +/// indices. +static ArrayAttr +transposeInBoundsAttr(OpBuilder &builder, ArrayAttr attr, + const SmallVector &permutation) { + SmallVector newInBoundsValues; + for (unsigned pos : permutation) + newInBoundsValues.push_back( + attr.getValue()[pos].cast().getValue()); + return builder.getBoolArrayAttr(newInBoundsValues); +} + /// Lower transfer_read op with permutation into a transfer_read with a /// permutation map composed of leading zeros followed by a minor identiy + /// vector.transpose op. @@ -3084,6 +3096,7 @@ newVectorShape[pos.value()] = originalShape[pos.index()]; } + // Transpose mask operand. Value newMask; if (op.mask()) { // Remove unused dims from the permutation map. E.g.: @@ -3103,12 +3116,20 @@ maskTransposeIndices); } + // Transpose in_bounds attribute. + ArrayAttr newInBounds = + op.in_bounds() ? transposeInBoundsAttr( + rewriter, op.in_bounds().getValue(), permutation) + : ArrayAttr(); + + // Generate new transfer_read operation. VectorType newReadType = VectorType::get(newVectorShape, op.getVectorType().getElementType()); Value newRead = rewriter.create( op.getLoc(), newReadType, op.source(), op.indices(), newMap, - op.padding(), newMask, op.in_bounds() ? *op.in_bounds() : ArrayAttr()); + op.padding(), newMask, newInBounds); + // Transpose result of transfer_read. SmallVector transposePerm(permutation.begin(), permutation.end()); rewriter.replaceOpWithNewOp(op, newRead, transposePerm); diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -231,9 +231,9 @@ %m = constant 1 : i1 %mask0 = splat %m : vector<7x14xi1> - %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {permutation_map = #map0} : memref, vector<7x14x8x16xf32> + %0 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask0 {in_bounds = [true, false, true, true], permutation_map = #map0} : memref, vector<7x14x8x16xf32> // CHECK: %[[MASK0:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1> -// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> +// CHECK: vector.transfer_read {{.*}} %[[MASK0]] {in_bounds = [false, true, true, true], permutation_map = #[[$MAP0]]} : memref, vector<14x7x8x16xf32> // CHECK: vector.transpose %{{.*}}, [1, 0, 2, 3] : vector<14x7x8x16xf32> to vector<7x14x8x16xf32> %mask1 = splat %m : vector<14x16xi1> @@ -243,9 +243,9 @@ // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> %mask2 = splat %m : vector<7x14xi1> - %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, true, false, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> + %2 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst, %mask2 {in_bounds = [true, false, true, true], permutation_map = #map2} : memref, vector<7x14x8x16xf32> // CHECK: %[[MASK2:.*]] = vector.transpose {{.*}} : vector<7x14xi1> to vector<14x7xi1> -// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [true, false, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> +// CHECK: vector.transfer_read {{.*}} %[[MASK2]] {in_bounds = [false, true, true], permutation_map = #[[$MAP1]]} : memref, vector<14x16x7xf32> // CHECK: vector.broadcast %{{.*}} : vector<14x16x7xf32> to vector<8x14x16x7xf32> // CHECK: vector.transpose %{{.*}}, [3, 1, 0, 2] : vector<8x14x16x7xf32> to vector<7x14x8x16xf32> diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-permutation-lowering.mlir @@ -0,0 +1,41 @@ +// Run test with and without test-vector-transfer-lowering-patterns. + +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + + +memref.global "private" @gv : memref<3x4xf32> = dense<[[0. , 1. , 2. , 3. ], + [10., 11., 12., 13.], + [20., 21., 22., 23.]]> + +// Vector load with transpose. +func @transfer_read_2d(%A : memref, %base1: index, %base2: index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %fm42 + {in_bounds = [true, false], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : + memref, vector<3x9xf32> + vector.print %f: vector<3x9xf32> + return +} + +func @entry() { + %c0 = constant 0: index + %c1 = constant 1: index + %c2 = constant 2: index + %c3 = constant 3: index + %0 = memref.get_global @gv : memref<3x4xf32> + %A = memref.cast %0 : memref<3x4xf32> to memref + + // 1. Read 2D vector from 2D memref with transpose. + call @transfer_read_2d(%A, %c1, %c2) : (memref, index, index) -> () + // CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( 20, 0, -42, -42, -42, -42, -42, -42, -42 ) ) + + return +}