diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -981,6 +981,64 @@ } }; +/// Reorders elementwise(broadcast) to broadcast(elementwise). +template +struct ReorderElemwiseOnBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ElemOpTy op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) + return failure(); + auto bcastOp = + op->getOperand(0).template getDefiningOp(); + if (!bcastOp) + return failure(); + + Type elemResTy = getElementTypeOrSelf(op.getResult()); + if (auto vecTy = bcastOp.getSourceType().template dyn_cast()) + elemResTy = VectorType::get(vecTy.getShape(), elemResTy); + auto elemOp = + rewriter.create(op.getLoc(), elemResTy, bcastOp.source()); + + auto vecTy = bcastOp.getVectorType(); + vecTy = VectorType::get(vecTy.getShape(), + getElementTypeOrSelf(elemOp.getResult())); + rewriter.replaceOpWithNewOp(op, vecTy, + elemOp.getResult()); + return success(); + } +}; + +/// Reorders elementwise(transpose) to transpose(elementwise). +template +struct ReorderElemwiseOnTranspose : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ElemOpTy op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) + return failure(); + auto transpOp = + op->getOperand(0).template getDefiningOp(); + if (!transpOp) + return failure(); + + auto elemResTy = transpOp.getVectorType(); + elemResTy = VectorType::get(elemResTy.getShape(), + getElementTypeOrSelf(op.getResult())); + auto elemOp = + rewriter.create(op.getLoc(), elemResTy, transpOp.vector()); + + auto vecTy = transpOp.getResultType(); + vecTy = VectorType::get(vecTy.getShape(), + getElementTypeOrSelf(elemOp.getResult())); + rewriter.replaceOpWithNewOp( + op, vecTy, elemOp.getResult(), transpOp.getTransp()); + return success(); + } +}; + } // namespace /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using @@ -2556,8 +2614,26 @@ void mlir::vector::populateVectorReductionToContractPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns + .add, + ReorderElemwiseOnBroadcast, + ReorderElemwiseOnBroadcast, + ReorderElemwiseOnBroadcast, + ReorderElemwiseOnBroadcast, + ReorderElemwiseOnBroadcast, + ReorderElemwiseOnBroadcast, + ReorderElemwiseOnBroadcast, + ReorderElemwiseOnBroadcast, + ReorderElemwiseOnTranspose, + ReorderElemwiseOnTranspose, + ReorderElemwiseOnTranspose, + ReorderElemwiseOnTranspose, + ReorderElemwiseOnTranspose, + ReorderElemwiseOnTranspose, + ReorderElemwiseOnTranspose, + ReorderElemwiseOnTranspose, + ReorderElemwiseOnTranspose>(patterns.getContext()); } void mlir::vector:: diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -85,3 +85,38 @@ kind = #vector.kind} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> return %1 : vector<8x32xf32> } + +//===----------------------------------------------------------------------===// +// Reorder casting ops and vector ops. The casting ops have almost identical +// pattern, so only arith.extsi op is tested. +//===----------------------------------------------------------------------===// + +// ----- + +func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> { + // CHECK: arith.extsi + // CHECK: vector.broadcast + %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +} + +// ----- + +func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> { + // CHECK: arith.extsi + // CHECK: vector.broadcast + %b = vector.broadcast %a : i8 to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +} + +// ----- + +func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> { + // CHECK: arith.extsi + // CHECK: vector.transpose + %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +}