diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -1048,43 +1049,86 @@ } }; -/// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and -/// contraction ops closer, which kicks in CombineContractTranspose pattern when -/// casting ops are around these operations. -/// Ex: +/// Reorders elementwise(transpose) to transpose(elementwise). This makes +/// transpose ops and contraction ops closer, which kicks in +/// CombineContractTranspose pattern when elementwise ops are between these +/// operations. Ex: /// ``` -/// %0 = vector.transpose %arg0, [2, 0, 1] -/// : vector<32x16x8xi8> to vector<8x32x16xi8> -/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> +/// %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> +/// %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> +/// %r = arith.addf %at, %bt : vector<2x4xf32> /// ``` /// Gets converted to: /// ``` -/// %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32> -/// %1 = vector.transpose %arg0, [2, 0, 1] -/// : vector<32x16x8xi32> to vector<8x32x16xi32> +/// %0 = arith.addf %a, %b : vector<4x2xf32> +/// %r = vector.transpose %0, [1, 0] : vector<2x4xf32> /// ``` -struct ReorderCastOpsOnTranspose - : public OpInterfaceRewritePattern { - - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(CastOpInterface op, +struct ReorderElementwiseOpsOnTranspose final + : public OpTraitRewritePattern { + using OpTraitRewritePattern::OpTraitRewritePattern; + LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - if (op->getNumOperands() != 1) + if (op->getNumResults() != 1 || op->getNumRegions() != 0) return failure(); - auto transpOp = op->getOperand(0).getDefiningOp(); - if (!transpOp) + + // Make sure all operands are transpose/constant ops and collect their + // transposition maps. + SmallVector transposeMaps; + transposeMaps.reserve(op->getNumOperands()); + // Record the initial type before transposition. We'll use its shape later. + // Any type will do here as we will check all transpose maps are the same. + VectorType srcType; + for (Value operand : op->getOperands()) { + auto transposeOp = operand.getDefiningOp(); + if (transposeOp) { + transposeMaps.push_back(transposeOp.getTransp()); + srcType = transposeOp.getVectorType(); + } else if (!matchPattern(operand, m_Constant())) { + return failure(); + } + } + if (transposeMaps.empty()) return failure(); + // This is an elementwise op, so all transposed operands should have the + // same type. We need to additionally check that all transposes uses the + // same map. + if (!llvm::is_splat(transposeMaps)) + return rewriter.notifyMatchFailure(op, "different transpose map"); + + SmallVector srcValues; + srcValues.reserve(op->getNumOperands()); + + // If there are constant operands, we need to insert inverse transposes for + // them. Calculate the inverse order first. + auto order = extractVector(transposeMaps.front()); + SmallVector invOrder(order.size()); + for (int i = 0, e = order.size(); i < e; ++i) + invOrder[order[i]] = i; + + for (Value operand : op->getOperands()) { + auto transposeOp = operand.getDefiningOp(); + if (transposeOp) { + srcValues.push_back(transposeOp.getVector()); + } else { + // This is a constant. Create a reverse transpose op for it. + auto vectorType = VectorType::get( + srcType.getShape(), + operand.getType().cast().getElementType()); + srcValues.push_back(rewriter.create( + operand.getLoc(), vectorType, operand, + rewriter.getI64ArrayAttr(invOrder))); + } + } - auto castResTy = transpOp.getVectorType(); - castResTy = VectorType::get(castResTy.getShape(), - getElementTypeOrSelf(op->getResult(0))); - auto *castOp = - rewriter.create(op->getLoc(), op->getName().getIdentifier(), - transpOp.getVector(), castResTy, op->getAttrs()); + auto vectorType = VectorType::get( + srcType.getShape(), + op->getResultTypes()[0].cast().getElementType()); + Operation *elementwiseOp = + rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, + vectorType, op->getAttrs()); rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType(), castOp->getResult(0), - transpOp.getTransp()); + op, op->getResultTypes()[0], elementwiseOp->getResult(0), + transposeMaps.front()); return success(); } }; @@ -2647,7 +2691,7 @@ RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + ReorderElementwiseOpsOnTranspose>(patterns.getContext()); } void mlir::vector:: diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -120,3 +120,80 @@ %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> return %r : vector<2x4xi32> } + +//===----------------------------------------------------------------------===// +// Reorder elementwise ops and vector ops. +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @transpose_elementwise_same_type +// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>) +// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32> +// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0] +// CHECK: return %[[T]] + +func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> { + %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %r = arith.addf %at, %bt : vector<2x4xf32> + return %r : vector<2x4xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_elementwise_diff_operand_types +// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>) +// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32> +// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: return %[[T]] +func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xf32> { + %condt = vector.transpose %cond, [1, 0]: vector<4x2xi1> to vector<2x4xi1> + %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %r = arith.select %condt, %at, %bt : vector<2x4xi1>, vector<2x4xf32> + return %r : vector<2x4xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type +// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>) +// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32> +// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<4x2xi1> to vector<2x4xi1> +// CHECK: return %[[T]] +func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>, %b : vector<4x2xf32>) -> vector<2x4xi1> { + %at = vector.transpose %a, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %bt = vector.transpose %b, [1, 0]: vector<4x2xf32> to vector<2x4xf32> + %r = arith.cmpf olt, %at, %bt : vector<2x4xf32> + return %r : vector<2x4xi1> +} + +// ----- + +// CHECK-LABEL: func @transpose_elementwise_splat_constant +// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>) +// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32> +// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x6x3x2xf32> +// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<4x6x3x2xf32> to vector<6x4x2x3xf32> +// CHECK: return %[[T:.+]] : vector<6x4x2x3xf32> + +func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vector<6x4x2x3xf32> { + %b = arith.constant dense<5.0> : vector<6x4x2x3xf32> + %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32> + %r = arith.addf %at, %b : vector<6x4x2x3xf32> + return %r : vector<6x4x2x3xf32> +} + +// ----- + +// CHECK-LABEL: func @transpose_elementwise_diff_map +// CHECK: vector.transpose +// CHECK: vector.transpose +// CHECK: arith.addf +func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6x2x4x3xf32>) -> vector<6x4x2x3xf32> { + %at = vector.transpose %a, [1, 0, 3, 2]: vector<4x6x3x2xf32> to vector<6x4x2x3xf32> + %bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x4x3xf32> to vector<6x4x2x3xf32> + %r = arith.addf %at, %bt : vector<6x4x2x3xf32> + return %r : vector<6x4x2x3xf32> +}