Index: mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -88,7 +88,7 @@ if (!outputTensorType.cast().hasStaticShape()) return failure(); - if (isa(op)) + if (isa(op)) return success(); return isContraction(op); @@ -111,12 +111,6 @@ return; } - assert(succeeded(isContraction(op)) && "Expected contraction"); - - // Vectorize other ops as vector contraction. - // TODO: interface. - LLVM_DEBUG(dbgs() << dbgPref - << "Rewrite linalg op as vector.contract: " << *op); // In the case of 0-D memrefs, return null and special case to scalar load or // store later. auto extractVectorTypeFromScalarView = [](Value v) { @@ -125,6 +119,49 @@ ? VectorType() : VectorType::get(mt.getShape(), mt.getElementType()); }; + + if (auto copyOp = dyn_cast(op)) { + // Vectorize copy as a vector.transfer_read+vector.transfer_write. + LLVM_DEBUG(dbgs() << dbgPref + << "Rewrite linalg.copy as vector.transfer_read + " + "vector.transfer_write: " + << *op); + Value zero = std_constant_index(0); + Value viewInput = copyOp.input(); + Value viewOutput = copyOp.output(); + Value vector; + if (VectorType inputType = extractVectorTypeFromScalarView(viewInput)) { + SmallVector indicesInput(inputType.getRank(), zero); + if (copyOp.inputPermutation()) + vector = vector_transfer_read( + extractVectorTypeFromScalarView(viewInput), viewInput, indicesInput, + copyOp.inputPermutation().getValue()); + else + vector = + vector_transfer_read(extractVectorTypeFromScalarView(viewInput), + viewInput, indicesInput); + } else { + vector = std_load(viewInput).value; + } + if (VectorType outputType = extractVectorTypeFromScalarView(viewOutput)) { + SmallVector indicesOutput(outputType.getRank(), zero); + if (copyOp.outputPermutation()) + vector_transfer_write(vector, viewOutput, indicesOutput, + copyOp.outputPermutation().getValue()); + else + vector_transfer_write(vector, viewOutput, indicesOutput); + } else { + std_store(vector, viewOutput); + } + return; + } + + assert(succeeded(isContraction(op)) && "Expected contraction"); + + // Vectorize other ops as vector contraction. + // TODO: interface. + LLVM_DEBUG(dbgs() << dbgPref + << "Rewrite linalg op as vector.contract: " << *op); auto linalgOp = cast(op); Value viewA = linalgOp.getInput(0); Value viewB = linalgOp.getInput(1); Index: mlir/test/Dialect/Linalg/transform-patterns.mlir =================================================================== --- mlir/test/Dialect/Linalg/transform-patterns.mlir +++ mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -135,6 +135,23 @@ // CHECK-LABEL: func @test_vectorize_fill // CHECK: vector.broadcast {{.*}} : f32 to vector<8x16xf32> +func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { + linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref<8x16xf32>, memref<8x16xf32> + return +} +// CHECK-LABEL: func @test_vectorize_copy +// CHECK: %[[V:.*]] = vector.transfer_read {{.*}} : memref<8x16xf32>, vector<8x16xf32> +// CHECK: vector.transfer_write %[[V]], {{.*}} : vector<8x16xf32>, memref<8x16xf32> + +func @test_vectorize_copy_scalar(%A : memref, %B : memref) { + linalg.copy(%A, %B) { __internal_linalg_transform__ = "VECTORIZE"} : memref, memref + return +} +// CHECK-LABEL: func @test_vectorize_copy_scalar +// CHECK: %[[V:.*]] = load {{.*}} : memref +// CHECK: store %[[V]], {{.*}} : memref + + #matmul_accesses = [ affine_map<(m, n, k) -> (m, k)>, affine_map<(m, n, k) -> (k, n)>, Index: mlir/test/lib/Transforms/TestLinalgTransforms.cpp =================================================================== --- mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -144,6 +144,7 @@ //===--------------------------------------------------------------------===// patterns.insert, LinalgVectorizationPattern, + LinalgVectorizationPattern, LinalgVectorizationPattern>( ctx, LinalgMarker(Identifier::get("VECTORIZE", ctx)));