diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -31,6 +31,9 @@ SmallVector computeStrides(ArrayRef shape, ArrayRef sizes); +/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'. +int64_t linearize(ArrayRef offsets, ArrayRef basis); + /// Given the slice strides together with a linear index in the dimension /// space, returns the vector-space offsets in each dimension for a /// de-linearized index. diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -69,15 +69,6 @@ return res; } -/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'. -static int64_t linearize(ArrayRef offsets, ArrayRef basis) { - assert(offsets.size() == basis.size()); - int64_t linearIndex = 0; - for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) - linearIndex += offsets[idx] * basis[idx]; - return linearIndex; -} - // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder, @@ -683,6 +674,99 @@ } }; +/// Returns the producer Value of the same type as 'consumerValue', by tracking +/// the tuple index and offsets of the consumer vector value through the +/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp) +/// from consumer to producer. Each operation in the chain is structured, and +/// so the tuple index and offsets can be mapped from result to input, while +/// visiting each operation in the chain. +/// Returns nullptr on failure. +static Value getProducerValue(Value consumerValue) { + auto consumerVectorType = consumerValue.getType().cast(); + // A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type. + int64_t tupleIndex = -1; + SmallVector offsets(consumerVectorType.getRank(), 0); + auto *op = consumerValue.getDefiningOp(); + while (op != nullptr) { + if (auto tupleGetOp = dyn_cast(op)) { + assert(tupleIndex == -1 && "TupleGetOp must have vector result type"); + + // Update 'tupleIndex' and next defining 'op' to visit. + tupleIndex = tupleGetOp.getIndex(); + op = tupleGetOp.vectors().getDefiningOp(); + } else if (auto extractSlicesOp = dyn_cast(op)) { + assert(tupleIndex >= 0); + + // Compute slice strides for 'extractSlicesOp'. + SmallVector sizes; + extractSlicesOp.getSizes(sizes); + auto sliceStrides = computeStrides( + extractSlicesOp.getSourceVectorType().getShape(), sizes); + + // Compute 'elementOffsets' into 'extractSlicesOp' input vector type, + // of 'extractSlicesOp' result vector tuple element at 'tupleIndex'. + auto vectorOffsets = delinearize(sliceStrides, tupleIndex); + auto elementOffsets = + computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); + + // Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative + // to the 'extractSlicesOp' input vector type. + assert(offsets.size() == elementOffsets.size()); + for (unsigned i = 0, e = offsets.size(); i < e; ++i) + offsets[i] += elementOffsets[i]; + + // Clear 'tupleIndex' and update next defining 'op' to visit. + tupleIndex = -1; + op = extractSlicesOp.vector().getDefiningOp(); + } else if (auto insertSlicesOp = dyn_cast(op)) { + assert(tupleIndex == -1); + + // Compute slice strides for 'insertSlicesOp'. + SmallVector sizes; + insertSlicesOp.getSizes(sizes); + auto sliceStrides = computeStrides( + insertSlicesOp.getResultVectorType().getShape(), sizes); + + // Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice, + // of 'insertSlicesOp' result vector type at 'offsets'. + SmallVector vectorOffsets(offsets.size()); + assert(offsets.size() == sizes.size()); + for (unsigned i = 0, e = offsets.size(); i < e; ++i) + vectorOffsets[i] = offsets[i] / sizes[i]; + + // Compute the source tuple element index. + tupleIndex = linearize(vectorOffsets, sliceStrides); + + // Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now + // relative to input tuple element vector type at 'tupleIndex'. + auto elementOffsets = + computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); + assert(offsets.size() == elementOffsets.size()); + for (unsigned i = 0, e = offsets.size(); i < e; ++i) { + offsets[i] -= elementOffsets[i]; + assert(offsets[i] >= 0); + } + + // Update next defining 'op' to visit. + op = insertSlicesOp.vectors().getDefiningOp(); + } else if (auto tupleOp = dyn_cast(op)) { + assert(tupleIndex >= 0); + + // Return tuple element 'value' at 'tupleIndex' if it matches type. + auto value = tupleOp.getOperand(tupleIndex); + if (value.getType() == consumerVectorType) + return value; + + // Update 'tupleIndex' and next defining 'op' to visit. + tupleIndex = -1; + op = value.getDefiningOp(); + } else { + break; + } + } + return nullptr; +} + /// ShapeCastOpFolder folds cancelling ShapeCastOps away. // // Example: @@ -740,28 +824,11 @@ LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp, PatternRewriter &rewriter) const override { - // Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp. - auto extractSlicesOp = dyn_cast_or_null( - tupleGetOp.vectors().getDefiningOp()); - if (!extractSlicesOp) - return failure(); - - // Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp. - auto insertSlicesOp = dyn_cast_or_null( - extractSlicesOp.vector().getDefiningOp()); - if (!insertSlicesOp) - return failure(); - - // Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp. - auto tupleOp = dyn_cast_or_null( - insertSlicesOp.vectors().getDefiningOp()); - if (!tupleOp) - return failure(); - - // Forward Value from 'tupleOp' at 'tupleGetOp.index'. - Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); - rewriter.replaceOp(tupleGetOp, tupleValue); - return success(); + if (auto producer = getProducerValue(tupleGetOp.getResult())) { + rewriter.replaceOp(tupleGetOp, producer); + return success(); + } + return failure(); } }; diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -28,10 +28,10 @@ using llvm::SetVector; -namespace mlir { +using namespace mlir; -SmallVector computeStrides(ArrayRef shape, - ArrayRef sizes) { +SmallVector mlir::computeStrides(ArrayRef shape, + ArrayRef sizes) { int64_t rank = shape.size(); // Compute the count for each dimension. SmallVector sliceDimCounts(rank); @@ -45,8 +45,16 @@ return sliceStrides; } -SmallVector delinearize(ArrayRef sliceStrides, - int64_t index) { +int64_t mlir::linearize(ArrayRef offsets, ArrayRef basis) { + assert(offsets.size() == basis.size()); + int64_t linearIndex = 0; + for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) + linearIndex += offsets[idx] * basis[idx]; + return linearIndex; +} + +SmallVector mlir::delinearize(ArrayRef sliceStrides, + int64_t index) { int64_t rank = sliceStrides.size(); SmallVector vectorOffsets(rank); for (int64_t r = 0; r < rank; ++r) { @@ -57,16 +65,15 @@ return vectorOffsets; } -SmallVector -computeElementOffsetsFromVectorSliceOffsets(ArrayRef sizes, - ArrayRef vectorOffsets) { +SmallVector mlir::computeElementOffsetsFromVectorSliceOffsets( + ArrayRef sizes, ArrayRef vectorOffsets) { return functional::zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes); } -SmallVector computeSliceSizes(ArrayRef shape, - ArrayRef sizes, - ArrayRef elementOffsets) { +SmallVector +mlir::computeSliceSizes(ArrayRef shape, ArrayRef sizes, + ArrayRef elementOffsets) { int64_t rank = shape.size(); SmallVector sliceSizes(rank); for (unsigned r = 0; r < rank; ++r) @@ -74,8 +81,8 @@ return sliceSizes; } -Optional> shapeRatio(ArrayRef superShape, - ArrayRef subShape) { +Optional> mlir::shapeRatio(ArrayRef superShape, + ArrayRef subShape) { if (superShape.size() < subShape.size()) { return Optional>(); } @@ -114,8 +121,8 @@ return SmallVector{result.rbegin(), result.rend()}; } -Optional> shapeRatio(VectorType superVectorType, - VectorType subVectorType) { +Optional> mlir::shapeRatio(VectorType superVectorType, + VectorType subVectorType) { assert(superVectorType.getElementType() == subVectorType.getElementType() && "vector types must be of the same elemental type"); return shapeRatio(superVectorType.getShape(), subVectorType.getShape()); @@ -201,9 +208,9 @@ return getParentsOfType(op); } -AffineMap -makePermutationMap(Operation *op, ArrayRef indices, - const DenseMap &loopToVectorDim) { +AffineMap mlir::makePermutationMap( + Operation *op, ArrayRef indices, + const DenseMap &loopToVectorDim) { DenseMap enclosingLoopToVectorDim; auto enclosingLoops = getEnclosingforOps(op); for (auto *forInst : enclosingLoops) { @@ -212,7 +219,7 @@ enclosingLoopToVectorDim.insert(*it); } } - return makePermutationMap(indices, enclosingLoopToVectorDim); + return ::makePermutationMap(indices, enclosingLoopToVectorDim); } bool matcher::operatesOnSuperVectorsOf(Operation &op, @@ -275,4 +282,3 @@ return true; } -} // namespace mlir diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -313,6 +313,95 @@ return %1 : vector<8xf32> } +// CHECK-LABEL: func @tuple_get_producer_consumer +// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>, +// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>, +// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>, +// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>, +// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>, +// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>, +// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>, +// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32> +// CHECK: return %[[A7]] : vector<2x4xf32> + +func @tuple_get_producer_consumer( + %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>, + %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>, + %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>, + %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> { + %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 + : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, + vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32> + // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0] + %1 = vector.insert_slices %0, [2, 4], [1, 1] + : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, + vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + into vector<4x16xf32> + // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12] + %2 = vector.extract_slices %1, [4, 8], [1, 1] + : vector<4x16xf32> into tuple, vector<4x8xf32>> + // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4] + %3 = vector.tuple_get %2, 1 : tuple, vector<4x8xf32>> + // %arg7 == %3 at tupleIndex = -1, offsets = [2, 4] + %4 = vector.extract_slices %3, [2, 4], [1, 1] + : vector<4x8xf32> into + tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + // %arg7 == %4 at tupleIndex = 3, offsets = [0, 0] + %5 = vector.tuple_get %4, 3 + : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + // %arg7 == %5 + return %5 : vector<2x4xf32> +} + +// CHECK-LABEL: func @tuple_get_producer_consumer_swizzle +// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32>, +// CHECK-SAME: %[[A1:.*1]]: vector<2x4xf32>, +// CHECK-SAME: %[[A2:.*2]]: vector<2x4xf32>, +// CHECK-SAME: %[[A3:.*3]]: vector<2x4xf32>, +// CHECK-SAME: %[[A4:.*4]]: vector<2x4xf32>, +// CHECK-SAME: %[[A5:.*5]]: vector<2x4xf32>, +// CHECK-SAME: %[[A6:.*6]]: vector<2x4xf32>, +// CHECK-SAME: %[[A7:.*7]]: vector<2x4xf32> +// CHECK: return %[[A7]] : vector<2x4xf32> + +func @tuple_get_producer_consumer_swizzle( + %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>, + %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>, + %arg4 : vector<2x4xf32>, %arg5 : vector<2x4xf32>, + %arg6 : vector<2x4xf32>, %arg7 : vector<2x4xf32>) -> vector<2x4xf32> { + %0 = vector.tuple %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7 + : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, + vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32> + // %arg7 == %0 at tupleIndex = 7, offsets = [0, 0] + %1 = vector.insert_slices %0, [2, 4], [1, 1] + : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, + vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + into vector<4x16xf32> + // %arg7 == %1 at tupleIndex = -1, offsets = [2, 12] + %2 = vector.extract_slices %1, [4, 8], [1, 1] + : vector<4x16xf32> into tuple, vector<4x8xf32>> + // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4] + + // Extract tuple elements. + %3 = vector.tuple_get %2, 0 : tuple, vector<4x8xf32>> + %4 = vector.tuple_get %2, 1 : tuple, vector<4x8xf32>> + // %arg7 == %4 at tupleIndex = -1, offsets = [2, 4] + + // Swizzle tuple elements. + %5 = vector.tuple %4, %3 : vector<4x8xf32>, vector<4x8xf32> + // %arg7 == %5 at tupleIndex = 0, offsets = [2, 4] + %6 = vector.tuple_get %5, 0 : tuple, vector<4x8xf32>> + // %arg7 == %6 at tupleIndex = -1, offsets = [2, 4] + %7 = vector.extract_slices %6, [2, 4], [1, 1] + : vector<4x8xf32> into + tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + // %arg7 == %7 at tupleIndex = 3, offsets = [0, 0] + %8 = vector.tuple_get %7, 3 + : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + // %arg7 == %8 + return %8 : vector<2x4xf32> +} + // CHECK-LABEL: func @vector_transfers_vector_element_type // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index