diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -1346,11 +1346,25 @@ } }; +// Fold broadcast1(broadcast2(x)) into broadcast1(x). +struct BroadcastFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + auto srcBroadcast = broadcastOp.source().getDefiningOp(); + if (!srcBroadcast) + return failure(); + rewriter.replaceOpWithNewOp( + broadcastOp, broadcastOp.getVectorType(), srcBroadcast.source()); + return success(); + } +}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2510,32 +2510,39 @@ return failure(); if (read.mask()) return failure(); - Operation *loadOp; - if (!broadcastedDims.empty() && - unbroadcastedVectorType.getNumElements() == 1) { - // If broadcasting is required and the number of loaded elements is 1 then - // we can create `memref.load` instead of `vector.load`. - loadOp = rewriter.create(read.getLoc(), read.source(), - read.indices()); - } else { - // Otherwise create `vector.load`. - loadOp = rewriter.create(read.getLoc(), - unbroadcastedVectorType, - read.source(), read.indices()); - } + auto loadOp = rewriter.create( + read.getLoc(), unbroadcastedVectorType, read.source(), read.indices()); // Insert a broadcasting op if required. if (!broadcastedDims.empty()) { rewriter.replaceOpWithNewOp( - read, read.getVectorType(), loadOp->getResult(0)); + read, read.getVectorType(), loadOp.result()); } else { - rewriter.replaceOp(read, loadOp->getResult(0)); + rewriter.replaceOp(read, loadOp.result()); } return success(); } }; +/// Replace a scalar vector.load with a memref.load. +struct VectorLoadToMemrefLoadLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto vecType = loadOp.getVectorType(); + if (vecType.getNumElements() != 1) + return failure(); + auto memrefLoad = rewriter.create( + loadOp.getLoc(), loadOp.base(), loadOp.indices()); + rewriter.replaceOpWithNewOp( + loadOp, VectorType::get({1}, vecType.getElementType()), memrefLoad); + return success(); + } +}; + /// Progressive lowering of transfer_write. This pattern supports lowering of /// `vector.transfer_write` to `vector.store` if all of the following hold: /// - The op writes to a memref with the default layout. @@ -3674,8 +3681,9 @@ void mlir::vector::populateVectorTransferLoweringPatterns( RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns + .add(patterns.getContext()); populateVectorTransferPermutationMapLoweringPatterns(patterns); } diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -613,6 +613,18 @@ // ----- +// CHECK-LABEL: @fold_consecutive_broadcasts( +// CHECK-SAME: %[[ARG0:.*]]: i32 +// CHECK: %[[RESULT:.*]] = vector.broadcast %[[ARG0]] : i32 to vector<4x16xi32> +// CHECK: return %[[RESULT]] +func @fold_consecutive_broadcasts(%a : i32) -> vector<4x16xi32> { + %1 = vector.broadcast %a : i32 to vector<16xi32> + %2 = vector.broadcast %1 : vector<16xi32> to vector<4x16xi32> + return %2 : vector<4x16xi32> +} + +// ----- + // CHECK-LABEL: shape_cast_constant // CHECK-DAG: %[[CST1:.*]] = constant dense<1> : vector<3x4x2xi32> // CHECK-DAG: %[[CST0:.*]] = constant dense<2.000000e+00> : vector<20x2xf32> diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s // transfer_read/write are lowered to vector.load/store // CHECK-LABEL: func @transfer_to_load( @@ -174,6 +174,21 @@ // ----- +// CHECK-LABEL: func @transfer_scalar( +// CHECK-SAME: %[[MEM:.*]]: memref, +// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> { +// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref +// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32> +// CHECK-NEXT: return %[[RES]] : vector<1xf32> +// CHECK-NEXT: } +func @transfer_scalar(%mem : memref, %i : index) -> vector<1xf32> { + %cf0 = constant 0.0 : f32 + %res = vector.transfer_read %mem[%i, %i], %cf0 {in_bounds = [true]} : memref, vector<1xf32> + return %res : vector<1xf32> +} + +// ----- + // An example with two broadcasted dimensions. // CHECK-LABEL: func @transfer_broadcasting_2D( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,