diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -137,6 +137,10 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Patterns that remove redundant vector broadcasts. +void populateSinkVectorBroadcastPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + /// Populate `patterns` with the following patterns. /// /// [DecomposeDifferentRankInsertStridedSlice] diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3066,6 +3066,8 @@ if (!getDisableMultiReductionToContractPatterns()) vector::populateVectorReductionToContractPatterns(patterns); + vector::populateSinkVectorBroadcastPatterns(patterns); + patterns.add(ctx, /*benefit=*/2); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -885,6 +885,66 @@ std::function controlFn; }; +/// Reorders elementwise(broadcast) to broadcast(elementwise). Ex: +/// ``` +/// %a = vector.broadcast %arg1 : index to vector<1x4xindex> +/// %b = vector.broadcast %arg2 : index to vector<1x4xindex> +/// %r = arith.addi %a, %b : vector<1x4xindex> +/// ``` +/// Gets converted to: +/// ``` +/// %r = arith.addi %arg0, %arg1 : index +/// %b = vector.broadcast %r : index to vector<1x4xindex> +/// ``` +struct ReorderElementwiseOpsOnBroadcast final + : public OpTraitRewritePattern { + using OpTraitRewritePattern::OpTraitRewritePattern; + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getNumResults() != 1) + return failure(); + if (!llvm::isa(op->getResults()[0].getType())) + return failure(); + if (!OpTrait::hasElementwiseMappableTraits(op)) + return failure(); + + // Get the type of the first operand + auto firstBcast = op->getOperand(0).getDefiningOp(); + if (!firstBcast) + return failure(); + auto firstOpType = firstBcast.getOperand().getType(); + + // Make sure that operands are "broadcast"ed from identical (scalar or + // vector) types. That indicates that it's safe to skip the broadcasting of + // operands. + if (!llvm::all_of(op->getOperands(), [&firstOpType](Value val) { + auto bcast = val.getDefiningOp(); + return (bcast && (bcast.getOperand().getType() == firstOpType)); + })) { + return failure(); + } + + // Collect the source values + SmallVector srcValues; + srcValues.reserve(op->getNumOperands()); + + for (Value operand : op->getOperands()) { + srcValues.push_back( + operand.getDefiningOp().getOperand()); + } + + Operation *elementwiseOp = + rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, + firstOpType, op->getAttrs()); + + auto vectorType = op->getResultTypes()[0]; + rewriter.replaceOpWithNewOp( + op, vectorType, elementwiseOp->getResults()); + + return success(); + } +}; + // Helper that returns a vector comparison that constructs a mask: // mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b] // @@ -1311,6 +1371,12 @@ patterns.add(patterns.getContext(), benefit); } +void mlir::vector::populateSinkVectorBroadcastPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); +} + //===----------------------------------------------------------------------===// // TableGen'd enum attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -130,27 +130,29 @@ return %25 : tensor<1x4xf32> } -// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex + +// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex( // CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>, -// CHECK-SAME: {{.*}}: index, +// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, // CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> { // CHECK: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> // CHECK: %[[VAL_7:.*]] = arith.constant 0 : i32 // CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[VAL_9:.*]] = arith.constant 0 : index // CHECK: %[[VAL_10:.*]] = arith.constant 79 : index -// CHECK: %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex> -// CHECK: %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex> -// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex> -// CHECK: %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex> -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex> -// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex> -// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex> -// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex> -// CHECK: %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex> -// CHECK: %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32> -// CHECK: %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> +// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index +// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex> +// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex> +// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex> +// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex> +// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex> +// CHECK: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_12]] : vector<1x4xindex> to vector<4xindex> +// CHECK: %[[VAL_18:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex> +// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_16]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex> +// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_18]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32> +// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32> +// CHECK: return %[[VAL_21]] : tensor<1x4xf32> +// CHECK: } transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): @@ -317,43 +319,16 @@ } // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_tensor_extract( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x20xi32>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<257x24xf32>, -// CHECK-SAME: %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index) -> tensor<1x1x4xf32> { -// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0> : vector<1x1x4xindex> -// CHECK-DAG: %[[VAL_7:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> -// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : i32 -// CHECK-DAG: %[[VAL_9:.*]] = arith.constant dense<256> : vector<1x1x4xindex> -// CHECK-DAG: %[[VAL_10:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[VAL_11:.*]] = arith.constant 0 : index -// CHECK: %[[VAL_12:.*]] = tensor.empty() : tensor<1x1x4xf32> -// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_2]] : index to vector<1x1x4xindex> -// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_4]] : index to vector<1x1x4xindex> -// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_13]], %[[VAL_14]] : vector<1x1x4xindex> -// CHECK: %[[VAL_16:.*]] = vector.broadcast %[[VAL_3]] : index to vector<1x1x4xindex> -// CHECK: %[[VAL_17:.*]] = vector.broadcast %[[VAL_7]] : vector<4xindex> to vector<1x1x4xindex> -// CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_16]], %[[VAL_17]] : vector<1x1x4xindex> -// CHECK: %[[VAL_19:.*]] = vector.broadcast %[[VAL_5]] : index to vector<1x1x4xindex> -// CHECK: %[[VAL_20:.*]] = arith.addi %[[VAL_18]], %[[VAL_19]] : vector<1x1x4xindex> -// CHECK: %[[VAL_21:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_22:.*]] = vector.extractelement %[[VAL_21]][%[[VAL_8]] : i32] : vector<4xindex> +// CHECK-SAME: %[[INPUT_1:.*]]: tensor<1x20xi32>, +// CHECK-SAME: %[[INPUT_2:.*]]: tensor<257x24xf32>, +// CHECK: %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index +// CHECK: %[[EXTRACTED_0_IDX_1:.*]] = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<4xindex> // First `tensor.extract` from the generic Op - loop invariant scalar load. -// CHECK: %[[VAL_23:.*]] = tensor.extract %[[VAL_0]][%[[VAL_11]], %[[VAL_22]]] : tensor<1x20xi32> -// CHECK: %[[VAL_24:.*]] = arith.index_cast %[[VAL_23]] : i32 to index -// CHECK: %[[VAL_25:.*]] = vector.broadcast %[[VAL_24]] : index to vector<1x1x4xindex> -// CHECK: %[[VAL_26:.*]] = arith.maxsi %[[VAL_25]], %[[VAL_6]] : vector<1x1x4xindex> -// CHECK: %[[VAL_27:.*]] = arith.minsi %[[VAL_26]], %[[VAL_9]] : vector<1x1x4xindex> -// CHECK: %[[VAL_28:.*]] = vector.shape_cast %[[VAL_27]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_29:.*]] = vector.extractelement %[[VAL_28]][%[[VAL_8]] : i32] : vector<4xindex> -// CHECK: %[[VAL_30:.*]] = vector.shape_cast %[[VAL_20]] : vector<1x1x4xindex> to vector<4xindex> -// CHECK: %[[VAL_31:.*]] = vector.extractelement %[[VAL_30]][%[[VAL_8]] : i32] : vector<4xindex> +// CHECK: tensor.extract %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] : tensor<1x20xi32> // The following `tensor.extract` from the generic Op s a contiguous load (all Ops used // for address calculation also satisfy the required conditions). -// CHECK: %[[VAL_32:.*]] = vector.transfer_read %[[VAL_1]][%[[VAL_29]], %[[VAL_31]]], %[[VAL_10]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32> -// CHECK: %[[VAL_33:.*]] = vector.broadcast %[[VAL_32]] : vector<1x4xf32> to vector<1x1x4xf32> -// CHECK: %[[VAL_34:.*]] = vector.transfer_write %[[VAL_33]], %[[VAL_12]][%[[VAL_11]], %[[VAL_11]], %[[VAL_11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32> -// CHECK: return %[[VAL_34]] : tensor<1x1x4xf32> -// CHECK: } +// CHECK: vector.transfer_read %[[INPUT_2]][%{{.*}}, %{{.*}}, %{{.*}} {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32> + transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): diff --git a/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/sink-vector-broadcast.mlir @@ -0,0 +1,78 @@ +// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @broadcast_scalar( +// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> { +// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex> +// CHECK: return %[[BCAST]] : vector<1x4xindex> +// CHECK: } + +func.func @broadcast_scalar( %arg1: index, %arg2: index) -> vector<1x4xindex> { + %0 = vector.broadcast %arg1 : index to vector<1x4xindex> + %1 = vector.broadcast %arg2 : index to vector<1x4xindex> + %2 = arith.addi %0, %1 : vector<1x4xindex> + return %2 : vector<1x4xindex> +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_vector( +// CHECK-SAME: %[[ARG_0:.*]]: vector<4xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: vector<4xf32>) -> vector<3x4xf32> { +// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<4xf32> +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32> +// CHECK: return %[[BCAST]] : vector<3x4xf32> +// CHECK: } + +func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> { + %arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32> + %arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32> + %2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32> + return %2 : vector<3x4xf32> +} +// ----- + +// CHECK-LABEL: func.func @broadcast_vector_and_scalar( +// CHECK-SAME: %[[ARG_0:.*]]: i32, +// CHECK-SAME: %[[ARG_1:.*]]: vector<4xi32>) -> vector<4xi32> { +// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<4xi32> +// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32> +// CHECK: return %[[ADD]] : vector<4xi32> +// CHECK: } + +func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> { + %arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32> + %2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32> + return %2 : vector<4xi32> +} + +// ----- + +#matmat_accesses = [ + affine_map<(i, j, k) -> (i, k)>, + affine_map<(i, j, k) -> (k, j)>, + affine_map<(i, j, k) -> (i, j)> +] +#matmat_trait = { + indexing_maps = #matmat_accesses, + iterator_types = ["parallel", "parallel", "reduction"] +} + +// CHECK-LABEL: func.func @broadcast_not_elementwise() -> vector<2x2xf32> { +// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32> +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32> +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32> +// CHECK: %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +func.func @broadcast_not_elementwise() -> vector<2x2xf32> { + %f1 = arith.constant 1.0: f32 + %f2 = arith.constant 2.0: f32 + %f3 = arith.constant 3.0: f32 + + %A = vector.broadcast %f1 : f32 to vector<2x2xf32> + %B = vector.broadcast %f2 : f32 to vector<2x2xf32> + %C = vector.broadcast %f3 : f32 to vector<2x2xf32> + %mm1 = vector.contract #matmat_trait %A, %B, %C + : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> + + return %mm1 : vector<2x2xf32> +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -374,6 +374,31 @@ } }; +struct TestSinkVectorBroadcast + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestSinkVectorBroadcast) + + TestSinkVectorBroadcast() = default; + TestSinkVectorBroadcast(const TestSinkVectorBroadcast &pass) = default; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + StringRef getArgument() const final { return "test-sink-vector-broadcast"; } + + StringRef getDescription() const final { + return "Test lowering patterns that eliminate redundant brodacast " + "operations."; + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateSinkVectorBroadcastPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct TestVectorReduceToContractPatternsPatterns : public PassWrapper> { @@ -735,6 +760,8 @@ PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration();