Index: mlir/include/mlir/Dialect/Vector/VectorTransforms.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -71,7 +71,8 @@ /// Unroll a transfer_write op. Break up the vector source into a tuple of /// vectors matching the given shape. Then store each element with its own -/// transfer_write. +/// transfer_write. If the transfer_write takes a tensor source, return the +/// unrolled Value in result. /// /// Example: /// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> @@ -83,7 +84,8 @@ /// %2 = vector.tuple_get %0, 1 : tuple, vector<2x4xf32>> /// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32> LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op, - ArrayRef targetShape); + ArrayRef targetShape, + SmallVector &result); /// Options that control the vector unrolling. struct UnrollVectorOptions { @@ -143,9 +145,10 @@ llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) return failure(); if (isa(op)) { - if (failed(unrollTransferWriteOp(rewriter, op, *targetShape))) + SmallVector result; + if (failed(unrollTransferWriteOp(rewriter, op, *targetShape, result))) return failure(); - rewriter.eraseOp(op); + rewriter.replaceOp(op, result); return success(); } if (op->getNumResults() != 1) Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -515,7 +515,7 @@ /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and /// calls 'fn' with linear index and indices for each slice. static void generateTransferOpSlices( - Type memrefElementType, VectorType vectorType, TupleType tupleType, + Type shapedElementType, VectorType vectorType, TupleType tupleType, ArrayRef sizes, ArrayRef strides, ArrayRef indices, OpBuilder &builder, function_ref)> fn) { // Compute strides w.r.t. to slice counts in each dimension. @@ -539,9 +539,9 @@ // vector rank is 4 - 2 = 2, and so 'indexOffset' = 3 - 2 = 1. // unsigned vectorRank = vectorType.getRank(); - if (auto memrefVectorElementType = memrefElementType.dyn_cast()) { - assert(vectorRank >= memrefVectorElementType.getRank()); - vectorRank -= memrefVectorElementType.getRank(); + if (auto sourceVectorElementType = shapedElementType.dyn_cast()) { + assert(vectorRank >= sourceVectorElementType.getRank()); + vectorRank -= sourceVectorElementType.getRank(); } unsigned indexOffset = numSliceIndices - vectorRank; @@ -598,8 +598,8 @@ SmallVector strides(targetShape.size(), 1); Location loc = readOp.getLoc(); - auto memrefElementType = - readOp.source().getType().cast().getElementType(); + auto shapedElementType = + readOp.source().getType().cast().getElementType(); auto tupleType = generateExtractSlicesOpResultType( sourceVectorType, targetShape, strides, builder); int64_t numSlices = tupleType.size(); @@ -618,7 +618,7 @@ readOp.permutation_map(), readOp.padding(), readOp.masked() ? *readOp.masked() : ArrayAttr()); }; - generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType, + generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType, targetShape, strides, indices, builder, createSlice); // Create tuple of splice transfer read operations. @@ -634,7 +634,8 @@ // Entry point for unrolling declarative pattern rewrite for transfer_write op. LogicalResult mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op, - ArrayRef targetShape) { + ArrayRef targetShape, + SmallVector &result) { auto writeOp = cast(op); if (!isIdentitySuffix(writeOp.permutation_map())) return failure(); @@ -645,20 +646,28 @@ Location loc = writeOp.getLoc(); Value tuple = builder.create( loc, tupleType, writeOp.vector(), targetShape, strides); - auto memrefElementType = - writeOp.source().getType().cast().getElementType(); + auto shapedElementType = + writeOp.source().getType().cast().getElementType(); SmallVector indices(writeOp.indices().begin(), writeOp.indices().end()); + // If the TransferWrite returns a tensor, keep track of the last tensor + // created. + Value resultTensor; auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { auto element = builder.create( loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index)); - builder.create( - loc, element.getResult(), writeOp.source(), sliceIndices, + Operation *write = builder.create( + loc, element.getResult(), + resultTensor ? resultTensor : writeOp.source(), sliceIndices, writeOp.permutation_map(), writeOp.masked() ? *writeOp.masked() : ArrayAttr()); + if (!write->getResults().empty()) + resultTensor = write->getResult(0); }; - generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType, + generateTransferOpSlices(shapedElementType, sourceVectorType, tupleType, targetShape, strides, indices, builder, createSlice); + if (resultTensor) + result.push_back(resultTensor); return success(); } @@ -761,25 +770,32 @@ insertSlicesOp.getStrides(strides); Location loc = xferWriteOp.getLoc(); - auto memrefElementType = - xferWriteOp.source().getType().cast().getElementType(); + auto shapedElementType = + xferWriteOp.source().getType().cast().getElementType(); SmallVector indices(xferWriteOp.indices().begin(), xferWriteOp.indices().end()); + Value resultTensor; auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { // Create split TransferWriteOp for source vector 'tupleOp.operand[i]'. // `masked` attribute propagates conservatively: if the coarse op didn't // need masking, the fine op doesn't either. - rewriter.create( - loc, tupleOp.getOperand(index), xferWriteOp.source(), sliceIndices, + Operation *write = rewriter.create( + loc, tupleOp.getOperand(index), + resultTensor ? resultTensor : xferWriteOp.source(), sliceIndices, xferWriteOp.permutation_map(), xferWriteOp.masked() ? *xferWriteOp.masked() : ArrayAttr()); + if (!write->getResults().empty()) + resultTensor = write->getResult(0); }; - generateTransferOpSlices(memrefElementType, resultVectorType, + generateTransferOpSlices(shapedElementType, resultVectorType, sourceTupleType, sizes, strides, indices, rewriter, createSlice); // Erase old 'xferWriteOp'. - rewriter.eraseOp(xferWriteOp); + if (resultTensor) + rewriter.replaceOp(xferWriteOp, ArrayRef(resultTensor)); + else + rewriter.eraseOp(xferWriteOp); return success(); } }; Index: mlir/test/Dialect/Vector/vector-transfer-unroll.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-transfer-unroll.mlir +++ mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -58,3 +58,65 @@ vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> return } + +// CHECK-LABEL: func @transfer_read_unroll_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> +// CHECK-NEXT: return %[[VEC]] : vector<4x4xf32> + +func @transfer_read_unroll_tensor(%arg0 : tensor<4x4xf32>) -> vector<4x4xf32> { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// CHECK-LABEL: func @transfer_write_unroll_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[T1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[T2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[T3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32> + +func @transfer_write_unroll_tensor(%arg0 : tensor<4x4xf32>, + %arg1 : vector<4x4xf32>) -> tensor<4x4xf32> { + %c0 = constant 0 : index + %r = vector.transfer_write %arg1, %arg0[%c0, %c0] : + vector<4x4xf32>, tensor<4x4xf32> + return %r: tensor<4x4xf32> +} + +// CHECK-LABEL: func @transfer_readwrite_unroll_tensor +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[VTR1]], %[[VTW0]][%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[VTR2]], %[[VTW1]][%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[VTR3]], %[[VTW2]][%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32> + +func @transfer_readwrite_unroll_tensor(%arg0 : tensor<4x4xf32>) -> + tensor<4x4xf32> { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : tensor<4x4xf32>, vector<4x4xf32> + %r = vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, tensor<4x4xf32> + return %r: tensor<4x4xf32> +} Index: mlir/test/Dialect/Vector/vector-transforms.mlir =================================================================== --- mlir/test/Dialect/Vector/vector-transforms.mlir +++ mlir/test/Dialect/Vector/vector-transforms.mlir @@ -530,6 +530,14 @@ // CHECK: %[[CMP1:.*]] = cmpf "ult", %[[VT1]], %[[VT5]] : vector<2x2xf32> // CHECK: %[[CMP2:.*]] = cmpf "ult", %[[VT2]], %[[VT6]] : vector<2x2xf32> // CHECK: %[[CMP3:.*]] = cmpf "ult", %[[VT3]], %[[VT7]] : vector<2x2xf32> +// CHECK: %[[VT0:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT1:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT2:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT3:.*]] = vector.transfer_read %[[ARG0]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT4:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT5:.*]] = vector.transfer_read %[[ARG1]][%[[C0]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT6:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C0]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK: %[[VT7:.*]] = vector.transfer_read %[[ARG1]][%[[C2]], %[[C2]]], {{.*}} : memref<4x4xf32>, vector<2x2xf32> // CHECK: %[[SEL0:.*]] = select %[[CMP0]], %[[VT0]], %[[VT4]] : vector<2x2xi1>, vector<2x2xf32> // CHECK: %[[SEL1:.*]] = select %[[CMP1]], %[[VT1]], %[[VT5]] : vector<2x2xi1>, vector<2x2xf32> // CHECK: %[[SEL2:.*]] = select %[[CMP2]], %[[VT2]], %[[VT6]] : vector<2x2xi1>, vector<2x2xf32> @@ -544,7 +552,52 @@ %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> %cond = cmpf "ult", %0, %1 : vector<4x4xf32> - %2 = select %cond, %0, %1 : vector<4x4xi1>, vector<4x4xf32> - vector.transfer_write %2, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> + // Vector transfer split pattern only support single user right now. + %2 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> + %3 = vector.transfer_read %arg1[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> + %4 = select %cond, %2, %3 : vector<4x4xi1>, vector<4x4xf32> + vector.transfer_write %4, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> return } + +// Check that vector.transfer read/write are split based on contract unrolling. +// CHECK: %[[VTR0:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x2xf32>, vector<2x2xf32> + +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<2x4xf32>, vector<2x2xf32> + +// CHECK-NEXT: %[[VTR4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR5:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32> + +// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + +// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[VTW2:.*]] = vector.transfer_write %[[R2]], %[[VTW1]][%[[C2]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: %[[VTW3:.*]] = vector.transfer_write %[[R3]], %[[VTW2]][%[[C2]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32> +// CHECK-NEXT: return %[[VTW3]] : tensor<4x4xf32> + +func @contraction4x4_ikj_xfer_read_tensor(%arg0 : tensor<4x2xf32>, + %arg1 : tensor<2x4xf32>, + %arg2 : tensor<4x4xf32>) -> + tensor<4x4xf32> { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : + tensor<4x2xf32>, vector<4x2xf32> + %1 = vector.transfer_read %arg1[%c0, %c0], %cf0 : + tensor<2x4xf32>, vector<2x4xf32> + %2 = vector.transfer_read %arg2[%c0, %c0], %cf0 : + tensor<4x4xf32>, vector<4x4xf32> + %3 = vector.contract #contraction_trait1 %0, %1, %2 + : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32> + %r = vector.transfer_write %3, %arg2[%c0, %c0] + : vector<4x4xf32>, tensor<4x4xf32> + return %r : tensor<4x4xf32> +} Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -28,7 +28,9 @@ OwningRewritePatternList patterns; auto *ctx = &getContext(); patterns.insert( - ctx, UnrollVectorOptions().setNativeShapeFn(getShape)); + ctx, + UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint( + filter)); populateVectorToVectorCanonicalizationPatterns(patterns, ctx); populateVectorToVectorTransformationPatterns(patterns, ctx); applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); @@ -39,13 +41,14 @@ static Optional> getShape(Operation *op) { if (isa(op)) return SmallVector(2, 2); - if (auto transferOp = dyn_cast(op)) { - return SmallVector(transferOp.getVectorType().getRank(), 2); - } if (isa(op)) return SmallVector(3, 2); return llvm::None; } + + static LogicalResult filter(Operation *op) { + return success(isa(op)); + } }; struct TestVectorSlicesConversion