diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -676,10 +676,10 @@ /// Returns the producer Value of the same type as 'consumerValue', by tracking /// the tuple index and offsets of the consumer vector value through the -/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp) -/// from consumer to producer. Each operation in the chain is structured, and -/// so the tuple index and offsets can be mapped from result to input, while -/// visiting each operation in the chain. +/// chain of operations (TupleGetOp, InsertSlicesOp, ExtractSlicesOp, TupleOp, +/// and ShapeCastOp) from consumer to producer. Each operation in the chain is +/// structured, and so the tuple index and offsets can be mapped from result to +/// input, while visiting each operation in the chain. /// Returns nullptr on failure. static Value getProducerValue(Value consumerValue) { auto consumerVectorType = consumerValue.getType().cast(); @@ -760,8 +760,57 @@ // Update 'tupleIndex' and next defining 'op' to visit. tupleIndex = -1; op = value.getDefiningOp(); + } else if (auto shapeCastOp = dyn_cast(op)) { + if (shapeCastOp.source().getType().isa()) + return nullptr; + assert(tupleIndex == -1); + auto sourceVectorType = shapeCastOp.getSourceVectorType(); + auto sourceVectorShape = sourceVectorType.getShape(); + unsigned sourceVectorRank = sourceVectorType.getRank(); + auto resultVectorType = shapeCastOp.getResultVectorType(); + auto resultVectorShape = resultVectorType.getShape(); + unsigned resultVectorRank = resultVectorType.getRank(); + + int i = sourceVectorRank - 1; + int j = resultVectorRank - 1; + + // Check that source/result vector shape prefixes match while + // updating 'newOffsets'. + bool canShapeCastFold = true; + SmallVector newOffsets(sourceVectorRank, 0); + + auto apply = [&](int64_t sourceSize, int64_t resultSize) { + canShapeCastFold = sourceSize == resultSize; + newOffsets[i--] = offsets[j--]; + }; + functional::zipApply(apply, llvm::reverse(sourceVectorShape), + llvm::reverse(resultVectorShape)); + if (!canShapeCastFold) + return nullptr; + + // Check that remaining prefix of source/result vector shapes are all 1s. + // Currently we only support producer/consumer tracking through trivial + // shape cast ops. Examples: + // %1 = vector.shape_cast %0 : vector<1x1x2x4xf32> to vector<2x4xf32> + // %3 = vector.shape_cast %2 : vector<16x8xf32> to vector<1x16x8xf32> + assert(i == -1 || j == -1); + if (i >= 0 && + !std::all_of(sourceVectorShape.begin(), sourceVectorShape.begin() + i, + [](int64_t v) { return v == 1; })) + return nullptr; + if (j >= 0 && + !std::all_of(resultVectorShape.begin(), resultVectorShape.begin() + j, + [](int64_t v) { return v == 1; })) + return nullptr; + + offsets.swap(newOffsets); + op = shapeCastOp.source().getDefiningOp(); } else { - break; + // Check if 'op' produces a Value with the same type as 'consumerValue'. + if (op->getNumResults() == 1 && + op->getResult(0).getType() == consumerVectorType) + return op->getResult(0); + return nullptr; } } return nullptr; @@ -788,6 +837,12 @@ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { + // Check if we can replace 'shapeCastOp' result with its producer. + if (auto producer = getProducerValue(shapeCastOp.getResult())) { + rewriter.replaceOp(shapeCastOp, producer); + return success(); + } + // Check if 'shapeCastOp' has vector source/result type. auto sourceVectorType = shapeCastOp.source().getType().dyn_cast_or_null(); diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -341,16 +341,21 @@ %2 = vector.extract_slices %1, [4, 8], [1, 1] : vector<4x16xf32> into tuple, vector<4x8xf32>> // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4] - %3 = vector.tuple_get %2, 1 : tuple, vector<4x8xf32>> - // %arg7 == %3 at tupleIndex = -1, offsets = [2, 4] - %4 = vector.extract_slices %3, [2, 4], [1, 1] + %3 = vector.shape_cast %2 : tuple, vector<4x8xf32>> to + tuple, vector<1x1x4x8xf32>> + // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4] + %4 = vector.tuple_get %3, 1 : tuple, vector<1x1x4x8xf32>> + // %arg7 == %4 at tupleIndex = -1, offsets = [0, 0, 2, 4] + %5 = vector.shape_cast %4 : vector<1x1x4x8xf32> to vector<4x8xf32> + // %arg7 == %5 at tupleIndex = -1, offsets = [2, 4] + %6 = vector.extract_slices %5, [2, 4], [1, 1] : vector<4x8xf32> into tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %4 at tupleIndex = 3, offsets = [0, 0] - %5 = vector.tuple_get %4, 3 + // %arg7 == %6 at tupleIndex = 3, offsets = [0, 0] + %7 = vector.tuple_get %6, 3 : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %5 - return %5 : vector<2x4xf32> + // %arg7 == %7 + return %7 : vector<2x4xf32> } // CHECK-LABEL: func @tuple_get_producer_consumer_swizzle @@ -381,25 +386,40 @@ %2 = vector.extract_slices %1, [4, 8], [1, 1] : vector<4x16xf32> into tuple, vector<4x8xf32>> // %arg7 == %2 at tupleIndex = 1, offsets = [2, 4] + %3= vector.shape_cast %2 : tuple, vector<4x8xf32>> to + tuple, vector<1x1x4x8xf32>> + // %arg7 = %3 at tupleIndex = 1, offsets = [0, 0, 2, 4] // Extract tuple elements. - %3 = vector.tuple_get %2, 0 : tuple, vector<4x8xf32>> - %4 = vector.tuple_get %2, 1 : tuple, vector<4x8xf32>> - // %arg7 == %4 at tupleIndex = -1, offsets = [2, 4] + %4 = vector.tuple_get %3, 0 : tuple, vector<1x1x4x8xf32>> + %5 = vector.tuple_get %3, 1 : tuple, vector<1x1x4x8xf32>> + // %arg7 == %5 at tupleIndex = -1, offsets = [0, 0, 2, 4] // Swizzle tuple elements. - %5 = vector.tuple %4, %3 : vector<4x8xf32>, vector<4x8xf32> - // %arg7 == %5 at tupleIndex = 0, offsets = [2, 4] - %6 = vector.tuple_get %5, 0 : tuple, vector<4x8xf32>> - // %arg7 == %6 at tupleIndex = -1, offsets = [2, 4] - %7 = vector.extract_slices %6, [2, 4], [1, 1] + %6 = vector.tuple %5, %4 : vector<1x1x4x8xf32>, vector<1x1x4x8xf32> + // %arg7 == %6 at tupleIndex = 0, offsets = [0, 0, 2, 4] + %7 = vector.shape_cast %6 : tuple, vector<1x1x4x8xf32>> to + tuple, vector<4x8xf32>> + // %arg7 = %7 at tupleIndex = 0, offsets = [2, 4] + %8 = vector.tuple_get %7, 0 : tuple, vector<4x8xf32>> + // %arg7 == %8 at tupleIndex = -1, offsets = [2, 4] + %9 = vector.extract_slices %8, [2, 4], [1, 1] : vector<4x8xf32> into tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %7 at tupleIndex = 3, offsets = [0, 0] - %8 = vector.tuple_get %7, 3 + // %arg7 == %9 at tupleIndex = 3, offsets = [0, 0] + %10 = vector.tuple_get %9, 3 : tuple, vector<2x4xf32>, vector<2x4xf32>, vector<2x4xf32>> - // %arg7 == %8 - return %8 : vector<2x4xf32> + // %arg7 == %10 + return %10 : vector<2x4xf32> +} + +// CHECK-LABEL: func @cancelling_shape_cast_ops +// CHECK-SAME: %[[A0:.*0]]: vector<2x4xf32> +// CHECK: return %[[A0]] : vector<2x4xf32> +func @cancelling_shape_cast_ops(%arg0 : vector<2x4xf32>) -> vector<2x4xf32> { + %0 = vector.shape_cast %arg0 : vector<2x4xf32> to vector<8xf32> + %1 = vector.shape_cast %0 : vector<8xf32> to vector<2x4xf32> + return %1 : vector<2x4xf32> } // CHECK-LABEL: func @vector_transfers_vector_element_type