diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -2,6 +2,7 @@ set(LLVM_TARGET_DEFINITIONS LinalgStructuredOps.td) mlir_tablegen(LinalgStructuredOps.h.inc -gen-op-decls) mlir_tablegen(LinalgStructuredOps.cpp.inc -gen-op-defs) +mlir_tablegen(LinalgNamedStructuredOpsInterface.cpp.inc -gen-linalg-named-structured-ops-defs) add_public_tablegen_target(MLIRLinalgStructuredOpsIncGen) set(LLVM_TARGET_DEFINITIONS LinalgStructuredOpsInterface.td) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -691,4 +691,36 @@ let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// Named Linalg ops, implemented as a declarative configurations of generic ops. +//===----------------------------------------------------------------------===// + +def NamedStructuredOpTraits : NativeOpTrait<"linalg::NamedStructuredOpTraits">; + +class LinalgNamedStructured_Op props> + : Op { + string spec = ?; + let assemblyFormat = "`(` operands `)` attr-dict `:` " + "functional-type(operands, results)"; +} + +def PointwiseAddOp : LinalgNamedStructured_Op<"pointwise_add", [ + NInputs<2>, + NOutputs<1>, + NamedStructuredOpTraits]> { + let arguments = (ins Variadic:$views); + let results = (outs Variadic:$output_tensors); + let spec = "A(...) = B(...) + C(...)"; +} + +def PointwiseMulOp : LinalgNamedStructured_Op<"pointwise_mul", [ + NInputs<2>, + NOutputs<1>, + NamedStructuredOpTraits]> { + let arguments = (ins Variadic:$views); + let results = (outs Variadic:$output_tensors); + let spec = "A(...) = B(...) * C(...)"; +} + #endif // LINALG_STRUCTURED_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -340,6 +340,17 @@ } }; +/// This class provides the API for named Linalg StructuredOps. +template +class NamedStructuredOpTraits + : public OpTrait::TraitBase { +public: + llvm::Optional> referenceIterators(); + llvm::Optional> referenceIndexingMaps(); + std::function)> + emitScalarImplementation(); +}; + } // namespace linalg } // namespace OpTrait } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -860,6 +860,17 @@ return success(); } +namespace mlir { +namespace OpTrait { +namespace linalg { + +// Linalg "named" traits live in the mlir::OpTrait::linalg namespace. +#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsInterface.cpp.inc" + +} // namespace linalg +} // namespace OpTrait +} // namespace mlir + namespace mlir { namespace linalg { @@ -879,8 +890,6 @@ MLIRContext *context) { if (maybeMap) return maybeMap.getValue(); - if (rank == 0) - return AffineMap::get(context); return AffineMap::getMultiDimIdentityMap(rank, context); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -83,7 +83,17 @@ namespace { template -class LinalgScopedEmitter {}; +class LinalgScopedEmitter { +public: + static void emitScalarImplementation(ArrayRef allIvs, + LinalgOpType linalgOp) { + assert(linalgOp.hasBufferSemantics() && + "expected linalg op with buffer semantics"); + llvm_unreachable("NYI"); + linalgOp.emitScalarImplementation()(ScopedContext::getBuilder(), + ScopedContext::getLocation(), allIvs); + } +}; template class LinalgScopedEmitter { diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -155,7 +155,9 @@ dimExprs.reserve(numDims); for (unsigned i = 0; i < numDims; ++i) dimExprs.push_back(mlir::getAffineDimExpr(i, context)); - return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs); + return dimExprs.empty() ? + get(/*dimCount=*/0, /*symbolCount=*/0, context) : + get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs); } MLIRContext *AffineMap::getContext() const { return map->context; } diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -544,16 +544,35 @@ memref return } -// CHECK-LABEL: func @reshape -// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] -// CHECK-SAME: memref into memref -// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] -// CHECK-SAME: memref into memref -// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] -// CHECK-SAME: memref into memref -// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] -// CHECK-SAME: memref into memref -// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] -// CHECK-SAME: memref into memref -// CHECK: linalg.reshape {{.*}} [#[[reshapeD01]], #[[reshapeD2]]] -// CHECK-SAME: memref into memref + +func @pointwise1d(%a: memref, %b: memref, %c: memref, + %ta: tensor, %tb: tensor, %tc: tensor) { + linalg.pointwise_add(%a, %b, %c): (memref, memref, memref) -> () + %dadd = linalg.pointwise_add(%a, %b): (memref, memref) -> tensor + %eadd = linalg.pointwise_add(%ta, %b): (tensor, memref) -> tensor + linalg.pointwise_add(%ta, %tb, %c): (tensor, tensor, memref) -> () + + linalg.pointwise_mul(%a, %b, %c): (memref, memref, memref) -> () + %dmul = linalg.pointwise_mul(%a, %b): (memref, memref) -> tensor + %emul = linalg.pointwise_mul(%ta, %b): (tensor, memref) -> tensor + linalg.pointwise_mul(%ta, %tb, %c): (tensor, tensor, memref) -> () + + return +} +// CHECK-LABEL: func @pointwise1d +// CHECK: linalg.pointwise_add +// CHECK-SAME: (memref, memref, memref) -> () +// CHECK: linalg.pointwise_add +// CHECK-SAME: (memref, memref) -> tensor +// CHECK: linalg.pointwise_add +// CHECK-SAME: (tensor, memref) -> tensor +// CHECK: linalg.pointwise_add +// CHECK-SAME: (tensor, tensor, memref) -> () +// CHECK: linalg.pointwise_mul +// CHECK-SAME: (memref, memref, memref) -> () +// CHECK: linalg.pointwise_mul +// CHECK-SAME: (memref, memref) -> tensor +// CHECK: linalg.pointwise_mul +// CHECK-SAME: (tensor, memref) -> tensor +// CHECK: linalg.pointwise_mul +// CHECK-SAME: (tensor, tensor, memref) -> () diff --git a/mlir/tools/mlir-tblgen/CMakeLists.txt b/mlir/tools/mlir-tblgen/CMakeLists.txt --- a/mlir/tools/mlir-tblgen/CMakeLists.txt +++ b/mlir/tools/mlir-tblgen/CMakeLists.txt @@ -6,6 +6,7 @@ add_tablegen(mlir-tblgen MLIR DialectGen.cpp EnumsGen.cpp + LinalgNamedOpsGen.cpp LLVMIRConversionGen.cpp LLVMIRIntrinsicGen.cpp mlir-tblgen.cpp diff --git a/mlir/tools/mlir-tblgen/LinalgNamedOpsGen.cpp b/mlir/tools/mlir-tblgen/LinalgNamedOpsGen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-tblgen/LinalgNamedOpsGen.cpp @@ -0,0 +1,136 @@ +//===- LinalgNamedOpsGen.cpp - MLIR Linalg op generator -------------------===// +// +// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// LinalgNamedOpsGen uses the description of structured operations in the Linalg +// dialect to generate op definitions, parsers, pretty-printers and matchers +// (e.g. linalg.matmul). +// +//===----------------------------------------------------------------------===// + +#include "mlir/Support/STLExtras.h" +#include "mlir/TableGen/GenInfo.h" +#include "mlir/TableGen/OpClass.h" +#include "mlir/TableGen/Operator.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" + +using namespace llvm; +using namespace mlir; +using namespace mlir::tblgen; + +using mlir::tblgen::Operator; + +// Emits a function with signature: +// `llvm::Optional> referenceIterators()` +// This is an instance method as opposed to a class method (i.e. a static +// method) so that it can work in a rank-agnostic fashion. +// TODO: lang + parser. +void iteratorTypesBody(const Record &r, StringRef opName, + llvm::raw_ostream &os) { + if (r.getValueAsString("spec") == "A(...) = B(...) + C(...)" || + r.getValueAsString("spec") == "A(...) = B(...) * C(...)") { + // Rank-polymorphic pointwise iterator types: `getNumParallelLoops` is + // retrieved from the generic op's referenceIterators. The + // referenceIterators are built by looking at the operands of the op + // instance soit can be polymorphic at compiler compile time and static at + // MLIR compile time. + os << opName << " concreteOp = cast<" << opName << ">(getOperation());"; + os << "unsigned nPar = concreteOp.getNumParallelLoops();"; + os << "return SmallVector(nPar, " + "getParallelIteratorTypeName());"; + } +} + +// Emits a instance method with signature: +// `llvm::Optional> referenceIndexingMaps()` +// This is an instance method as opposed to a class method (i.e. a static +// method) so that it can work in a rank-agnostic fashion. +// TODO: lang + parser. +void indexingMapsBody(const Record &r, StringRef opName, + llvm::raw_ostream &os) { + if (r.getValueAsString("spec") == "A(...) = B(...) + C(...)" || + r.getValueAsString("spec") == "A(...) = B(...) * C(...)") { + // Rank-polymorphic pointwise indexing maps. + os << opName << " concreteOp = cast<" << opName << ">(getOperation());"; + os << R"SPEC( + MLIRContext *context = getOperation()->getContext(); + auto id = AffineMap::getMultiDimIdentityMap( + concreteOp.getNumParallelLoops(), context); + return SmallVector{id, id, id}; + )SPEC"; + } +} + +// Emits a instance method with signature: +// `void emitScalarImplementationBody()` +// This is an instance method as opposed to a class method (i.e. a static +// method) so that it can work in a rank-agnostic fashion. +// TODO: lang + parser. +void emitScalarImplementationBody(const Record &r, StringRef opName, + llvm::raw_ostream &os) { + if (r.getValueAsString("spec") == "A(...) = B(...) + C(...)" || + r.getValueAsString("spec") == "A(...) = B(...) * C(...)") { + // Rank-polymorphic pointwise indexing maps. + os << R"SPEC( + return [](OpBuilder &, Location, ArrayRef) { + llvm_unreachable("NYI in LinalgNamedOpsGen.cpp"); + }; + )SPEC"; + } +} + +template +void emitMethodWithBody(StringRef traitName, StringRef opName, + StringRef resultType, StringRef funName, + StringRef operands, const Record &r, + llvm::raw_ostream &os, Lambda fun) { + os << "template<>\n"; + os << resultType << " " << traitName << "<" << opName << ">::" << funName + << "(" << operands << ") {\n"; + fun(r, opName, os); + os << "\n}\n\n"; +} + +bool emitLinalgNamedOpsInterfaceFunctions(const llvm::RecordKeeper &records, + llvm::raw_ostream &os) { + llvm::emitSourceFileHeader("Operations for Linalg Named Structured Ops", os); + os << "#include \"mlir/Dialect/Linalg/IR/LinalgTraits.h\"\n\n"; + + auto defs = records.getAllDerivedDefinitions("LinalgNamedStructured_Op"); + for (const llvm::Record *r : defs) { + const auto *traitName = "NamedStructuredOpTraits"; + auto opName = r->getName(); + const auto *resultType = "llvm::Optional>"; + const auto *funName = "referenceIterators"; + emitMethodWithBody(traitName, opName, resultType, funName, /*operands=*/"", + *r, os, iteratorTypesBody); + + resultType = "llvm::Optional>"; + funName = "referenceIndexingMaps"; + emitMethodWithBody(traitName, opName, resultType, funName, /*operands=*/"", + *r, os, indexingMapsBody); + + resultType = "std::function)>"; + funName = "emitScalarImplementation"; + emitMethodWithBody(traitName, opName, resultType, funName, /*operands=*/"", + *r, os, emitScalarImplementationBody); + } + + return false; +} + +static mlir::GenRegistration + genLinalgNamedOps("gen-linalg-named-structured-ops-defs", + "Generate Linalg named StructuredOps", + emitLinalgNamedOpsInterfaceFunctions);