Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2842,6 +2842,20 @@ return ops; } +/// Converts TransferRead op used by ExtractMap op into a smaller dimension +/// TransferRead. +/// Example: +/// ``` +/// %a = vector.transfer_read %A[%c0, %c0, %c0], %cf0: +/// memref<64x64x64xf32>, vector<64x4x32xf32> +/// %e = vector.extract_map %a[%id] : vector<64x4x32xf32> to vector<2x4x1xf32> +/// ``` +/// to: +/// ``` +/// %id1 = affine.apply affine_map<()[s0] -> (s0 * 2)> (%id) +/// %e = vector.transfer_read %A[%id1, %c0, %id1], %cf0 : +/// memref<64x64x64xf32>, vector<2x4x1xf32> +/// ``` struct TransferReadExtractPattern : public OpRewritePattern { TransferReadExtractPattern(MLIRContext *context) @@ -2858,18 +2872,23 @@ return failure(); SmallVector indices(read.indices().begin(), read.indices().end()); - AffineMap map = extract.map(); + AffineMap indexMap = extract.map().compose(read.permutation_map()); unsigned idCount = 0; ImplicitLocOpBuilder lb(read.getLoc(), rewriter); - for (auto expr : map.getResults()) { + for (auto it : + llvm::zip(indexMap.getResults(), extract.map().getResults())) { AffineExpr d0, d1; bindDims(read.getContext(), d0, d1); - unsigned pos = expr.cast().getPosition(); + auto indexExpr = std::get<0>(it).dyn_cast(); + if (!indexExpr) + continue; + unsigned indexPos = indexExpr.getPosition(); + unsigned vectorPos = std::get<1>(it).cast().getPosition(); auto scale = getAffineConstantExpr( - extract.getResultType().getDimSize(pos), read.getContext()); - indices[pos] = - makeComposedAffineApply(rewriter, read.getLoc(), d0 + scale * d1, - {indices[pos], extract.ids()[idCount++]}); + extract.getResultType().getDimSize(vectorPos), read.getContext()); + indices[indexPos] = makeComposedAffineApply( + rewriter, read.getLoc(), d0 + scale * d1, + {indices[indexPos], extract.ids()[idCount++]}); } Value newRead = lb.create( extract.getType(), read.source(), indices, read.permutation_map(), @@ -2895,18 +2914,24 @@ return failure(); SmallVector indices(write.indices().begin(), write.indices().end()); - AffineMap map = insert.map(); + AffineMap indexMap = insert.map().compose(write.permutation_map()); unsigned idCount = 0; Location loc = write.getLoc(); - for (auto expr : map.getResults()) { + for (auto it : + llvm::zip(indexMap.getResults(), insert.map().getResults())) { AffineExpr d0, d1; bindDims(write.getContext(), d0, d1); - unsigned pos = expr.cast().getPosition(); + auto indexExpr = std::get<0>(it).dyn_cast(); + if (!indexExpr) + continue; + unsigned indexPos = indexExpr.getPosition(); + unsigned vectorPos = std::get<1>(it).cast().getPosition(); auto scale = getAffineConstantExpr( - insert.getSourceVectorType().getDimSize(pos), write.getContext()); - indices[pos] = + insert.getSourceVectorType().getDimSize(vectorPos), + write.getContext()); + indices[indexPos] = makeComposedAffineApply(rewriter, loc, d0 + scale * d1, - {indices[pos], insert.ids()[idCount++]}); + {indices[indexPos], insert.ids()[idCount++]}); } rewriter.create( loc, insert.vector(), write.source(), indices, write.permutation_map(), Index: mlir/test/Dialect/Vector/vector-distribution.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-distribution.mlir +++ mlir/test/Dialect/Vector/vector-distribution.mlir @@ -123,4 +123,34 @@ return } +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)> +#map1 = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 2)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, 0, 0)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (0, d3, d0)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2, d3) -> (d3, d2, d1)> +// CHECK: func @vector_add_transfer_permutation +// CHECK-SAME: (%[[ID_0:.*]]: index, %[[ID_1:.*]]: index +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[ID2:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]] +// CHECK-NEXT: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[ID2]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref, vector<2x4x1xf32> +// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID_0]], %[[C0]], %[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP2]]} : memref, vector<2x4x1xf32> +// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2x4x1xf32> +// CHECK-NEXT: %[[ID3:.*]] = affine.apply #[[MAP0]]()[%[[ID_0]]] +// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]], %[[ID_1]], %[[C0]], %[[ID3]]] {permutation_map = #[[MAP3]]} : vector<2x4x1xf32>, memref +// CHECK-NEXT: return +func @vector_add_transfer_permutation(%id0 : index, %id1 : index, %A: memref, + %B: memref, %C: memref) { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %a = vector.transfer_read %A[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map0} : memref, vector<64x4x32xf32> + %b = vector.transfer_read %B[%c0, %c0, %c0, %c0], %cf0 {permutation_map = #map1}: memref, vector<64x4x32xf32> + %acc = addf %a, %b: vector<64x4x32xf32> + vector.transfer_write %acc, %C[%c0, %c0, %c0, %c0] {permutation_map = #map2}: vector<64x4x32xf32>, memref + return +}