diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1009,6 +1009,84 @@ } }; +/// Reorders cast(broadcast) to broadcast(cast). This makes broadcast ops and +/// contraction ops closer, which kicks in CombineContractBroadcast pattern when +/// casting ops are around these operations. +/// Ex: +/// ``` +/// %0 = vector.broadcast %arg0 : vector<32x16xi8> to vector<8x32x16xi8> +/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> +/// ``` +/// Gets converted to: +/// ``` +/// %0 = arith.extsi %0 : vector<32x16xi8> to vector<32x16xi32> +/// %1 = vector.broadcast %arg0 : vector<32x16xi32> to vector<8x32x16xi32> +/// ``` +struct ReorderCastOpsOnBroadcast + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(CastOpInterface op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) + return failure(); + auto bcastOp = op->getOperand(0).getDefiningOp(); + if (!bcastOp) + return failure(); + + Type castResTy = getElementTypeOrSelf(op->getResult(0)); + if (auto vecTy = bcastOp.getSourceType().dyn_cast()) + castResTy = VectorType::get(vecTy.getShape(), castResTy); + OperationState state(op->getLoc(), op->getName(), bcastOp.source(), + castResTy, op->getAttrs()); + auto castOp = rewriter.createOperation(state); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), castOp->getResult(0)); + return success(); + } +}; + +/// Reorders cast(transpose) to transpose(cast). This makes broadcast ops and +/// contraction ops closer, which kicks in CombineContractTranspose pattern when +/// casting ops are around these operations. +/// Ex: +/// ``` +/// %0 = vector.transpose %arg0, [2, 0, 1] +/// : vector<32x16x8xi8> to vector<8x32x16xi8> +/// %1 = arith.extsi %0 : vector<8x32x16xi8> to vector<8x32x16xi32> +/// ``` +/// Gets converted to: +/// ``` +/// %0 = arith.extsi %0 : vector<32x16x8xi8> to vector<32x16x8xi32> +/// %1 = vector.transpose %arg0, [2, 0, 1] +/// : vector<32x16x8xi32> to vector<8x32x16xi32> +/// ``` +struct ReorderCastOpsOnTranspose + : public OpInterfaceRewritePattern { + + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + LogicalResult matchAndRewrite(CastOpInterface op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) + return failure(); + auto transpOp = op->getOperand(0).getDefiningOp(); + if (!transpOp) + return failure(); + + auto castResTy = transpOp.getVectorType(); + castResTy = VectorType::get(castResTy.getShape(), + getElementTypeOrSelf(op->getResult(0))); + OperationState state(op->getLoc(), op->getName(), transpOp.vector(), + castResTy, op->getAttrs()); + auto castOp = rewriter.createOperation(state); + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), castOp->getResult(0), + transpOp.getTransp()); + return success(); + } +}; + } // namespace /// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using @@ -2585,7 +2663,8 @@ void mlir::vector::populateVectorReductionToContractPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + CombineContractTranspose, ReorderCastOpsOnBroadcast, + ReorderCastOpsOnTranspose>(patterns.getContext()); } void mlir::vector:: diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -85,3 +85,38 @@ kind = #vector.kind} %0, %arg1, %cst : vector<8x32x16xf32>, vector<8x32x16xf32> into vector<8x32xf32> return %1 : vector<8x32xf32> } + +//===----------------------------------------------------------------------===// +// Reorder casting ops and vector ops. The casting ops have almost identical +// pattern, so only arith.extsi op is tested. +//===----------------------------------------------------------------------===// + +// ----- + +func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> { + // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4xi8> to vector<4xi32> + // CHECK: vector.broadcast %[[EXT:.+]] : vector<4xi32> to vector<2x4xi32> + %b = vector.broadcast %a : vector<4xi8> to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +} + +// ----- + +func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> { + // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32 + // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32> + %b = vector.broadcast %a : i8 to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +} + +// ----- + +func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> { + // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32> + // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32> + %b = vector.transpose %a, [1, 0]: vector<4x2xi8> to vector<2x4xi8> + %r = arith.extsi %b : vector<2x4xi8> to vector<2x4xi32> + return %r : vector<2x4xi32> +}