Index: mlir/include/mlir/Dialect/Vector/VectorTransforms.h =================================================================== --- mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -69,6 +69,22 @@ Operation *op, ArrayRef targetShape); +/// Unroll a transfer_write op. Break up the vector source into a tuple of +/// vectors matching the given shape. Then store each element with its own +/// transfer_write. +/// +/// Example: +/// vector.transfer_write %A, %M[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> +/// -> +/// %0 = vector.extract_slices %A, [2, 4], [1, 1] : +/// vector<4x4xf32> into tuple, vector<2x4xf32>> +/// %1 = vector.tuple_get %0, 0 : tuple, vector<2x4xf32>> +/// vector.transfer_write %1, %M[%c0, %c0] : vector<2x4xf32>, memref<4x4xf32> +/// %2 = vector.tuple_get %0, 1 : tuple, vector<2x4xf32>> +/// vector.transfer_write %2, %M[%c2, %c0] : vector<2x4xf32>, memref<4x4xf32> +LogicalResult unrollTransferWriteOp(OpBuilder &builder, Operation *op, + ArrayRef targetShape); + /// Pattern to apply `unrollSingleResultVectorOp` to a `targetShape` /// declaratively. template @@ -95,6 +111,12 @@ if (!maybeShapeRatio || llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) return failure(); + if (std::is_same::value) { + if (failed(unrollTransferWriteOp(rewriter, op, targetShape))) + return failure(); + rewriter.eraseOp(op); + return success(); + } if (op.getOperation()->getNumResults() != 1) return failure(); auto resultVector = unrollSingleResultVectorOp(rewriter, op, targetShape); Index: mlir/lib/Dialect/Vector/VectorTransforms.cpp =================================================================== --- mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -510,35 +510,6 @@ resultIndex = numVectors - 1; } -// Entry point for unrolling declarative pattern rewrites. -SmallVector -mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op, - ArrayRef targetShape) { - assert(op->getNumResults() == 1 && "Expected single result operation"); - - // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'. - SmallVector iterationBounds; - auto unrollableVectorOp = cast(op); - auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); - assert(maybeUnrollShape && "Trying to unroll an incorrect vector op"); - - std::vector vectors; - unsigned resultIndex; - - if (auto contractionOp = dyn_cast(op)) { - // Populate state for vector ContractionOp. - getVectorContractionOpUnrollState(contractionOp, targetShape, vectors, - resultIndex); - } else { - // Populate state for vector elementwise op. - getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex); - } - - // Unroll 'op' with 'iterationBounds' to 'targetShape'. - return SmallVector{unrollSingleResultStructuredOp( - op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)}; -} - /// Generates slices of 'vectorType' according to 'sizes' and 'strides, and /// calls 'fn' with linear index and indices for each slice. static void generateTransferOpSlices( @@ -614,6 +585,114 @@ return true; } +/// Unroll transfer_read ops to the given shape and create an aggregate with all +/// the chunks. +static Value unrollTransferReadOp(vector::TransferReadOp readOp, + ArrayRef targetShape, + OpBuilder &builder) { + if (!isIdentitySuffix(readOp.permutation_map())) + return nullptr; + auto sourceVectorType = readOp.getVectorType(); + SmallVector strides(targetShape.size(), 1); + + Location loc = readOp.getLoc(); + auto memrefElementType = + readOp.memref().getType().cast().getElementType(); + auto tupleType = generateExtractSlicesOpResultType( + sourceVectorType, targetShape, strides, builder); + int64_t numSlices = tupleType.size(); + + SmallVector vectorTupleValues(numSlices); + SmallVector indices(readOp.indices().begin(), + readOp.indices().end()); + auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { + // Get VectorType for slice 'i'. + auto sliceVectorType = tupleType.getType(index); + // Create split TransferReadOp for 'sliceUser'. + // `masked` attribute propagates conservatively: if the coarse op didn't + // need masking, the fine op doesn't either. + vectorTupleValues[index] = builder.create( + loc, sliceVectorType, readOp.memref(), sliceIndices, + readOp.permutation_map(), readOp.padding(), + readOp.masked() ? *readOp.masked() : ArrayAttr()); + }; + generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType, + targetShape, strides, indices, builder, createSlice); + + // Create tuple of splice transfer read operations. + Value tupleOp = + builder.create(loc, tupleType, vectorTupleValues); + // Replace 'readOp' with result 'insertSlicesResult'. + Value newVec = builder.create( + loc, sourceVectorType, tupleOp, builder.getI64ArrayAttr(targetShape), + builder.getI64ArrayAttr(strides)); + return newVec; +} + +// Entry point for unrolling declarative pattern rewrite for transfer_write op. +LogicalResult +mlir::vector::unrollTransferWriteOp(OpBuilder &builder, Operation *op, + ArrayRef targetShape) { + auto writeOp = cast(op); + if (!isIdentitySuffix(writeOp.permutation_map())) + return failure(); + VectorType sourceVectorType = writeOp.getVectorType(); + SmallVector strides(targetShape.size(), 1); + TupleType tupleType = generateExtractSlicesOpResultType( + sourceVectorType, targetShape, strides, builder); + Location loc = writeOp.getLoc(); + Value tuple = builder.create( + loc, tupleType, writeOp.vector(), targetShape, strides); + auto memrefElementType = + writeOp.memref().getType().cast().getElementType(); + SmallVector indices(writeOp.indices().begin(), + writeOp.indices().end()); + auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { + auto element = builder.create( + loc, tupleType.getType(index), tuple, builder.getI64IntegerAttr(index)); + builder.create( + loc, element.getResult(), writeOp.memref(), sliceIndices, + writeOp.permutation_map(), + writeOp.masked() ? *writeOp.masked() : ArrayAttr()); + }; + generateTransferOpSlices(memrefElementType, sourceVectorType, tupleType, + targetShape, strides, indices, builder, createSlice); + return success(); +} + +// Entry point for unrolling declarative pattern rewrites. +SmallVector +mlir::vector::unrollSingleResultVectorOp(OpBuilder &builder, Operation *op, + ArrayRef targetShape) { + assert(op->getNumResults() == 1 && "Expected single result operation"); + + // Populate 'iterationBounds', 'vectors' and 'resultIndex' to unroll 'op'. + SmallVector iterationBounds; + auto unrollableVectorOp = cast(op); + auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); + assert(maybeUnrollShape && "Trying to unroll an incorrect vector op"); + + std::vector vectors; + unsigned resultIndex; + + if (auto readOp = dyn_cast(op)) + return SmallVector{ + unrollTransferReadOp(readOp, targetShape, builder)}; + + if (auto contractionOp = dyn_cast(op)) { + // Populate state for vector ContractionOp. + getVectorContractionOpUnrollState(contractionOp, targetShape, vectors, + resultIndex); + } else { + // Populate state for vector elementwise op. + getVectorElementwiseOpUnrollState(op, targetShape, vectors, resultIndex); + } + + // Unroll 'op' with 'iterationBounds' to 'targetShape'. + return SmallVector{unrollSingleResultStructuredOp( + op, *maybeUnrollShape, vectors, resultIndex, targetShape, builder)}; +} + namespace { // Splits vector TransferReadOp into smaller TransferReadOps based on slicing @@ -635,43 +714,16 @@ return failure(); // Get 'sizes' and 'strides' parameters from ExtractSlicesOp user. - auto sourceVectorType = extractSlicesOp.getSourceVectorType(); - auto resultTupleType = extractSlicesOp.getResultTupleType(); SmallVector sizes; extractSlicesOp.getSizes(sizes); SmallVector strides; extractSlicesOp.getStrides(strides); assert(llvm::all_of(strides, [](int64_t s) { return s == 1; })); - Location loc = xferReadOp.getLoc(); - auto memrefElementType = - xferReadOp.memref().getType().cast().getElementType(); - int64_t numSlices = resultTupleType.size(); - SmallVector vectorTupleValues(numSlices); - SmallVector indices(xferReadOp.indices().begin(), - xferReadOp.indices().end()); - auto createSlice = [&](unsigned index, ArrayRef sliceIndices) { - // Get VectorType for slice 'i'. - auto sliceVectorType = resultTupleType.getType(index); - // Create split TransferReadOp for 'sliceUser'. - // `masked` attribute propagates conservatively: if the coarse op didn't - // need masking, the fine op doesn't either. - vectorTupleValues[index] = rewriter.create( - loc, sliceVectorType, xferReadOp.memref(), sliceIndices, - xferReadOp.permutation_map(), xferReadOp.padding(), - xferReadOp.masked() ? *xferReadOp.masked() : ArrayAttr()); - }; - generateTransferOpSlices(memrefElementType, sourceVectorType, - resultTupleType, sizes, strides, indices, rewriter, - createSlice); - - // Create tuple of splice xfer read operations. - Value tupleOp = rewriter.create(loc, resultTupleType, - vectorTupleValues); - // Replace 'xferReadOp' with result 'insertSlicesResult'. - rewriter.replaceOpWithNewOp( - xferReadOp, sourceVectorType, tupleOp, extractSlicesOp.sizes(), - extractSlicesOp.strides()); + Value newVec = unrollTransferReadOp(xferReadOp, sizes, rewriter); + if (!newVec) + return failure(); + rewriter.replaceOp(xferReadOp, newVec); return success(); } }; Index: mlir/test/Dialect/Vector/vector-transfer-unroll.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Vector/vector-transfer-unroll.mlir @@ -0,0 +1,60 @@ +// RUN: mlir-opt %s -test-vector-transfer-unrolling-patterns | FileCheck %s + +// CHECK-LABEL: func @transfer_read_unroll +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[TUPL:.*]] = vector.tuple %[[VTR0]], %[[VTR1]], %[[VTR2]], %[[VTR3]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VEC:.*]] = vector.insert_slices %[[TUPL]], [2, 2], [1, 1] : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32> +// CHECK-NEXT: return %[[VEC]] : vector<4x4xf32> + +func @transfer_read_unroll(%arg0 : memref<4x4xf32>) -> vector<4x4xf32> { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// CHECK-LABEL: func @transfer_write_unroll +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[TUPL:.*]] = vector.extract_slices {{.*}}, [2, 2], [1, 1] : vector<4x4xf32> into tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: %[[T0:.*]] = vector.tuple_get %[[TUPL]], 0 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: vector.transfer_write %[[T0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: %[[T1:.*]] = vector.tuple_get %[[TUPL]], 1 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: vector.transfer_write %[[T1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: %[[T2:.*]] = vector.tuple_get %[[TUPL]], 2 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: vector.transfer_write %[[T2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: %[[T3:.*]] = vector.tuple_get %[[TUPL]], 3 : tuple, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> +// CHECK-NEXT: vector.transfer_write %[[T3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: return + +func @transfer_write_unroll(%arg0 : memref<4x4xf32>, %arg1 : vector<4x4xf32>) { + %c0 = constant 0 : index + vector.transfer_write %arg1, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> + return +} + +// CHECK-LABEL: func @transfer_readwrite_unroll +// CHECK: %[[C0:.*]] = constant 0 : index +// CHECK: %[[C2:.*]] = constant 2 : index +// CHECK: %[[VTR0:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR1:.*]] = vector.transfer_read {{.*}}[%[[C0]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR2:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: %[[VTR3:.*]] = vector.transfer_read {{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR0]], {{.*}}[%[[C0]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR1]], {{.*}}[%[[C0]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR2]], {{.*}}[%[[C2]], %[[C0]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: vector.transfer_write %[[VTR3]], {{.*}}[%[[C2]], %[[C2]]] {{.*}} : vector<2x2xf32>, memref<4x4xf32> +// CHECK-NEXT: return + +func @transfer_readwrite_unroll(%arg0 : memref<4x4xf32>) { + %c0 = constant 0 : index + %cf0 = constant 0.0 : f32 + %0 = vector.transfer_read %arg0[%c0, %c0], %cf0 : memref<4x4xf32>, vector<4x4xf32> + vector.transfer_write %0, %arg0[%c0, %c0] : vector<4x4xf32>, memref<4x4xf32> + return +} Index: mlir/test/lib/Transforms/TestVectorTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -125,6 +125,24 @@ } }; +struct TestVectorTransferUnrollingPatterns + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + MLIRContext *ctx = &getContext(); + OwningRewritePatternList patterns; + patterns.insert>( + ArrayRef{2, 2}, ctx); + patterns.insert>( + ArrayRef{2, 2}, ctx); + populateVectorToVectorCanonicalizationPatterns(patterns, ctx); + populateVectorToVectorTransformationPatterns(patterns, ctx); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + struct TestVectorTransferFullPartialSplitPatterns : public PassWrapper { @@ -174,6 +192,10 @@ "test-vector-unrolling-patterns", "Test conversion patterns to unroll contract ops in the vector dialect"); + PassRegistration transferOpUnrollingPass( + "test-vector-transfer-unrolling-patterns", + "Test conversion patterns to unroll transfer ops in the vector dialect"); + PassRegistration vectorTransformFullPartialPass("test-vector-transfer-full-partial-split", "Test conversion patterns to split "