diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -8,6 +8,11 @@ x(m) = std_addf(std_mulf(A(m, n), y(n))); } +ods_def: +def dot(A: f32(M), B: f32(M)) -> (C: f32()) { + C() = std_addf(std_mulf(A(m), B(m))); +} + ods_def: def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(b, k, n))); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -51,9 +51,9 @@ /// 1. linalg.fill(%A, %f) : memref, f32 /// name mangles into `linalg_fill_viewf32_f32_impl` /// -/// 2. linalg.dot(%A, %B, %C) : -/// memref, -/// memref, memref +/// 2. linalg.dot %A, %B, %C : +/// (memref, +/// memref, memref) /// name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl` /// /// 3. linalg.matmul(...) : diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -180,31 +180,6 @@ let hasFolder = 1; } -def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> { - - let arguments = (ins AnyStridedMemRefOfRank<1>, - AnyStridedMemRefOfRank<1>, - AnyStridedMemRefOfRank<0>); - - let extraClassDeclaration = libraryCallName # [{ - llvm::Optional> referenceIterators() { - return SmallVector{getReductionIteratorTypeName()}; - } - - // A(r_i) * B(r_i) -> C() - llvm::Optional> referenceIndexingMaps() { - MLIRContext *context = getContext(); - auto r_i = getAffineDimExpr(0, context); - return SmallVector{ - AffineMap::get(1, 0, {r_i}, context), - AffineMap::get(1, 0, {r_i}, context), - AffineMap::get(1, 0, {}, context)}; - } - }]; - - let hasFolder = 1; -} - /// A base class for pooling operation such as conv. The arguments must contain /// optional arguments `strides`, `dilations` and `padding` with following type: /// OptionalAttr:$strides diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp --- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp +++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp @@ -235,13 +235,13 @@ LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, - LinalgOpConversion, - LinalgOpConversion, + LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion>(ctx); // TODO: collect all auto-generated named ops with a tblgen directive. patterns.insert< + LinalgOpConversion, LinalgOpConversion, LinalgOpConversion, LinalgOpConversion>(ctx); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1173,10 +1173,6 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } -LogicalResult DotOp::fold(ArrayRef, - SmallVectorImpl &) { - return foldMemRefCast(*this); -} LogicalResult FillOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCast(*this); @@ -1280,6 +1276,17 @@ if (!tensorResultTypes.empty()) result.addTypes(tensorResultTypes); + // The number of parsed arguments must equal + // the number of expected arguments for the current operation. + auto parsedArgs = operandsInfo.size(); + auto expectedArgs = NamedStructuredOpType::getNumInputs() + + NamedStructuredOpType::getNumOutputs(); + if (parsedArgs != expectedArgs) + return parser.emitError(parser.getNameLoc(), + "expects " + std::to_string(expectedArgs) + + " operands, but found " + + std::to_string(parsedArgs)); + buildNamedStructuredOpRegionAndAttributes( parser.getBuilder(), result, operandTypes, tensorResultTypes); @@ -1299,6 +1306,10 @@ SmallVectorImpl &) { return foldMemRefCast(*this); } +LogicalResult DotOp::fold(ArrayRef, + SmallVectorImpl &) { + return foldMemRefCast(*this); +} LogicalResult MatmulOp::fold(ArrayRef, SmallVectorImpl &) { return foldMemRefCast(*this); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -295,18 +295,6 @@ nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); } -template -void emitScalarImplementation(ArrayRef allIvs, DotOp dotOp) { - assert(dotOp.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - assert(allIvs.size() == 1); - Value r_i(allIvs[0]); - IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), - C(dotOp.getOutputBuffer(0)); - // Emit scalar form. - C() = C() + A(r_i) * B(r_i); -} - template Value getConvOpInput(ConvOp convOp, StdIndexedValue im, MutableArrayRef imIdx) { @@ -673,8 +661,6 @@ return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); - if (isa(op)) - return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); if (isa(op)) @@ -693,6 +679,8 @@ return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); + if (isa(op)) + return linalgOpToLoopsImpl(op, builder); if (isa(op)) return linalgOpToLoopsImpl(op, builder); llvm_unreachable("Unexpected op in linalgOpToLoopsImpl"); diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -422,8 +422,8 @@ // ----- func @generic_result_0_element_type(%arg0: memref) { - // expected-error @+1 {{'linalg.dot' op expected 3 operands, but found 2}} - linalg.dot(%arg0, %arg0): memref, memref + // expected-error @+1 {{'linalg.dot' expects 3 operands, but found 2}} + linalg.dot %arg0, %arg0 : (memref, memref) } // ----- diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -123,7 +123,7 @@ %1 = view %arg0[%c0][%M] : memref to memref %2 = view %arg0[%c0][%M] : memref to memref %3 = view %arg0[%c0][] : memref to memref - linalg.dot(%1, %2, %3) : memref, memref, memref + linalg.dot %1, %2, %3 : (memref, memref, memref) return } // CHECKLOOP-LABEL: func @dot(%{{.*}}: memref, @@ -154,7 +154,9 @@ func @dot_view(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.dot(%arg0, %arg1, %arg2) : memref, memref, memref + linalg.dot %arg0, %arg1, %arg2 : (memref, + memref, + memref) return } // CHECKLOOP-LABEL: func @dot_view( diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -88,10 +88,10 @@ memref) linalg.matvec %arg0, %arg1, %arg2 : (memref, memref, - memref) - linalg.dot(%arg1, %arg2, %arg3) : memref, - memref, - memref + memref) + linalg.dot %arg1, %arg2, %arg3 : (memref, + memref, + memref) return } // CHECK-LABEL: func @ops(% @@ -103,10 +103,10 @@ // CHECK-SAME: (memref, // CHECK-SAME: memref, // CHECK-SAME: memref) -// CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) : -// CHECK-SAME: memref, -// CHECK-SAME: memref, -// CHECK-SAME: memref +// CHECK-NEXT: linalg.dot %{{.*}}, %{{.*}}, %{{.*}} : +// CHECK-SAME: (memref, +// CHECK-SAME: memref, +// CHECK-SAME: memref) // ----- diff --git a/mlir/test/Dialect/Linalg/standard.mlir b/mlir/test/Dialect/Linalg/standard.mlir --- a/mlir/test/Dialect/Linalg/standard.mlir +++ b/mlir/test/Dialect/Linalg/standard.mlir @@ -13,9 +13,9 @@ func @dot(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.dot(%arg0, %arg1, %arg2) : memref, - memref, - memref + linalg.dot %arg0, %arg1, %arg2 : (memref, + memref, + memref) return } // CHECK-LABEL: func @dot( diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -271,7 +271,9 @@ // TILE-234: linalg.matvec %[[sAij]], %[[sBj]], %[[sCi]] : (memref, memref, memref) func @dot(%arg0: memref, %arg1: memref, %arg2: memref) { - linalg.dot(%arg0, %arg1, %arg2) : memref, memref, memref + linalg.dot %arg0, %arg1, %arg2 : (memref, + memref, + memref) return } // TILE-2-LABEL: func @dot( @@ -285,7 +287,7 @@ // TILE-2: %[[localM:.*]] = dim %{{.*}}, %c0 // TILE-2: %[[szM:.*]] = affine.min #[[$bound_map]](%[[I]])[%[[localM]]] // TILE-2: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref -// TILE-2: linalg.dot(%[[sAi]], %[[sBi]], {{.*}}) : memref, memref, memref +// TILE-2: linalg.dot %[[sAi]], %[[sBi]], {{.*}} : (memref, memref, memref) // TILE-02-LABEL: func @dot( // TILE-02-NOT: scf.for @@ -304,7 +306,7 @@ // TILE-234: %[[localM:.*]] = dim %{{.*}}, %c0 // TILE-234: %[[szM:.*]] = affine.min #[[$bound_map_2]](%[[I]])[%[[localM]]] // TILE-234: %[[sBi:.*]] = subview %{{.*}}[%[[I]]] [%[[szM]]] [1] : memref to memref -// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref, memref, memref +// TILE-234: linalg.dot %[[sAi]], %[[sBi]], %{{.*}} : (memref, memref, memref) func @fill_static(%arg0: memref<127x99xf32>, %arg1: f32) { linalg.fill(%arg0, %arg1) : memref<127x99xf32>, f32 diff --git a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns-matmul-to-vector.mlir @@ -36,7 +36,7 @@ func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref) { // VECTOR-CONTRACTION: vector.contract // VECTOR-CONTRACTION-SAME: vector<1584xf32>, vector<1584xf32> into f32 - linalg.dot(%A, %B, %C) : memref<1584xf32>, memref<1584xf32>, memref + linalg.dot %A, %B, %C : (memref<1584xf32>, memref<1584xf32>, memref) return } diff --git a/mlir/test/Dialect/Linalg/transform-patterns.mlir b/mlir/test/Dialect/Linalg/transform-patterns.mlir --- a/mlir/test/Dialect/Linalg/transform-patterns.mlir +++ b/mlir/test/Dialect/Linalg/transform-patterns.mlir @@ -14,10 +14,10 @@ func @dot(%x: memref, %y: memref, %v: memref) { - linalg.dot(%x, %y, %v) { __internal_linalg_transform__ = "MEM" } : - memref, - memref, - memref + linalg.dot %x, %y, %v { __internal_linalg_transform__ = "MEM" } : + (memref, + memref, + memref) return } // CHECK-LABEL: func @dot @@ -28,8 +28,8 @@ // CHECK: scf.for {{.*}} = %[[c0]] to {{.*}} step %[[c1]] { // CHECK: load // CHECK: load -// CHECK: mulf // CHECK: load +// CHECK: mulf // CHECK: addf // CHECK: store diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir @@ -51,7 +51,7 @@ %B = view %bB[%c0][%c16] : memref to memref %C = view %bC[%c0][] : memref to memref - linalg.dot(%A, %B, %C) : memref, memref, memref + linalg.dot %A, %B, %C : (memref, memref, memref) %res = load %C[] : memref dealloc %bC : memref