diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3175,7 +3175,7 @@ } }; -struct CastAwayBrodcastLeadingOneDim +struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -3200,6 +3200,44 @@ } }; +class CastAwayElementwiseLeadingOneDim : public RewritePattern { +public: + CastAwayElementwiseLeadingOneDim(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1) + return failure(); + auto vecType = op->getResultTypes()[0].dyn_cast(); + if (!vecType) + return failure(); + VectorType newVecType = trimLeadingOneDims(vecType); + if (newVecType == vecType) + return failure(); + + SmallVector newOperands; + for (Value operand : op->getOperands()) { + if (auto opVecType = operand.getType().dyn_cast()) { + auto newType = + VectorType::get(newVecType.getShape(), opVecType.getElementType()); + newOperands.push_back(rewriter.create( + op->getLoc(), newType, operand)); + } else { + newOperands.push_back(operand); + } + } + OperationState state(op->getLoc(), op->getName()); + state.addAttributes(op->getAttrs()); + state.addOperands(newOperands); + state.addTypes(newVecType); + Operation *newOp = rewriter.createOperation(state); + rewriter.replaceOpWithNewOp(op, vecType, + newOp->getResult(0)); + return success(); + } +}; + // Returns the values in `arrayAttr` as an integer vector. static SmallVector getIntValueVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>( @@ -3795,12 +3833,13 @@ void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns( RewritePatternSet &patterns) { - patterns.add( - patterns.getContext()); + patterns + .add( + patterns.getContext()); } void mlir::vector::populateBubbleVectorBitCastOpPatterns( diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -689,6 +689,34 @@ return %0, %1, %2: vector<1x1x8xf32>, vector<1x1x4xf32>, vector<1x3x4xf32> } +// CHECK-LABEL: func @cast_away_elementwise_leading_one_dims +func @cast_away_elementwise_leading_one_dims( + %arg0: vector<1x1x8xf32>, %arg1: f32, %arg2: vector<1x4xf32>, + %arg3: vector<1x4xf32>, %arg4: i1) -> + (vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32>) { + // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<1x1x8xf32> to vector<8xf32> + // CHECK: addf %{{.*}}, %{{.*}} : vector<8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<8xf32> to vector<1x1x8xf32> + %0 = addf %arg0, %arg0 : vector<1x1x8xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: cmpf ogt, %{{.*}}, %{{.*}} : vector<4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<4xi1> to vector<1x4xi1> + %1 = cmpf ogt, %arg2, %arg3 : vector<1x4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: select %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi1>, vector<4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32> + %2 = select %1, %arg3, %arg2 : vector<1x4xi1>, vector<1x4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4xf32> + // CHECK: select %arg4, %12, %{{.*}} : vector<4xf32> + // CHECK: vector.shape_cast %{{.*}} : vector<4xf32> to vector<1x4xf32> + %3 = select %arg4, %arg3, %arg2 : vector<1x4xf32> + return %0, %1, %2, %3: vector<1x1x8xf32>, vector<1x4xi1>, vector<1x4xf32>, vector<1x4xf32> +} + // CHECK-LABEL: func @bubble_down_bitcast_in_extract // CHECK-SAME: %[[SRC:.+]]: vector<4xf32> func @bubble_down_bitcast_in_extract(%src: vector<4xf32>) -> (f16, f16) {