diff --git a/mlir/include/mlir/Dialect/Vector/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/VectorUtils.h @@ -31,6 +31,9 @@ SmallVector computeStrides(ArrayRef shape, ArrayRef sizes); +/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'. +int64_t linearize(ArrayRef offsets, ArrayRef basis); + /// Given the slice strides together with a linear index in the dimension /// space, returns the vector-space offsets in each dimension for a /// de-linearized index. diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -69,15 +69,6 @@ return res; } -/// Computes and returns the linearized index of 'offsets' w.r.t. 'basis'. -static int64_t linearize(ArrayRef offsets, ArrayRef basis) { - assert(offsets.size() == basis.size()); - int64_t linearIndex = 0; - for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) - linearIndex += offsets[idx] * basis[idx]; - return linearIndex; -} - // Clones `op` into a new operations that takes `operands` and returns // `resultTypes`. static Operation *cloneOpWithOperandsAndTypes(PatternRewriter &builder, @@ -683,6 +674,102 @@ } }; +// Returns the producer Value of the same type as 'consumerValue', by tracking +// the tuple index and offsets of the consumer vector value through the +// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp) +// from consumer to producer. Each operation in the chain is structured, and +// so the tuple index and offsets can be mapped from result to input, while +// visiting each operation in the chain. +// Returns nullptr on failure. +static Value getProducerValue(Value consumerValue) { + auto consumerVectorType = consumerValue.getType().cast(); + // A tupleIndex == -1 indicates that 'offsets' are w.r.t a vector type. + int64_t tupleIndex = -1; + SmallVector offsets(consumerVectorType.getRank(), 0); + auto *op = consumerValue.getDefiningOp(); + while (op != nullptr) { + if (auto tupleGetOp = dyn_cast(op)) { + assert(tupleIndex == -1); + + // Update 'tupleIndex' and next defining 'op' to visit. + tupleIndex = tupleGetOp.getIndex(); + op = tupleGetOp.vectors().getDefiningOp(); + } else if (auto extractSlicesOp = dyn_cast(op)) { + assert(tupleIndex >= 0); + + // Compute slice strides for 'extractSlicesOp'. + SmallVector sizes; + extractSlicesOp.getSizes(sizes); + auto sliceStrides = computeStrides( + extractSlicesOp.getSourceVectorType().getShape(), sizes); + + // Compute 'elementOffsets' into 'extractSlicesOp' input vector type, + // of 'extractSlicesOp' result vector tuple element at 'tupleIndex'. + auto vectorOffsets = delinearize(sliceStrides, tupleIndex); + auto elementOffsets = + computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); + + // Add 'elementOffsets' to 'offsets' so that 'offsets' are now relative + // to the 'extractSlicesOp' input vector type. + assert(offsets.size() == elementOffsets.size()); + for (unsigned i = 0, e = offsets.size(); i < e; ++i) { + offsets[i] += elementOffsets[i]; + } + + // Clear 'tupleIndex' and update next defining 'op' to visit. + tupleIndex = -1; + op = extractSlicesOp.vector().getDefiningOp(); + } else if (auto insertSlicesOp = dyn_cast(op)) { + assert(tupleIndex == -1); + + // Compute slice strides for 'insertSlicesOp'. + SmallVector sizes; + insertSlicesOp.getSizes(sizes); + auto sliceStrides = computeStrides( + insertSlicesOp.getResultVectorType().getShape(), sizes); + + // Compute 'vectorOffsets' of 'insertSlicesOp' input vector slice, + // of 'insertSlicesOp' result vector type at 'offsets'. + SmallVector vectorOffsets(offsets.size()); + assert(offsets.size() == sizes.size()); + for (unsigned i = 0, e = offsets.size(); i < e; ++i) + vectorOffsets[i] = offsets[i] / sizes[i]; + + // Compute the source tuple element index. + tupleIndex = linearize(vectorOffsets, sliceStrides); + + // Subtract 'elementOffsets' from 'offsets' so that 'offsets' are now + // relative to input tuple element vector type at 'tupleIndex'. + auto elementOffsets = + computeElementOffsetsFromVectorSliceOffsets(sizes, vectorOffsets); + assert(offsets.size() == elementOffsets.size()); + for (unsigned i = 0, e = offsets.size(); i < e; ++i) { + offsets[i] -= elementOffsets[i]; + assert(offsets[i] >= 0); + } + + // Update next defining 'op' to visit. + op = insertSlicesOp.vectors().getDefiningOp(); + } else if (auto tupleOp = dyn_cast(op)) { + assert(tupleIndex >= 0); + + // Return tuple element 'value' at 'tupleIndex' if it matches type. + auto value = tupleOp.getOperand(tupleIndex); + if (auto producerVectorType = value.getType().cast()) { + if (producerVectorType == consumerVectorType) + return value; + } + + // Update 'tupleIndex' and next defining 'op' to visit. + tupleIndex = -1; + op = value.getDefiningOp(); + } else { + break; + } + } + return nullptr; +} + /// ShapeCastOpFolder folds cancelling ShapeCastOps away. // // Example: @@ -740,28 +827,11 @@ LogicalResult matchAndRewrite(vector::TupleGetOp tupleGetOp, PatternRewriter &rewriter) const override { - // Return if 'tupleGetOp.vectors' arg was not defined by ExtractSlicesOp. - auto extractSlicesOp = dyn_cast_or_null( - tupleGetOp.vectors().getDefiningOp()); - if (!extractSlicesOp) - return failure(); - - // Return if 'extractSlicesOp.vector' arg was not defined by InsertSlicesOp. - auto insertSlicesOp = dyn_cast_or_null( - extractSlicesOp.vector().getDefiningOp()); - if (!insertSlicesOp) - return failure(); - - // Return if 'insertSlicesOp.vectors' arg was not defined by TupleOp. - auto tupleOp = dyn_cast_or_null( - insertSlicesOp.vectors().getDefiningOp()); - if (!tupleOp) - return failure(); - - // Forward Value from 'tupleOp' at 'tupleGetOp.index'. - Value tupleValue = tupleOp.getOperand(tupleGetOp.getIndex()); - rewriter.replaceOp(tupleGetOp, tupleValue); - return success(); + if (auto producer = getProducerValue(tupleGetOp.getResult())) { + rewriter.replaceOp(tupleGetOp, producer); + return success(); + } + return failure(); } }; diff --git a/mlir/lib/Dialect/Vector/VectorUtils.cpp b/mlir/lib/Dialect/Vector/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/VectorUtils.cpp @@ -45,6 +45,14 @@ return sliceStrides; } +int64_t linearize(ArrayRef offsets, ArrayRef basis) { + assert(offsets.size() == basis.size()); + int64_t linearIndex = 0; + for (unsigned idx = 0, e = basis.size(); idx < e; ++idx) + linearIndex += offsets[idx] * basis[idx]; + return linearIndex; +} + SmallVector delinearize(ArrayRef sliceStrides, int64_t index) { int64_t rank = sliceStrides.size(); diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -313,6 +313,76 @@ return %1 : vector<8xf32> } +// CHECK-LABEL: func @tuple_get_producer_consumer +// CHECK: return %arg2 : vector<2x4xf32> + +func @tuple_get_producer_consumer( + %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>, + %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>) -> vector<2x4xf32> { + %0 = vector.tuple %arg0, %arg1, %arg2, %arg3 + : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32> + // %arg2 == %0 at tupleIndex = 2, offsets = [0, 0] + %1 = vector.insert_slices %0, [2, 4], [1, 1] + : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + into vector<8x4xf32> + // %arg2 == %1 at tupleIndex = -1, offsets = [4, 0] + // (NOTE: tupleIndex = -1 means result is vector type (not tuple)) + %2 = vector.extract_slices %1, [4, 4], [1, 1] + : vector<8x4xf32> into tuple, vector<4x4xf32>> + // %arg2 == %2 at tupleIndex = 1, offsets = [0, 0] + %3 = vector.insert_slices %2, [4, 4], [1, 1] + : tuple, vector<4x4xf32>> into vector<8x4xf32> + // %arg2 == %3 at tupleIndex = -1, offsets = [4, 0] + %4 = vector.extract_slices %3, [2, 4], [1, 1] + : vector<8x4xf32> into + tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + // %arg2 == %4 at tupleIndex = 2, offsets = [0, 0] + %5 = vector.tuple_get %4, 2 + : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + // %arg2 == %5 + return %5 : vector<2x4xf32> +} + +// CHECK-LABEL: func @tuple_get_producer_consumer_tuple_element_swizzle +// CHECK: return %arg2 : vector<2x4xf32> + +func @tuple_get_producer_consumer_tuple_element_swizzle( + %arg0 : vector<2x4xf32>, %arg1 : vector<2x4xf32>, + %arg2 : vector<2x4xf32>, %arg3 : vector<2x4xf32>) -> vector<2x4xf32> { + %0 = vector.tuple %arg0, %arg1, %arg2, %arg3 + : vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32> + // %arg2 == %0 at tupleIndex = 2, offsets = [0, 0] + %1 = vector.insert_slices %0, [2, 4], [1, 1] + : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + into vector<8x4xf32> + // %arg2 == %1 at tupleIndex = -1, offsets = [4, 0] + // (NOTE: tupleIndex = -1 means result is vector type (not tuple)) + %2 = vector.extract_slices %1, [4, 4], [1, 1] + : vector<8x4xf32> into tuple, vector<4x4xf32>> + // %arg2 == %2 at tupleIndex = 1, offsets = [0, 0] + + // Extract tuple elements. + %3 = vector.tuple_get %2, 0 : tuple, vector<4x4xf32>> + %4 = vector.tuple_get %2, 1 : tuple, vector<4x4xf32>> + // %arg2 == %4 at tupleIndex = -1, offsets = [0, 0] + + // Swizzle tuple elements. + %5 = vector.tuple %4, %3 : vector<4x4xf32>, vector<4x4xf32> + // %arg2 == %5 at tupleIndex = 0, offsets = [0, 0] + + %6 = vector.insert_slices %5, [4, 4], [1, 1] + : tuple, vector<4x4xf32>> into vector<8x4xf32> + // %arg2 == %6 at tupleIndex = -1, offsets = [0, 0] + %7 = vector.extract_slices %6, [2, 4], [1, 1] + : vector<8x4xf32> into + tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + // %arg2 == %7 at tupleIndex = 0, offsets = [0, 0] + %8 = vector.tuple_get %7, 0 + : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> + // %arg2 == %8 + return %8 : vector<2x4xf32> +} + // CHECK-LABEL: func @vector_transfers_vector_element_type // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C1:.*]] = constant 1 : index