Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -106,17 +106,6 @@ StringRef dbgPref = "\n[" DEBUG_TYPE "]: "; (void)dbgPref; edsc::ScopedContext scope(builder, op->getLoc()); - if (auto fillOp = dyn_cast(op)) { - // Vectorize fill as a vector.broadcast. - LLVM_DEBUG(dbgs() << dbgPref - << "Rewrite linalg.fill as vector.broadcast: " << *op); - Value memref = vector_type_cast(fillOp.getOutputBuffer(0)); - Value dst = std_load(memref); - Value res = vector_broadcast(dst.getType(), fillOp.value()); - std_store(res, memref); - return; - } - // In the case of 0-D memrefs, return null and special case to scalar load or // store later. auto extractVectorTypeFromScalarView = [](Value v) { @@ -125,7 +114,24 @@ ? VectorType() : VectorType::get(mt.getShape(), mt.getElementType()); }; - + if (auto fillOp = dyn_cast(op)) { + // Vectorize fill as a vector.broadcast. + LLVM_DEBUG(dbgs() << dbgPref + << "Rewrite linalg.fill as vector.broadcast: " << *op); + Value viewOutput = fillOp.output(); + if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) { + auto vecType = + VectorType::get(fillOp.getOutputBufferType(0).getShape(), + fillOp.getOutputBufferType(0).getElementType()); + Value vector = vector_broadcast(vecType, fillOp.value()); + Value zero = std_constant_index(0); + SmallVector indicesOutput(outputType.getRank(), zero); + vector_transfer_write(vector, viewOutput, indicesOutput); + } else { + std_store(fillOp.value(), viewOutput); + } + return; + } if (auto copyOp = dyn_cast(op)) { // Vectorize copy as a vector.transfer_read+vector.transfer_write. LLVM_DEBUG(dbgs() << dbgPref Index: mlir/test/Dialect/Linalg/transform-patterns.mlir =================================================================== --- mlir/test/Dialect/Linalg/transform-patterns.mlir +++ mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -157,7 +157,16 @@ return } // CHECK-LABEL: func @test_vectorize_fill -// CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> +// CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32> +// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + +func @test_vectorize_fill_scalar(%A : memref, %arg0 : f32) { + linalg.fill(%A, %arg0) { __internal_linalg_transform__ = "VECTORIZE"} : memref, f32 + return +} +// CHECK-LABEL: func @test_vectorize_fill +// CHECK-SAME: (%[[M:.*]]: memref, %[[V:.*]]: f32) +// CHECK: store %[[V]], %[[M]][] : memref func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, memref<8x16xf32>