diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -52,9 +52,17 @@ auto pattern2 = m_Op(m_Op(c, m_Op(a, b))); auto pattern3 = m_Op(m_Op(m_Op(b, a), c)); auto pattern4 = m_Op(m_Op(c, m_Op(b, a))); + auto pattern5 = m_Op(m_Op(m_Op(a, b), c)); + auto pattern6 = m_Op(m_Op(c, m_Op(a, b))); + auto pattern7 = m_Op(m_Op(m_Op(b, a), c)); + auto pattern8 = m_Op(m_Op(c, m_Op(b, a))); return pattern1.match(&r.front().back()) || pattern2.match(&r.front().back()) || - pattern3.match(&r.front().back()) || pattern4.match(&r.front().back()); + pattern3.match(&r.front().back()) || + pattern4.match(&r.front().back()) || + pattern5.match(&r.front().back()) || + pattern6.match(&r.front().back()) || + pattern7.match(&r.front().back()) || pattern8.match(&r.front().back()); } // TODO: Should be Tablegen'd from a single source that generates the op itself. diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -118,6 +118,23 @@ // CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xf32>, vector<16x32xf32> into vector<8x32xf32> // CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xf32>, memref<8x32xf32> +func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>, + %C: memref<8x32xi32>) { + linalg.generic #matmul_trait %A, %B, %C { + ^bb(%a: i32, %b: i32, %c: i32) : + %d = muli %a, %b: i32 + %e = addi %c, %d: i32 + linalg.yield %e : i32 + } : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32> + return +} +// CHECK-LABEL: func @vectorization_test_integer +// CHECK: vector.transfer_read %{{.*}} : memref<8x16xi32>, vector<8x16xi32> +// CHECK: vector.transfer_read %{{.*}} : memref<16x32xi32>, vector<16x32xi32> +// CHECK: vector.transfer_read %{{.*}} : memref<8x32xi32>, vector<8x32xi32> +// CHECK: vector.contract {indexing_maps = [#[[$mk]], #[[$kn]], #[[$mn]]], iterator_types = ["parallel", "parallel", "reduction"]} %{{.*}}, %{{.*}}, %{{.*}} : vector<8x16xi32>, vector<16x32xi32> into vector<8x32xi32> +// CHECK: vector.transfer_write %{{.*}}, %{{.*}} : vector<8x32xi32>, memref<8x32xi32> + func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, %C: memref<8x32xf32>) { linalg.matmul %A, %B, %C { __internal_linalg_transform__ = "VECTORIZE"} :