Index: mlir/include/mlir/Dialect/Linalg/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/Linalg/CMakeLists.txt +++ mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(ComprehensiveBufferize) +add_subdirectory(Frontend) add_subdirectory(IR) set(LLVM_TARGET_DEFINITIONS Passes.td) Index: mlir/include/mlir/Dialect/Linalg/Frontend/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Linalg/Frontend/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS LinalgFrontendInterfaces.td) +mlir_tablegen(LinalgFrontendOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(LinalgFrontendOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRLinalgFrontendInterfacesIncGen) +add_dependencies(mlir-headers MLIRLinalgFrontendInterfacesIncGen) Index: mlir/include/mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.h @@ -0,0 +1,23 @@ +//===- LinalgFrontendOpInterfaces.h - Linalg frontend op interfaces -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the frontend operation interfaces for Linalg +// operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_FRONTEND_LINALGFRONTENDINTERFACES_H_ +#define MLIR_DIALECT_LINALG_FRONTEND_LINALGFRONTENDINTERFACES_H_ + +#include "mlir/IR/Builders.h" +#include "mlir/Parser.h" + +/// Include the generated interface declarations. +#include "mlir/Dialect/Linalg/Frontend/LinalgFrontendOpInterfaces.h.inc" + +#endif // MLIR_DIALECT_LINALG_FRONTEND_LINALGFRONTENDINTERFACES_H_ Index: mlir/include/mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.td @@ -0,0 +1,223 @@ +//===- LinalgFrontendInterfaces.td - Linalg Frontend Interfaces -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Linalg Frontend Interfaces +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_FRONTEND_LINALGFRONTENDINTERFACES +#define LINALG_FRONTEND_LINALGFRONTENDINTERFACES + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" + +class Linalg_FrontendInterfaceBase : OpInterface { + let cppNamespace = "::mlir::linalg"; +} + +class Linalg_InstantiateOperatorOpInterface< + string operatorCamelCase, + string attrName, + string defFloatOp = "", + string defIntOp = ""> : + Linalg_FrontendInterfaceBase< + "Instantiate" # operatorCamelCase # "OperatorOpInterface"> +{ + let description = [{ + Base class for all interfaces that allow for the instantiation of + an arithmetic operator via an attribute. + + `operatorCamelCase` is supposed to be the name of the arithmetic + operator in camel case, e.g., `Add` or `MaxUnsigned`. + + `attrName` specifies the name of the attribute that will be parsed + into the name of the operation implementing the operator and an + optional result type specified after a colon. + + `defIntOp` and `defFloatOp` define the default operations that get + instantiated if no value is provided in the attribute. Both + `defIntOp` and `defFloatOp` may be omitted if there are no + sensible default implementations. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/"Checks if an operation is the " # attrName # + " operator for this operation", + /*retTy=*/"bool", + /*methodName=*/"is" # operatorCamelCase # "Operator", + /*args=*/(ins "::mlir::Operation &":$operation), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + ::mlir::StringAttr strAttr = + $_op->template getAttrOfType<::mlir::StringAttr>("}] # attrName # [{"); + + if(strAttr) { + // Attribute was specified: use operation from attribute + std::pair<::llvm::StringRef, ::llvm::StringRef> spec = + strAttr.strref().split(":"); + + ::llvm::StringRef opName = std::get<0>(spec); + + return (operation.getName().getStringRef() == opName); + } + + // Attribute was NOT specified: use default operations + }] # + !if(!not(!empty(defFloatOp)), [{ + if(operation.getName().getStringRef() == "}] # defFloatOp # [{") + return true; + }], [{ }]) # + !if(!not(!empty(defIntOp)), [{ + if(operation.getName().getStringRef() == "}] # defIntOp # [{") + return true; + }], [{ }]) # + [{ + + return false; + }] + >, + + StaticInterfaceMethod< + /*desc=*/"Instantiates the " # attrName # " operator in the absence of " # + "an attribute specifying it", + /*retTy=*/"::mlir::Value", + /*methodName=*/"instantiateDefault" # operatorCamelCase # "Operator", + /*args=*/(ins + "::mlir::OpBuilder &":$builder, + "::mlir::Location":$location, + "::mlir::ValueRange":$operands), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + const char* opName = nullptr; + ::mlir::Type resType; + + }] # + !if(!not(!empty(defFloatOp)), [{ + if(::llvm::all_of(operands, [](::mlir::Value v) { + return v.getType().isa<::mlir::FloatType>(); + })) { + opName = "}] # defFloatOp # [{"; + resType = operands[0].getType(); + } + }], [{ }]) # + !if(!not(!empty(defIntOp)), [{ + if(::llvm::all_of(operands, [](::mlir::Value v) { + return v.getType().isa<::mlir::IntegerType>(); + })) { + opName = "}] # defIntOp # [{"; + resType = operands[0].getType(); + } + }], [{ }]) # + [{ + + if(opName != nullptr) { + ::mlir::OperationState state(location, opName, operands, resType, {}); + ::mlir::Operation *oprtr = builder.createOperation(state); + return oprtr->getResult(0); + } + + return ::mlir::Value(); + }] + >, + + StaticInterfaceMethod< + /*desc=*/"Instantiates the " # attrName # " operator from the " # + "attribute \"" # attrName # "\"", + /*retTy=*/"::mlir::Value", + /*methodName=*/"instantiate" # operatorCamelCase # "Operator", + /*args=*/(ins + "::mlir::OpBuilder &":$builder, + "::mlir::Location":$location, + "::mlir::NamedAttrList":$attrs, + "::mlir::ValueRange":$operands), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + ::llvm::Optional<::mlir::NamedAttribute> attr = + attrs.getNamed("}] # attrName # [{"); + + // If attribute was not specified, fall back to default operation + if(!attr.hasValue()) { + if(::mlir::Value opr = instantiateDefault}] # operatorCamelCase # [{Operator( + builder, location, operands)) { + return opr; + } else { + ::mlir::emitError( + location, + "Could not generate default operation implementing " + "operator }] # attrName # [{, for the given operands. " + "Please specify an operation using the " + "attribute '}] # attrName # [{'."); + llvm_unreachable(""); + } + } + + ::mlir::StringAttr strAttr = + attr.getValue().getValue().dyn_cast<::mlir::StringAttr>(); + + if (!attr) { + ::mlir::emitError( + location, + "Attribute }] # attrName # [{ must be a string attribute."); + llvm_unreachable(""); + } + + // Extract operation name and result type from specification + std::pair<::llvm::StringRef, ::llvm::StringRef> spec = + strAttr.strref().split(":"); + ::llvm::StringRef opName = std::get<0>(spec); + ::llvm::StringRef resTypeName = std::get<1>(spec); + + ::mlir::Type resType; + + // Use type of LHS operand by default for the result type + if(resTypeName.empty()) { + if(operands.empty()) { + ::mlir::emitError(location, "Missing result type for }] # attrName # [{ operator."); + llvm_unreachable(""); + } + + resType = operands[0].getType(); + } else { + resType = ::mlir::parseType(resTypeName, builder.getContext()); + + if(!resType) { + ::mlir::emitError(location, "Could not parse type '") << resTypeName << "'"; + llvm_unreachable(""); + } + } + + ::mlir::OperationState state(location, opName, operands, resType, {}); + ::mlir::Operation *oprtr = builder.createOperation(state); + return oprtr->getResult(0); + }] + > + ]; +} + +// Create one interface for each common arithmetic operator +def Linalg_InstantiateAddOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Add", "add", "arith.addf", "arith.addi">; +def Linalg_InstantiateSubOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Sub", "sub", "arith.subf", "arith.subi">; +def Linalg_InstantiateMulOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Mul", "mul", "arith.mulf", "arith.muli">; +def Linalg_InstantiateMaxOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Max", "max", "arith.maxf", "arith.maxsi">; +def Linalg_InstantiateMaxUnsignedOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"MaxUnsigned", "max_unsigned", "arith.maxf", "arith.maxui">; +def Linalg_InstantiateMinOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Min", "min", "arith.minf", "arith.minsi">; +def Linalg_InstantiateMinUnsignedOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"MinUnsigned", "min_unsigned", "arith.minf", "arith.minui">; +def Linalg_InstantiateExpOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Exp", "exp", "math.exp">; +def Linalg_InstantiateLogOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Log", "log", "math.log">; + +#endif // LINALG_FRONTEND_LINALGFRONTENDINTERFACES Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -50,6 +50,7 @@ } // namespace linalg } // namespace mlir +#include "mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.h" #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc" /// Include the generated interface declarations. Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -14,6 +14,7 @@ #ifndef LINALG_STRUCTURED_OPS #define LINALG_STRUCTURED_OPS +include "mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.td" include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" Index: mlir/lib/Dialect/Linalg/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Linalg/CMakeLists.txt +++ mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(Analysis) add_subdirectory(ComprehensiveBufferize) +add_subdirectory(Frontend) add_subdirectory(IR) add_subdirectory(Transforms) add_subdirectory(Utils) Index: mlir/lib/Dialect/Linalg/Frontend/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Linalg/Frontend/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_dialect_library(MLIRLinalgFrontend + LinalgFrontendInterfaces.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg + + LINK_LIBS PUBLIC + MLIRIR + ) Index: mlir/lib/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.cpp @@ -0,0 +1,11 @@ +//===- LinalgFrontendInterfaces.cpp - -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.h" + +#include "mlir/Dialect/Linalg/Frontend/LinalgFrontendOpInterfaces.cpp.inc" Index: mlir/lib/Dialect/Linalg/IR/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -26,4 +26,5 @@ MLIRTensor MLIRTilingInterface MLIRViewLikeInterface + MLIRLinalgFrontend ) Index: mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp =================================================================== --- mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -42,16 +42,22 @@ }; } -/// Return the unique instance of OpType in `block` if it is indeed unique. -/// Return null if none or more than 1 instances exist. -template -static OpType getSingleOpOfType(Block &block) { - OpType res = nullptr; - block.walk([&](OpType op) { +/// Return the unique operation in `block` that satisfies `predicate` +/// if it is indeed unique. Return null if none or more than 1 +/// instances exist. +static Operation * +getSingleOpMatching(Block &block, std::function predicate) { + Operation *res = nullptr; + + block.walk([&](Operation *op) { + if (!predicate(op)) + return WalkResult::advance(); + if (res) { res = nullptr; return WalkResult::interrupt(); } + res = op; return WalkResult::advance(); }); @@ -59,18 +65,29 @@ } /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))` -/// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent +/// on the field (+, *), where u1, u2, u3, u4 and u5 represent /// unary operations that may change the type. -template -static bool isAddMul(Block &block) { +static bool hasAddMulRegion(linalg::LinalgOp &linalgOp) { + Block &block = linalgOp->getRegion(0).front(); + if (block.getNumArguments() != 3) return false; Operation *yieldOp = block.getTerminator(); if (yieldOp->getNumOperands() != 1) return false; - AddOpType addOp = getSingleOpOfType(block); - MulOpType mulOp = getSingleOpOfType(block); + auto addOpIface = + dyn_cast(linalgOp.getOperation()); + auto mulOpIface = + dyn_cast(linalgOp.getOperation()); + + if (!addOpIface || !mulOpIface) + return false; + + Operation *addOp = getSingleOpMatching( + block, [&](Operation *op) { return addOpIface.isAddOperator(*op); }); + Operation *mulOp = getSingleOpMatching( + block, [&](Operation *op) { return mulOpIface.isMulOperator(*op); }); if (!addOp || !mulOp) return false; @@ -112,8 +129,7 @@ [](AffineMap m) { return !m.isProjectedPermutation(); })) return MatchContractionResult::NotProjectedPermutations; // TODO: more fields than add/mul. - if (!isAddMul(linalgOp->getRegion(0).front()) && - !isAddMul(linalgOp->getRegion(0).front())) + if (!hasAddMulRegion(linalgOp)) return MatchContractionResult::NotAddMul; return MatchContractionResult::Success; } Index: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -241,92 +241,6 @@ return cast(toType, operand, true); } - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__add(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__exp(Value x) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(x)) - return builder.create(x.getLoc(), x); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__log(Value x) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(x)) - return builder.create(x.getLoc(), x); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__sub(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__mul(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__max(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__max_unsigned(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__min(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__min_unsigned(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - void yieldOutputs(ValueRange values) { assert(!values.empty() && "linalg ops must yield outputs"); if (values.empty()) Index: mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml =================================================================== --- mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -89,7 +89,8 @@ # IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); # IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]); -# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]); +# IMPL-DAG: SmallVector [[VAL4:[a-z0-9]+]]_operands{[[VAL1]], [[VAL3]]}; +# IMPL-DAG: Value [[VAL4]] = Test1Op::instantiateAddOperator(b, b.getLoc(), attrs, [[VAL4]]_operands); # @linalg_structured_op Index: mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp =================================================================== --- mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -21,11 +21,13 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/YAMLTraits.h" +#include using namespace mlir; @@ -641,6 +643,54 @@ } )FMT"; +/// Returns the operator instantiation interface name for an +/// arithmetic function. +static std::string +arithFnToOperatorInstantiationInterface(const std::string &fnName) { + return "Linalg_Instantiate" + + llvm::convertToCamelFromSnakeCase(fnName, true) + + "OperatorOpInterface"; +} + +/// Returns the operator instantiation interface function name for an +/// arithmetic function. +static std::string +arithFnToOperatorInstantiationFunction(const std::string &fnName) { + return "instantiate" + llvm::convertToCamelFromSnakeCase(fnName, true) + + "Operator"; +} + +/// Returns a set with all operator instantiation interfaces required +/// for an operation. +static std::set +requiredOperatorInstantiationInterfaces(LinalgOpConfig &opConfig) { + std::set interfaces; + + if (opConfig.structuredOp.hasValue()) { + std::function collectInterfaces = + [&](const ScalarExpression &expr) { + if (expr.arithFn.hasValue()) { + interfaces.emplace(arithFnToOperatorInstantiationInterface( + expr.arithFn.getValue().fnName)); + for (const ScalarExpression &subexpr : + expr.arithFn.getValue().operands) + collectInterfaces(subexpr); + } + if (expr.typeFn.hasValue()) { + for (const ScalarExpression &subexpr : + expr.typeFn.getValue().operands) + collectInterfaces(subexpr); + } + }; + + for (const ScalarAssign &assign : + opConfig.structuredOp.getValue().assignments) + collectInterfaces(assign.value); + } + + return interfaces; +} + // Implementation of parse/print. // Parameters: // {0}: Class name @@ -681,6 +731,16 @@ interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); + std::set operatorIfaces = + requiredOperatorInstantiationInterfaces(opConfig); + + if (!operatorIfaces.empty()) { + if (!interfaceNameList.empty()) + interfaceNameList += ", "; + + interfaceNameList += interleaveToString(operatorIfaces, ", "); + } + // Assemble the attribute specific logic required for the op definition. if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { return arg.usage == LinalgOperandDefUsage::IndexAttr; @@ -1020,10 +1080,16 @@ operandCppValues.push_back(*operandCppValue); } std::string cppIdent = llvm::formatv("value{0}", ++localCounter); - stmts.push_back( - llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent, - expression.arithFn->fnName, - interleaveToString(operandCppValues, ", "))); + + stmts.push_back(llvm::formatv( + "SmallVector {0}_operands{{{1}};", cppIdent, + interleaveToString(operandCppValues, ", "))); + stmts.push_back(llvm::formatv( + "Value {0} = {1}::{2}(b, b.getLoc(), attrs, {0}_operands);", + cppIdent, className, + arithFnToOperatorInstantiationFunction( + expression.arithFn->fnName))); + return cppIdent; } if (expression.typeFn) {