Index: mlir/include/mlir/Dialect/Linalg/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/Linalg/CMakeLists.txt +++ mlir/include/mlir/Dialect/Linalg/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Frontend) add_subdirectory(IR) add_subdirectory(TransformOps) Index: mlir/include/mlir/Dialect/Linalg/Frontend/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Linalg/Frontend/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS LinalgFrontendInterfaces.td) +mlir_tablegen(LinalgFrontendOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(LinalgFrontendOpInterfaces.cpp.inc -gen-op-interface-defs) +add_public_tablegen_target(MLIRLinalgFrontendInterfacesIncGen) +add_dependencies(mlir-headers MLIRLinalgFrontendInterfacesIncGen) Index: mlir/include/mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.h @@ -0,0 +1,24 @@ +//===- LinalgFrontendOpInterfaces.h - Linalg frontend op interfaces -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the frontend operation interfaces for Linalg +// operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_FRONTEND_LINALGFRONTENDINTERFACES_H_ +#define MLIR_DIALECT_LINALG_FRONTEND_LINALGFRONTENDINTERFACES_H_ + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/IR/Builders.h" + +/// Include the generated interface declarations. +#include "mlir/Dialect/Linalg/Frontend/LinalgFrontendOpInterfaces.h.inc" + +#endif // MLIR_DIALECT_LINALG_FRONTEND_LINALGFRONTENDINTERFACES_H_ Index: mlir/include/mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.td @@ -0,0 +1,289 @@ +//===- LinalgFrontendInterfaces.td - Linalg Frontend Interfaces -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Linalg Frontend Interfaces +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_FRONTEND_LINALGFRONTENDINTERFACES +#define LINALG_FRONTEND_LINALGFRONTENDINTERFACES + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Linalg/IR/LinalgBase.td" + +class Linalg_FrontendInterfaceBase : OpInterface { + let cppNamespace = "::mlir::linalg"; +} + +class Linalg_InstantiateOperatorOpInterface< + string operatorCamelCase, + string attrName, + string defFloatOp = "", + string defIntOp = "", + string defBoolOp = "", + string defComplexOp = ""> : + Linalg_FrontendInterfaceBase< + "Instantiate" # operatorCamelCase # "OperatorOpInterface"> +{ + let description = [{ + Base class for all interfaces that allow for the instantiation of + an arithmetic operator via an attribute. + + `operatorCamelCase` is supposed to be the name of the arithmetic + operator in camel case, e.g., `Add` or `MaxUnsigned`. + + `attrName` specifies the name of the attribute that will be parsed + into the name of the operation implementing the operator and an + optional result type specified after a colon. + + `defIntOp`, `defBoolOp`, `defFloatOp` and `defComplexOp` define + the default operations that get instantiated if no value is + provided in the attribute. All of `defIntOp`, `defBoolOp`, + `defFloatOp` and `DefComplexOp` may be omitted if there are no + sensible default implementations. + }]; + + let extraClassDeclaration = [{ + static ::mlir::Value instantiateOperator( + ::mlir::OpBuilder & builder, + ::mlir::Location location, + ::mlir::NamedAttrList attrs, + ::mlir::ValueRange operands) + { + ::llvm::Optional<::mlir::NamedAttribute> attr = + attrs.getNamed("}] # attrName # [{"); + + // If attribute was not specified, fall back to default operation + if(!attr.has_value()) { + if(::mlir::Value opr = Instantiate}] # operatorCamelCase # [{OperatorOpInterface::instantiateDefaultOperator( + builder, location, operands)) { + return opr; + } else { + ::mlir::emitError( + location, + "Could not generate default operation implementing " + "operator }] # attrName # [{ for the given operands. " + "Please specify an operation using the " + "attribute '}] # attrName # [{'."); + llvm_unreachable(""); + } + } + + ::mlir::StringAttr strAttr = + attr.value().getValue().dyn_cast<::mlir::StringAttr>(); + + if (!attr) { + ::mlir::emitError( + location, + "Attribute }] # attrName # [{ must be a string attribute."); + llvm_unreachable(""); + } + + // Extract operation name and result type from specification + std::pair<::llvm::StringRef, ::llvm::StringRef> spec = + strAttr.strref().split(":"); + ::llvm::StringRef opName = std::get<0>(spec); + ::llvm::StringRef resTypeName = std::get<1>(spec); + + ::mlir::Type resType; + + // Use type of LHS operand by default for the result type + if(resTypeName.empty()) { + if(operands.empty()) { + ::mlir::emitError(location, "Missing result type for }] # attrName # [{ operator."); + llvm_unreachable(""); + } + + resType = operands[0].getType(); + } else { + resType = ::mlir::parseType(resTypeName, builder.getContext()); + + if(!resType) { + ::mlir::emitError(location, "Could not parse type '") << resTypeName << "'"; + llvm_unreachable(""); + } + } + + ::mlir::OperationState state(location, opName, operands, resType, {}); + ::mlir::Operation *oprtr = builder.create(state); + return oprtr->getResult(0); + } + + static ::mlir::Value instantiateDefaultOperator( + ::mlir::OpBuilder & builder, + ::mlir::Location location, + ::mlir::ValueRange operands) + { + const char* opName = nullptr; + ::mlir::Type resType; + + }] # + !if(!not(!empty(defFloatOp)), [{ + if(::llvm::all_of(operands, [](::mlir::Value v) { + return v.getType().isa<::mlir::FloatType>(); + })) { + opName = "}] # defFloatOp # [{"; + resType = operands[0].getType(); + } + }], [{ }]) # + !if(!not(!empty(defBoolOp)), [{ + if(::llvm::all_of(operands, [](::mlir::Value v) { + return v.getType().isa<::mlir::IntegerType>() && + v.getType().getIntOrFloatBitWidth() == 1; + })) { + opName = "}] # defBoolOp # [{"; + resType = operands[0].getType(); + } + }], [{ }]) # + !if(!not(!empty(defIntOp)), [{ + if(!opName && ::llvm::all_of(operands, [](::mlir::Value v) { + return v.getType().isa<::mlir::IntegerType>(); + })) { + opName = "}] # defIntOp # [{"; + resType = operands[0].getType(); + } + }], [{ }]) # + !if(!not(!empty(defComplexOp)), [{ + if(::llvm::all_of(operands, [](::mlir::Value v) { + return v.getType().isa<::mlir::ComplexType>(); + })) { + opName = "}] # defComplexOp # [{"; + resType = operands[0].getType(); + } + }], [{ }]) # + [{ + + if(opName != nullptr) { + ::mlir::OperationState state(location, opName, operands, resType, {}); + ::mlir::Operation *oprtr = builder.create(state); + return oprtr->getResult(0); + } + + return ::mlir::Value(); + } + + static bool isOperatorInstance(::mlir::Operation &operation) { + ::mlir::StringAttr strAttr = + operation.template getAttrOfType<::mlir::StringAttr>("}] # attrName # [{"); + + if(strAttr) { + // Attribute was specified: use operation from attribute + std::pair<::llvm::StringRef, ::llvm::StringRef> spec = + strAttr.strref().split(":"); + + ::llvm::StringRef opName = std::get<0>(spec); + + return (operation.getName().getStringRef() == opName); + } + + // Attribute was NOT specified: use default operations + }] # + !if(!not(!empty(defFloatOp)), [{ + if(operation.getName().getStringRef() == "}] # defFloatOp # [{") + return true; + }], [{ }]) # + !if(!not(!empty(defIntOp)), [{ + if(operation.getName().getStringRef() == "}] # defIntOp # [{") + return true; + }], [{ }]) # + !if(!not(!empty(defBoolOp)), [{ + if(operation.getName().getStringRef() == "}] # defBoolOp # [{") + return true; + }], [{ }]) # + !if(!not(!empty(defComplexOp)), [{ + if(operation.getName().getStringRef() == "}] # defComplexOp # [{") + return true; + }], [{ }]) # + [{ + + return false; + } + + /// Returns a string containing the name of the operator for this + /// instantiation interface + static ::mlir::StringRef getOperatorName() { + return "}] # operatorCamelCase # [{"; + } + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/"Checks if an operation is the " # attrName # + " operator for this operation", + /*retTy=*/"bool", + /*methodName=*/"is" # operatorCamelCase # "Operator", + /*args=*/(ins "::mlir::Operation &":$operation), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return Instantiate}] # operatorCamelCase # [{OperatorOpInterface::isOperatorInstance(operation); + }] + >, + + StaticInterfaceMethod< + /*desc=*/"Instantiates the " # attrName # " operator in the absence of " # + "an attribute specifying it", + /*retTy=*/"::mlir::Value", + /*methodName=*/"instantiateDefault" # operatorCamelCase # "Operator", + /*args=*/(ins + "::mlir::OpBuilder &":$builder, + "::mlir::Location":$location, + "::mlir::ValueRange":$operands), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return Instantiate}] # operatorCamelCase # [{OperatorOpInterface::instantiateDefaultOperator(builder, location, operands); + }] + >, + + StaticInterfaceMethod< + /*desc=*/"Instantiates the " # attrName # " operator from the " # + "attribute \"" # attrName # "\"", + /*retTy=*/"::mlir::Value", + /*methodName=*/"instantiate" # operatorCamelCase # "Operator", + /*args=*/(ins + "::mlir::OpBuilder &":$builder, + "::mlir::Location":$location, + "::mlir::NamedAttrList":$attrs, + "::mlir::ValueRange":$operands), + /*methodBody=*/[{}], + /*defaultImplementation=*/[{ + return Instantiate}] # operatorCamelCase # [{OperatorOpInterface::instantiateOperator(builder, location, attrs, operands); + }] + > + ]; +} + +// Create one interface for each common arithmetic operator +def Linalg_InstantiateAddOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Add", "add", "arith.addf", "arith.addi", "arith.ori", "complex.add">; +def Linalg_InstantiateSubOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Sub", "sub", "arith.subf", "arith.subi", "arith.subi", "complex.sub">; +def Linalg_InstantiateMulOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Mul", "mul", "arith.mulf", "arith.muli", "arith.andi", "complex.mul">; +def Linalg_InstantiateMaxSignedOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"MaxSigned", "max_signed", "arith.maxf", "arith.maxsi">; +def Linalg_InstantiateMaxUnsignedOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"MaxUnsigned", "max_unsigned", "arith.maxf", "arith.maxui", "arith.maxui">; +def Linalg_InstantiateMinSignedOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"MinSigned", "min_signed", "arith.minf", "arith.minsi">; +def Linalg_InstantiateMinUnsignedOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"MinUnsigned", "min_unsigned", "arith.minf", "arith.minui", "arith.minui">; + +def Linalg_InstantiateExpOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Exp", "exp", "math.exp">; +def Linalg_InstantiateLogOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Log", "log", "math.log">; +def Linalg_InstantiateAbsOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Abs", "abs", "math.absf", "math.absi">; +def Linalg_InstantiateCeilOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Ceil", "ceil", "math.ceil">; +def Linalg_InstantiateFloorOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Floor", "floor", "math.floor">; +def Linalg_InstantiateNegfOperatorOpInterface : + Linalg_InstantiateOperatorOpInterface<"Negf", "negf", "arith.negf">; + +#endif // LINALG_FRONTEND_LINALGFRONTENDINTERFACES Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h +++ mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h @@ -64,6 +64,7 @@ } // namespace linalg } // namespace mlir +#include "mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.h" #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc" /// Include the generated interface declarations. Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -14,6 +14,7 @@ #ifndef LINALG_STRUCTURED_OPS #define LINALG_STRUCTURED_OPS +include "mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.td" include "mlir/Dialect/Linalg/IR/LinalgBase.td" include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" Index: mlir/lib/Dialect/Linalg/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Linalg/CMakeLists.txt +++ mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Analysis) +add_subdirectory(Frontend) add_subdirectory(IR) add_subdirectory(TransformOps) add_subdirectory(Transforms) Index: mlir/lib/Dialect/Linalg/Frontend/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/Linalg/Frontend/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_dialect_library(MLIRLinalgFrontend + LinalgFrontendInterfaces.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg + + LINK_LIBS PUBLIC + MLIRIR +) Index: mlir/lib/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.cpp @@ -0,0 +1,11 @@ +//===- LinalgFrontendInterfaces.cpp - -------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.h" + +#include "mlir/Dialect/Linalg/Frontend/LinalgFrontendOpInterfaces.cpp.inc" Index: mlir/lib/Dialect/Linalg/IR/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -21,6 +21,7 @@ MLIRDialectUtils MLIRInferTypeOpInterface MLIRIR + MLIRLinalgFrontend MLIRParser MLIRSideEffectInterfaces MLIRSparseTensorDialect Index: mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp =================================================================== --- mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -58,16 +58,22 @@ }; } -/// Return the unique instance of OpType in `block` if it is indeed unique. -/// Return null if none or more than 1 instances exist. -template -static OpType getSingleOpOfType(Block &block) { - OpType res = nullptr; - block.walk([&](OpType op) { +/// Return the unique operation in `block` that satisfies `predicate` +/// if it is indeed unique. Return null if none or more than 1 +/// instances exist. +static Operation * +getSingleOpMatching(Block &block, std::function predicate) { + Operation *res = nullptr; + + block.walk([&](Operation *op) { + if (!predicate(op)) + return WalkResult::advance(); + if (res) { res = nullptr; return WalkResult::interrupt(); } + res = op; return WalkResult::advance(); }); @@ -75,18 +81,29 @@ } /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))` -/// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent +/// on the field (+, *), where u1, u2, u3, u4 and u5 represent /// unary operations that may change the type. -template -static bool isAddMul(Block &block) { +static bool hasAddMulRegion(linalg::LinalgOp &linalgOp) { + Block &block = linalgOp->getRegion(0).front(); + if (block.getNumArguments() != 3) return false; Operation *yieldOp = block.getTerminator(); if (yieldOp->getNumOperands() != 1) return false; - AddOpType addOp = getSingleOpOfType(block); - MulOpType mulOp = getSingleOpOfType(block); + auto addOpIface = + dyn_cast(linalgOp.getOperation()); + auto mulOpIface = + dyn_cast(linalgOp.getOperation()); + + if (!addOpIface || !mulOpIface) + return false; + + Operation *addOp = getSingleOpMatching( + block, [&](Operation *op) { return addOpIface.isAddOperator(*op); }); + Operation *mulOp = getSingleOpMatching( + block, [&](Operation *op) { return mulOpIface.isMulOperator(*op); }); if (!addOp || !mulOp) return false; @@ -128,11 +145,7 @@ [](AffineMap m) { return !m.isProjectedPermutation(); })) return MatchContractionResult::NotProjectedPermutations; // TODO: more fields than add/mul. - if (!isAddMul(linalgOp->getRegion(0).front()) && - !isAddMul(linalgOp->getRegion(0).front()) && - !isAddMul( - linalgOp->getRegion(0).front()) && - !isAddMul(linalgOp->getRegion(0).front())) + if (!hasAddMulRegion(linalgOp)) return MatchContractionResult::NotAddMul; return MatchContractionResult::Success; } Index: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Linalg/Frontend/LinalgFrontendInterfaces.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -278,20 +279,23 @@ // Helper build the unary, binary, and type conversion functions defined by the // DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that uses this class. // -// Implementations of the math functions must be polymorphic over numeric types, -// internally performing necessary casts. If the function application makes no -// sense, then the only recourse is to assert and return nullptr. This can be -// extended later if it becomes possible to fail construction of the region. The -// invariant should be enforced at a higher level. +// Implementations of the math functions are provided by operator +// instantiation interfaces, one per mathematical operator. These +// interfaces either instantiate an operation from the name of an +// operation provided as an attribute or a default operation for the +// given values. // -// TODO: These helpers are currently type polymorphic over the class of integer -// and floating point types, but they will not internally cast within bit -// widths of a class (mixed precision such as i8->i32) or across classes -// (i.e. mixed float and integer). Many such combinations are ambiguous or need -// to be handled with care and work is being considered to extend the op -// language to make such cases explicit. In the mean-time, violating this will -// fail verification, which is deemed acceptable. -//===----------------------------------------------------------------------===// +// TODO: The default instantiation scheme in the absence of an +// attribute specifying an operation implementing an operator is +// currently type polymorphic over the class of integer and floating +// point types, but does not internally cast within bit widths of a +// class (mixed precision such as i8->i32) or across classes +// (i.e. mixed float and integer). Many such combinations are +// ambiguous or need to be handled with care and work is being +// considered to extend the op language to make such cases +// explicit. In the mean-time, violating this will fail verification, +// which is deemed acceptable. +// ===----------------------------------------------------------------------===// namespace { @@ -301,88 +305,99 @@ : context(context), block(block) {} // Build the unary functions defined by OpDSL. - Value buildUnaryFn(UnaryFn unaryFn, Value arg) { - if (!isFloatingPoint(arg)) - llvm_unreachable("unsupported non numeric type"); + Value buildUnaryFn(ArrayRef attrs, UnaryFn unaryFn, + Value arg) { OpBuilder builder = getBuilder(); + Value ret; + switch (unaryFn) { case UnaryFn::exp: - return builder.create(arg.getLoc(), arg); + ret = InstantiateExpOperatorOpInterface::instantiateOperator( + builder, arg.getLoc(), attrs, arg); + break; case UnaryFn::log: - return builder.create(arg.getLoc(), arg); + ret = InstantiateLogOperatorOpInterface::instantiateOperator( + builder, arg.getLoc(), attrs, arg); + break; case UnaryFn::abs: - return builder.create(arg.getLoc(), arg); + ret = InstantiateAbsOperatorOpInterface::instantiateOperator( + builder, arg.getLoc(), attrs, arg); + break; case UnaryFn::ceil: - return builder.create(arg.getLoc(), arg); + ret = InstantiateCeilOperatorOpInterface::instantiateOperator( + builder, arg.getLoc(), attrs, arg); + break; case UnaryFn::floor: - return builder.create(arg.getLoc(), arg); + ret = InstantiateFloorOperatorOpInterface::instantiateOperator( + builder, arg.getLoc(), attrs, arg); + break; case UnaryFn::negf: - return builder.create(arg.getLoc(), arg); + ret = InstantiateNegfOperatorOpInterface::instantiateOperator( + builder, arg.getLoc(), attrs, arg); + break; + default: + llvm_unreachable("unsupported unary function"); } - llvm_unreachable("unsupported unary function"); + + if (!ret) { + emitError(arg.getLoc(), "Could not instantiate operator '") + << stringifyEnum(unaryFn) << "' for type '" << arg.getType() << "'"; + } + + return ret; } // Build the binary functions defined by OpDSL. - Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) { - bool allComplex = isComplex(arg0) && isComplex(arg1); - bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); - bool allInteger = isInteger(arg0) && isInteger(arg1); - bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 && - arg1.getType().getIntOrFloatBitWidth() == 1; - if (!allComplex && !allFloatingPoint && !allInteger) - llvm_unreachable("unsupported non numeric type"); + Value buildBinaryFn(ArrayRef attrs, BinaryFn binaryFn, + Value arg0, Value arg1) { OpBuilder builder = getBuilder(); + Value ret; + switch (binaryFn) { case BinaryFn::add: - if (allComplex) - return builder.create(arg0.getLoc(), arg0, arg1); - if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - if (allBool) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + ret = InstantiateAddOperatorOpInterface::instantiateOperator( + builder, arg0.getLoc(), attrs, {arg0, arg1}); + break; case BinaryFn::sub: - if (allComplex) - return builder.create(arg0.getLoc(), arg0, arg1); - if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - if (allBool) - llvm_unreachable("unsupported operation: sub with bools"); - return builder.create(arg0.getLoc(), arg0, arg1); + ret = InstantiateSubOperatorOpInterface::instantiateOperator( + builder, arg0.getLoc(), attrs, {arg0, arg1}); + break; case BinaryFn::mul: - if (allComplex) - return builder.create(arg0.getLoc(), arg0, arg1); - if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - if (allBool) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + ret = InstantiateMulOperatorOpInterface::instantiateOperator( + builder, arg0.getLoc(), attrs, {arg0, arg1}); + break; case BinaryFn::max_signed: - assert(!allComplex); - if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + ret = InstantiateMaxSignedOperatorOpInterface::instantiateOperator( + builder, arg0.getLoc(), attrs, {arg0, arg1}); + break; case BinaryFn::min_signed: - assert(!allComplex); - if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + ret = InstantiateMinSignedOperatorOpInterface::instantiateOperator( + builder, arg0.getLoc(), attrs, {arg0, arg1}); + break; case BinaryFn::max_unsigned: - assert(!allComplex); - if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + ret = InstantiateMaxUnsignedOperatorOpInterface::instantiateOperator( + builder, arg0.getLoc(), attrs, {arg0, arg1}); + break; case BinaryFn::min_unsigned: - assert(!allComplex); - if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); + ret = InstantiateMinUnsignedOperatorOpInterface::instantiateOperator( + builder, arg0.getLoc(), attrs, {arg0, arg1}); + break; + default: + llvm_unreachable("unsupported binary function"); } - llvm_unreachable("unsupported binary function"); + + if (!ret) { + emitError(arg0.getLoc(), "Could not instantiate operator '") + << stringifyEnum(binaryFn) << "' for types '" << arg0.getType() + << "' and '" << arg1.getType() << "'"; + } + + return ret; } // Build the type functions defined by OpDSL. - Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { + Value buildTypeFn(ArrayRef attrs, TypeFn typeFn, Type toType, + Value operand) { switch (typeFn) { case TypeFn::cast_signed: return cast(toType, operand, false); Index: mlir/test/Dialect/Linalg/named-ops-custom-operators.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Linalg/named-ops-custom-operators.mlir @@ -0,0 +1,46 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s + +// Specify matrix multiplication on float, integer, index and complex +// values using attributes for the add and mul operator. + +// CHECK-LABEL: func @matmul_float +func.func @matmul_float(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>, %output: memref<2x2xf32>) -> () { + linalg.matmul { add = "arith.addf", mul = "arith.mulf" } + ins(%arg0, %arg1: memref<2x2xf32>, + memref<2x2xf32>) + outs(%output: memref<2x2xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @matmul_int +func.func @matmul_int(%arg0: memref<2x2xi32>, %arg1: memref<2x2xi32>, %output: memref<2x2xi32>) -> () { + linalg.matmul { add = "arith.addi", mul = "arith.muli" } + ins(%arg0, %arg1: memref<2x2xi32>, + memref<2x2xi32>) + outs(%output: memref<2x2xi32>) + return +} + +// ----- + +// CHECK-LABEL: func @matmul_index +func.func @matmul_index(%arg0: memref<2x2xindex>, %arg1: memref<2x2xindex>, %output: memref<2x2xindex>) -> () { + linalg.matmul { add = "arith.addi", mul = "arith.muli" } + ins(%arg0, %arg1: memref<2x2xindex>, + memref<2x2xindex>) + outs(%output: memref<2x2xindex>) + return +} + +// ----- + +// CHECK-LABEL: func @matmul_complex +func.func @matmul_complex(%arg0: memref<2x2xcomplex>, %arg1: memref<2x2xcomplex>, %output: memref<2x2xcomplex>) -> () { + linalg.matmul { add = "complex.add", mul = "complex.mul" } + ins(%arg0, %arg1: memref<2x2xcomplex>, + memref<2x2xcomplex>) + outs(%output: memref<2x2xcomplex>) + return +} Index: mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml =================================================================== --- mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -96,10 +96,10 @@ # IMPL-NEXT: } # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); -# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]); +# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(attrs, castVal, block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); -# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]); -# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]]); +# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(attrs, castVal, block.getArgument(0).getType(), [[VAL2]]); +# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(attrs, BinaryFn::add, [[VAL1]], [[VAL3]]); # @linalg_structured_op @@ -318,8 +318,8 @@ # IMPL: UnaryFn unary_funVal = UnaryFn::exp # IMPL: BinaryFn binary_funVal = BinaryFn::add -# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0)) -# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0)) +# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(attrs, unary_funVal, block.getArgument(0)) +# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(attrs, binary_funVal, [[VAL0]], block.getArgument(0)) # IMPL-NEXT: yields.push_back([[VAL1]]) # @linalg_structured_op Index: mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp =================================================================== --- mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -28,6 +28,8 @@ #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/YAMLTraits.h" +#include + using namespace mlir; using llvm::yaml::Input; @@ -649,6 +651,45 @@ } )FMT"; +/// Returns the operator instantiation interface name for an +/// arithmetic function. +static std::string +scalarFnToOperatorInstantiationInterface(const std::string &fnName) { + return "Linalg_Instantiate" + + llvm::convertToCamelFromSnakeCase(fnName, true) + + "OperatorOpInterface"; +} + +/// Returns a set with all operator instantiation interfaces required +/// for an operation. +static std::set +requiredOperatorInstantiationInterfaces(LinalgOpConfig &opConfig) { + std::set interfaces; + + if (opConfig.structuredOp.hasValue()) { + std::function collectInterfaces = + [&](const ScalarExpression &expr) { + if (expr.scalarFn.hasValue()) { + if (expr.scalarFn->kind != ScalarFnKind::Type && + expr.scalarFn->fnName.hasValue()) { + interfaces.emplace(scalarFnToOperatorInstantiationInterface( + expr.scalarFn->fnName.getValue())); + } + + for (const ScalarExpression &subexpr : + expr.scalarFn.getValue().operands) + collectInterfaces(subexpr); + } + }; + + for (const ScalarAssign &assign : + opConfig.structuredOp.getValue().assignments) + collectInterfaces(assign.value); + } + + return interfaces; +} + // Implementations of fold and getEffects. // Parameters: // {0}: Class name @@ -706,6 +747,16 @@ interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); + std::set operatorIfaces = + requiredOperatorInstantiationInterfaces(opConfig); + + if (!operatorIfaces.empty()) { + if (!interfaceNameList.empty()) + interfaceNameList += ", "; + + interfaceNameList += interleaveToString(operatorIfaces, ", "); + } + std::string definitionList; for (const std::string &definition : opConfig.metadata->defines) { static const char definitionFmt[] = "let {0} = 1;\n"; @@ -1134,8 +1185,8 @@ // Call the function builder. std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back(llvm::formatv( - "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName, - funcType, interleaveToString(operandCppValues, ", "))); + "Value {0} = helper.build{1}(attrs, {2}, {3});", cppIdent, + enumName, funcType, interleaveToString(operandCppValues, ", "))); return cppIdent; } emitError(genContext.getLoc()) << "unknown ScalarExpression type";