diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -168,7 +168,8 @@ for (Type outputTensorType : linalgOp.getOutputTensorTypes()) if (!outputTensorType.cast().hasStaticShape()) return failure(); - if (isa(op) || isa(op)) + if (isa(op) || isa(op) || + isa(op)) return success(); auto genericOp = dyn_cast(op); @@ -210,6 +211,17 @@ auto dstVec = std_load(dstMemrefVec); auto resVec = vector_broadcast(dstVec, fillOp.value()); std_store(resVec, dstMemrefVec); + } else if (auto copyOp = dyn_cast(op)) { + // Vectorize fill as a vector.broadcast. + LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE + "]: Rewrite linalg.copy as vector.broadcast: " + << *op << ":\n"); + auto dstMemrefVec = vector_type_cast(copyOp.getOutputBuffer(0)); + auto dstVec = std_load(dstMemrefVec); + auto srcMemrefVec = vector_type_cast(copyOp.getInput(0)); + auto srcVec = std_load(srcMemrefVec); + auto resVec = vector_broadcast(dstVec, srcVec); + std_store(resVec, dstMemrefVec); } else { // Vectorize other ops as vector contraction (currently only matmul). LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -212,6 +212,14 @@ // CHECK-LABEL: func @test_vectorize_fill // CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> +func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { + linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"}: memref<8x16xf32>, + memref<8x16xf32> + return +} +// CHECK-LABEL: func @test_vectorize_copy +// CHECK: vector.broadcast {{.*}} : vector<8x16xf32> to vector<8x16xf32> + func @fma(%a: f32, %b: f32, %c: f32) -> f32 { %d = mulf %a, %b: f32 %e = addf %c, %d: f32 diff --git a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td --- a/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td +++ b/mlir/test/lib/DeclarativeTransforms/TestLinalgTransformPatterns.td @@ -111,6 +111,12 @@ HasLinalgTransformMarker<"VECTORIZE">, PreconditionVectorizeLinalgOp ]>>)]>; +def : Pattern<(CopyOp:$op $_, $_, $_, $_), + [(VectorizeLinalgOp)], + [(Constraint, + PreconditionVectorizeLinalgOp + ]>>)]>; def : Pattern<(GenericOp:$op $_, $_, $_, $_, $_, $_, $_, $_), [(VectorizeLinalgOp)], [(Constraint