diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md --- a/mlir/docs/Dialects/Linalg.md +++ b/mlir/docs/Dialects/Linalg.md @@ -662,6 +662,18 @@ } ``` +### YAML Based Named Structured Ops + +Linalg provides a declarative generation tool (`mlir-linalg-ods-yaml-gen`) to +automatically produce named ops from a YAML-based op description format +intended to capture the structure of the named ops and be generated from a +higher level "mathy" DSL syntax. This facility is currently in flight and is +intended to subsume the above when ready. See the C++ class to YAML mapping +traits in `mlir-mlinalg-ods-yaml-gen.cpp` as the source of truth for the schema. + +Most of the above documentation roughly applies to this path and will be ported +as migration continues. + ## Open Issues and Design Alternatives Multiple open issues and design alternatives are in flight and it is time to lay diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -1,8 +1,8 @@ # Declare a function to generate ODS with mlir-linalg-ods-gen -function(add_linalg_ods_gen tc_filename output_file) +function(add_linalg_ods_tc_gen tc_filename output_file) set(TC_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${tc_filename}) - set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.td) - set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.cpp.inc) + set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.tcgen.td) + set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.tcgen.cpp.inc) set_source_files_properties( ${GEN_ODS_FILE} PROPERTIES GENERATED TRUE) @@ -20,17 +20,52 @@ ${MLIR_LINALG_ODS_GEN_TARGET} VERBATIM) add_custom_target( - MLIR${output_file}IncGen + MLIR${output_file}TcIncGen DEPENDS ${MLIR_LINALG_ODS_GEN_EXE} ${MLIR_LINALG_ODS_GEN_TARGET} ${GEN_ODS_FILE} ${GEN_CPP_FILE}) endfunction() -add_linalg_ods_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps) +# Declare a function to generate ODS with mlir-linalg-ods-yaml-gen +function(add_linalg_ods_yaml_gen yaml_ast_file output_file) + set(YAML_AST_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${yaml_ast_file}) + set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.yamlgen.td) + set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.yamlgen.cpp.inc) + set_source_files_properties( + ${GEN_ODS_FILE} + PROPERTIES GENERATED TRUE) + set_source_files_properties( + ${GEN_CPP_FILE} + PROPERTIES GENERATED TRUE) + add_custom_command( + OUTPUT ${GEN_ODS_FILE} ${GEN_CPP_FILE} + COMMAND ${MLIR_LINALG_ODS_YAML_GEN_EXE} ${YAML_AST_SOURCE} -o-ods-decl=${GEN_ODS_FILE} -o-impl=${GEN_CPP_FILE} + MAIN_DEPENDENCY + ${YAML_AST_SOURCE} + DEPENDS + ${MLIR_LINALG_ODS_YAML_GEN_EXE} + ${MLIR_LINALG_ODS_YAML_GEN_TARGET}) + add_custom_target( + MLIR${output_file}YamlIncGen + DEPENDS + ${MLIR_LINALG_ODS_YAML_GEN_EXE} + ${MLIR_LINALG_ODS_YAML_GEN_TARGET} + ${GEN_ODS_FILE} ${GEN_CPP_FILE}) +endfunction() + +# TODO: Delete tc generation and replace with the YAML variant once all ops are +# ported. +add_linalg_ods_tc_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps) +add_linalg_ods_yaml_gen(LinalgNamedStructuredOps.yaml LinalgNamedStructuredOps) + # Provide a short name for all external dependency that needs to # include Linalg in ODS -add_custom_target(LinalgOdsGen DEPENDS MLIRLinalgNamedStructuredOpsIncGen) +add_custom_target(LinalgOdsGen + DEPENDS + MLIRLinalgNamedStructuredOpsTcIncGen + MLIRLinalgNamedStructuredOpsYamlIncGen +) add_dependencies(mlir-headers LinalgOdsGen) add_mlir_dialect(LinalgOps linalg) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -0,0 +1,50 @@ +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: polymorphic_matmul + cpp_op_name: PolymorphicMatmulOp + doc: |- + Type polymorphic matrix multiplication. + + This op is presently here to test a new path for generation and will replace + the existing 'matmul' op when ready. Do not use. +structured_op: !LinalgStructuredOpConfig + args: + - ! + name: A + usage: input + shape: affine_map<()[s0, s1, s2] -> (s0, s2)> + - ! + name: B + usage: input + shape: affine_map<()[s0, s1, s2] -> (s2, s1)> + - ! + name: C + usage: output + shape: affine_map<()[s0, s1, s2] -> (s0, s1)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)> + - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)> + iterator_types: + - parallel + - parallel + - reduction + assignments: + - !ScalarAssign + arg: C + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: C + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + scalar_arg: A + - !ScalarExpression + scalar_arg: B + diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -343,7 +343,7 @@ // parallelized across; i.e. [zs] in the TF notation above whose number // match `xs` (i.e. 1 window loop per "image" dimension). // This may evolve in the future. - // Conditionally check nPar is large enough for cases of ill-formed op: + // Conditionally check nPar is large enough for cases of ill-formed op: // this avoids overflows before hitting the verifier. assert(nPar > getNumBatchDimensions() + getNumInputFeatureDimensions() && "expected at least one window dimension (i.e. memref ranks greater " @@ -806,6 +806,7 @@ //===----------------------------------------------------------------------===// // This file is auto-generated from a TC def specification. -include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td" +include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.td" +include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.td" #endif // LINALG_STRUCTURED_OPS diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -14,6 +14,7 @@ LINK_LIBS PUBLIC MLIRAffine MLIRIR + MLIRParser MLIRSideEffectInterfaces MLIRViewLikeInterface MLIRStandard diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Parser.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SetVector.h" @@ -121,6 +122,81 @@ return success(folded); } +//===----------------------------------------------------------------------===// +// Region builder helper. +// TODO: Move this to a utility library. +// The public methods on this class are referenced directly from generated code +// and bind by name to math functions in the DSL as: +// `applyfn__{fnName}` +// Examples: +// `applyfn__add` +// `applyfn__mul` +// The naming convention is intentional in order to match snake-cased DSL names. +// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class. +// +// Implementations of the math functions must be polymorphic over numeric types, +// internally performing necessary casts. If the function application makes no +// sense, then the only recourse is to assert and return nullptr. This can be +// extended later if it becomes possible to fail construction of the region. The +// invariant should be enforced at a higher level. +// +// TODO: These helpers are currently type polymorphic over the class of integer +// and floating point types, but they will not internally cast within bit +// widths of a class (mixed precision such as i8->i32) or across classes +// (i.e. mixed float and integer). Many such combinations are ambiguous or need +// to be handled with care and work is being considered to extend the op +// language to make such cases explicit. In the mean-time, violating this will +// fail verification, which is deemed acceptable. +//===----------------------------------------------------------------------===// + +namespace { + +class RegionBuilderHelper { +public: + RegionBuilderHelper(Block &block) : block(block) {} + + Value applyfn__add(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(lhs); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + else if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + Value applyfn__mul(Value lhs, Value rhs) { + OpBuilder builder = getBuilder(lhs); + if (isFloatingPoint(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + else if (isInteger(lhs)) + return builder.create(lhs.getLoc(), lhs, rhs); + llvm_unreachable("unsupported non numeric type"); + } + + void yieldOutputs(ValueRange values) { + assert(!values.empty() && "linalg ops must yield outputs"); + if (values.empty()) + return; + Value first = values.front(); + OpBuilder builder = getBuilder(first); + builder.create(first.getLoc(), values); + } + +private: + Block █ + + bool isFloatingPoint(Value value) { return value.getType().isa(); } + bool isInteger(Value value) { return value.getType().isa(); } + + OpBuilder getBuilder(Value value) { + OpBuilder builder(value.getContext()); + builder.setInsertionPointToEnd(&block); + return builder; + } +}; + +} // namespace + //===----------------------------------------------------------------------===// // CopyOp //===----------------------------------------------------------------------===// @@ -1868,7 +1944,8 @@ struct FoldTensorCastOp; } // namespace -#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" +#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc" +#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" @@ -2032,7 +2109,8 @@ unsigned actual = body->getNumArguments(); unsigned expected = NamedStructuredOpType::getNumRegionArgs(); if (expected != actual) { - if (errorHandler) errorHandler(expected, actual); + if (errorHandler) + errorHandler(expected, actual); return; } diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s + +func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> { + %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>) + outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32> + return %0: tensor<16x32xf32> +} + +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32) +// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32 +// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32 +// CHECK-NEXT: linalg.yield %[[ADD]] : f32 +// CHECK-NEXT: -> tensor<16x32xf32> + +// ----- + +func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> { + %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>) + outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32> + return %0: tensor<16x32xi32> +} + +// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i32) +// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_ARG]], %[[B_ARG]] : i32 +// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32 +// CHECK-NEXT: linalg.yield %[[ADD]] : i32 +// CHECK-NEXT: -> tensor<16x32xi32> diff --git a/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt --- a/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt +++ b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt @@ -2,6 +2,13 @@ Core Support ) + +set(LLVM_OPTIONAL_SOURCES + mlir-linalg-ods-gen.cpp + mlir-linalg-ods-yaml-gen.cpp +) + +# Original mlir-linalg-ods-gen (to be replaced). add_llvm_tool(mlir-linalg-ods-gen mlir-linalg-ods-gen.cpp ) @@ -30,3 +37,35 @@ endif() endif() endif() + + +# New mlir-linalg-ods-yaml-gen. +add_llvm_tool(mlir-linalg-ods-yaml-gen + mlir-linalg-ods-yaml-gen.cpp +) +llvm_update_compile_flags(mlir-linalg-ods-yaml-gen) +target_link_libraries(mlir-linalg-ods-yaml-gen PRIVATE + MLIRIR + MLIRSupport + MLIRParser + ) + +set(MLIR_LINALG_ODS_YAML_GEN mlir-linalg-ods-yaml-gen CACHE + STRING "Native mlir-linalg-ods-yaml-gen executable. Saves building one when cross-compiling.") + +set(MLIR_LINALG_ODS_YAML_GEN_EXE ${MLIR_LINALG_ODS_YAML_GEN} PARENT_SCOPE) +set(MLIR_LINALG_ODS_YAML_GEN_TARGET mlir-linalg-ods-yaml-gen PARENT_SCOPE) + +if(LLVM_USE_HOST_TOOLS) +if ("${MLIR_LINALG_ODS_YAML_GEN_EXE}" STREQUAL mlir-linalg-ods-yaml-gen) + build_native_tool(mlir-linalg-ods-yaml-gen MLIR_LINALG_ODS_YAML_GEN_EXE DEPENDS mlir-linalg-ods-yaml-gen) + set(MLIR_LINALG_ODS_YAML_GEN_EXE ${MLIR_LINALG_ODS_YAML_GEN_EXE} PARENT_SCOPE) + + add_custom_target(mlir-linalg-ods-yaml-gen-host DEPENDS ${MLIR_LINALG_ODS_YAML_GEN_EXE}) + set(MLIR_LINALG_ODS_YAML_GEN_TARGET mlir-linalg-ods-yaml-gen-host DEPENDS PARENT_SCOPE) + + if(NOT LLVM_BUILD_UTILS) + set_target_properties(mlir-linalg-ods-yaml-gen PROPERTIES EXCLUDE_FROM_ALL ON) + endif() +endif() +endif() diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp new file mode 100644 --- /dev/null +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -0,0 +1,878 @@ +//===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements an ODS (and C++) generator from a YAML form +// derived from the mathematical expression of linalg named ops. Typically a +// math oriented DSL will be used to export the essential representation to +// this form, and maintaining the SOT at the math level (versus recreating it +// in MLIR) is deemed to have systemic value. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Parser.h" +#include "mlir/Support/FileUtilities.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/Support/YAMLTraits.h" + +using namespace mlir; + +using llvm::yaml::Input; +using llvm::yaml::IO; +using llvm::yaml::MappingTraits; +using llvm::yaml::ScalarEnumerationTraits; +using llvm::yaml::ScalarTraits; + +#define DEBUG_TYPE "linalg-ods-gen" + +//===----------------------------------------------------------------------===// +// Mapping structs (correspond to data types in the YAML description). +// TODO: Since this is a schema/part of the contract, it should be moved to +// a real header. +//===----------------------------------------------------------------------===// + +namespace { + +struct LinalgYAMLContext { + MLIRContext *mlirContext; +}; + +struct LinalgOpMetadata { + std::string name; + std::string cppOpName; + Optional doc; +}; + +struct SerializedAffineMap { + AffineMapAttr affineMapAttr; + + AffineMap affineMap() { return affineMapAttr.getValue(); } +}; + +enum class LinalgTensorUsageDef { + input, + output, + temporary, +}; + +struct LinalgTensorDef { + std::string name; + LinalgTensorUsageDef usage; + SerializedAffineMap shape; +}; + +enum class LinalgIteratorTypeDef { + parallel, + reduction, +}; + +struct LinalgIndexingMapsConfig { + Optional> staticIndexingMaps; +}; + +struct ScalarExpression; + +struct ScalarApply { + std::string fnName; + // NOTE: Must be pure heap allocated container (not SmallVector) + // due to recursive data type. + std::vector operands; +}; + +struct ScalarExpression { + Optional scalarArg; + Optional scalarApply; +}; + +struct ScalarAssign { + std::string arg; + ScalarExpression value; +}; + +struct LinalgStructuredOpConfig { + SmallVector args; + LinalgIndexingMapsConfig indexingMaps; + SmallVector iteratorTypes; + SmallVector assignments; +}; + +struct LinalgOpConfig { + Optional metadata; + Optional structuredOp; +}; + +} // namespace + +//===----------------------------------------------------------------------===// +// Mapping traits. +//===----------------------------------------------------------------------===// + +LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgTensorDef); +LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap); +LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef); +LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign); +LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression); +LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig); + +namespace llvm { +namespace yaml { + +/// Top-level type containing op metadata and one of a concrete op type. +/// Currently, the only defined op type is `structured_op` (maps to +/// `LinalgStructuredOpConfig`). +template <> +struct MappingTraits { + static void mapping(IO &io, LinalgOpConfig &info) { + io.mapOptional("metadata", info.metadata); + io.mapOptional("structured_op", info.structuredOp); + } +}; + +/// A structured op models (at most) a single contraction by modeling +/// - A list of named arguments (`LinalgTensorDef`), which can be inputs, +/// outputs, or temporaries. +/// - List of indexing maps (see `LinalgIndexingMaps`). +/// - Iterator types (see `LinalgIteratorTypeDef`). +/// - List of scalar level assignment (see `ScalarAssign`). +template <> +struct MappingTraits { + static void mapping(IO &io, LinalgStructuredOpConfig &info) { + io.mapRequired("args", info.args); + io.mapRequired("indexing_maps", info.indexingMaps); + io.mapRequired("iterator_types", info.iteratorTypes); + io.mapRequired("assignments", info.assignments); + } +}; + +/// Maps a named tensor-argument to an operation, consisting of: +/// - `name`: Must be unique within the operation. +/// - `usage`: How the argument is used (input, output, etc). +/// - `shape`: An AffineMap from all op symbols to the specific shape +/// of this argument. Each shape must be normalized over the same list of +/// symbols and have no dimension inputs. +template <> +struct MappingTraits { + static void mapping(IO &io, LinalgTensorDef &info) { + io.mapRequired("name", info.name); + io.mapRequired("usage", info.usage); + io.mapRequired("shape", info.shape); + } +}; + +/// Usage enum for a named argument. +template <> +struct ScalarEnumerationTraits { + static void enumeration(IO &io, LinalgTensorUsageDef &value) { + io.enumCase(value, "input", LinalgTensorUsageDef::input); + io.enumCase(value, "output", LinalgTensorUsageDef::output); + io.enumCase(value, "temporary", LinalgTensorUsageDef::temporary); + } +}; + +/// Iterator type enum. +template <> +struct ScalarEnumerationTraits { + static void enumeration(IO &io, LinalgIteratorTypeDef &value) { + io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel); + io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction); + } +}; + +/// Metadata about the op (name, C++ name, and documentation). +template <> +struct MappingTraits { + static void mapping(IO &io, LinalgOpMetadata &info) { + io.mapRequired("name", info.name); + io.mapRequired("cpp_op_name", info.cppOpName); + io.mapOptional("doc", info.doc); + } +}; + +/// How the ops indexing maps are produced. Must be one of: +/// - static_indexing_maps: A static list of AffineMaps, possibly with +/// some symbols that bind to attributes of the op. Each indexing map must +/// be normalized over the same list of dimensions, and its symbols must +/// match the symbols for argument shapes. +template <> +struct MappingTraits { + static void mapping(IO &io, LinalgIndexingMapsConfig &info) { + io.mapOptional("static_indexing_maps", info.staticIndexingMaps); + } +}; + +/// Models an assignment to a named output. +/// - The `arg` name must match a named output or temporary. +/// - The `value` is a scalar expression for computing the value to +/// assign (see `ScalarExpression`). +template <> +struct MappingTraits { + static void mapping(IO &io, ScalarAssign &info) { + io.mapRequired("arg", info.arg); + io.mapRequired("value", info.value); + } +}; + +/// A scalar expression (RHS of an assignment). Must be one of: +/// - `scalar_arg`: Name of an argument to the op. +/// - `scalar_apply`: Result of evaluating a named function (see +/// `ScalarApply`). +template <> +struct MappingTraits { + static void mapping(IO &io, ScalarExpression &info) { + io.mapOptional("scalar_arg", info.scalarArg); + io.mapOptional("scalar_apply", info.scalarApply); + } +}; + +/// A scalar expression that evaluates a named function. +/// Functions are generally "math" level and type polymorphic. Builtin +/// functions include: +/// - `add(lhs, rhs)` +/// - `mul(lhs, rhs)` +template <> +struct MappingTraits { + static void mapping(IO &io, ScalarApply &info) { + io.mapRequired("fn_name", info.fnName); + io.mapRequired("operands", info.operands); + } +}; + +/// Helper mapping which accesses an AffineMapAttr as a serialized string of +/// the same. +template <> +struct ScalarTraits { + static void output(const SerializedAffineMap &value, void *rawYamlContext, + raw_ostream &out) { + assert(value.affineMapAttr); + value.affineMapAttr.print(out); + } + static StringRef input(StringRef scalar, void *rawYamlContext, + SerializedAffineMap &value) { + assert(rawYamlContext); + auto *yamlContext = static_cast(rawYamlContext); + if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext) + .dyn_cast_or_null()) + value.affineMapAttr = attr; + else if (!value.affineMapAttr || !value.affineMapAttr.isa()) + return "could not parse as an affine map attribute"; + return StringRef(); + } + static QuotingType mustQuote(StringRef) { return QuotingType::None; } +}; + +} // namespace yaml +} // namespace llvm + +namespace { + +//===----------------------------------------------------------------------===// +// Generation utilities +//===----------------------------------------------------------------------===// + +class GenerationContext { +public: + GenerationContext(MLIRContext *context, raw_ostream *odsOut, + raw_ostream *defnOut) + : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut), + defnOut(defnOut) {} + + MLIRContext *getContext() { return context; } + + void setLoc(Location loc) { this->loc = loc; } + Location getLoc() { return loc; } + + bool shouldGenerateOds() { return odsOut; } + bool shouldGenerateDefns() { return defnOut; } + + raw_ostream &odss() { + assert(odsOut && "ODS stream not defined"); + return *odsOut; + } + + raw_ostream &defns() { + assert(defnOut && "Definition stream not defined"); + return *defnOut; + } + +private: + MLIRContext *context; + Location loc; + raw_ostream *odsOut; + raw_ostream *defnOut; +}; + +} // namespace + +static std::string generateCppExpression(SerializedAffineMap self, + StringRef contextName) { + std::string printedStr; + llvm::raw_string_ostream printedSs(printedStr); + self.affineMapAttr.print(printedSs); + printedSs.flush(); + + static const char exprFormat[] = + R"FMT(mlir::parseAttribute("{0}", {1}).cast().getValue())FMT"; + return llvm::formatv(exprFormat, printedStr, contextName); +} + +template +static std::string interleaveToString(Container &container, + StringRef separator) { + std::string result; + llvm::raw_string_ostream ss(result); + llvm::interleave(container, ss, separator); + ss.flush(); + return result; +} + +static Optional +findTensorDefArgIndex(StringRef name, SmallVectorImpl &args) { + for (auto it : llvm::enumerate(args)) { + if (it.value().name == name) + return it.index(); + } + return None; +} + +static ScalarAssign * +findAssignment(StringRef name, SmallVectorImpl &assignments) { + for (auto &assign : assignments) { + if (assign.arg == name) + return &assign; + } + return nullptr; +} + +//===----------------------------------------------------------------------===// +// Templates +//===----------------------------------------------------------------------===// + +// A single line banner format. Parameters: +// {0}: Single line comment +static const char bannerFormat[] = R"FMT( +//===----------------------------------------------------------------------===// +// {0} +//===----------------------------------------------------------------------===// +)FMT"; + +//===----------------------------------------------------------------------===// +// Named generic op generation. +// These ops map at most a single contraction that complies with the limitations +// of a linalg.generic. +//===----------------------------------------------------------------------===// + +// Template for Linalg named ops' ODS definitions. Parameters: +// {0}: ODS/C++ op name +// {1}: assembly op mnemonic +// {2}: op interface list +// {3}: documentation (summary + description) +// {4}: op attribute list +// {5}: the number of arguments for the op region +// {6}: builder methods taking standalone attribute parameters +// {7}: additional methods for attributes used by indexing maps +static const char structuredOpOdsHeaderFormat[] = R"FMT( +//===----------------------------------------------------------------------===// +// Op definition for {0} +//===----------------------------------------------------------------------===// + +def {0} : LinalgStructuredBase_Op<"{1}", [ + AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp"> + /*extraInterfaces=*/{2}]> { + {3} + let arguments = (ins + Variadic:$inputs, + Variadic:$outputs{4} + ); + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilderDAG< + (ins "ValueRange":$inputs, "ValueRange":$outputs), + [{{ + $_state.addOperands(inputs); + $_state.addOperands(outputs); + $_state.addAttribute( + "operand_segment_sizes", + $_builder.getI32VectorAttr({{ + static_cast(inputs.size()), + static_cast(outputs.size())})); + createAndFillStructuredOpRegion<{0}>( + $_builder, + $_state, + TypeRange(inputs), + TypeRange(outputs)/*, TODO: support captures*/); + }]>, + OpBuilderDAG< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, + "ValueRange":$outputs), + [{{ + $_state.addOperands(inputs); + $_state.addOperands(outputs); + $_state.addTypes(resultTensorTypes); + $_state.addAttribute( + "operand_segment_sizes", + $_builder.getI32VectorAttr({{ + static_cast(inputs.size()), + static_cast(outputs.size())})); + createAndFillStructuredOpRegion<{0}>( + $_builder, + $_state, + TypeRange(inputs), + TypeRange(outputs)/*, TODO: support captures*/); + }]>, + OpBuilderDAG< + (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands, + CArg<"ArrayRef", "{{}">:$attributes), + [{{ + $_state.addOperands(operands); + $_state.addAttributes(attributes); + $_state.addTypes(resultTensorTypes); + (void)$_state.addRegion(); + }]> + {6} + ]; + let printer = [{{ return ::printNamedStructuredOp(p, *this); }]; + let parser = [{{ + return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/); + }]; + let hasFolder = 1; + let hasCanonicalizer = 1; + + let extraClassDeclaration = structuredOpsBaseDecls # [{{ + // Auto-generated. + ArrayAttr iterator_types(); + ArrayAttr indexing_maps(); + static void regionBuilder(Block &block, ValueRange captures); + static std::function getRegionBuilder() {{ + return regionBuilder; + } + + // Generic methods. + static unsigned getNumRegionArgs(); + std::string getLibraryCallName(); + {7} + }]; +} +)FMT"; + +// The iterator_types() method implementation. Parameters: +// {0}: Class name +// {1}: Comma interleaved iterator type names. +static const char structuredOpIteratorTypesFormat[] = + R"FMT( +ArrayAttr {0}::iterator_types() { + return Builder(getContext()).getStrArrayAttr(SmallVector{{ {1} }); +} +)FMT"; + +// Implementations of getCanonicalizationPatterns, fold and getEffects. +// Parameters: +// {0}: Class name +const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT( +void {0}::getCanonicalizationPatterns( + OwningRewritePatternList &results, + MLIRContext *context) {{ + results.insert(); + results.insert(); +} +LogicalResult {0}::fold(ArrayRef, + SmallVectorImpl &) {{ + return foldMemRefCast(*this); +} +void {0}::getEffects(SmallVectorImpl< + SideEffects::EffectInstance >&effects) {{ + getGenericEffectsImpl(effects, + getOperation()->getResults(), getInputBuffers(), getOutputBuffers()); +} +)FMT"; + +static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig, + GenerationContext &genContext) { + if (!genContext.shouldGenerateOds()) + return success(); + + raw_ostream &os = genContext.odss(); + + std::string interfaceNameList; + std::string attrList; + std::string attrMethods; + std::string attrBuilder; + + std::string doc; + if (opConfig.metadata->doc) { + const char *docFmt = R"FMT( + let summary = [{ {0} }]; + let description = [{ + {1} + }]; + )FMT"; + StringRef summary, description; + std::tie(summary, description) = + StringRef(*opConfig.metadata->doc).trim().split('\n'); + doc = llvm::formatv(docFmt, summary.trim(), description.trim()); + } + + os << llvm::formatv(structuredOpOdsHeaderFormat, opConfig.metadata->cppOpName, + opConfig.metadata->name, interfaceNameList, doc, attrList, + opConfig.structuredOp->args.size(), attrBuilder, + attrMethods); + + return success(); +} + +static LogicalResult +generateNamedGenericOpDefns(LinalgOpConfig &opConfig, + GenerationContext &genContext) { + if (!genContext.shouldGenerateDefns()) + return success(); + + raw_ostream &os = genContext.defns(); + StringRef className = opConfig.metadata->cppOpName; + + // Implementation banner. + std::string bannerComment = llvm::formatv("Implementation of {0}", className); + os << llvm::formatv(bannerFormat, bannerComment); + + // Reference iterators. + { + std::string iteratorsStr; + llvm::raw_string_ostream ss(iteratorsStr); + llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss, + [&](LinalgIteratorTypeDef it) { + switch (it) { + case LinalgIteratorTypeDef::parallel: + ss << "getParallelIteratorTypeName()"; + break; + case LinalgIteratorTypeDef::reduction: + ss << "getReductionIteratorTypeName()"; + break; + } + }); + ss.flush(); + os << llvm::formatv(structuredOpIteratorTypesFormat, className, + iteratorsStr); + } + + // Static indexing maps. + if (auto &staticMaps = + opConfig.structuredOp->indexingMaps.staticIndexingMaps) { + if (staticMaps->empty()) + return emitError(genContext.getLoc()) << "op has no indexing maps"; + AffineMap firstMap = staticMaps->front().affineMap(); + + // Symbol bindings. + { + // For each symbol, generate a declaration for it, either with an + // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from + // an attribute). + // TODO: Implement attribute constants. + // TODO: Possibly lift into a top-level method. + static const char structuredOpSymbolBindingsFormat[] = R"FMT( +static SmallVector getSymbolBindings({0} self) { + MLIRContext *context = self.getContext(); + SmallVector exprs; +{1} + return exprs; +} +)FMT"; + + unsigned symbolCount = firstMap.getNumSymbols(); + SmallVector symbolBindings; + for (unsigned i = 0; i < symbolCount; ++i) { + // TODO: Switch and emit constants for attribute bound symbols. + symbolBindings.push_back(llvm::formatv( + " exprs.push_back(getAffineSymbolExpr({0}, context));", i)); + } + std::string symbolBindingsStr; + llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr); + llvm::interleave(symbolBindings, symbolBindingsSs, "\n"); + symbolBindingsSs.flush(); + + os << llvm::formatv(structuredOpSymbolBindingsFormat, className, + symbolBindingsStr); + } + + // Indexing maps. + { + // Parameters: + // {0}: Class name + // {1}: Comma-separated list of dimension variable names. + // {2}: Statements + static const char structuredOpIndexingMapsFormat[] = R"FMT( +ArrayAttr {0}::indexing_maps() { + MLIRContext *context = getContext(); + auto symbolBindings = getSymbolBindings(*this); + SmallVector maps; + {2} + return Builder(context).getAffineMapArrayAttr(maps); +} +)FMT"; + + unsigned dimCount = firstMap.getNumDims(); + + // Generate a comma-separated list of dim identifiers to be passed to + // bindDims, ensuring tht AffineExpr identifiers are bound in the right + // order to the proper AffineDimExpr. + // This results in vars in scope like: d0, d1, d2... + SmallVector dimIndices; + for (unsigned i = 0; i < dimCount; ++i) + dimIndices.push_back(i); + std::string dimIdentsStr; + llvm::raw_string_ostream dimIdentsSs(dimIdentsStr); + llvm::interleaveComma(dimIndices, dimIdentsSs, + [&](unsigned i) { dimIdentsSs << "d" << i; }); + dimIdentsSs.flush(); + + // Statements to add and simplify each affine map. + SmallVector stmts; + for (auto &indexingMap : *staticMaps) { + // TODO: Assert that dim and symbol count match the first. + stmts.push_back( + llvm::formatv("maps.push_back({0});", + generateCppExpression(indexingMap, "context"))); + stmts.push_back(llvm::formatv( + "maps.back() = " + "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, " + "symbolBindings, {0}, 0));", + dimCount)); + } + + // TODO: This needs to be memoized and/or converted to non-parser based + // C++ codegen prior to real use. + os << llvm::formatv(structuredOpIndexingMapsFormat, className, + dimIdentsStr, interleaveToString(stmts, "\n ")); + } + } else { + return emitError(genContext.getLoc()) + << "generating code for non static indexing maps not currently " + "supported"; + } + + // getNumRegionArgs() + { + // Generates a getNumRegionArgs() method. Parameters: + // {0}: Class name + // {1}: Number of region args + static const char structuredOpGetNumRegionArgsFormat[] = R"FMT( +unsigned {0}::getNumRegionArgs() {{ return {1}; } +)FMT"; + os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className, + opConfig.structuredOp->args.size()); + } + + // getLibraryCallName() + { + // Generates a getLibraryCallName method. Parameters: + // {0}: Class name + static const char structuredOpGetLibraryCallFormat[] = R"FMT( +std::string {0}::getLibraryCallName() {{ + return generateLibraryCallName(getOperation()); +} +)FMT"; + os << llvm::formatv(structuredOpGetLibraryCallFormat, className); + } + + // regionBuilder() + { + // Generates a regionBuilder method. Parameters. + // {0}: Class name + // {1}: Statements + static const char structuredOpRegionBuilderFormat[] = R"FMT( +void {0}::regionBuilder(Block &block, ValueRange captures) {{ + RegionBuilderHelper helper(block); + SmallVector yields; + {1} + helper.yieldOutputs(yields); +} +)FMT"; + auto &args = opConfig.structuredOp->args; + auto &assignments = opConfig.structuredOp->assignments; + size_t generatedAssignmentCount = 0; + int localCounter = 0; + SmallVector stmts; + for (LinalgTensorDef &arg : args) { + if (arg.usage != LinalgTensorUsageDef::output && + arg.usage != LinalgTensorUsageDef::temporary) + continue; + + // Find the assignment that correlates with the argument. + ScalarAssign *assignment = findAssignment(arg.name, assignments); + if (!assignment) + return emitError(genContext.getLoc()) + << "no assignment found for output argument " << arg.name; + ++generatedAssignmentCount; + + // Recursively generate the expression. + std::function(ScalarExpression &)> + generateExpression = + [&](ScalarExpression &expression) -> Optional { + if (expression.scalarArg) { + Optional argIndex = + findTensorDefArgIndex(*expression.scalarArg, args); + if (!argIndex) { + emitError(genContext.getLoc()) + << "scalar argument not defined on the op: " << arg.name; + return None; + } + return std::string( + llvm::formatv("block.getArgument({0})", *argIndex)); + } else if (expression.scalarApply) { + // Recursively generate operands. + SmallVector operandCppValues; + for (ScalarExpression &operand : expression.scalarApply->operands) { + auto operandCppValue = generateExpression(operand); + if (!operandCppValue) + return None; + operandCppValues.push_back(*operandCppValue); + } + std::string cppIdent = llvm::formatv("value{0}", ++localCounter); + stmts.push_back( + llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent, + expression.scalarApply->fnName, + interleaveToString(operandCppValues, ", "))); + return cppIdent; + } else { + emitError(genContext.getLoc()) << "unknown ScalarExpression type"; + return None; + } + }; + Optional cppValue = generateExpression(assignment->value); + if (!cppValue) + return failure(); + stmts.push_back(llvm::formatv("yields.push_back({0});", cppValue)); + } + + if (generatedAssignmentCount != assignments.size()) + return emitError(genContext.getLoc()) + << "mismatched number of assignments vs output arguments"; + + os << llvm::formatv(structuredOpRegionBuilderFormat, className, + interleaveToString(stmts, "\n ")); + } + + // Canonicalizers and folders. + os << llvm::formatv(structuredOpCanonicalizersAndFoldersFormat, className); + + return success(); +} + +static LogicalResult generateOp(LinalgOpConfig &opConfig, + GenerationContext &genContext) { + // Switch on op type being generated. + if (opConfig.structuredOp) { + return success( + succeeded(generateNamedGenericOpOds(opConfig, genContext)) && + succeeded(generateNamedGenericOpDefns(opConfig, genContext))); + } else { + return emitError(genContext.getLoc()) << "unsupported operation type"; + } +} + +//===----------------------------------------------------------------------===// +// Command line options and main +//===----------------------------------------------------------------------===// + +static llvm::cl::opt + inputFilename(llvm::cl::Positional, llvm::cl::desc(""), + llvm::cl::init("-"), llvm::cl::value_desc("YAML filename")); + +static llvm::cl::opt + outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"), + llvm::cl::value_desc("filename"), llvm::cl::init("")); + +static llvm::cl::opt + outputCppImplFilename("o-impl", + llvm::cl::desc("C++ implementation file name"), + llvm::cl::value_desc("filename"), llvm::cl::init("")); + +int main(int argc, char **argv) { + llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML"); + + // Set up the input file. + std::string errorMessage; + std::unique_ptr file = + mlir::openInputFile(inputFilename, &errorMessage); + if (!file) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + + MLIRContext mlirContext; + LinalgYAMLContext yamlContext{&mlirContext}; + + std::vector opConfigs; + + // Parse input. + Input yin(file->getBuffer(), &yamlContext); + yin >> opConfigs; + + if (yin.error()) + return 1; + + // Open output files. + std::unique_ptr outputOdsDecl; + if (!outputOdsDeclFilename.empty()) { + outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage); + if (!outputOdsDecl) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + } + + std::unique_ptr outputCppImpl; + if (!outputCppImplFilename.empty()) { + outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage); + if (!outputCppImpl) { + llvm::errs() << errorMessage << "\n"; + return 1; + } + } + + if (!outputOdsDecl && !outputCppImpl) { + llvm::errs() << "error: No output files specified\n"; + return 1; + } + + // Generate. + GenerationContext genContext(&mlirContext, + outputOdsDecl ? &outputOdsDecl->os() : nullptr, + outputCppImpl ? &outputCppImpl->os() : nullptr); + + for (auto &opConfig : opConfigs) { + if (!opConfig.metadata) { + emitError(genContext.getLoc()) + << "missing operation metadata on subsequent op"; + return 1; + } + + genContext.setLoc(NameLoc::get( + Identifier::get(opConfig.metadata->cppOpName, &mlirContext), + &mlirContext)); + if (failed(generateOp(opConfig, genContext))) { + return 1; + } + } + + if (outputOdsDecl) + outputOdsDecl->keep(); + if (outputCppImpl) + outputCppImpl->keep(); + + return 0; +}