diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -120,14 +120,30 @@ // Vectorize other ops as vector contraction (currently only matmul). LLVM_DEBUG(dbgs() << dbgPref << "Rewrite linalg op as vector.contract: " << *op); + auto extractVectorTypeFromScalarView = [](Value v) { + MemRefType mt = v.getType().cast(); + return VectorType::get(mt.getShape(), mt.getElementType()); + }; auto linalgOp = cast(op); - Value a = std_load(vector_type_cast(linalgOp.getInput(0))); - Value b = std_load(vector_type_cast(linalgOp.getInput(1))); - Value memref = vector_type_cast(linalgOp.getOutputBuffer(0)); - Value c = std_load(memref); + Value viewA = linalgOp.getInput(0); + Value viewB = linalgOp.getInput(1); + Value viewC = linalgOp.getOutputBuffer(0); + Value zero = std_constant_index(0); + SmallVector indicesA(linalgOp.getInputShapedType(0).getRank(), + zero); + SmallVector indicesB(linalgOp.getInputShapedType(1).getRank(), + zero); + SmallVector indicesC(linalgOp.getOutputShapedType(0).getRank(), + zero); + Value a = vector_transfer_read(extractVectorTypeFromScalarView(viewA), viewA, + indicesA); + Value b = vector_transfer_read(extractVectorTypeFromScalarView(viewB), viewB, + indicesB); + Value c = vector_transfer_read(extractVectorTypeFromScalarView(viewC), viewC, + indicesC); Value res = vector_contract(a, b, c, linalgOp.indexing_maps(), linalgOp.iterator_types()); - std_store(res, memref); + vector_transfer_write(res, viewC, indicesC); } /// Check whether there is any interleaved use of any `values` between `firstOp` diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns | FileCheck %s +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-patterns +//| FileCheck %s // CHECK-DAG: #[[STRIDED_1D:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // Map corresponding to a 2D memory access where the stride along the last dim is known to be 1. @@ -106,14 +107,11 @@ return } // CHECK-LABEL: func @vectorization_test -// CHECK: vector.type_cast %{{.*}} : memref<8x16xf32> to memref> -// CHECK: load %{{.*}}[] : memref> -// CHECK: vector.type_cast %{{.*}} : memref<16x32xf32> to memref> -// CHECK: load %{{.*}}[] : memref> -// CHECK: vector.type_cast %{{.*}} : memref<8x32xf32> to memref> -// CHECK: load %{{.*}}[] : memref> +// CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32> to vector<8x16xf32> +// CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32> to vector<16x32xf32> +// CHECK: vector.transfer_read %{{.*}} : memref<8x32xf32> to vector<8x32xf32> // CHECK: vector.contract {indexing_maps = [#[[mk]], #[[kn]], #[[mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> -// CHECK: store %{{.*}}, %{{.*}}[] : memref> +// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[] : vector<8x32xf32> func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) {