diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1010,7 +1010,7 @@ } }; -/// Merge TransposeOp into ContractionOp user. +/// Merge LHS/RHS (A/B) TransposeOp into ContractionOp user. /// Ex: /// ``` /// %0 = vector.transpose %arg0, [2, 0, 1] @@ -1033,7 +1033,7 @@ /// kind = add} %arg0, %arg1, %cst_f0 /// : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> /// ``` -struct CombineContractTranspose +struct CombineContractABTranspose final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1050,8 +1050,6 @@ auto transposeOp = operand->getDefiningOp(); if (!transposeOp) continue; - SmallVector perm; - transposeOp.getTransp(perm); AffineMap permutationMap = AffineMap::getPermutationMap( extractVector(transposeOp.getTransp()), contractOp.getContext()); @@ -1068,6 +1066,81 @@ } }; +/// Merges accumulator and result transposes into contract. +/// +/// For example: +/// ```mlir +/// %accT = vector.transpose %acc, [0, 2, 1] +/// : vector<2x8x4xf32> to vector<2x4x8xf32> +/// %contract = vector.contract { +/// indexing_maps = [ +/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, +/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, +/// affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +/// ], +/// iterator_types = ["parallel", "parallel", "parallel", "reduction"], +/// kind = #vector.kind +/// } %lhs, %rhs, %accT +/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32> +/// %0 = vector.transpose %contract, [0, 2, 1] +/// : vector<2x4x8xf32> to vector<2x8x4> +/// ``` +/// Becomes: +/// ```mlir +/// %0 = vector.contract { +/// indexing_maps = [ +/// affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, +/// affine_map<(d0, d1, d2, d3) -> (d3, d2)>, +/// affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> +/// ], +/// iterator_types = ["parallel", "parallel", "parallel", "reduction"], +/// kind = #vector.kind +/// } %lhs, %rhs, %acc +/// : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x8x4xf32> +/// ``` +struct CombineContractResultTranspose final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransposeOp resTOp, + PatternRewriter &rewriter) const override { + auto contractOp = resTOp.getVector().getDefiningOp(); + if (!contractOp || !contractOp->hasOneUse()) + return failure(); + + auto accTOp = contractOp.getAcc().getDefiningOp(); + if (!accTOp) + return failure(); + + MLIRContext *context = contractOp.getContext(); + auto maps = llvm::to_vector<3>(contractOp.getIndexingMapsArray()); + AffineMap contractMap = maps.back(); + + // Accumulator transpose performs f(A) -> B. Contract performs g(C) -> B. + // To index into A in contract, we need revert(f)(g(C)) -> A. + auto accTMap = AffineMap::getPermutationMap( + extractVector(accTOp.getTransp()), context); + + // Contract performs g(C) -> D. Result transpose performs h(D) -> E. + // To index into E in contract, we need h(g(C)) -> E. + auto resTMap = AffineMap::getPermutationMap( + extractVector(resTOp.getTransp()), context); + auto combinedResMap = resTMap.compose(contractMap); + + // The accumulator and result share the same indexing map. So they should be + // the same to be able to merge. This means combinedResMap is the same as + // inversePermutation(accTMap).compose(contractMap), which means + if (inversePermutation(accTMap) != resTMap) + return failure(); + maps.back() = combinedResMap; + + rewriter.replaceOpWithNewOp( + resTOp, contractOp.getLhs(), contractOp.getRhs(), accTOp.getVector(), + rewriter.getAffineMapArrayAttr(maps), contractOp.getIteratorTypes()); + return success(); + } +}; + /// Merge BroadcastOp into ContractionOp user. /// Ex: /// ``` @@ -1233,7 +1306,7 @@ /// Reorders elementwise(transpose) to transpose(elementwise). This makes /// transpose ops and contraction ops closer, which kicks in -/// CombineContractTranspose pattern when elementwise ops are between these +/// CombineContractABTranspose pattern when elementwise ops are between these /// operations. Ex: /// ``` /// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> @@ -2939,9 +3012,9 @@ void mlir::vector::populateVectorReductionToContractPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), - benefit); + CombineContractABTranspose, CombineContractResultTranspose, + ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose>( + patterns.getContext(), benefit); } void mlir::vector:: diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -356,3 +356,32 @@ %r = arith.addf %at, %bt : vector<6x4x2x3xf32> return %r : vector<6x4x2x3xf32> } + +// ----- + +// CHECK-DAG: #[[$LHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)> +// CHECK-DAG: #[[$RHS_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)> +// CHECK-DAG: #[[$ACC_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)> + +// CHECK-LABEL: func.func @contract_result_transpose +// CHECK-SAME: (%[[LHS:.+]]: vector<2x4x4xf32>, %[[RHS:.+]]: vector<4x8xf32>, %[[ACC:.+]]: vector<2x8x4xf32>) +// CHECK: %[[CONTRACT:.+]] = vector.contract +// CHECK-SAME: indexing_maps = [#[[$LHS_MAP]], #[[$RHS_MAP]], #[[$ACC_MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"] +// CHECK-SAME: kind = #vector.kind +// CHECK-SAME: %[[LHS]], %[[RHS]], %[[ACC]] +// CHECK: return %[[CONTRACT]] +func.func @contract_result_transpose(%lhs : vector<2x4x4xf32>, %rhs: vector<4x8xf32>, %acc: vector<2x8x4xf32>) -> vector<2x8x4xf32> { + %accT = vector.transpose %acc, [0, 2, 1] : vector<2x8x4xf32> to vector<2x4x8xf32> + %contract = vector.contract { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, + affine_map<(d0, d1, d2, d3) -> (d3, d2)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> + ], + iterator_types = ["parallel", "parallel", "parallel", "reduction"], + kind = #vector.kind + } %lhs, %rhs, %accT : vector<2x4x4xf32>, vector<4x8xf32> into vector<2x4x8xf32> + %resT = vector.transpose %contract, [0, 2, 1] : vector<2x4x8xf32> to vector<2x8x4xf32> + return %resT : vector<2x8x4xf32> +}