diff --git a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp --- a/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorTransforms.cpp @@ -460,12 +460,13 @@ op, iterationBounds, vectors, resultIndex, targetShape, builder)}; } -// Generates slices of 'vectorType' according to 'sizes' and 'strides, and -// calls 'fn' with linear index and indices for each slice. +/// Generates slices of 'vectorType' according to 'sizes' and 'strides, and +/// calls 'fn' with linear index and indices for each slice. static void -generateTransferOpSlices(VectorType vectorType, TupleType tupleType, - ArrayRef sizes, ArrayRef strides, - ArrayRef indices, PatternRewriter &rewriter, +generateTransferOpSlices(Type memrefElementType, VectorType vectorType, + TupleType tupleType, ArrayRef sizes, + ArrayRef strides, ArrayRef indices, + PatternRewriter &rewriter, function_ref)> fn) { // Compute strides w.r.t. to slice counts in each dimension. auto maybeDimSliceCounts = shapeRatio(vectorType.getShape(), sizes); @@ -475,6 +476,25 @@ int64_t numSlices = tupleType.size(); unsigned numSliceIndices = indices.size(); + // Compute 'indexOffset' at which to update 'indices', which is equal + // to the memref rank (indices.size) minus the effective 'vectorRank'. + // The effective 'vectorRank', is equal to the rank of the vector type + // minus the rank of the memref vector element type (if it has one). + // + // For example: + // + // Given memref type 'memref<6x2x1xvector<2x4xf32>>' and vector + // transfer_read/write ops which read/write vectors of type + // 'vector<2x1x2x4xf32>'. The memref rank is 3, and the effective + // vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1. + // + unsigned vectorRank = vectorType.getRank(); + if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { + assert(vectorRank >= memrefVectorElementType.getRank()); + vectorRank -= memrefVectorElementType.getRank(); + } + unsigned indexOffset = numSliceIndices - vectorRank; + auto *ctx = rewriter.getContext(); for (unsigned i = 0; i < numSlices; ++i) { auto vectorOffsets = delinearize(sliceStrides, i); @@ -482,18 +502,41 @@ computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); // Compute 'sliceIndices' by adding 'sliceOffsets[i]' to 'indices[i]'. SmallVector sliceIndices(numSliceIndices); - for (auto it : llvm::enumerate(indices)) { - auto expr = getAffineDimExpr(0, ctx) + - getAffineConstantExpr(elementOffsets[it.index()], ctx); - auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); - sliceIndices[it.index()] = rewriter.create( - it.value().getLoc(), map, ArrayRef(it.value())); + for (unsigned j = 0; j < numSliceIndices; ++j) { + if (j < indexOffset) { + sliceIndices[j] = indices[j]; + } else { + auto expr = getAffineDimExpr(0, ctx) + + getAffineConstantExpr(elementOffsets[j - indexOffset], ctx); + auto map = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, expr); + sliceIndices[j] = rewriter.create( + indices[j].getLoc(), map, ArrayRef(indices[j])); + } } // Call 'fn' to generate slice 'i' at 'sliceIndices'. fn(i, sliceIndices); } } +/// Returns true if 'map' is a suffix of an identity affine map, false +/// otherwise. Example: affine_map<(d0, d1, d2, d3) -> (d2, d3)> +static bool isIdentitySuffix(AffineMap map) { + if (map.getNumDims() < map.getNumResults()) + return false; + ArrayRef results = map.getResults(); + Optional lastPos; + for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { + auto expr = results[i].dyn_cast(); + if (!expr) + return false; + int currPos = static_cast(expr.getPosition()); + if (lastPos.hasValue() && currPos != lastPos.getValue() + 1) + return false; + lastPos = currPos; + } + return true; +} + namespace { // Splits vector TransferReadOp into smaller TransferReadOps based on slicing // scheme of its unique ExtractSlicesOp user. @@ -504,7 +547,7 @@ PatternRewriter &rewriter) const override { // TODO(andydavis, ntv) Support splitting TransferReadOp with non-identity // permutation maps. Repurpose code from MaterializeVectors transformation. - if (!xferReadOp.permutation_map().isIdentity()) + if (!isIdentitySuffix(xferReadOp.permutation_map())) return matchFailure(); // Return unless the unique 'xferReadOp' user is an ExtractSlicesOp. Value xferReadResult = xferReadOp.getResult(); @@ -523,6 +566,8 @@ assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); Location loc = xferReadOp.getLoc(); + auto memrefElementType = + xferReadOp.memref().getType().cast().getElementType(); int64_t numSlices = resultTupleType.size(); SmallVector vectorTupleValues(numSlices); SmallVector indices(xferReadOp.indices().begin(), @@ -535,8 +580,9 @@ loc, sliceVectorType, xferReadOp.memref(), sliceIndices, xferReadOp.permutation_map(), xferReadOp.padding()); }; - generateTransferOpSlices(sourceVectorType, resultTupleType, sizes, strides, - indices, rewriter, createSlice); + generateTransferOpSlices(memrefElementType, sourceVectorType, + resultTupleType, sizes, strides, indices, rewriter, + createSlice); // Create tuple of splice xfer read operations. Value tupleOp = rewriter.create(loc, resultTupleType, @@ -557,7 +603,7 @@ PatternRewriter &rewriter) const override { // TODO(andydavis, ntv) Support splitting TransferWriteOp with non-identity // permutation maps. Repurpose code from MaterializeVectors transformation. - if (!xferWriteOp.permutation_map().isIdentity()) + if (!isIdentitySuffix(xferWriteOp.permutation_map())) return matchFailure(); // Return unless the 'xferWriteOp' 'vector' operand is an 'InsertSlicesOp'. auto *vectorDefOp = xferWriteOp.vector().getDefiningOp(); @@ -580,6 +626,8 @@ insertSlicesOp.getStrides(strides); Location loc = xferWriteOp.getLoc(); + auto memrefElementType = + xferWriteOp.memref().getType().cast().getElementType(); SmallVector indices(xferWriteOp.indices().begin(), xferWriteOp.indices().end()); auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { @@ -588,8 +636,9 @@ loc, tupleOp.getOperand(index), xferWriteOp.memref(), sliceIndices, xferWriteOp.permutation_map()); }; - generateTransferOpSlices(resultVectorType, sourceTupleType, sizes, strides, - indices, rewriter, createSlice); + generateTransferOpSlices(memrefElementType, resultVectorType, + sourceTupleType, sizes, strides, indices, rewriter, + createSlice); // Erase old 'xferWriteOp'. rewriter.eraseOp(xferWriteOp); diff --git a/mlir/test/Dialect/VectorOps/vector-transforms.mlir b/mlir/test/Dialect/VectorOps/vector-transforms.mlir --- a/mlir/test/Dialect/VectorOps/vector-transforms.mlir +++ b/mlir/test/Dialect/VectorOps/vector-transforms.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -test-vector-to-vector-conversion | FileCheck %s // CHECK-DAG: #[[MAP0:map[0-9]+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d1, d2)> // CHECK-LABEL: func @add4x2 // CHECK: %[[ES1:.*]] = vector.extract_slices %{{.*}}, [2, 2], [1, 1] : vector<4x2xf32> into tuple, vector<2x2xf32>> @@ -311,3 +312,37 @@ %1 = vector.tuple_get %0, 1 : tuple, vector<8xf32>> return %1 : vector<8xf32> } + +// CHECK-LABEL: func @vector_transfers_vector_element_type +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C1]], %[[C0]]], %{{.*}} {permutation_map = #[[MAP1]]} : memref<6x2x1xvector<2x4xf32>>, vector<1x1x2x4xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR0]], %{{.*}}[%[[C0]], %[[C0]], %[[C0]]] {permutation_map = #[[MAP1]]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> +// CHECK-NEXT: vector.transfer_write %[[VTR1]], %{{.*}}[%[[C0]], %[[C1]], %[[C0]]] {permutation_map = #[[MAP1]]} : vector<1x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> + +func @vector_transfers_vector_element_type() { + %c0 = constant 0 : index + %cf0 = constant 0.000000e+00 : f32 + %vf0 = splat %cf0 : vector<2x4xf32> + + %0 = alloc() : memref<6x2x1xvector<2x4xf32>> + + %1 = vector.transfer_read %0[%c0, %c0, %c0], %vf0 + {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} + : memref<6x2x1xvector<2x4xf32>>, vector<2x1x2x4xf32> + + %2 = vector.extract_slices %1, [1, 1, 2, 4], [1, 1, 1, 1] + : vector<2x1x2x4xf32> into tuple, vector<1x1x2x4xf32>> + %3 = vector.tuple_get %2, 0 : tuple, vector<1x1x2x4xf32>> + %4 = vector.tuple_get %2, 1 : tuple, vector<1x1x2x4xf32>> + %5 = vector.tuple %3, %4 : vector<1x1x2x4xf32>, vector<1x1x2x4xf32> + %6 = vector.insert_slices %5, [1, 1, 2, 4], [1, 1, 1, 1] + : tuple, vector<1x1x2x4xf32>> into vector<2x1x2x4xf32> + + vector.transfer_write %6, %0[%c0, %c0, %c0] + {permutation_map = affine_map<(d0, d1, d2) -> (d1, d2)>} + : vector<2x1x2x4xf32>, memref<6x2x1xvector<2x4xf32>> + + return +}