Index: mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -55,7 +55,9 @@ add_dependencies(mlir-headers MLIRLinalgStructuredOpsIncGen) set(LLVM_TARGET_DEFINITIONS LinalgInterfaces.td) -mlir_tablegen(LinalgInterfaces.h.inc -gen-op-interface-decls) -mlir_tablegen(LinalgInterfaces.cpp.inc -gen-op-interface-defs) +mlir_tablegen(LinalgOpInterfaces.h.inc -gen-op-interface-decls) +mlir_tablegen(LinalgOpInterfaces.cpp.inc -gen-op-interface-defs) +mlir_tablegen(LinalgTypeInterfaces.h.inc -gen-type-interface-decls) +mlir_tablegen(LinalgTypeInterfaces.cpp.inc -gen-type-interface-defs) add_public_tablegen_target(MLIRLinalgInterfacesIncGen) add_dependencies(mlir-headers MLIRLinalgInterfacesIncGen) Index: mlir/include/mlir/Dialect/Linalg/IR/Linalg.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -105,10 +105,16 @@ #include "mlir/Dialect/Linalg/IR/LinalgOpsDialect.h.inc" //===----------------------------------------------------------------------===// -// Linalg Interfaces +// Linalg Op Interfaces //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/IR/LinalgOpInterfaces.h" + +//===----------------------------------------------------------------------===// +// Linalg Type Interfaces +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgTypeInterfaces.h" //===----------------------------------------------------------------------===// // Linalg Dialect Operations Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -1047,4 +1047,164 @@ let verify = [{ return detail::verifyStructuredOpInterface($_op); }]; } +def LinalgAddOpInterface : OpInterface<"LinalgAddOpInterface"> { + let description = [{ + Operation interface that allows for the identification of + operations that qualify as an addition on two tensor elements. + }]; + + let cppNamespace = "::mlir::linalg"; +} + +def LinalgMulOpInterface : OpInterface<"LinalgMulOpInterface"> { + let description = [{ + Operation interface that allows for the identification of + operations that qualify as a multiplication of two tensor + elements. + }]; + let cppNamespace = "::mlir::linalg"; +} + +def LinalgArithmeticOperatorTypeInterface : TypeInterface<"LinalgArithmeticOperatorTypeInterface"> { + let cppNamespace = "::mlir::linalg"; + + let description = [{ + Type interface that allows for the instantiation of operations + implementing arithmetic operators for scalars. This allows Linalg + named operations to instantiate the correct operation for scalar + operators, e.g., to instantiate an operation implementing an + addition for a custom type when building the body of a matrix + multiplication. + + The set of interface methods covers all scalar operations required + by named operations. A type may provide implementations of only a + subset of the interface methods, in which case only those Linalg + named operations become available for tensors of that type that + use the implemented subset. Instantiation of other named + operations that require at least one arithmetic operator tensor + elements, for which no interface method has been provided, results + in an error at execution time. + + The absence of an implementation of an arithmetic operator can be + indicated in two ways. If there is no sensible possible + implementation of an operator for the type, the respective + interface method should return + `LinalgArithmeticOperatorInstantiationResult::undefined()`. If a + sensible definition may exist, but has not been implemented, the + method should return + `LinalgArithmeticOperatorInstantiationResult::notImplemented()`. This + is also the default return value for any interface method not + overriden by a type. + }]; + + let methods = [ + InterfaceMethod< + /*description=*/"Returns an operation implementing an addition.", + /*retTy=*/"::mlir::linalg::LinalgArithmeticOperatorInstantiationResult", + /*methodName=*/"createAdd", + /*args=*/(ins "::mlir::OpBuilder":$builder, + "::mlir::Value":$lhs, + "::mlir::Value":$rhs), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::linalg::LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }] + >, + InterfaceMethod< + /*description=*/"Returns an operation implementing a subtraction.", + /*retTy=*/"::mlir::linalg::LinalgArithmeticOperatorInstantiationResult", + /*methodName=*/"createSub", + /*args=*/(ins "::mlir::OpBuilder":$builder, + "::mlir::Value":$lhs, + "::mlir::Value":$rhs), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::linalg::LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }] + >, + InterfaceMethod< + /*description=*/"Returns an operation implementing a multiplication.", + /*retTy=*/"::mlir::linalg::LinalgArithmeticOperatorInstantiationResult", + /*methodName=*/"createMul", + /*args=*/(ins "::mlir::OpBuilder":$builder, + "::mlir::Value":$lhs, + "::mlir::Value":$rhs), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::linalg::LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }] + >, + InterfaceMethod< + /*description=*/"Returns an operation returning the maximum of two values.", + /*retTy=*/"::mlir::linalg::LinalgArithmeticOperatorInstantiationResult", + /*methodName=*/"createMax", + /*args=*/(ins "::mlir::OpBuilder":$builder, + "::mlir::Value":$lhs, + "::mlir::Value":$rhs), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::linalg::LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }] + >, + InterfaceMethod< + /*description=*/"Returns an operation returning the unsigned maximum of two values.", + /*retTy=*/"::mlir::linalg::LinalgArithmeticOperatorInstantiationResult", + /*methodName=*/"createUnsignedMax", + /*args=*/(ins "::mlir::OpBuilder":$builder, + "::mlir::Value":$lhs, + "::mlir::Value":$rhs), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::linalg::LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }] + >, + InterfaceMethod< + /*description=*/"Returns an operation returning the minimum of two values.", + /*retTy=*/"::mlir::linalg::LinalgArithmeticOperatorInstantiationResult", + /*methodName=*/"createMin", + /*args=*/(ins "::mlir::OpBuilder":$builder, + "::mlir::Value":$lhs, + "::mlir::Value":$rhs), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::linalg::LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }] + >, + InterfaceMethod< + /*description=*/"Returns an operation returning the unsigned minimum of two values.", + /*retTy=*/"::mlir::linalg::LinalgArithmeticOperatorInstantiationResult", + /*methodName=*/"createUnsignedMin", + /*args=*/(ins "::mlir::OpBuilder":$builder, + "::mlir::Value":$lhs, + "::mlir::Value":$rhs), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::linalg::LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }] + >, + InterfaceMethod< + /*description=*/"Returns an operation returning e raised to the power of x.", + /*retTy=*/"::mlir::linalg::LinalgArithmeticOperatorInstantiationResult", + /*methodName=*/"createExp", + /*args=*/(ins "::mlir::OpBuilder":$builder, + "::mlir::Value":$x), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::linalg::LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }] + >, + InterfaceMethod< + /*description=*/"Returns an operation calculating the e-based logarithm of x.", + /*retTy=*/"::mlir::linalg::LinalgArithmeticOperatorInstantiationResult", + /*methodName=*/"createLog", + /*args=*/(ins "::mlir::OpBuilder":$builder, + "::mlir::Value":$x), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::linalg::LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }] + > + ]; +} + #endif // LINALG_IR_LINALGINTERFACES Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgOpInterfaces.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/IR/LinalgOpInterfaces.h +++ mlir/include/mlir/Dialect/Linalg/IR/LinalgOpInterfaces.h @@ -10,8 +10,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_ -#define MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_ +#ifndef MLIR_DIALECT_LINALG_IR_LINALGOPINTERFACES_H_ +#define MLIR_DIALECT_LINALG_IR_LINALGOPINTERFACES_H_ #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineMap.h" @@ -47,12 +47,15 @@ LogicalResult verifyStructuredOpInterface(Operation *op); } // namespace detail + +/// Attaches the linalg op interfaces to builtin ops +void attachDefaultOpInterfaceImplementations(MLIRContext &ctx); } // namespace linalg } // namespace mlir #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.h.inc" /// Include the generated interface declarations. -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h.inc" +#include "mlir/Dialect/Linalg/IR/LinalgOpInterfaces.h.inc" #endif // MLIR_DIALECT_LINALG_IR_LINALGINTERFACES_H_ Index: mlir/include/mlir/Dialect/Linalg/IR/LinalgTypeInterfaces.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/Linalg/IR/LinalgTypeInterfaces.h @@ -0,0 +1,107 @@ +//===- LinalgTypeInterfaces.h - Linalg operations interfaces --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the type interfaces for Linalg operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_IR_LINALGTYPEINTERFACES_H_ +#define MLIR_DIALECT_LINALG_IR_LINALGTYPEINTERFACES_H_ + +#include "mlir/IR/Builders.h" +#include "llvm/ADT/Optional.h" + +namespace mlir { +namespace linalg { + +/// Result for an attempt to instantiate an arithmetic operator via +/// `LinalgArithmeticOperatorTypeInterface`. +/// +/// On successful instantiation, the result holds a value +/// corresponding to the result of an operation implementing the +/// operator. +/// +/// If the instantiation fails, the result indicates whether this is +/// because the operator is not implemented for the type (e.g., +/// because of an incomplete implementation of +/// `LinalgArithmeticOperatorTypeInterface`) or because there is no +/// sensible implementation of the operator for the type. +class LinalgArithmeticOperatorInstantiationResult { +protected: + enum class Status { + /// Arithmetic operator is implemented for the type + IMPLEMENTED, + + /// Arithmetic operator has not been implemented for the type + NOT_IMPLEMENTED, + + /// No sensible implementation of the arithmetic operator for the + /// type + UNDEFINED + }; + + LinalgArithmeticOperatorInstantiationResult(Status s, llvm::Optional v) + : value(v), status(s) {} + + llvm::Optional value; + Status status; + +public: + /// Constructor generating a result indicating a successful + /// instantiation. The result of the operation implementing the + /// operator is specified by `v`. + LinalgArithmeticOperatorInstantiationResult(Value v) + : value(v), status(Status::IMPLEMENTED) {} + + /// Creates a result indicating that the arithmetic operator is not + /// implemented for the type (e.g., due to an incomplete + /// implementation of `LinalgArithmeticOperatorTypeInterface`). + static LinalgArithmeticOperatorInstantiationResult notImplemented() { + LinalgArithmeticOperatorInstantiationResult res(Status::NOT_IMPLEMENTED, + llvm::None); + + return res; + } + + /// Creates a result indicating that there is no sensible + /// implementation of the arithmetic operator for the type. + static LinalgArithmeticOperatorInstantiationResult undefined() { + LinalgArithmeticOperatorInstantiationResult res(Status::UNDEFINED, + llvm::None); + + return res; + } + + /// Returns the result of the operation implementing the arithmetic + /// operator. If the instantiation was a failure, the instance of + /// `Optional` is empty. + llvm::Optional getValue() { return this->value; } + + /// Returns true if the arithmetic operator is implemented and + /// instantiation was a success. + bool isImplemented() { return this->status == Status::IMPLEMENTED; } + + /// Returns true if the instantiation was a failure due to arithmetic + /// operator is not being implemented for the type. + bool isNotImplemented() { return this->status == Status::NOT_IMPLEMENTED; } + + /// Returns true if the instantiation was a failure because no + /// sensible implementation of the arithmetic operator exists for the + /// type. + bool isUndefined() { return this->status == Status::UNDEFINED; } +}; + +/// Attaches the linalg type interfaces to builtin types +void attachDefaultTypeInterfaceImplementations(MLIRContext &ctx); +} // namespace linalg +} // namespace mlir + +/// Include the generated type interface declarations. +#include "mlir/Dialect/Linalg/IR/LinalgTypeInterfaces.h.inc" + +#endif // MLIR_DIALECT_LINALG_IR_LINALGTYPEINTERFACES_H_ Index: mlir/lib/Dialect/Linalg/IR/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -1,5 +1,8 @@ add_mlir_dialect_library(MLIRLinalg - LinalgInterfaces.cpp + LinalgOpInterfaces.cpp + LinalgTypeInterfaces.cpp + LinalgDefaultTypeInterfaceImplementations.cpp + LinalgDefaultOpInterfaceImplementations.cpp LinalgOps.cpp LinalgDialect.cpp Index: mlir/lib/Dialect/Linalg/IR/LinalgDefaultOpInterfaceImplementations.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Linalg/IR/LinalgDefaultOpInterfaceImplementations.cpp @@ -0,0 +1,46 @@ +//===- LinalgDefaultOpInterfaceImplementations.cpp ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides implementations of the linalg op interfaces for +// builtin ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/LinalgOpInterfaces.h" + +namespace mlir { +namespace linalg { +/// Implementation of LinalgAddOpInterface for arith::AddFOp +struct AddFOpAddOpInterface + : public LinalgAddOpInterface::ExternalModel {}; + +/// Implementation of LinalgAddOpInterface for arith::AddIOp +struct AddIOpAddOpInterface + : public LinalgAddOpInterface::ExternalModel {}; + +/// Implementation of LinalgMulOpInterface for arith::MulFOp +struct MulFOpMulOpInterface + : public LinalgMulOpInterface::ExternalModel {}; + +/// Implementation of LinalgMulOpInterface for arith::MulIOp +struct MulIOpMulOpInterface + : public LinalgMulOpInterface::ExternalModel {}; + +void attachDefaultOpInterfaceImplementations(MLIRContext &ctx) { + arith::AddFOp::attachInterface(ctx); + arith::AddIOp::attachInterface(ctx); + arith::MulFOp::attachInterface(ctx); + arith::MulIOp::attachInterface(ctx); +} +} // namespace linalg +} // namespace mlir Index: mlir/lib/Dialect/Linalg/IR/LinalgDefaultTypeInterfaceImplementations.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Linalg/IR/LinalgDefaultTypeInterfaceImplementations.cpp @@ -0,0 +1,139 @@ +//===- LinalgDefaultTypeInterfaceImplementations.cpp ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file provides implementations of the linalg type interfaces +// for builtin types. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypeInterfaces.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/BuiltinTypes.h" + +namespace mlir { +namespace linalg { +/// Implementation of ArithmeticOperatorTypeInterface for floats +template +struct FloatArithmeticOperatorTypeInterface + : public LinalgArithmeticOperatorTypeInterface::ExternalModel< + FloatArithmeticOperatorTypeInterface, FloatTy> { + LinalgArithmeticOperatorInstantiationResult + createAdd(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createSub(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createMul(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createMax(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createUnsignedMax(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createMin(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createUnsignedMin(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createExp(Type &t, OpBuilder &builder, Value &x) const { + return builder.create(x.getLoc(), x).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createLog(Type &t, OpBuilder &builder, Value &x) const { + return builder.create(x.getLoc(), x).getResult(); + } +}; + +/// Implementation of ArithmeticOperatorTypeInterface for integers +struct IntegerArithmeticOperatorTypeInterface + : public LinalgArithmeticOperatorTypeInterface::ExternalModel< + IntegerArithmeticOperatorTypeInterface, IntegerType> { + LinalgArithmeticOperatorInstantiationResult + createAdd(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createSub(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createMul(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createMax(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createUnsignedMax(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createMin(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createUnsignedMin(Type &t, OpBuilder &builder, Value &lhs, Value &rhs) const { + return builder.create(lhs.getLoc(), lhs, rhs).getResult(); + } + + LinalgArithmeticOperatorInstantiationResult + createExp(Type &t, OpBuilder &builder, Value &x) const { + return LinalgArithmeticOperatorInstantiationResult::notImplemented(); + } + + LinalgArithmeticOperatorInstantiationResult + createLog(Type &t, OpBuilder &builder, Value &x) const { + return LinalgArithmeticOperatorInstantiationResult::notImplemented(); + } +}; + +void attachDefaultTypeInterfaceImplementations(MLIRContext &ctx) { + BFloat16Type::attachInterface< + FloatArithmeticOperatorTypeInterface>(ctx); + Float16Type::attachInterface< + FloatArithmeticOperatorTypeInterface>(ctx); + Float32Type::attachInterface< + FloatArithmeticOperatorTypeInterface>(ctx); + Float64Type::attachInterface< + FloatArithmeticOperatorTypeInterface>(ctx); + Float80Type::attachInterface< + FloatArithmeticOperatorTypeInterface>(ctx); + Float128Type::attachInterface< + FloatArithmeticOperatorTypeInterface>(ctx); + + IntegerType::attachInterface(ctx); +} +} // namespace linalg +} // namespace mlir Index: mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp =================================================================== --- mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -121,6 +121,8 @@ >(namedStructuredOpRegionBuilders); addInterfaces(); + attachDefaultTypeInterfaceImplementations(*getContext()); + attachDefaultOpInterfaceImplementations(*getContext()); } LogicalResult LinalgDialect::verifyOperationAttribute(Operation *op, Index: mlir/lib/Dialect/Linalg/IR/LinalgOpInterfaces.cpp =================================================================== --- mlir/lib/Dialect/Linalg/IR/LinalgOpInterfaces.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgOpInterfaces.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Linalg/IR/LinalgOpInterfaces.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" @@ -21,7 +21,7 @@ using namespace mlir::linalg; /// Include the definitions of the copy operation interface. -#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.cpp.inc" +#include "mlir/Dialect/Linalg/IR/LinalgOpInterfaces.cpp.inc" //===----------------------------------------------------------------------===// // ContractionOpInterface implementation @@ -42,16 +42,20 @@ }; } -/// Return the unique instance of OpType in `block` if it is indeed unique. -/// Return null if none or more than 1 instances exist. -template -static OpType getSingleOpOfType(Block &block) { - OpType res = nullptr; - block.walk([&](OpType op) { - if (res) { - res = nullptr; +/// Return the unique operation in `block` implementing the interface +/// `OpIface` if it is indeed unique. Return null if none or more +/// than 1 instances exist. +template +static Operation *getSingleOpImplementingInterface(Block &block) { + Operation *res = nullptr; + + block.walk([&](Operation *op) { + if (!dyn_cast(op)) + return WalkResult::advance(); + + if (res) return WalkResult::interrupt(); - } + res = op; return WalkResult::advance(); }); @@ -61,7 +65,6 @@ /// Detect whether res is any permutation of `u5(u1(c) + u2(u3(a) * u4(b)))` /// on the field (AddOpType, MulOpType), where u1, u2, u3, u4 and u5 represent /// unary operations that may change the type. -template static bool isAddMul(Block &block) { if (block.getNumArguments() != 3) return false; @@ -69,8 +72,11 @@ if (yieldOp->getNumOperands() != 1) return false; - AddOpType addOp = getSingleOpOfType(block); - MulOpType mulOp = getSingleOpOfType(block); + Operation *addOp = + getSingleOpImplementingInterface(block); + Operation *mulOp = + getSingleOpImplementingInterface(block); + if (!addOp || !mulOp) return false; @@ -111,10 +117,11 @@ if (llvm::any_of(mapRange, [](AffineMap m) { return !m.isProjectedPermutation(); })) return MatchContractionResult::NotProjectedPermutations; + // TODO: more fields than add/mul. - if (!isAddMul(linalgOp->getRegion(0).front()) && - !isAddMul(linalgOp->getRegion(0).front())) + if (!isAddMul(linalgOp->getRegion(0).front())) return MatchContractionResult::NotAddMul; + return MatchContractionResult::Success; } Index: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -242,88 +242,54 @@ // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__add(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); + return createBinaryArithmeticOperator(BinaryArithmeticOperator::ADD, lhs, + rhs); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__exp(Value x) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(x)) - return builder.create(x.getLoc(), x); - llvm_unreachable("unsupported non numeric type"); + return createUnaryArithmeticOperator(UnaryArithmeticOperator::EXP, x); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__log(Value x) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(x)) - return builder.create(x.getLoc(), x); - llvm_unreachable("unsupported non numeric type"); + return createUnaryArithmeticOperator(UnaryArithmeticOperator::LOG, x); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__sub(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); + return createBinaryArithmeticOperator(BinaryArithmeticOperator::SUB, lhs, + rhs); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__mul(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); + return createBinaryArithmeticOperator(BinaryArithmeticOperator::MUL, lhs, + rhs); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__max(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); + return createBinaryArithmeticOperator(BinaryArithmeticOperator::MAX, lhs, + rhs); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__max_unsigned(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); + return createBinaryArithmeticOperator( + BinaryArithmeticOperator::UNSIGNED_MAX, lhs, rhs); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__min(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); + return createBinaryArithmeticOperator(BinaryArithmeticOperator::MIN, lhs, + rhs); } // NOLINTNEXTLINE(*-identifier-naming): externally called. Value arithfn__min_unsigned(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); + return createBinaryArithmeticOperator( + BinaryArithmeticOperator::UNSIGNED_MIN, lhs, rhs); } void yieldOutputs(ValueRange values) { @@ -360,16 +326,160 @@ MLIRContext *context; Block █ - bool isFloatingPoint(Value value) { return value.getType().isa(); } - bool isInteger(Value value) { return value.getType().isa(); } - OpBuilder getBuilder() { OpBuilder builder(context); builder.setInsertionPointToEnd(&block); return builder; } -}; + /// All binary arithmetic operators that can be instantiated through + /// `LinalgArithmeticOperatorTypeInterface` + enum class BinaryArithmeticOperator { + ADD, + SUB, + MUL, + MAX, + UNSIGNED_MAX, + MIN, + UNSIGNED_MIN + }; + + /// Mapping from binary arithmetic operators to human-readable names, + /// e.g., for error messages + static const char * + binaryArithmeticOperatorName(BinaryArithmeticOperator oper) { + switch (oper) { + case BinaryArithmeticOperator::ADD: + return "Addition"; + case BinaryArithmeticOperator::SUB: + return "Subtraction"; + case BinaryArithmeticOperator::MUL: + return "Multiplication"; + case BinaryArithmeticOperator::MAX: + return "Maximum"; + case BinaryArithmeticOperator::UNSIGNED_MAX: + return "Unsigned maximum"; + case BinaryArithmeticOperator::MIN: + return "Minimum"; + case BinaryArithmeticOperator::UNSIGNED_MIN: + return "Unsigned minimum"; + } + + llvm_unreachable("Unknown binary arithmetic operator"); + } + + /// All unary arithmetic operators that can be instantiated through + /// `LinalgArithmeticOperatorTypeInterface` + enum class UnaryArithmeticOperator { EXP, LOG }; + + /// Mapping from unary arithmetic operators to human-readable names, + /// e.g., for error messages + static const char *unaryArithmeticOperatorName(UnaryArithmeticOperator oper) { + switch (oper) { + case UnaryArithmeticOperator::EXP: + return "Exponentiatin of e"; + case UnaryArithmeticOperator::LOG: + return "Natural logarithm"; + } + + llvm_unreachable("Unknown unary arithmetic operator"); + } + + /// Helper method that attempts to create an operation with the + /// human-readable name `operatorName` for the type `operandType` + /// through invocation of `instFun` with a suitable builder and an + /// instance of `LinalgArithmeticOperatorTypeInterface` to be used + /// for instantiation. + Value createArithmeticOperator( + StringRef operatorName, Type operandType, + function_ref + instFun) { + std::string errBuf; + llvm::raw_string_ostream errSos(errBuf); + + LinalgArithmeticOperatorTypeInterface aoty = + operandType.dyn_cast(); + + if (!aoty) { + errSos << "Unsupported numeric type: "; + operandType.print(errSos); + errSos << " does not implement LinalgArithmeticOperatorTypeInterface"; + llvm_unreachable(errSos.str().c_str()); + } + + OpBuilder builder = getBuilder(); + LinalgArithmeticOperatorInstantiationResult res = instFun(builder, aoty); + + if (!res.isImplemented()) { + errSos << operatorName << " is "; + + if (res.isNotImplemented()) + errSos << "not implemented"; + else if (res.isUndefined()) + errSos << "undefined"; + + errSos << " for type "; + operandType.print(errSos); + + llvm_unreachable(errSos.str().c_str()); + } + + return res.getValue().getValue(); + } + + /// Creates a single instance of the binary arithmetic operator + /// specified by `oper` through invocation of the respective method + /// of `LinalgArithmeticOperatorTypeInterface` on the type of `lhs`. + Value createBinaryArithmeticOperator(BinaryArithmeticOperator oper, Value lhs, + Value rhs) { + auto instFun = [&](OpBuilder &builder, + LinalgArithmeticOperatorTypeInterface &aoty) { + switch (oper) { + case BinaryArithmeticOperator::ADD: + return aoty.createAdd(builder, lhs, rhs); + case BinaryArithmeticOperator::SUB: + return aoty.createSub(builder, lhs, rhs); + case BinaryArithmeticOperator::MUL: + return aoty.createMul(builder, lhs, rhs); + case BinaryArithmeticOperator::MAX: + return aoty.createMax(builder, lhs, rhs); + case BinaryArithmeticOperator::UNSIGNED_MAX: + return aoty.createUnsignedMax(builder, lhs, rhs); + case BinaryArithmeticOperator::MIN: + return aoty.createMin(builder, lhs, rhs); + case BinaryArithmeticOperator::UNSIGNED_MIN: + return aoty.createUnsignedMin(builder, lhs, rhs); + } + + return LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }; + + const char *operName = binaryArithmeticOperatorName(oper); + return createArithmeticOperator(operName, lhs.getType(), instFun); + } + + /// Creates a single instance of the unary arithmetic operator + /// specified by `oper` by invoking + /// `LinalgArithmeticOperatorTypeInterface` on the type of `x`. + Value createUnaryArithmeticOperator(UnaryArithmeticOperator oper, Value x) { + auto instFun = [&](OpBuilder &builder, + LinalgArithmeticOperatorTypeInterface &aoty) { + switch (oper) { + case UnaryArithmeticOperator::EXP: + return aoty.createExp(builder, x); + case UnaryArithmeticOperator::LOG: + return aoty.createLog(builder, x); + } + + return LinalgArithmeticOperatorInstantiationResult::notImplemented(); + }; + + const char *operName = unaryArithmeticOperatorName(oper); + + return createArithmeticOperator(operName, x.getType(), instFun); + } +}; } // namespace //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Linalg/IR/LinalgTypeInterfaces.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/Linalg/IR/LinalgTypeInterfaces.cpp @@ -0,0 +1,11 @@ +//===- LinalgTypeInterfaces.cpp - Linalg type interfaces implementation ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/IR/LinalgTypeInterfaces.h" + +#include "mlir/Dialect/Linalg/IR/LinalgTypeInterfaces.cpp.inc"