diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -3,6 +3,11 @@ C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); } +ods_def: +def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) { + x(m) = std_addf(std_mulf(A(m, n), y(n))); +} + ods_def: def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(b, k, n))); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -197,34 +197,6 @@ let hasFolder = 1; } -def MatvecOp : LinalgStructured_Op<"matvec", [NInputs<2>, NOutputs<1>]> { - - let arguments = (ins AnyStridedMemRefOfRank<2>, - AnyStridedMemRefOfRank<1>, - AnyStridedMemRefOfRank<1>); - - let extraClassDeclaration = libraryCallName # [{ - llvm::Optional> referenceIterators() { - return SmallVector{ - getParallelIteratorTypeName(), getReductionIteratorTypeName()}; - } - - // A(i, r_j) * B(r_j) -> C(i) - llvm::Optional> referenceIndexingMaps() { - MLIRContext *context = getContext(); - AffineExpr i, r_j; - bindDims(context, i, r_j); - return SmallVector{ - AffineMap::get(2, 0, {i, r_j}, context), - AffineMap::get(2, 0, {r_j}, context), - AffineMap::get(2, 0, {i}, context) - }; - } - }]; - - let hasFolder = 1; -} - /// A base class for pooling operation such as conv. The arguments must contain /// optional arguments `strides`, `dilations` and `padding` with following type: /// OptionalAttr:$strides diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -240,11 +240,11 @@ LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion>(ctx); + LinalgOpConversion>(ctx); // TODO: collect all auto-generated named ops with a tblgen directive. patterns.insert< LinalgOpConversion, + LinalgOpConversion, LinalgOpConversion>(ctx); // clang-format on } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1124,10 +1124,6 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } -LogicalResult MatvecOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} OpFoldResult ReshapeOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); @@ -1242,3 +1238,7 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } +LogicalResult MatvecOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -242,17 +242,6 @@ // Emit scalar form. C() = C() + A(r_i) * B(r_i); } -template -void emitScalarImplementation(ArrayRef allIvs, MatvecOp matvecOp) { - assert(matvecOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(allIvs.size() == 2); - Value i(allIvs[0]), r_j(allIvs[1]); - IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), - C(matvecOp.getOutputBuffer(0)); - // Emit scalar form. - C(i) = C(i) + A(i, r_j) * B(r_j); -} template Value getConvOpInput(ConvOp convOp, StdIndexedValue im, @@ -624,8 +613,6 @@ return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); if (isa(op)) @@ -642,6 +629,8 @@ return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); llvm_unreachable("Unexpected op in linalgOpToLoopsImpl"); diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -77,7 +77,7 @@ %2 = view %arg0[%c0][%M, %N] : memref to memref %3 = view %arg0[%c0][%M] : memref to memref %4 = view %arg0[%c0][%N] : memref to memref - linalg.matvec(%2, %3, %4) : memref, memref, memref + linalg.matvec %2, %3, %4 : (memref, memref, memref) return } // CHECKLOOP-LABEL: func @matvec(%{{.*}}: memref, diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -86,9 +86,9 @@ linalg.matmul %arg0, %arg0, %arg0 : (memref, memref, memref) - linalg.matvec(%arg0, %arg1, %arg2) : memref, + linalg.matvec %arg0, %arg1, %arg2 : (memref, memref, - memref + memref) linalg.dot(%arg1, %arg2, %arg3) : memref, memref, memref @@ -99,10 +99,10 @@ // CHECK-SAME: (memref, // CHECK-SAME: memref, // CHECK-SAME: memref) -// CHECK-NEXT: linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) : -// CHECK-SAME: memref, +// CHECK-NEXT: linalg.matvec %{{.*}}, %{{.*}}, %{{.*}} : +// CHECK-SAME: (memref, // CHECK-SAME: memref, -// CHECK-SAME: memref +// CHECK-SAME: memref) // CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) : // CHECK-SAME: memref, // CHECK-SAME: memref, diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -199,7 +199,10 @@ // TILE-234: memref) func @matvec(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.matvec(%arg0, %arg1, %arg2) : memref, memref, memref + linalg.matvec %arg0, %arg1, %arg2 : ( + memref, + memref, + memref) return } // TILE-2-LABEL: func @matvec( @@ -217,7 +220,7 @@ // TILE-2: %[[localN:.*]] = dim %{{.*}}, %c0 // TILE-2: %[[szN:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localN]]] // TILE-2: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szN]]] [1] : memref to memref -// TILE-2: linalg.matvec(%[[sAi]], %{{.*}}, %[[sCi]]) : memref, memref, memref +// TILE-2: linalg.matvec %[[sAi]], %{{.*}}, %[[sCi]] : (memref, memref, memref) // TILE-02-LABEL: func @matvec( // TILE-02-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref @@ -234,7 +237,7 @@ // TILE-02: %[[localN:.*]] = dim %{{.*}}, %c0 // TILE-02: %[[szN:.*]] = affine.min #[[$bound_map]](%[[J]])[%[[localN]]] // TILE-02: %[[sBj:.*]] = subview %{{.*}}[%[[J]]] [%[[szN]]] [1] : memref to memref -// TILE-02: linalg.matvec(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref +// TILE-02: linalg.matvec %[[sAj]], %[[sBj]], %{{.*}} : (memref, memref, memref) // TILE-002-LABEL: func @matvec( // TILE-002-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref @@ -265,7 +268,7 @@ // TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]] // TILE-234: %[[sCi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref // -// TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : memref, memref, memref +// TILE-234: linalg.matvec %[[sAij]], %[[sBj]], %[[sCi]] : (memref, memref, memref) func @dot(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.dot(%arg0, %arg1, %arg2) : memref, memref, memref diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -36,10 +36,10 @@ func @matvec(%A: memref, %x: memref, %y: memref) { - linalg.matvec(%A, %x, %y) : - memref, - memref, - memref + linalg.matvec %A, %x, %y : + (memref, + memref, + memref) return } // CHECK-LABEL: func @matvec @@ -48,7 +48,7 @@ // CHECK-DAG: %[[c6:.*]] = constant 6 : index // CHECK: scf.parallel {{.*}} step (%[[c5]]) // CHECK: scf.for {{.*}} step %[[c6]] -// CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref +// CHECK: linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref, memref, memref) func @matmul(%A: memref, %B: memref, @@ -202,10 +202,10 @@ func @matvec_perm(%A: memref, %x: memref, %y: memref) { - linalg.matvec(%A, %x, %y) {__internal_linalg_transform__ = "__with_perm__"} : - memref, - memref, - memref + linalg.matvec %A, %x, %y {__internal_linalg_transform__ = "__with_perm__"} : + (memref, + memref, + memref) return } // CHECK-LABEL: func @matvec_perm @@ -214,7 +214,7 @@ // CHECK-DAG: %[[c6:.*]] = constant 6 : index // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c6]] // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c5]] -// CHECK: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : memref, memref, memref +// CHECK: linalg.matvec {{.*}}, {{.*}}, {{.*}} : (memref, memref, memref) func @matmul_perm(%A: memref, %B: memref,