diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1199,38 +1199,43 @@ // a. Get the first max ranked shape. VectorType firstMaxRankedType; for (Value operand : op->getOperands()) { - auto vecType = dyn_cast(bvm.lookup(operand).getType()); + auto vecOperand = bvm.lookup(operand); + assert(vecOperand && "Vector operand couldn't be found"); + + auto vecType = dyn_cast(vecOperand.getType()); if (vecType && (!firstMaxRankedType || firstMaxRankedType.getRank() < vecType.getRank())) firstMaxRankedType = vecType; } // b. Broadcast each op if needed. - SmallVector vectorizedOperands; + SmallVector vecOperands; for (Value scalarOperand : op->getOperands()) { - Value vectorizedOperand = bvm.lookup(scalarOperand); - auto vecType = - VectorType::get(firstMaxRankedType.getShape(), - getElementTypeOrSelf(vectorizedOperand.getType()), - firstMaxRankedType.getNumScalableDims()); - vectorizedOperands.push_back( - !firstMaxRankedType - ? vectorizedOperand - : broadcastIfNeeded(rewriter, vectorizedOperand, vecType)); + Value vecOperand = bvm.lookup(scalarOperand); + assert(vecOperand && "Vector operand couldn't be found"); + + if (firstMaxRankedType) { + auto vecType = VectorType::get(firstMaxRankedType.getShape(), + getElementTypeOrSelf(vecOperand.getType()), + firstMaxRankedType.getNumScalableDims()); + vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType)); + } else { + vecOperands.push_back(vecOperand); + } } // c. for elementwise, the result is the vector with the firstMaxRankedShape SmallVector resultTypes; for (Type resultType : op->getResultTypes()) { resultTypes.push_back( - !firstMaxRankedType - ? resultType - : VectorType::get(firstMaxRankedType.getShape(), resultType, - firstMaxRankedType.getNumScalableDims())); + firstMaxRankedType + ? VectorType::get(firstMaxRankedType.getShape(), resultType, + firstMaxRankedType.getNumScalableDims()) + : resultType); } // d. Build and return the new op. return VectorizationResult{ VectorizationStatus::NewOp, - rewriter.create(op->getLoc(), op->getName().getIdentifier(), - vectorizedOperands, resultTypes, op->getAttrs())}; + rewriter.create(op->getLoc(), op->getName().getIdentifier(), vecOperands, + resultTypes, op->getAttrs())}; } /// Generic vectorization function that rewrites the body of a `linalgOp` into diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1719,3 +1719,35 @@ %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 { vectorize_padding } : (!transform.any_op) -> !transform.any_op } + +// ----- + +func.func @zero_dim_tensor(%input: tensor, %output: tensor) -> tensor +{ + %0 = linalg.generic { indexing_maps = [ affine_map<() -> ()>, affine_map<() -> ()> ], + iterator_types = [] } + ins(%input : tensor) + outs(%output : tensor) { + ^bb0(%arg0: f32, %arg1: f32): + %2 = arith.addf %arg0, %arg1 : f32 + linalg.yield %2 : f32 + } -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %4 = get_closest_isolated_parent %3 : (!transform.any_op) -> !transform.any_op + %5 = transform.structured.vectorize %4 : (!transform.any_op) -> !transform.any_op +} + +// CHECK-LABEL: func @zero_dim_tensor +// CHECK: vector.transfer_read {{.*}} : tensor, vector +// CHECK: vector.extractelement +// CHECK: vector.transfer_read {{.*}} : tensor, vector +// CHECK: vector.extractelement +// CHECK: arith.addf {{.*}} : f32 +// CHECK: vector.broadcast %{{.*}} : f32 to vector +// CHECK: vector.transfer_write {{.*}} : vector, tensor +