diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1048,7 +1048,7 @@ } }; -/// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and +/// Reorders cast(transpose) to transpose(cast). This makes transpose ops and /// contraction ops closer, which kicks in CombineContractTranspose pattern when /// casting ops are around these operations. /// Ex: @@ -1089,6 +1089,85 @@ } }; +/// Reorders elementwise(transpose) to transpose(elementwise). This makes +/// transpose ops and contraction ops closer, which kicks in +/// CombineContractTranspose pattern when elementwise ops are between these +/// operations. Ex: +/// ``` +/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> +/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> +/// %r = arith.addf %at, %bt : vector<2x4xf32> +/// ``` +/// Gets converted to: +/// ``` +/// %0 = arith.addf %a, %b : vector<4x2xf32> +/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> +/// ``` +struct ReorderElementwiseOpsOnTranspose final + : public OpTraitRewritePattern { + using OpTraitRewritePattern::OpTraitRewritePattern; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumResults() != 1 || op->getNumRegions() != 0) + return failure(); + // Cast ops has their own pattern---ReorderCastOnTranspose. + if (isa(op)) + return failure(); + + // Make sure all operands are transpose/constant ops and collect their + // transposition maps. + SmallVector transposeMaps; + transposeMaps.reserve(op->getNumOperands()); + Type srcType; + for (Value operand : op->getOperands()) { + auto transposeOp = operand.getDefiningOp(); + if (transposeOp) { + transposeMaps.push_back(transposeOp.getTransp()); + srcType = transposeOp.getVectorType(); + } else if (!matchPattern(operand, m_Constant())) { + return failure(); + } + } + if (transposeMaps.empty()) + return failure(); + // This is an elementwise op, so all transposed operands should have the + // same type. We need to additionally check that all transposes uses the + // same map. + if (!llvm::is_splat(transposeMaps)) + return rewriter.notifyMatchFailure(op, "different transpose map"); + + SmallVector srcValues; + srcValues.reserve(op->getNumOperands()); + + // If there are constant operands, we need to insert inverse transposes for + // them. Calculate the inverse order first. + auto order = extractVector(transposeMaps.front()); + SmallVector invOrder(order.size()); + for (int i = 0, e = order.size(); i < e; ++i) + invOrder[order[i]] = i; + + for (Value operand : op->getOperands()) { + auto transposeOp = operand.getDefiningOp(); + if (transposeOp) { + srcValues.push_back(transposeOp.getVector()); + } else { + // This is a constant. Create a reverse transpose op for it. + srcValues.push_back(rewriter.create( + operand.getLoc(), srcType, operand, + rewriter.getI64ArrayAttr(invOrder))); + } + } + + Operation *elementwiseOp = + rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, + srcType, op->getAttrs()); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes()[0], elementwiseOp->getResult(0), + transposeMaps.front()); + return success(); + } +}; + } // namespace /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using @@ -2647,7 +2726,8 @@ RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + ReorderCastOpsOnTranspose, ReorderElementwiseOpsOnTranspose>( + patterns.getContext()); } void mlir::vector:: diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -120,3 +120,51 @@ %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> return %r : vector<2x4xi32> } + +//===----------------------------------------------------------------------===// +// Reorder elementwise ops and vector ops. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @transpose_add +// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>) +// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32> +// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0] +// CHECK: return %[[T]] + +func @transpose_add(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> { + %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %r = arith.addf %at, %bt : vector<2x4xf32> + return %r : vector<2x4xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_add_splat_constant +// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>) +// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32> +// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32> +// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32> +// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32> + +func @transpose_add_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> { + %b = arith.constant dense<5.0> : vector<6x4x2x3xf32> + %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32> + %r = arith.addf %at, %b : vector<6x4x2x3xf32> + return %r : vector<6x4x2x3xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_add_diff_map +// CHECK: vector.transpose +// CHECK: vector.transpose +// CHECK: arith.addf +func @transpose_add_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> { + %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32> + %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32> + %r = arith.addf %at, %bt : vector<6x4x2x3xf32> + return %r : vector<6x4x2x3xf32> +}