diff --git a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp @@ -71,8 +71,22 @@ /// Given a vector transfer op, calculate which dimension of the `source` /// memref should be unpacked in the next application of TransferOpConversion. template -static int64_t unpackedDim(OpTy xferOp) { - return xferOp.getShapedType().getRank() - xferOp.getVectorType().getRank(); +static unsigned unpackedDim(OpTy xferOp) { + auto map = xferOp.permutation_map(); + // TODO: Handle broadcast + auto expr = map.getResult(0).template dyn_cast(); + assert(expr && "Expected AffineDimExpr in permutation map result"); + return expr.getPosition(); +} + +/// Compute the permutation map for the new (N-1)-D vector transfer op. This +/// map is identical to the current permutation map, but the first result is +/// omitted. +template +static AffineMap unpackedPermutationMap(OpTy xferOp, OpBuilder &builder) { + auto map = xferOp.permutation_map(); + return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), + builder.getContext()); } /// Calculate the indices for the new vector transfer op. @@ -124,7 +138,7 @@ /// `resultTypes`. template static Value generateInBoundsCheck( - OpTy xferOp, Value iv, OpBuilder &builder, int64_t dim, + OpTy xferOp, Value iv, OpBuilder &builder, unsigned dim, TypeRange resultTypes, function_ref inBoundsCase, function_ref outOfBoundsCase = nullptr) { @@ -253,11 +267,13 @@ auto bufferType = buffer.getType().dyn_cast(); auto vecType = bufferType.getElementType().dyn_cast(); - auto map = getTransferMinorIdentityMap(xferOp.getShapedType(), vecType); auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr()); - auto newXfer = vector_transfer_read( - vecType, xferOp.source(), xferIndices, AffineMapAttr::get(map), - xferOp.padding(), Value(), inBoundsAttr).value; + auto newXfer = + vector_transfer_read( + vecType, xferOp.source(), xferIndices, + AffineMapAttr::get(unpackedPermutationMap(xferOp, rewriter)), + xferOp.padding(), Value(), inBoundsAttr) + .value; if (vecType.getRank() > kTargetRank) newXfer.getDefiningOp()->setAttr(kPassLabel, rewriter.getUnitAttr()); @@ -331,11 +347,11 @@ auto vec = memref_load(buffer, loadIndices); auto vecType = vec.value.getType().dyn_cast(); - auto map = getTransferMinorIdentityMap(xferOp.getShapedType(), vecType); auto inBoundsAttr = dropFirstElem(rewriter, xferOp.in_boundsAttr()); - auto newXfer = vector_transfer_write(Type(), vec, xferOp.source(), - xferIndices, AffineMapAttr::get(map), - Value(), inBoundsAttr); + auto newXfer = vector_transfer_write( + Type(), vec, xferOp.source(), xferIndices, + AffineMapAttr::get(unpackedPermutationMap(xferOp, rewriter)), Value(), + inBoundsAttr); if (vecType.getRank() > kTargetRank) newXfer.op->setAttr(kPassLabel, rewriter.getUnitAttr()); @@ -360,8 +376,6 @@ return failure(); if (xferOp.mask()) return failure(); - if (!xferOp.permutation_map().isMinorIdentity()) - return failure(); return success(); } @@ -649,13 +663,15 @@ void populateProgressiveVectorToSCFConversionPatterns( RewritePatternSet &patterns) { - patterns.add, - TransferOpConversion, - Strided1dTransferOpConversion, - Strided1dTransferOpConversion>( - patterns.getContext()); + TransferOpConversion>(patterns.getContext()); + + if (kTargetRank == 1) { + patterns.add, + Strided1dTransferOpConversion>( + patterns.getContext()); + } } struct ConvertProgressiveVectorToSCFPass diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-2d.mlir @@ -17,6 +17,16 @@ return } +func @transfer_read_2d_transposed( + %A : memref, %base1: index, %base2: index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%base1, %base2], %fm42 + {permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : + memref, vector<4x9xf32> + vector.print %f: vector<4x9xf32> + return +} + func @transfer_write_2d(%A : memref, %base1: index, %base2: index) { %fn1 = constant -1.0 : f32 %vf0 = splat %fn1 : vector<1x4xf32> @@ -53,12 +63,20 @@ // On input, memory contains [[ 0, 1, 2, ...], [10, 11, 12, ...], ...] // Read shifted by 2 and pad with -42: call @transfer_read_2d(%A, %c1, %c2) : (memref, index, index) -> () + // Same as above, but transposed + call @transfer_read_2d_transposed(%A, %c1, %c2) + : (memref, index, index) -> () // Write into memory shifted by 3 call @transfer_write_2d(%A, %c3, %c1) : (memref, index, index) -> () // Read shifted by 0 and pad with -42: call @transfer_read_2d(%A, %c0, %c0) : (memref, index, index) -> () + // Same as above, but transposed + call @transfer_read_2d_transposed(%A, %c0, %c0) + : (memref, index, index) -> () return } // CHECK: ( ( 12, 13, -42, -42, -42, -42, -42, -42, -42 ), ( 22, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) +// CHECK: ( ( 12, 22, -42, -42, -42, -42, -42, -42, -42 ), ( 13, 23, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) // CHECK: ( ( 0, 1, 2, 3, -42, -42, -42, -42, -42 ), ( 10, 11, 12, 13, -42, -42, -42, -42, -42 ), ( 20, 21, 22, 23, -42, -42, -42, -42, -42 ), ( -42, -42, -42, -42, -42, -42, -42, -42, -42 ) ) +// CHECK: ( ( 0, 10, 20, -42, -42, -42, -42, -42, -42 ), ( 1, 11, 21, -42, -42, -42, -42, -42, -42 ), ( 2, 12, 22, -42, -42, -42, -42, -42, -42 ), ( 3, 13, 23, -42, -42, -42, -42, -42, -42 ) ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-read-3d.mlir @@ -19,6 +19,16 @@ return } +func @transfer_read_3d_transposed(%A : memref, + %o: index, %a: index, %b: index, %c: index) { + %fm42 = constant -42.0: f32 + %f = vector.transfer_read %A[%o, %a, %b, %c], %fm42 + {permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1)>} + : memref, vector<3x5x3xf32> + vector.print %f: vector<3x5x3xf32> + return +} + func @transfer_write_3d(%A : memref, %o: index, %a: index, %b: index, %c: index) { %fn1 = constant -1.0 : f32 @@ -66,8 +76,11 @@ : (memref, index, index, index, index) -> () call @transfer_read_3d(%A, %c0, %c0, %c0, %c0) : (memref, index, index, index, index) -> () + call @transfer_read_3d_transposed(%A, %c0, %c0, %c0, %c0) + : (memref, index, index, index, index) -> () return } // CHECK: ( ( ( 0, 0, -42 ), ( 2, 3, -42 ), ( 4, 6, -42 ), ( 6, 9, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, 33, -42 ), ( 24, 36, -42 ), ( 26, 39, -42 ), ( -42, -42, -42 ) ) ) // CHECK: ( ( ( 0, 0, -42 ), ( 2, -1, -42 ), ( 4, -1, -42 ), ( 6, -1, -42 ), ( -42, -42, -42 ) ), ( ( 20, 30, -42 ), ( 22, -1, -42 ), ( 24, -1, -42 ), ( 26, -1, -42 ), ( -42, -42, -42 ) ) ) +// CHECK: ( ( ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ), ( 0, 20, 40 ) ), ( ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ), ( 0, 30, 60 ) ), ( ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ), ( -42, -42, -42 ) ) ) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-to-loops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-to-loops.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-to-loops.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-transfer-to-loops.mlir @@ -3,6 +3,11 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext,%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -test-progressive-convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext,%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1) -> (d1)>