diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -885,6 +885,58 @@ std::function controlFn; }; +/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex: +/// ``` +/// %a = vector.broadcast %arg1 : index to vector<1x4xindex> +/// %b = vector.broadcast %arg2 : index to vector<1x4xindex> +/// %r = arith.addi %a, %b : vector<1x4xindex> +/// ``` +/// Gets converted to: +/// ``` +/// %r = arith.addi %arg0, %arg1 : index +/// %b = vector.broadcast %r : index to vector<1x4xindex> +/// ``` +struct ReorderElementwiseOpsOnBroadcast final + : public OpTraitRewritePattern { + using OpTraitRewritePattern::OpTraitRewritePattern; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumResults() != 1) + return failure(); + if (!llvm::isa(op->getResults()[0].getType())) + return failure(); + + // Make sure that operands are "broadcast"s from scalars + if (!llvm::all_of(op->getOperands(), [](Value val) { + return ( + llvm::isa_and_present(val.getDefiningOp()) && + !llvm::isa(val.getDefiningOp() + .getOperand() + .getType())); + })) { + return failure(); + } + + SmallVector srcValues; + srcValues.reserve(op->getNumOperands()); + + for (Value operand : op->getOperands()) { + srcValues.push_back( + operand.getDefiningOp().getOperand()); + } + + auto vectorType = op->getResultTypes()[0]; + Operation *elementwiseOp = rewriter.create( + op->getLoc(), op->getName().getIdentifier(), srcValues, + cast(vectorType).getElementType(), op->getAttrs()); + + rewriter.replaceOpWithNewOp( + op, vectorType, elementwiseOp->getResults()); + + return success(); + } +}; + // Helper that returns a vector comparison that constructs a mask: // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] // @@ -1301,8 +1353,9 @@ RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add( - patterns.getContext(), benefit); + ReorderCastOpsOnBroadcast, ReorderElementwiseOpsOnTranspose, + ReorderElementwiseOpsOnBroadcast>(patterns.getContext(), + benefit); } void mlir::vector:: diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -130,27 +130,29 @@ return %25 : tensor<1x4xf32> } -// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex + +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex( // CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>, -// CHECK-SAME: {{.*}}: index, +// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, // CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { // CHECK: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> // CHECK: %[[VAL_7:.*]] = arith.constant 0 : i32 // CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[VAL_9:.*]] = arith.constant 0 : index // CHECK: %[[VAL_10:.*]] = arith.constant 79 : index -// CHECK: %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex> -// CHECK: %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex> -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex> -// CHECK: %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex> -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex> -// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex> -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex> -// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex> -// CHECK: %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex> -// CHECK: %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32> -// CHECK: %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index +// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex> +// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex> +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex> +// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex> +// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex> +// CHECK: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_12]] : vector<1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_18:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex> +// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_16]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex> +// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_18]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32> +// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> +// CHECK: return %[[VAL_21]] : tensor<1x4xf32> +// CHECK: } transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): @@ -317,44 +319,44 @@ } // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_tensor_extract( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index) -> tensor<1x1x4xf32> { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex> -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant dense<256> : vector<1x1x4xindex> -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>, +// CHECK-SAME: %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index) -> tensor<1x1x4xf32> { +// CHECK: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex> +// CHECK: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> +// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_9:.*]] = arith.constant dense<256> : vector<1x1x4xindex> +// CHECK: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_11:.*]] = arith.constant 0 : index // CHECK: %[[VAL_12:.*]] = tensor.empty() : tensor<1x1x4xf32> -// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x1x4xindex> -// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_4]] : index to vector<1x1x4xindex> -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : vector<1x1x4xindex> -// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_3]] : index to vector<1x1x4xindex> -// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex> -// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : vector<1x1x4xindex> -// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_5]] : index to vector<1x1x4xindex> -// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex> -// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_22:.*]] = vector.extractelement %[[VAL_21]][%[[VAL_8]] : i32] : vector<4xindex> +// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_2]], %[[VAL_4]] : index +// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : index to vector<1x1x4xindex> +// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_3]] : index to vector<1x1x4xindex> +// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex> +// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<1x1x4xindex> +// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_5]] : index to vector<1x1x4xindex> +// CHECK: %[[VAL_19:.*]] = arith.addi %[[VAL_17]], %[[VAL_18]] : vector<1x1x4xindex> +// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_14]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]][%[[VAL_8]] : i32] : vector<4xindex> // First `tensor.extract` from the generic Op - loop invariant scalar load. -// CHECK: %[[VAL_23:.*]] = tensor.extract %[[VAL_0]][%[[VAL_11]], %[[VAL_22]]] : tensor<1x20xi32> -// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index -// CHECK: %[[VAL_25:.*]] = vector.broadcast %[[VAL_24]] : index to vector<1x1x4xindex> -// CHECK: %[[VAL_26:.*]] = arith.maxsi %[[VAL_25]], %[[VAL_6]] : vector<1x1x4xindex> -// CHECK: %[[VAL_27:.*]] = arith.minsi %[[VAL_26]], %[[VAL_9]] : vector<1x1x4xindex> -// CHECK: %[[VAL_28:.*]] = vector.shape_cast %[[VAL_27]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_29:.*]] = vector.extractelement %[[VAL_28]][%[[VAL_8]] : i32] : vector<4xindex> -// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_20]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_31:.*]] = vector.extractelement %[[VAL_30]][%[[VAL_8]] : i32] : vector<4xindex> +// CHECK: %[[VAL_22:.*]] = tensor.extract %[[VAL_0]][%[[VAL_11]], %[[VAL_21]]] : tensor<1x20xi32> +// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_22]] : i32 to index +// CHECK: %[[VAL_24:.*]] = vector.broadcast %[[VAL_23]] : index to vector<1x1x4xindex> +// CHECK: %[[VAL_25:.*]] = arith.maxsi %[[VAL_24]], %[[VAL_6]] : vector<1x1x4xindex> +// CHECK: %[[VAL_26:.*]] = arith.minsi %[[VAL_25]], %[[VAL_9]] : vector<1x1x4xindex> +// CHECK: %[[VAL_27:.*]] = vector.shape_cast %[[VAL_26]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_28:.*]] = vector.extractelement %[[VAL_27]][%[[VAL_8]] : i32] : vector<4xindex> +// CHECK: %[[VAL_29:.*]] = vector.shape_cast %[[VAL_19]] : vector<1x1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_30:.*]] = vector.extractelement %[[VAL_29]][%[[VAL_8]] : i32] : vector<4xindex> // The following `tensor.extract` from the generic Op s a contiguous load (all Ops used // for address calculation also satisfy the required conditions). -// CHECK: %[[VAL_32:.*]] = vector.transfer_read %[[VAL_1]][%[[VAL_29]], %[[VAL_31]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32> -// CHECK: %[[VAL_33:.*]] = vector.broadcast %[[VAL_32]] : vector<1x4xf32> to vector<1x1x4xf32> -// CHECK: %[[VAL_34:.*]] = vector.transfer_write %[[VAL_33]], %[[VAL_12]][%[[VAL_11]], %[[VAL_11]], %[[VAL_11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32> -// CHECK: return %[[VAL_34]] : tensor<1x1x4xf32> +// CHECK: %[[VAL_31:.*]] = vector.transfer_read %[[VAL_1]][%[[VAL_28]], %[[VAL_30]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32> +// CHECK: %[[VAL_32:.*]] = vector.broadcast %[[VAL_31]] : vector<1x4xf32> to vector<1x1x4xf32> +// CHECK: %[[VAL_33:.*]] = vector.transfer_write %[[VAL_32]], %[[VAL_12]][%[[VAL_11]], %[[VAL_11]], %[[VAL_11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32> +// CHECK: return %[[VAL_33]] : tensor<1x1x4xf32> // CHECK: } + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op diff --git a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir --- a/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir +++ b/mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir @@ -385,3 +385,19 @@ %resT = vector.transpose %contract, [0, 2, 1] : vector<2x4x8xf32> to vector<2x8x4xf32> return %resT : vector<2x8x4xf32> } + +// ----- + +// CHECK-LABEL: func.func @broadcast_elementwise( +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> +// CHECK: } + +func.func @broadcast_elementwise( %arg1: index, %arg2: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg1 : index to vector<1x4xindex> + %1 = vector.broadcast %arg2 : index to vector<1x4xindex> + %2 = arith.addi %0, %1 : vector<1x4xindex> + return %2 : vector<1x4xindex> +}