Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -88,7 +88,7 @@ if (!outputTensorType.cast().hasStaticShape()) return failure(); - if (isa(op)) + if (isa(op)) return success(); return isContraction(op); @@ -111,12 +111,6 @@ return; } - assert(succeeded(isContraction(op)) && "Expected contraction"); - - // Vectorize other ops as vector contraction. - // TODO: interface. - LLVM_DEBUG(dbgs() << dbgPref - << "Rewrite linalg op as vector.contract: " << *op); // In the case of 0-D memrefs, return null and special case to scalar load or // store later. auto extractVectorTypeFromScalarView = [](Value v) { @@ -125,6 +119,42 @@ ? VectorType() : VectorType::get(mt.getShape(), mt.getElementType()); }; + + if (auto copyOp = dyn_cast(op)) { + // Vectorize copy as a vector.transfer_read+vector.transfer_write. + LLVM_DEBUG(dbgs() << dbgPref + << "Rewrite linalg.copy as vector.transfer_read + " + "vector.transfer_write: " + << *op); + Value zero = std_constant_index(0); + SmallVector indicesInput(copyOp.getInputShapedType(0).getRank(), + zero); + SmallVector indicesOutput(copyOp.getOutputShapedType(0).getRank(), + zero); + Value viewInput = copyOp.input(); + Value viewOutput = copyOp.output(); + Value vector; + if (copyOp.inputPermutation()) + vector = vector_transfer_read(extractVectorTypeFromScalarView(viewInput), + viewInput, indicesInput, + copyOp.inputPermutation().getValue()); + else + vector = vector_transfer_read(extractVectorTypeFromScalarView(viewInput), + viewInput, indicesInput); + if (copyOp.outputPermutation()) + vector_transfer_write(vector, viewOutput, indicesOutput, + copyOp.outputPermutation().getValue()); + else + vector_transfer_write(vector, viewOutput, indicesOutput); + return; + } + + assert(succeeded(isContraction(op)) && "Expected contraction"); + + // Vectorize other ops as vector contraction. + // TODO: interface. + LLVM_DEBUG(dbgs() << dbgPref + << "Rewrite linalg op as vector.contract: " << *op); auto linalgOp = cast(op); Value viewA = linalgOp.getInput(0); Value viewB = linalgOp.getInput(1); Index: mlir/test/Dialect/Linalg/transform-patterns.mlir =================================================================== --- mlir/test/Dialect/Linalg/transform-patterns.mlir +++ mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -135,6 +135,14 @@ // CHECK-LABEL: func @test_vectorize_fill // CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> +func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { + linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, memref<8x16xf32> + return +} +// CHECK-LABEL: func @test_vectorize_copy +// CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> +// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + #matmul_accesses = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>, Index: mlir/test/lib/Transforms/TestLinalgTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -144,6 +144,7 @@ //===--------------------------------------------------------------------===// patterns.insert, LinalgVectorizationPattern, + LinalgVectorizationPattern, LinalgVectorizationPattern>( ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));