diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md
--- a/mlir/docs/Dialects/Linalg.md
+++ b/mlir/docs/Dialects/Linalg.md
@@ -662,6 +662,18 @@
}
```
+### YAML Based Named Structured Ops
+
+Linalg provides a declarative generation tool (`mlir-linalg-ods-yaml-gen`) to
+automatically produce named ops from a YAML-based op description format
+intended to capture the structure of the named ops and be generated from a
+higher level "mathy" DSL syntax. This facility is currently in flight and is
+intended to subsume the above when ready. See the C++ class to YAML mapping
+traits in `mlir-mlinalg-ods-yaml-gen.cpp` as the source of truth for the schema.
+
+Most of the above documentation roughly applies to this path and will be ported
+as migration continues.
+
## Open Issues and Design Alternatives
Multiple open issues and design alternatives are in flight and it is time to lay
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
--- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt
@@ -1,8 +1,8 @@
# Declare a function to generate ODS with mlir-linalg-ods-gen
-function(add_linalg_ods_gen tc_filename output_file)
+function(add_linalg_ods_tc_gen tc_filename output_file)
set(TC_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${tc_filename})
- set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.td)
- set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.cpp.inc)
+ set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.tcgen.td)
+ set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.tcgen.cpp.inc)
set_source_files_properties(
${GEN_ODS_FILE}
PROPERTIES GENERATED TRUE)
@@ -20,17 +20,52 @@
${MLIR_LINALG_ODS_GEN_TARGET}
VERBATIM)
add_custom_target(
- MLIR${output_file}IncGen
+ MLIR${output_file}TcIncGen
DEPENDS
${MLIR_LINALG_ODS_GEN_EXE}
${MLIR_LINALG_ODS_GEN_TARGET}
${GEN_ODS_FILE} ${GEN_CPP_FILE})
endfunction()
-add_linalg_ods_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps)
+# Declare a function to generate ODS with mlir-linalg-ods-yaml-gen
+function(add_linalg_ods_yaml_gen yaml_ast_file output_file)
+ set(YAML_AST_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/${yaml_ast_file})
+ set(GEN_ODS_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.yamlgen.td)
+ set(GEN_CPP_FILE ${CMAKE_CURRENT_BINARY_DIR}/${output_file}.yamlgen.cpp.inc)
+ set_source_files_properties(
+ ${GEN_ODS_FILE}
+ PROPERTIES GENERATED TRUE)
+ set_source_files_properties(
+ ${GEN_CPP_FILE}
+ PROPERTIES GENERATED TRUE)
+ add_custom_command(
+ OUTPUT ${GEN_ODS_FILE} ${GEN_CPP_FILE}
+ COMMAND ${MLIR_LINALG_ODS_YAML_GEN_EXE} ${YAML_AST_SOURCE} -o-ods-decl=${GEN_ODS_FILE} -o-impl=${GEN_CPP_FILE}
+ MAIN_DEPENDENCY
+ ${YAML_AST_SOURCE}
+ DEPENDS
+ ${MLIR_LINALG_ODS_YAML_GEN_EXE}
+ ${MLIR_LINALG_ODS_YAML_GEN_TARGET})
+ add_custom_target(
+ MLIR${output_file}YamlIncGen
+ DEPENDS
+ ${MLIR_LINALG_ODS_YAML_GEN_EXE}
+ ${MLIR_LINALG_ODS_YAML_GEN_TARGET}
+ ${GEN_ODS_FILE} ${GEN_CPP_FILE})
+endfunction()
+
+# TODO: Delete tc generation and replace with the YAML variant once all ops are
+# ported.
+add_linalg_ods_tc_gen(LinalgNamedStructuredOpsSpec.tc LinalgNamedStructuredOps)
+add_linalg_ods_yaml_gen(LinalgNamedStructuredOps.yaml LinalgNamedStructuredOps)
+
# Provide a short name for all external dependency that needs to
# include Linalg in ODS
-add_custom_target(LinalgOdsGen DEPENDS MLIRLinalgNamedStructuredOpsIncGen)
+add_custom_target(LinalgOdsGen
+ DEPENDS
+ MLIRLinalgNamedStructuredOpsTcIncGen
+ MLIRLinalgNamedStructuredOpsYamlIncGen
+)
add_dependencies(mlir-headers LinalgOdsGen)
add_mlir_dialect(LinalgOps linalg)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
new file mode 100644
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -0,0 +1,50 @@
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+ name: polymorphic_matmul
+ cpp_op_name: PolymorphicMatmulOp
+ doc: |-
+ Type polymorphic matrix multiplication.
+
+ This op is presently here to test a new path for generation and will replace
+ the existing 'matmul' op when ready. Do not use.
+structured_op: !LinalgStructuredOpConfig
+ args:
+ - !
+ name: A
+ usage: input
+ shape: affine_map<()[s0, s1, s2] -> (s0, s2)>
+ - !
+ name: B
+ usage: input
+ shape: affine_map<()[s0, s1, s2] -> (s2, s1)>
+ - !
+ name: C
+ usage: output
+ shape: affine_map<()[s0, s1, s2] -> (s0, s1)>
+ indexing_maps: !LinalgIndexingMapsConfig
+ static_indexing_maps:
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
+ iterator_types:
+ - parallel
+ - parallel
+ - reduction
+ assignments:
+ - !ScalarAssign
+ arg: C
+ value: !ScalarExpression
+ scalar_apply:
+ fn_name: add
+ operands:
+ - !ScalarExpression
+ scalar_arg: C
+ - !ScalarExpression
+ scalar_apply:
+ fn_name: mul
+ operands:
+ - !ScalarExpression
+ scalar_arg: A
+ - !ScalarExpression
+ scalar_arg: B
+
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -343,7 +343,7 @@
// parallelized across; i.e. [zs] in the TF notation above whose number
// match `xs` (i.e. 1 window loop per "image" dimension).
// This may evolve in the future.
- // Conditionally check nPar is large enough for cases of ill-formed op:
+ // Conditionally check nPar is large enough for cases of ill-formed op:
// this avoids overflows before hitting the verifier.
assert(nPar > getNumBatchDimensions() + getNumInputFeatureDimensions() &&
"expected at least one window dimension (i.e. memref ranks greater "
@@ -806,6 +806,7 @@
//===----------------------------------------------------------------------===//
// This file is auto-generated from a TC def specification.
-include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.td"
+include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.td"
+include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.td"
#endif // LINALG_STRUCTURED_OPS
diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
--- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt
@@ -14,6 +14,7 @@
LINK_LIBS PUBLIC
MLIRAffine
MLIRIR
+ MLIRParser
MLIRSideEffectInterfaces
MLIRViewLikeInterface
MLIRStandard
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Parser.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SetVector.h"
@@ -121,6 +122,81 @@
return success(folded);
}
+//===----------------------------------------------------------------------===//
+// Region builder helper.
+// TODO: Move this to a utility library.
+// The public methods on this class are referenced directly from generated code
+// and bind by name to math functions in the DSL as:
+// `applyfn__{fnName}`
+// Examples:
+// `applyfn__add`
+// `applyfn__mul`
+// The naming convention is intentional in order to match snake-cased DSL names.
+// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class.
+//
+// Implementations of the math functions must be polymorphic over numeric types,
+// internally performing necessary casts. If the function application makes no
+// sense, then the only recourse is to assert and return nullptr. This can be
+// extended later if it becomes possible to fail construction of the region. The
+// invariant should be enforced at a higher level.
+//
+// TODO: These helpers are currently type polymorphic over the class of integer
+// and floating point types, but they will not internally cast within bit
+// widths of a class (mixed precision such as i8->i32) or across classes
+// (i.e. mixed float and integer). Many such combinations are ambiguous or need
+// to be handled with care and work is being considered to extend the op
+// language to make such cases explicit. In the mean-time, violating this will
+// fail verification, which is deemed acceptable.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+class RegionBuilderHelper {
+public:
+ RegionBuilderHelper(Block &block) : block(block) {}
+
+ Value applyfn__add(Value lhs, Value rhs) {
+ OpBuilder builder = getBuilder(lhs);
+ if (isFloatingPoint(lhs))
+ return builder.create(lhs.getLoc(), lhs, rhs);
+ else if (isInteger(lhs))
+ return builder.create(lhs.getLoc(), lhs, rhs);
+ llvm_unreachable("unsupported non numeric type");
+ }
+
+ Value applyfn__mul(Value lhs, Value rhs) {
+ OpBuilder builder = getBuilder(lhs);
+ if (isFloatingPoint(lhs))
+ return builder.create(lhs.getLoc(), lhs, rhs);
+ else if (isInteger(lhs))
+ return builder.create(lhs.getLoc(), lhs, rhs);
+ llvm_unreachable("unsupported non numeric type");
+ }
+
+ void yieldOutputs(ValueRange values) {
+ assert(!values.empty() && "linalg ops must yield outputs");
+ if (values.empty())
+ return;
+ Value first = values.front();
+ OpBuilder builder = getBuilder(first);
+ builder.create(first.getLoc(), values);
+ }
+
+private:
+ Block █
+
+ bool isFloatingPoint(Value value) { return value.getType().isa(); }
+ bool isInteger(Value value) { return value.getType().isa(); }
+
+ OpBuilder getBuilder(Value value) {
+ OpBuilder builder(value.getContext());
+ builder.setInsertionPointToEnd(&block);
+ return builder;
+ }
+};
+
+} // namespace
+
//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
@@ -1868,7 +1944,8 @@
struct FoldTensorCastOp;
} // namespace
-#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.tcgen.cpp.inc"
+#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
@@ -2032,7 +2109,8 @@
unsigned actual = body->getNumArguments();
unsigned expected = NamedStructuredOpType::getNumRegionArgs();
if (expected != actual) {
- if (errorHandler) errorHandler(expected, actual);
+ if (errorHandler)
+ errorHandler(expected, actual);
return;
}
diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
new file mode 100644
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s -split-input-file -linalg-generalize-named-ops | FileCheck %s
+
+func @generalize_matmul_tensor_f32(%A : tensor<16x8xf32>, %B: tensor<8x32xf32>, %C: tensor<16x32xf32>) -> tensor<16x32xf32> {
+ %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xf32>, tensor<8x32xf32>)
+ outs(%C: tensor<16x32xf32>) -> tensor<16x32xf32>
+ return %0: tensor<16x32xf32>
+}
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: f32, %[[B_ARG:.+]]: f32, %[[C_ARG:.+]]: f32)
+// CHECK-NEXT: %[[MUL:.+]] = mulf %[[A_ARG]], %[[B_ARG]] : f32
+// CHECK-NEXT: %[[ADD:.+]] = addf %[[C_ARG]], %[[MUL]] : f32
+// CHECK-NEXT: linalg.yield %[[ADD]] : f32
+// CHECK-NEXT: -> tensor<16x32xf32>
+
+// -----
+
+func @generalize_matmul_tensor_i32(%A : tensor<16x8xi32>, %B: tensor<8x32xi32>, %C: tensor<16x32xi32>) -> tensor<16x32xi32> {
+ %0 = linalg.polymorphic_matmul ins(%A, %B: tensor<16x8xi32>, tensor<8x32xi32>)
+ outs(%C: tensor<16x32xi32>) -> tensor<16x32xi32>
+ return %0: tensor<16x32xi32>
+}
+
+// CHECK: ^{{.*}}(%[[A_ARG:.+]]: i32, %[[B_ARG:.+]]: i32, %[[C_ARG:.+]]: i32)
+// CHECK-NEXT: %[[MUL:.+]] = muli %[[A_ARG]], %[[B_ARG]] : i32
+// CHECK-NEXT: %[[ADD:.+]] = addi %[[C_ARG]], %[[MUL]] : i32
+// CHECK-NEXT: linalg.yield %[[ADD]] : i32
+// CHECK-NEXT: -> tensor<16x32xi32>
diff --git a/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt
--- a/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt
+++ b/mlir/tools/mlir-linalg-ods-gen/CMakeLists.txt
@@ -2,6 +2,13 @@
Core
Support
)
+
+set(LLVM_OPTIONAL_SOURCES
+ mlir-linalg-ods-gen.cpp
+ mlir-linalg-ods-yaml-gen.cpp
+)
+
+# Original mlir-linalg-ods-gen (to be replaced).
add_llvm_tool(mlir-linalg-ods-gen
mlir-linalg-ods-gen.cpp
)
@@ -30,3 +37,35 @@
endif()
endif()
endif()
+
+
+# New mlir-linalg-ods-yaml-gen.
+add_llvm_tool(mlir-linalg-ods-yaml-gen
+ mlir-linalg-ods-yaml-gen.cpp
+)
+llvm_update_compile_flags(mlir-linalg-ods-yaml-gen)
+target_link_libraries(mlir-linalg-ods-yaml-gen PRIVATE
+ MLIRIR
+ MLIRSupport
+ MLIRParser
+ )
+
+set(MLIR_LINALG_ODS_YAML_GEN mlir-linalg-ods-yaml-gen CACHE
+ STRING "Native mlir-linalg-ods-yaml-gen executable. Saves building one when cross-compiling.")
+
+set(MLIR_LINALG_ODS_YAML_GEN_EXE ${MLIR_LINALG_ODS_YAML_GEN} PARENT_SCOPE)
+set(MLIR_LINALG_ODS_YAML_GEN_TARGET mlir-linalg-ods-yaml-gen PARENT_SCOPE)
+
+if(LLVM_USE_HOST_TOOLS)
+if ("${MLIR_LINALG_ODS_YAML_GEN_EXE}" STREQUAL mlir-linalg-ods-yaml-gen)
+ build_native_tool(mlir-linalg-ods-yaml-gen MLIR_LINALG_ODS_YAML_GEN_EXE DEPENDS mlir-linalg-ods-yaml-gen)
+ set(MLIR_LINALG_ODS_YAML_GEN_EXE ${MLIR_LINALG_ODS_YAML_GEN_EXE} PARENT_SCOPE)
+
+ add_custom_target(mlir-linalg-ods-yaml-gen-host DEPENDS ${MLIR_LINALG_ODS_YAML_GEN_EXE})
+ set(MLIR_LINALG_ODS_YAML_GEN_TARGET mlir-linalg-ods-yaml-gen-host DEPENDS PARENT_SCOPE)
+
+ if(NOT LLVM_BUILD_UTILS)
+ set_target_properties(mlir-linalg-ods-yaml-gen PROPERTIES EXCLUDE_FROM_ALL ON)
+ endif()
+endif()
+endif()
diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
new file mode 100644
--- /dev/null
+++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp
@@ -0,0 +1,878 @@
+//===- mlir-linalg-ods-yaml-gen.cpp - Linalg ODS generation from yaml ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements an ODS (and C++) generator from a YAML form
+// derived from the mathematical expression of linalg named ops. Typically a
+// math oriented DSL will be used to export the essential representation to
+// this form, and maintaining the SOT at the math level (versus recreating it
+// in MLIR) is deemed to have systemic value.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/Parser.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "llvm/Support/YAMLTraits.h"
+
+using namespace mlir;
+
+using llvm::yaml::Input;
+using llvm::yaml::IO;
+using llvm::yaml::MappingTraits;
+using llvm::yaml::ScalarEnumerationTraits;
+using llvm::yaml::ScalarTraits;
+
+#define DEBUG_TYPE "linalg-ods-gen"
+
+//===----------------------------------------------------------------------===//
+// Mapping structs (correspond to data types in the YAML description).
+// TODO: Since this is a schema/part of the contract, it should be moved to
+// a real header.
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+struct LinalgYAMLContext {
+ MLIRContext *mlirContext;
+};
+
+struct LinalgOpMetadata {
+ std::string name;
+ std::string cppOpName;
+ Optional doc;
+};
+
+struct SerializedAffineMap {
+ AffineMapAttr affineMapAttr;
+
+ AffineMap affineMap() { return affineMapAttr.getValue(); }
+};
+
+enum class LinalgTensorUsageDef {
+ input,
+ output,
+ temporary,
+};
+
+struct LinalgTensorDef {
+ std::string name;
+ LinalgTensorUsageDef usage;
+ SerializedAffineMap shape;
+};
+
+enum class LinalgIteratorTypeDef {
+ parallel,
+ reduction,
+};
+
+struct LinalgIndexingMapsConfig {
+ Optional> staticIndexingMaps;
+};
+
+struct ScalarExpression;
+
+struct ScalarApply {
+ std::string fnName;
+ // NOTE: Must be pure heap allocated container (not SmallVector)
+ // due to recursive data type.
+ std::vector operands;
+};
+
+struct ScalarExpression {
+ Optional scalarArg;
+ Optional scalarApply;
+};
+
+struct ScalarAssign {
+ std::string arg;
+ ScalarExpression value;
+};
+
+struct LinalgStructuredOpConfig {
+ SmallVector args;
+ LinalgIndexingMapsConfig indexingMaps;
+ SmallVector iteratorTypes;
+ SmallVector assignments;
+};
+
+struct LinalgOpConfig {
+ Optional metadata;
+ Optional structuredOp;
+};
+
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Mapping traits.
+//===----------------------------------------------------------------------===//
+
+LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgTensorDef);
+LLVM_YAML_IS_SEQUENCE_VECTOR(SerializedAffineMap);
+LLVM_YAML_IS_SEQUENCE_VECTOR(LinalgIteratorTypeDef);
+LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarAssign);
+LLVM_YAML_IS_SEQUENCE_VECTOR(ScalarExpression);
+LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(LinalgOpConfig);
+
+namespace llvm {
+namespace yaml {
+
+/// Top-level type containing op metadata and one of a concrete op type.
+/// Currently, the only defined op type is `structured_op` (maps to
+/// `LinalgStructuredOpConfig`).
+template <>
+struct MappingTraits {
+ static void mapping(IO &io, LinalgOpConfig &info) {
+ io.mapOptional("metadata", info.metadata);
+ io.mapOptional("structured_op", info.structuredOp);
+ }
+};
+
+/// A structured op models (at most) a single contraction by modeling
+/// - A list of named arguments (`LinalgTensorDef`), which can be inputs,
+/// outputs, or temporaries.
+/// - List of indexing maps (see `LinalgIndexingMaps`).
+/// - Iterator types (see `LinalgIteratorTypeDef`).
+/// - List of scalar level assignment (see `ScalarAssign`).
+template <>
+struct MappingTraits {
+ static void mapping(IO &io, LinalgStructuredOpConfig &info) {
+ io.mapRequired("args", info.args);
+ io.mapRequired("indexing_maps", info.indexingMaps);
+ io.mapRequired("iterator_types", info.iteratorTypes);
+ io.mapRequired("assignments", info.assignments);
+ }
+};
+
+/// Maps a named tensor-argument to an operation, consisting of:
+/// - `name`: Must be unique within the operation.
+/// - `usage`: How the argument is used (input, output, etc).
+/// - `shape`: An AffineMap from all op symbols to the specific shape
+/// of this argument. Each shape must be normalized over the same list of
+/// symbols and have no dimension inputs.
+template <>
+struct MappingTraits {
+ static void mapping(IO &io, LinalgTensorDef &info) {
+ io.mapRequired("name", info.name);
+ io.mapRequired("usage", info.usage);
+ io.mapRequired("shape", info.shape);
+ }
+};
+
+/// Usage enum for a named argument.
+template <>
+struct ScalarEnumerationTraits {
+ static void enumeration(IO &io, LinalgTensorUsageDef &value) {
+ io.enumCase(value, "input", LinalgTensorUsageDef::input);
+ io.enumCase(value, "output", LinalgTensorUsageDef::output);
+ io.enumCase(value, "temporary", LinalgTensorUsageDef::temporary);
+ }
+};
+
+/// Iterator type enum.
+template <>
+struct ScalarEnumerationTraits {
+ static void enumeration(IO &io, LinalgIteratorTypeDef &value) {
+ io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel);
+ io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction);
+ }
+};
+
+/// Metadata about the op (name, C++ name, and documentation).
+template <>
+struct MappingTraits {
+ static void mapping(IO &io, LinalgOpMetadata &info) {
+ io.mapRequired("name", info.name);
+ io.mapRequired("cpp_op_name", info.cppOpName);
+ io.mapOptional("doc", info.doc);
+ }
+};
+
+/// How the ops indexing maps are produced. Must be one of:
+/// - static_indexing_maps: A static list of AffineMaps, possibly with
+/// some symbols that bind to attributes of the op. Each indexing map must
+/// be normalized over the same list of dimensions, and its symbols must
+/// match the symbols for argument shapes.
+template <>
+struct MappingTraits {
+ static void mapping(IO &io, LinalgIndexingMapsConfig &info) {
+ io.mapOptional("static_indexing_maps", info.staticIndexingMaps);
+ }
+};
+
+/// Models an assignment to a named output.
+/// - The `arg` name must match a named output or temporary.
+/// - The `value` is a scalar expression for computing the value to
+/// assign (see `ScalarExpression`).
+template <>
+struct MappingTraits {
+ static void mapping(IO &io, ScalarAssign &info) {
+ io.mapRequired("arg", info.arg);
+ io.mapRequired("value", info.value);
+ }
+};
+
+/// A scalar expression (RHS of an assignment). Must be one of:
+/// - `scalar_arg`: Name of an argument to the op.
+/// - `scalar_apply`: Result of evaluating a named function (see
+/// `ScalarApply`).
+template <>
+struct MappingTraits {
+ static void mapping(IO &io, ScalarExpression &info) {
+ io.mapOptional("scalar_arg", info.scalarArg);
+ io.mapOptional("scalar_apply", info.scalarApply);
+ }
+};
+
+/// A scalar expression that evaluates a named function.
+/// Functions are generally "math" level and type polymorphic. Builtin
+/// functions include:
+/// - `add(lhs, rhs)`
+/// - `mul(lhs, rhs)`
+template <>
+struct MappingTraits {
+ static void mapping(IO &io, ScalarApply &info) {
+ io.mapRequired("fn_name", info.fnName);
+ io.mapRequired("operands", info.operands);
+ }
+};
+
+/// Helper mapping which accesses an AffineMapAttr as a serialized string of
+/// the same.
+template <>
+struct ScalarTraits {
+ static void output(const SerializedAffineMap &value, void *rawYamlContext,
+ raw_ostream &out) {
+ assert(value.affineMapAttr);
+ value.affineMapAttr.print(out);
+ }
+ static StringRef input(StringRef scalar, void *rawYamlContext,
+ SerializedAffineMap &value) {
+ assert(rawYamlContext);
+ auto *yamlContext = static_cast(rawYamlContext);
+ if (auto attr = mlir::parseAttribute(scalar, yamlContext->mlirContext)
+ .dyn_cast_or_null())
+ value.affineMapAttr = attr;
+ else if (!value.affineMapAttr || !value.affineMapAttr.isa())
+ return "could not parse as an affine map attribute";
+ return StringRef();
+ }
+ static QuotingType mustQuote(StringRef) { return QuotingType::None; }
+};
+
+} // namespace yaml
+} // namespace llvm
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Generation utilities
+//===----------------------------------------------------------------------===//
+
+class GenerationContext {
+public:
+ GenerationContext(MLIRContext *context, raw_ostream *odsOut,
+ raw_ostream *defnOut)
+ : context(context), loc(UnknownLoc::get(context)), odsOut(odsOut),
+ defnOut(defnOut) {}
+
+ MLIRContext *getContext() { return context; }
+
+ void setLoc(Location loc) { this->loc = loc; }
+ Location getLoc() { return loc; }
+
+ bool shouldGenerateOds() { return odsOut; }
+ bool shouldGenerateDefns() { return defnOut; }
+
+ raw_ostream &odss() {
+ assert(odsOut && "ODS stream not defined");
+ return *odsOut;
+ }
+
+ raw_ostream &defns() {
+ assert(defnOut && "Definition stream not defined");
+ return *defnOut;
+ }
+
+private:
+ MLIRContext *context;
+ Location loc;
+ raw_ostream *odsOut;
+ raw_ostream *defnOut;
+};
+
+} // namespace
+
+static std::string generateCppExpression(SerializedAffineMap self,
+ StringRef contextName) {
+ std::string printedStr;
+ llvm::raw_string_ostream printedSs(printedStr);
+ self.affineMapAttr.print(printedSs);
+ printedSs.flush();
+
+ static const char exprFormat[] =
+ R"FMT(mlir::parseAttribute("{0}", {1}).cast().getValue())FMT";
+ return llvm::formatv(exprFormat, printedStr, contextName);
+}
+
+template
+static std::string interleaveToString(Container &container,
+ StringRef separator) {
+ std::string result;
+ llvm::raw_string_ostream ss(result);
+ llvm::interleave(container, ss, separator);
+ ss.flush();
+ return result;
+}
+
+static Optional
+findTensorDefArgIndex(StringRef name, SmallVectorImpl &args) {
+ for (auto it : llvm::enumerate(args)) {
+ if (it.value().name == name)
+ return it.index();
+ }
+ return None;
+}
+
+static ScalarAssign *
+findAssignment(StringRef name, SmallVectorImpl &assignments) {
+ for (auto &assign : assignments) {
+ if (assign.arg == name)
+ return &assign;
+ }
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Templates
+//===----------------------------------------------------------------------===//
+
+// A single line banner format. Parameters:
+// {0}: Single line comment
+static const char bannerFormat[] = R"FMT(
+//===----------------------------------------------------------------------===//
+// {0}
+//===----------------------------------------------------------------------===//
+)FMT";
+
+//===----------------------------------------------------------------------===//
+// Named generic op generation.
+// These ops map at most a single contraction that complies with the limitations
+// of a linalg.generic.
+//===----------------------------------------------------------------------===//
+
+// Template for Linalg named ops' ODS definitions. Parameters:
+// {0}: ODS/C++ op name
+// {1}: assembly op mnemonic
+// {2}: op interface list
+// {3}: documentation (summary + description)
+// {4}: op attribute list
+// {5}: the number of arguments for the op region
+// {6}: builder methods taking standalone attribute parameters
+// {7}: additional methods for attributes used by indexing maps
+static const char structuredOpOdsHeaderFormat[] = R"FMT(
+//===----------------------------------------------------------------------===//
+// Op definition for {0}
+//===----------------------------------------------------------------------===//
+
+def {0} : LinalgStructuredBase_Op<"{1}", [
+ AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods,
+ SingleBlockImplicitTerminator<"YieldOp">
+ /*extraInterfaces=*/{2}]> {
+ {3}
+ let arguments = (ins
+ Variadic:$inputs,
+ Variadic:$outputs{4}
+ );
+ let results = (outs Variadic:$result_tensors);
+ let regions = (region AnyRegion:$region);
+
+ let skipDefaultBuilders = 1;
+ let builders = [
+ OpBuilderDAG<
+ (ins "ValueRange":$inputs, "ValueRange":$outputs),
+ [{{
+ $_state.addOperands(inputs);
+ $_state.addOperands(outputs);
+ $_state.addAttribute(
+ "operand_segment_sizes",
+ $_builder.getI32VectorAttr({{
+ static_cast(inputs.size()),
+ static_cast(outputs.size())}));
+ createAndFillStructuredOpRegion<{0}>(
+ $_builder,
+ $_state,
+ TypeRange(inputs),
+ TypeRange(outputs)/*, TODO: support captures*/);
+ }]>,
+ OpBuilderDAG<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
+ "ValueRange":$outputs),
+ [{{
+ $_state.addOperands(inputs);
+ $_state.addOperands(outputs);
+ $_state.addTypes(resultTensorTypes);
+ $_state.addAttribute(
+ "operand_segment_sizes",
+ $_builder.getI32VectorAttr({{
+ static_cast(inputs.size()),
+ static_cast(outputs.size())}));
+ createAndFillStructuredOpRegion<{0}>(
+ $_builder,
+ $_state,
+ TypeRange(inputs),
+ TypeRange(outputs)/*, TODO: support captures*/);
+ }]>,
+ OpBuilderDAG<
+ (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
+ CArg<"ArrayRef", "{{}">:$attributes),
+ [{{
+ $_state.addOperands(operands);
+ $_state.addAttributes(attributes);
+ $_state.addTypes(resultTensorTypes);
+ (void)$_state.addRegion();
+ }]>
+ {6}
+ ];
+ let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
+ let parser = [{{
+ return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
+ }];
+ let hasFolder = 1;
+ let hasCanonicalizer = 1;
+
+ let extraClassDeclaration = structuredOpsBaseDecls # [{{
+ // Auto-generated.
+ ArrayAttr iterator_types();
+ ArrayAttr indexing_maps();
+ static void regionBuilder(Block &block, ValueRange captures);
+ static std::function getRegionBuilder() {{
+ return regionBuilder;
+ }
+
+ // Generic methods.
+ static unsigned getNumRegionArgs();
+ std::string getLibraryCallName();
+ {7}
+ }];
+}
+)FMT";
+
+// The iterator_types() method implementation. Parameters:
+// {0}: Class name
+// {1}: Comma interleaved iterator type names.
+static const char structuredOpIteratorTypesFormat[] =
+ R"FMT(
+ArrayAttr {0}::iterator_types() {
+ return Builder(getContext()).getStrArrayAttr(SmallVector{{ {1} });
+}
+)FMT";
+
+// Implementations of getCanonicalizationPatterns, fold and getEffects.
+// Parameters:
+// {0}: Class name
+const char structuredOpCanonicalizersAndFoldersFormat[] = R"FMT(
+void {0}::getCanonicalizationPatterns(
+ OwningRewritePatternList &results,
+ MLIRContext *context) {{
+ results.insert();
+ results.insert();
+}
+LogicalResult {0}::fold(ArrayRef,
+ SmallVectorImpl &) {{
+ return foldMemRefCast(*this);
+}
+void {0}::getEffects(SmallVectorImpl<
+ SideEffects::EffectInstance >&effects) {{
+ getGenericEffectsImpl(effects,
+ getOperation()->getResults(), getInputBuffers(), getOutputBuffers());
+}
+)FMT";
+
+static LogicalResult generateNamedGenericOpOds(LinalgOpConfig &opConfig,
+ GenerationContext &genContext) {
+ if (!genContext.shouldGenerateOds())
+ return success();
+
+ raw_ostream &os = genContext.odss();
+
+ std::string interfaceNameList;
+ std::string attrList;
+ std::string attrMethods;
+ std::string attrBuilder;
+
+ std::string doc;
+ if (opConfig.metadata->doc) {
+ const char *docFmt = R"FMT(
+ let summary = [{ {0} }];
+ let description = [{
+ {1}
+ }];
+ )FMT";
+ StringRef summary, description;
+ std::tie(summary, description) =
+ StringRef(*opConfig.metadata->doc).trim().split('\n');
+ doc = llvm::formatv(docFmt, summary.trim(), description.trim());
+ }
+
+ os << llvm::formatv(structuredOpOdsHeaderFormat, opConfig.metadata->cppOpName,
+ opConfig.metadata->name, interfaceNameList, doc, attrList,
+ opConfig.structuredOp->args.size(), attrBuilder,
+ attrMethods);
+
+ return success();
+}
+
+static LogicalResult
+generateNamedGenericOpDefns(LinalgOpConfig &opConfig,
+ GenerationContext &genContext) {
+ if (!genContext.shouldGenerateDefns())
+ return success();
+
+ raw_ostream &os = genContext.defns();
+ StringRef className = opConfig.metadata->cppOpName;
+
+ // Implementation banner.
+ std::string bannerComment = llvm::formatv("Implementation of {0}", className);
+ os << llvm::formatv(bannerFormat, bannerComment);
+
+ // Reference iterators.
+ {
+ std::string iteratorsStr;
+ llvm::raw_string_ostream ss(iteratorsStr);
+ llvm::interleaveComma(opConfig.structuredOp->iteratorTypes, ss,
+ [&](LinalgIteratorTypeDef it) {
+ switch (it) {
+ case LinalgIteratorTypeDef::parallel:
+ ss << "getParallelIteratorTypeName()";
+ break;
+ case LinalgIteratorTypeDef::reduction:
+ ss << "getReductionIteratorTypeName()";
+ break;
+ }
+ });
+ ss.flush();
+ os << llvm::formatv(structuredOpIteratorTypesFormat, className,
+ iteratorsStr);
+ }
+
+ // Static indexing maps.
+ if (auto &staticMaps =
+ opConfig.structuredOp->indexingMaps.staticIndexingMaps) {
+ if (staticMaps->empty())
+ return emitError(genContext.getLoc()) << "op has no indexing maps";
+ AffineMap firstMap = staticMaps->front().affineMap();
+
+ // Symbol bindings.
+ {
+ // For each symbol, generate a declaration for it, either with an
+ // AffineSymbolExpr or an AffineConstantExpr (if the symbol derives from
+ // an attribute).
+ // TODO: Implement attribute constants.
+ // TODO: Possibly lift into a top-level method.
+ static const char structuredOpSymbolBindingsFormat[] = R"FMT(
+static SmallVector getSymbolBindings({0} self) {
+ MLIRContext *context = self.getContext();
+ SmallVector exprs;
+{1}
+ return exprs;
+}
+)FMT";
+
+ unsigned symbolCount = firstMap.getNumSymbols();
+ SmallVector symbolBindings;
+ for (unsigned i = 0; i < symbolCount; ++i) {
+ // TODO: Switch and emit constants for attribute bound symbols.
+ symbolBindings.push_back(llvm::formatv(
+ " exprs.push_back(getAffineSymbolExpr({0}, context));", i));
+ }
+ std::string symbolBindingsStr;
+ llvm::raw_string_ostream symbolBindingsSs(symbolBindingsStr);
+ llvm::interleave(symbolBindings, symbolBindingsSs, "\n");
+ symbolBindingsSs.flush();
+
+ os << llvm::formatv(structuredOpSymbolBindingsFormat, className,
+ symbolBindingsStr);
+ }
+
+ // Indexing maps.
+ {
+ // Parameters:
+ // {0}: Class name
+ // {1}: Comma-separated list of dimension variable names.
+ // {2}: Statements
+ static const char structuredOpIndexingMapsFormat[] = R"FMT(
+ArrayAttr {0}::indexing_maps() {
+ MLIRContext *context = getContext();
+ auto symbolBindings = getSymbolBindings(*this);
+ SmallVector maps;
+ {2}
+ return Builder(context).getAffineMapArrayAttr(maps);
+}
+)FMT";
+
+ unsigned dimCount = firstMap.getNumDims();
+
+ // Generate a comma-separated list of dim identifiers to be passed to
+ // bindDims, ensuring tht AffineExpr identifiers are bound in the right
+ // order to the proper AffineDimExpr.
+ // This results in vars in scope like: d0, d1, d2...
+ SmallVector dimIndices;
+ for (unsigned i = 0; i < dimCount; ++i)
+ dimIndices.push_back(i);
+ std::string dimIdentsStr;
+ llvm::raw_string_ostream dimIdentsSs(dimIdentsStr);
+ llvm::interleaveComma(dimIndices, dimIdentsSs,
+ [&](unsigned i) { dimIdentsSs << "d" << i; });
+ dimIdentsSs.flush();
+
+ // Statements to add and simplify each affine map.
+ SmallVector stmts;
+ for (auto &indexingMap : *staticMaps) {
+ // TODO: Assert that dim and symbol count match the first.
+ stmts.push_back(
+ llvm::formatv("maps.push_back({0});",
+ generateCppExpression(indexingMap, "context")));
+ stmts.push_back(llvm::formatv(
+ "maps.back() = "
+ "simplifyAffineMap(maps.back().replaceDimsAndSymbols({{}, "
+ "symbolBindings, {0}, 0));",
+ dimCount));
+ }
+
+ // TODO: This needs to be memoized and/or converted to non-parser based
+ // C++ codegen prior to real use.
+ os << llvm::formatv(structuredOpIndexingMapsFormat, className,
+ dimIdentsStr, interleaveToString(stmts, "\n "));
+ }
+ } else {
+ return emitError(genContext.getLoc())
+ << "generating code for non static indexing maps not currently "
+ "supported";
+ }
+
+ // getNumRegionArgs()
+ {
+ // Generates a getNumRegionArgs() method. Parameters:
+ // {0}: Class name
+ // {1}: Number of region args
+ static const char structuredOpGetNumRegionArgsFormat[] = R"FMT(
+unsigned {0}::getNumRegionArgs() {{ return {1}; }
+)FMT";
+ os << llvm::formatv(structuredOpGetNumRegionArgsFormat, className,
+ opConfig.structuredOp->args.size());
+ }
+
+ // getLibraryCallName()
+ {
+ // Generates a getLibraryCallName method. Parameters:
+ // {0}: Class name
+ static const char structuredOpGetLibraryCallFormat[] = R"FMT(
+std::string {0}::getLibraryCallName() {{
+ return generateLibraryCallName(getOperation());
+}
+)FMT";
+ os << llvm::formatv(structuredOpGetLibraryCallFormat, className);
+ }
+
+ // regionBuilder()
+ {
+ // Generates a regionBuilder method. Parameters.
+ // {0}: Class name
+ // {1}: Statements
+ static const char structuredOpRegionBuilderFormat[] = R"FMT(
+void {0}::regionBuilder(Block &block, ValueRange captures) {{
+ RegionBuilderHelper helper(block);
+ SmallVector yields;
+ {1}
+ helper.yieldOutputs(yields);
+}
+)FMT";
+ auto &args = opConfig.structuredOp->args;
+ auto &assignments = opConfig.structuredOp->assignments;
+ size_t generatedAssignmentCount = 0;
+ int localCounter = 0;
+ SmallVector stmts;
+ for (LinalgTensorDef &arg : args) {
+ if (arg.usage != LinalgTensorUsageDef::output &&
+ arg.usage != LinalgTensorUsageDef::temporary)
+ continue;
+
+ // Find the assignment that correlates with the argument.
+ ScalarAssign *assignment = findAssignment(arg.name, assignments);
+ if (!assignment)
+ return emitError(genContext.getLoc())
+ << "no assignment found for output argument " << arg.name;
+ ++generatedAssignmentCount;
+
+ // Recursively generate the expression.
+ std::function(ScalarExpression &)>
+ generateExpression =
+ [&](ScalarExpression &expression) -> Optional {
+ if (expression.scalarArg) {
+ Optional argIndex =
+ findTensorDefArgIndex(*expression.scalarArg, args);
+ if (!argIndex) {
+ emitError(genContext.getLoc())
+ << "scalar argument not defined on the op: " << arg.name;
+ return None;
+ }
+ return std::string(
+ llvm::formatv("block.getArgument({0})", *argIndex));
+ } else if (expression.scalarApply) {
+ // Recursively generate operands.
+ SmallVector operandCppValues;
+ for (ScalarExpression &operand : expression.scalarApply->operands) {
+ auto operandCppValue = generateExpression(operand);
+ if (!operandCppValue)
+ return None;
+ operandCppValues.push_back(*operandCppValue);
+ }
+ std::string cppIdent = llvm::formatv("value{0}", ++localCounter);
+ stmts.push_back(
+ llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent,
+ expression.scalarApply->fnName,
+ interleaveToString(operandCppValues, ", ")));
+ return cppIdent;
+ } else {
+ emitError(genContext.getLoc()) << "unknown ScalarExpression type";
+ return None;
+ }
+ };
+ Optional cppValue = generateExpression(assignment->value);
+ if (!cppValue)
+ return failure();
+ stmts.push_back(llvm::formatv("yields.push_back({0});", cppValue));
+ }
+
+ if (generatedAssignmentCount != assignments.size())
+ return emitError(genContext.getLoc())
+ << "mismatched number of assignments vs output arguments";
+
+ os << llvm::formatv(structuredOpRegionBuilderFormat, className,
+ interleaveToString(stmts, "\n "));
+ }
+
+ // Canonicalizers and folders.
+ os << llvm::formatv(structuredOpCanonicalizersAndFoldersFormat, className);
+
+ return success();
+}
+
+static LogicalResult generateOp(LinalgOpConfig &opConfig,
+ GenerationContext &genContext) {
+ // Switch on op type being generated.
+ if (opConfig.structuredOp) {
+ return success(
+ succeeded(generateNamedGenericOpOds(opConfig, genContext)) &&
+ succeeded(generateNamedGenericOpDefns(opConfig, genContext)));
+ } else {
+ return emitError(genContext.getLoc()) << "unsupported operation type";
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Command line options and main
+//===----------------------------------------------------------------------===//
+
+static llvm::cl::opt
+ inputFilename(llvm::cl::Positional, llvm::cl::desc(""),
+ llvm::cl::init("-"), llvm::cl::value_desc("YAML filename"));
+
+static llvm::cl::opt
+ outputOdsDeclFilename("o-ods-decl", llvm::cl::desc("ODS output filename"),
+ llvm::cl::value_desc("filename"), llvm::cl::init(""));
+
+static llvm::cl::opt
+ outputCppImplFilename("o-impl",
+ llvm::cl::desc("C++ implementation file name"),
+ llvm::cl::value_desc("filename"), llvm::cl::init(""));
+
+int main(int argc, char **argv) {
+ llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen from YAML");
+
+ // Set up the input file.
+ std::string errorMessage;
+ std::unique_ptr file =
+ mlir::openInputFile(inputFilename, &errorMessage);
+ if (!file) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+
+ MLIRContext mlirContext;
+ LinalgYAMLContext yamlContext{&mlirContext};
+
+ std::vector opConfigs;
+
+ // Parse input.
+ Input yin(file->getBuffer(), &yamlContext);
+ yin >> opConfigs;
+
+ if (yin.error())
+ return 1;
+
+ // Open output files.
+ std::unique_ptr outputOdsDecl;
+ if (!outputOdsDeclFilename.empty()) {
+ outputOdsDecl = openOutputFile(outputOdsDeclFilename, &errorMessage);
+ if (!outputOdsDecl) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+ }
+
+ std::unique_ptr outputCppImpl;
+ if (!outputCppImplFilename.empty()) {
+ outputCppImpl = openOutputFile(outputCppImplFilename, &errorMessage);
+ if (!outputCppImpl) {
+ llvm::errs() << errorMessage << "\n";
+ return 1;
+ }
+ }
+
+ if (!outputOdsDecl && !outputCppImpl) {
+ llvm::errs() << "error: No output files specified\n";
+ return 1;
+ }
+
+ // Generate.
+ GenerationContext genContext(&mlirContext,
+ outputOdsDecl ? &outputOdsDecl->os() : nullptr,
+ outputCppImpl ? &outputCppImpl->os() : nullptr);
+
+ for (auto &opConfig : opConfigs) {
+ if (!opConfig.metadata) {
+ emitError(genContext.getLoc())
+ << "missing operation metadata on subsequent op";
+ return 1;
+ }
+
+ genContext.setLoc(NameLoc::get(
+ Identifier::get(opConfig.metadata->cppOpName, &mlirContext),
+ &mlirContext));
+ if (failed(generateOp(opConfig, genContext))) {
+ return 1;
+ }
+ }
+
+ if (outputOdsDecl)
+ outputOdsDecl->keep();
+ if (outputCppImpl)
+ outputCppImpl->keep();
+
+ return 0;
+}