diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc @@ -1,13 +1,13 @@ ods_def implements_interface : def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { - C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); + C(m, n) = std_addf(C(m, n), std_mulf(A(m, k), B(k, n))); } ods_def implements_interface : def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) { - C(n, m) = std_addf(std_mulf(A(k, m), B(n, k))); + C(n, m) = std_addf(C(n, m), std_mulf(A(k, m), B(n, k))); } ods_def @@ -15,169 +15,170 @@ def matmul_i8_i8_i32(A: i8(M, K), B: i8(K, N)) -> (C: i32(M, N)) { // TODO: ideally something closer to // C(m, n) += cast(A(m, k)) * cast(B(k, n)) - C(m, n) = std_addi(std_sexti32(std_muli(A(m, k), B(k, n)))); + C(m, n) = std_addi(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n)))); } ods_def implements_interface : def matmul_i16_i16_i32(A: i16(M, K), B: i16(K, N)) -> (C: i32(M, N)) { - C(m, n) = std_addi(std_sexti32(std_muli(A(m, k), B(k, n)))); + C(m, n) = std_addi(C(m, n), std_sexti32(std_muli(A(m, k), B(k, n)))); } ods_def implements_interface : def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) { - C(m, n) = std_addi(std_muli(A(m, k), B(k, n))); + C(m, n) = std_addi(C(m, n), std_muli(A(m, k), B(k, n))); } ods_def implements_interface : def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) { - x(m) = std_addf(std_mulf(A(m, n), y(n))); + x(m) = std_addf(x(m), std_mulf(A(m, n), y(n))); } ods_def implements_interface : def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) { - x(m) = std_addi(std_sexti32(std_muli(A(m, n), y(n)))); + x(m) = std_addi(x(m), std_sexti32(std_muli(A(m, n), y(n)))); } ods_def implements_interface : def matvec_i16_i16_i32(A: i16(M, N), y: i16(N)) -> (x: i32(M)) { - x(m) = std_addi(std_sexti32(std_muli(A(m, n), y(n)))); + x(m) = std_addi(x(m), std_sexti32(std_muli(A(m, n), y(n)))); } ods_def implements_interface : def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) { - x(m) = std_addi(std_muli(A(m, n), y(n))); + x(m) = std_addi(x(m), std_muli(A(m, n), y(n))); } ods_def implements_interface : def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) { - x(n) = std_addf(std_mulf(y(m), A(m, n))); + x(n) = std_addf(x(n), std_mulf(y(m), A(m, n))); } ods_def implements_interface : def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) { - x(n) = std_addi(std_sexti32(std_muli(y(m), A(m, n)))); + x(n) = std_addi(x(n), std_sexti32(std_muli(y(m), A(m, n)))); } ods_def implements_interface : def vecmat_i16_i16_i32(y: i16(M), A: i16(M, N)) -> (x: i32(N)) { - x(n) = std_addi(std_sexti32(std_muli(y(m), A(m, n)))); + x(n) = std_addi(x(n), std_sexti32(std_muli(y(m), A(m, n)))); } - ods_def implements_interface : def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) { - x(n) = std_addi(std_muli(y(m), A(m, n))); + x(n) = std_addi(x(n), std_muli(y(m), A(m, n))); } ods_def implements_interface : def dot(A: f32(M), B: f32(M)) -> (C: f32()) { - C() = std_addf(std_mulf(A(m), B(m))); + C() = std_addf(C(), std_mulf(A(m), B(m))); } ods_def implements_interface : def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) { - C() = std_addi(std_sexti32(std_muli(A(m), B(m)))); + C() = std_addi(C(), std_sexti32(std_muli(A(m), B(m)))); } ods_def implements_interface : def dot_i16_i16_i32(A: i16(M), B: i16(M)) -> (C: i32()) { - C() = std_addi(std_sexti32(std_muli(A(m), B(m)))); + C() = std_addi(C(), std_sexti32(std_muli(A(m), B(m)))); } - ods_def implements_interface : def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) { - C() = std_addi(std_muli(A(m), B(m))); + C() = std_addi(C(), std_muli(A(m), B(m))); } - ods_def implements_interface : def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) { - C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(b, k, n))); + C(b, m, n) = std_addf(C(b, m, n), std_mulf(A(b, m, k), B(b, k, n))); } ods_def implements_interface : def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) { - C(b, m, n) = std_addi(std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); + C(b, m, n) = + std_addi(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); } ods_def implements_interface : def batch_matmul_i16_i16_i32(A: i16(Batch, M, K), B: i16(Batch, K, N)) -> (C: i32(Batch, M, N)) { - C(b, m, n) = std_addi(std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); + C(b, m, n) = + std_addi(C(b, m, n), std_sexti32(std_muli(A(b, m, k), B(b, k, n)))); } ods_def implements_interface : def batch_matmul_i32_i32_i32(A: i32(Batch, M, K), B: i32(Batch, K, N)) -> (C: i32(Batch, M, N)) { - C(b, m, n) = std_addi(std_muli(A(b, m, k), B(b, k, n))); + C(b, m, n) = std_addi(C(b, m, n), std_muli(A(b, m, k), B(b, k, n))); } ods_def: def conv_1d(I: f32(W), K: f32(KW)) -> (O: f32(W)) { - O(w) = std_addf(std_mulf(I(w + kw), K(kw))); + O(w) = std_addf(O(w), std_mulf(I(w + kw), K(kw))); } ods_def: def conv_1d_nwc(I: f32(N, W, C), K: f32(F, KW, C)) -> (O: f32(N, W, F)) { - O(n, w, f) = std_addf(std_mulf(I(n, w + kw, c), K(f, kw, c))); + O(n, w, f) = std_addf(O(n, w, f), std_mulf(I(n, w + kw, c), K(f, kw, c))); } ods_def: def conv_1d_ncw(I: f32(N, C, W), K: f32(F, C, KW)) -> (O: f32(N, F, W)) { - O(n, f, w) = std_addf(std_mulf(I(n, c, w + kw), K(f, c, kw))); + O(n, f, w) = std_addf(O(n, f, w), std_mulf(I(n, c, w + kw), K(f, c, kw))); } ods_def: def conv_2d(I: f32(H, W), K: f32(KH, KW)) -> (O: f32(H, W)) { - O(h, w) = std_addf(std_mulf(I(h + kh, w + kw), K(kh, kw))); + O(h, w) = std_addf(O(h, w), std_mulf(I(h + kh, w + kw), K(kh, kw))); } ods_def: def conv_2d_nhwc(I: f32(N, H, W, C), K: f32(F, KH, KW, C)) -> (O: f32(N, H, W, F)) { - O(n, h, w, f) = std_addf(std_mulf( - I(n, h + kh, w + kw, c), K(f, kh, kw, c))); + O(n, h, w, f) = std_addf( + O(n, h, w, f), std_mulf(I(n, h + kh, w + kw, c), K(f, kh, kw, c))); } ods_def: def conv_2d_nchw(I: f32(N, C, H, W), K: f32(F, C, KH, KW)) -> (O: f32(N, F, H, W)) { - O(n, f, h, w) = std_addf(std_mulf( - I(n, c, h + kh, w + kw), K(f, c, kh, kw))); + O(n, f, h, w) = std_addf( + O(n, f, h, w), std_mulf(I(n, c, h + kh, w + kw), K(f, c, kh, kw))); } ods_def: def conv_3d(I: f32(D, H, W), K: f32(KD, KH, KW)) -> (O: f32(D, H, W)) { - O(d, h, w) = std_addf(std_mulf( - I(d + kd, h + kh, w + kw), K(kd, kh, kw))); + O(d, h, w) = std_addf( + O(d, h, w), std_mulf(I(d + kd, h + kh, w + kw), K(kd, kh, kw))); } ods_def: def conv_3d_ndhwc(I: f32(N, D, H, W, C), K: f32(F, KD, KH, KW, C)) -> (O: f32(N, D, H, W, F)) { - O(n, d, h, w, f) = std_addf(std_mulf( - I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c))); + O(n, d, h, w, f) = std_addf( + O(n, d, h, w, f), + std_mulf(I(n, d + kd, h + kh, w + kw, c), K(f, kd, kh, kw, c))); } ods_def: def conv_3d_ncdhw(I: f32(N, C, D, H, W), K: f32(F, C, KD, KH, KW)) -> (O: f32(N, F, D, H, W)) { - O(n, f, d, h, w) = std_addf(std_mulf( - I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw))); + O(n, f, d, h, w) = std_addf( + O(n, f, d, h, w), + std_mulf(I(n, c, d + kd, h + kh, w + kw), K(f, c, kd, kh, kw))); } ods_def: @@ -209,8 +210,10 @@ Note: this op only supports channel multiplier == 1. """ { - O(n, oh, ow, c) = std_addf(std_mulf( - I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), K(kh, kw, c))); + O(n, oh, ow, c) = std_addf( + O(n, oh, ow, c), + std_mulf(I(n, oh * strides[0] + kh, ow * strides[1] + kw, c), + K(kh, kw, c))); } ods_def: @@ -226,6 +229,7 @@ """ { O(n, w, f) = std_addf( + O(n, w, f), std_mulf(I(n, w * strides[0] + kw * dilations[0], c), K(kw, c, f))); } @@ -242,6 +246,7 @@ """ { O(n, f, w) = std_addf( + O(n, f, w), std_mulf(I(n, c, w * strides[0] + kw * dilations[0]), K(kw, c, f))); } @@ -257,10 +262,10 @@ order of (`N`, `H`, `W`, `F`, `KH`, `KW`, `C`). """ { - O(n, h, w, f) = - std_addf(std_mulf(I(n, h * strides[0] + kh * dilations[0], - w * strides[1] + kw * dilations[1], c), - K(kh, kw, c, f))); + O(n, h, w, f) = std_addf( + O(n, h, w, f), std_mulf(I(n, h * strides[0] + kh * dilations[0], + w * strides[1] + kw * dilations[1], c), + K(kh, kw, c, f))); } ods_def: @@ -277,10 +282,10 @@ order of (`N`, `F`, `H`, `W`, `KH`, `KW`, `C`). """ { - O(n, f, h, w) = - std_addf(std_mulf(I(n, c, h * strides[0] + kh * dilations[0], - w * strides[1] + kw * dilations[1]), - K(kh, kw, c, f))); + O(n, f, h, w) = std_addf( + O(n, f, h, w), std_mulf(I(n, c, h * strides[0] + kh * dilations[0], + w * strides[1] + kw * dilations[1]), + K(kh, kw, c, f))); } ods_def: @@ -297,11 +302,11 @@ order of (`N`, `D`, `H`, `W`, `F`, `KD`, `KH`, `KW`, `C`). """ { - O(n, d, h, w, f) = - std_addf(std_mulf(I(n, d * strides[0] + kd * dilations[0], - h * strides[1] + kh * dilations[1], - w * strides[2] + kw * dilations[2], c), - K(kd, kh, kw, c, f))); + O(n, d, h, w, f) = std_addf( + O(n, d, h, w, f), std_mulf(I(n, d * strides[0] + kd * dilations[0], + h * strides[1] + kh * dilations[1], + w * strides[2] + kw * dilations[2], c), + K(kd, kh, kw, c, f))); } ods_def: @@ -318,8 +323,9 @@ order of (`N`, `F`, `D`, `H`, `W`, `KD`, `KH`, `KW`, `C`). """ { - O(n, f, d, h, w) = std_addf(std_mulf( - I(n, c, d * strides[0] + kd * dilations[0], - h * strides[1] + kh * dilations[1], w * strides[2] + kw * dilations[2]), - K(kd, kh, kw, c, f))); + O(n, f, d, h, w) = std_addf( + O(n, f, d, h, w), std_mulf(I(n, c, d * strides[0] + kd * dilations[0], + h * strides[1] + kh * dilations[1], + w * strides[2] + kw * dilations[2]), + K(kd, kh, kw, c, f))); } diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -31,7 +31,7 @@ // ods_def : def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) { - C(m) = std_addf(std_mulf(A(m, k), B(k))); + C(m) = std_addf(C(m), std_mulf(A(m, k), B(k))); } // ODS-LABEL: def Test2Op : LinalgStructuredBase_Op<"test2", [ @@ -55,7 +55,7 @@ // ods_def : def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) { - C(m, n) = std_addf(std_mulf(A(m, k), B(k, n))); + C(m, n) = std_addf(C(m, n), std_mulf(A(m, k), B(k, n))); } // ODS-LABEL: def Test3Op : LinalgStructuredBase_Op<"test3", [ @@ -79,7 +79,7 @@ // ods_def : def test3(A: f32(Batch, M, K), B: f32(K, N)) -> (C: f32(Batch, M, N)) { - C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); + C(b, m, n) = std_addf(C(b, m, n), std_mulf(A(b, m, k), B(k, n))); } // Test attribute definitions @@ -115,7 +115,7 @@ array_attr : f32[], optional_attr? : f32 ) { - C(b, m, n) = std_addf(std_mulf(A(b, m, k), B(k, n))); + C(b, m, n) = std_addf(C(b, m, n), std_mulf(A(b, m, k), B(k, n))); } // Test attribute usage in affine expressions @@ -157,7 +157,7 @@ It has one output. """ { - C(m) = std_addf(std_mulf(A(m, k), B(k))); + C(m) = std_addf(C(m), std_mulf(A(m, k), B(k))); } // Test attribute builder @@ -172,5 +172,17 @@ def test7(A: f32(M, K), B: f32(K)) -> (C: f32(M)) attr(attr_a: f32, attr_b: 4xi32) { - C(m) = std_addf(std_mulf(A(m, k), B(k))); + C(m) = std_addf(C(m), std_mulf(A(m, k), B(k))); +} + +// Test output arg order. +// IMPL-LABEL: void Test8Op::regionBuilder(Block &block, ValueRange captures) { +// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); +// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); +// IMPL: Value [[e:.*]] = std_subf([[d]], [[c]]); +// IMPL: (linalg_yield(ValueRange{ [[e]] })); +ods_def: +def test8(A: f32(M, K), B: f32(K)) -> (C: f32(M)) +{ + C(m) = std_subf(std_mulf(A(m, k), B(k)), C(m)); } diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1087,6 +1087,9 @@ /// Parses a tensor use. struct ComprehensionParsingState { + /// The number of operands (which includes inputs and outputs) in a + /// comprehension. + size_t numArgs; AffineDimList dims; SmallVector, 4> expressions; llvm::DenseMap orderedTensorArgs; @@ -1510,11 +1513,6 @@ reductionDims.push_back(iter.cast().getPosition()); } - // If this op is a reduction, it's first argument is the `currentDefinition` - // tensor use. - if (!reductionDims.empty()) - expressions.push_back(std::make_unique(currentDefinition)); - LLVM_DEBUG(llvm::dbgs() << "op: " << opOrTensor << "\n"); auto parseExpr = [&]() -> LogicalResult { std::unique_ptr e; @@ -1619,7 +1617,8 @@ auto tensorIter = registeredTensors.find(use.tensorId); assert(tensorIter != registeredTensors.end() && "unregistered tensor"); auto &tensor = tensorIter->getValue(); - if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0) { + if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0 && + tensor.indexingMap.getResults() != use.indexingMap.getResults()) { LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap); (void)parser.emitError( "Unexpected multi-read of a tensor with different accesses"); @@ -1630,6 +1629,7 @@ tensor.indexingMap = use.indexingMap; state.orderedTensorArgs[use] = tensor.index; }); + state.numArgs = seenDefs.size(); if (failed) return failure(); @@ -2004,8 +2004,7 @@ // Finally put everything together. os << llvm::formatv(header, cppOpName, linalgOpName, interfaceNameList, doc, - attrList, state.orderedTensorArgs.size(), attrBuilder, - attrMethods); + attrList, state.numArgs, attrBuilder, attrMethods); } /// Print the C++ StructuredOpsInterface impl of `iterator_types`. @@ -2203,7 +2202,7 @@ std::string mapsStr; llvm::raw_string_ostream mapsStringStream(mapsStr); - SmallVector orderedUses(state.orderedTensorArgs.size()); + SmallVector orderedUses(state.numArgs); for (const auto &it : state.orderedTensorArgs) orderedUses[it.second] = it.first; @@ -2286,7 +2285,7 @@ /// Print the C++ StructuredOpsInterface impl of `regionBuilder`. void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName, ComprehensionParsingState &state) { - unsigned count = state.orderedTensorArgs.size(); + unsigned count = state.numArgs; llvm::DenseMap subExprsMap; std::function printExpr; printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void { @@ -2326,7 +2325,7 @@ std::string valueHandleStr; llvm::raw_string_ostream valueHandleStringStream(valueHandleStr); llvm::interleaveComma( - state.orderedTensorArgs, valueHandleStringStream, [&](auto) { + llvm::seq(0, state.numArgs), valueHandleStringStream, [&](auto) { valueHandleStringStream << "_" << idx << "(args[" << idx << "])"; idx++; });