diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -21,10 +21,9 @@ // one result, all of which must be complex numbers of the same type. class ComplexArithmeticOp traits = []> : Complex_Op, - ElementwiseMappable])> { + traits # [NoSideEffect, SameOperandsAndResultType, + DeclareOpInterfaceMethods, Elementwise, + Scalarizable, Vectorizable, Tensorizable]> { let arguments = (ins Complex:$lhs, Complex:$rhs); let results = (outs Complex:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -19,7 +19,7 @@ class FloatUnaryOp traits = []> : MathOp, - ElementwiseMappable, + Elementwise, Scalarizable, Vectorizable, Tensorizable, SameOperandsAndResultType]> { let arguments = (ins FloatLike:$operand); @@ -31,7 +31,7 @@ class FloatBinaryOp traits = []> : MathOp, - ElementwiseMappable, + Elementwise, Scalarizable, Vectorizable, Tensorizable, SameOperandsAndResultType]> { let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); let results = (outs FloatLike:$result); diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -72,8 +72,8 @@ // Base class for arithmetic cast operations. class ArithmeticCastOp traits = []> : CastOp])> { + traits # [Elementwise, Vectorizable, Tensorizable, Scalarizable, + DeclareOpInterfaceMethods]> { } // Base class for unary ops. Requires single operand and result. Individual @@ -96,9 +96,8 @@ class FloatUnaryOp traits = []> : UnaryOpSameOperandAndResultType, - ElementwiseMappable])>, + traits # [DeclareOpInterfaceMethods, Elementwise, + Vectorizable, Tensorizable, Scalarizable]>, Arguments<(ins FloatLike:$operand)>; // Base class for standard arithmetic operations. Requires operands and @@ -106,10 +105,9 @@ // types. class ArithmeticOp traits = []> : Op, - ElementwiseMappable])> { + traits # [NoSideEffect, SameOperandsAndResultType, + DeclareOpInterfaceMethods, + Elementwise, Vectorizable, Tensorizable, Scalarizable]> { let results = (outs AnyType:$result); @@ -931,8 +929,8 @@ } def CmpFOp : Std_Op<"cmpf", - [NoSideEffect, SameTypeOperands, ElementwiseMappable, - DeclareOpInterfaceMethods, + [NoSideEffect, SameTypeOperands, Elementwise, Vectorizable, Tensorizable, + Scalarizable, DeclareOpInterfaceMethods, TypesMatchWith< "result type has i1 element type and same shape as operands", "lhs", "result", "getI1SameShape($_self)">]> { @@ -1016,8 +1014,8 @@ } def CmpIOp : Std_Op<"cmpi", - [NoSideEffect, SameTypeOperands, ElementwiseMappable, - DeclareOpInterfaceMethods, + [NoSideEffect, SameTypeOperands, Elementwise, Vectorizable, Tensorizable, + Scalarizable, DeclareOpInterfaceMethods, TypesMatchWith< "result type has i1 element type and same shape as operands", "lhs", "result", "getI1SameShape($_self)">]> { @@ -2160,8 +2158,9 @@ //===----------------------------------------------------------------------===// def SelectOp : Std_Op<"select", [NoSideEffect, - AllTypesMatch<["true_value", "false_value", "result"]>, - ElementwiseMappable, DeclareOpInterfaceMethods]> { + AllTypesMatch<["true_value", "false_value", "result"]>, Elementwise, + Vectorizable, Tensorizable, Scalarizable, + DeclareOpInterfaceMethods]> { let summary = "select operation"; let description = [{ The `select` operation chooses one value based on a binary condition @@ -2392,8 +2391,8 @@ //===----------------------------------------------------------------------===// def SignExtendIOp : Std_Op<"sexti", - [NoSideEffect, ElementwiseMappable, - DeclareOpInterfaceMethods]> { + [NoSideEffect, Elementwise, Vectorizable, Tensorizable, Scalarizable, + DeclareOpInterfaceMethods]> { let summary = "integer sign extension operation"; let description = [{ The integer sign extension operation takes an integer input of @@ -3220,8 +3219,8 @@ //===----------------------------------------------------------------------===// def TruncateIOp : Std_Op<"trunci", - [NoSideEffect, ElementwiseMappable, - DeclareOpInterfaceMethods,]> { + [NoSideEffect, Elementwise, Vectorizable, Tensorizable, Scalarizable, + DeclareOpInterfaceMethods,]> { let summary = "integer truncation operation"; let description = [{ The integer truncation operation takes an integer input of @@ -3463,8 +3462,8 @@ //===----------------------------------------------------------------------===// def ZeroExtendIOp : Std_Op<"zexti", - [NoSideEffect, ElementwiseMappable, - DeclareOpInterfaceMethods,]> { + [NoSideEffect, Elementwise, Vectorizable, Tensorizable, Scalarizable, + DeclareOpInterfaceMethods,]> { let summary = "integer zero extension operation"; let description = [{ The integer zero extension operation takes an integer input of diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1785,9 +1785,13 @@ // Op can be safely normalized in the presence of MemRefs with // non-identity maps. def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; -// Op can be systematically interconverted between scalar and vector/tensor -// form by mapping elementwise based on the type. -def ElementwiseMappable : NativeOpTrait<"ElementwiseMappable">; +// Op is elementwise on tensor/vector operands and results. +def Elementwise : NativeOpTrait<"Elementwise">; +// Elementwise op can be applied to scalars instead tensor/vector operands. +def Scalarizable : NativeOpTrait<"Scalarizable">; +// Elementwise op can be applied all-tensor/vector operands. +def Vectorizable : NativeOpTrait<"Vectorizable">; +def Tensorizable : NativeOpTrait<"Tensorizable">; // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -282,7 +282,7 @@ LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyNoRegionArguments(Operation *op); -LogicalResult verifyElementwiseMappable(Operation *op); +LogicalResult verifyElementwise(Operation *op); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -1213,93 +1213,130 @@ struct MemRefsNormalizable : public TraitBase {}; -/// This trait tags scalar ops that also can be applied to vectors/tensors, with -/// their semantics on vectors/tensors being elementwise application. +/// This trait tags element-wise ops that operate on scalars, vectors, or +/// tensors. /// /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this -/// trait. In particular, broadcasting behavior is not allowed. This trait -/// describes a set of invariants that allow systematic -/// vectorization/tensorization, and the reverse, scalarization. The properties -/// needed for this also can be used to implement a number of -/// transformations/analyses/interfaces. +/// trait. In particular, broadcasting behavior is not allowed. /// -/// An `ElementwiseMappable` op must satisfy the following properties: +/// An `Elementwise` op must satisfy the following properties: /// -/// 1. If any result is a vector (resp. tensor), then at least one operand must -/// be a vector (resp. tensor). -/// 2. If any operand is a vector (resp. tensor), then there must be at least -/// one result, and all results must be vectors (resp. tensors). -/// 3. The static types of all vector (resp. tensor) operands and results must -/// have the same shape. -/// 4. In the case of tensor operands, the dynamic shapes of all tensor operands -/// must be the same, otherwise the op has undefined behavior. -/// 5. ("systematic scalarization" property) If an op has vector/tensor -/// operands/results, then the same op, with the operand/result types changed to -/// their corresponding element type, shall be a verifier-valid op. -/// 6. The semantics of the op on vectors (resp. tensors) shall be the same as -/// applying the scalarized version of the op for each corresponding element of -/// the vector (resp. tensor) operands in parallel. -/// 7. ("systematic vectorization/tensorization" property) If an op has -/// scalar operands/results, the op shall remain verifier-valid if all scalar -/// operands are replaced with vectors/tensors of the same shape and -/// corresponding element types. +/// 1. If any result is a vector/tensor then at least one operand must also be a +/// vector/tensor. +/// 2. If any operand is a vector/tensor then there must be at least one result +/// and all results must be vectors/tensors. +/// 3. All operand and result vector/tensor types must be of the same shape. The +/// shape may be dynamic in which case the op's behaviour is undefined for +/// non-matching shapes. +/// 4. The operation must be elementwise on its vector/tensor operands and +/// results. When applied to single-element vectors/tensors, the result must +/// be the same per elememnt. /// -/// Together, these properties provide an easy way for scalar operations to -/// conveniently generalize their behavior to vectors/tensors, and systematize -/// conversion between these forms. +/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new +/// interface `ElementwiseTypeInterface` that describes the container types for +/// which the operation is elementwise. /// -/// Examples: -/// ``` -/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32 -/// // Applying the systematic vectorization/tensorization property, this op -/// // must also be valid: -/// %tensor = "std.addf"(%a_tensor, %b_tensor) -/// : (tensor, tensor) -> tensor) +/// Rationale: +/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases +/// of 0 non-scalar operands or 0 non-scalar results, which complicate a +/// generic definition of the iteration space. +/// - 3. guarantees that folding can be done across scalars/vectors/tensors with +/// the same pattern, as otherwise lots of special handling for type +/// mismatches would be needed. +/// - 4. guarantees that no error handling is needed. Higher-level dialects +/// should reify any needed guards or error handling code before lowering to +/// an `Elementwise` op. +template +struct Elementwise : public TraitBase { + static LogicalResult verifyTrait(Operation *op) { + return ::mlir::OpTrait::impl::verifyElementwise(op); + } +}; + +/// This trait tags `Elementwise` operatons that can be systematically +/// scalarized. All vector/tensor operands and results are then replaced by +/// scalars of the respective element type. Semantically, this is the operation +/// on a single element per vector/tensor. /// -/// // These properties generalize well to the cases of non-scalar operands. -/// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val) -/// : (i1, tensor, tensor) -> tensor -/// // Applying the systematic vectorization / tensorization property, this -/// // op must also be valid: -/// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val) -/// : (tensor, tensor, tensor) -/// -> tensor -/// // Applying the systematic scalarization property, this op must also -/// // be valid. -/// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar) -/// : (i1, f32, f32) -> f32 +/// Rationale: +/// Allow to define the vector/tensor semantics of elementwise operations based +/// on scalars. This provides a constructive procedure for IR transformations +/// to, e.g., create scalar loop bodies from tensor ops. +/// +/// Example: +/// ``` +/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val) +/// : (tensor, tensor, tensor) +/// -> tensor /// ``` +/// can be scalarized to /// -/// TODO: Avoid hardcoding vector/tensor, and generalize this to any type -/// implementing a new "ElementwiseMappableTypeInterface" that describes types -/// for which it makes sense to apply a scalar function to each element. +/// ``` +/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar) +/// : (i1, f32, f32) -> f32 +/// ``` +template +struct Scalarizable : public TraitBase { + static LogicalResult verifyTrait(Operation *op) { + assert(op->hasTrait() && + "`Scalarizable` trait is only applicable to `Elementwise` ops."); + return success(); + } +}; + +/// These traits tag `Elementwise` operatons that can be systematically +/// vectorized/tensorized. All scalar operands and results are then replaced by +/// tensors/vectors with the respective element type. Semantically, this is the +/// operation on multiple arguments simultaneously. /// /// Rationale: -/// - 1. and 2. guarantee a well-defined iteration space for 6. -/// - These also exclude the cases of 0 non-scalar operands or 0 non-scalar -/// results, which complicate a generic definition of the iteration space. -/// - 3. guarantees that folding can be done across scalars/vectors/tensors -/// with the same pattern, as otherwise lots of special handling of type -/// mismatches would be needed. -/// - 4. guarantees that no error handling cases need to be considered. -/// - Higher-level dialects should reify any needed guards / error handling -/// code before lowering to an ElementwiseMappable op. -/// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar -/// semantics and provide a constructive procedure for IR transformations -/// to e.g. create scalar loop bodies from tensor ops. -/// - 7. provides the reverse of 5., which when chained together allows -/// reasoning about the relationship between the tensor and vector case. -/// Additionally, it permits reasoning about promoting scalars to -/// vectors/tensors via broadcasting in cases like `%select_scalar_pred` -/// above. +/// Provide the reverse to `Scalarizable` which, when chained together, allows +/// reasoning about the relationship between the tensor and vector case. +/// Additionally, it permits reasoning about promoting scalars to +/// vectors/tensors via broadcasting in cases like `%select_scalar_pred` above. +/// +/// Examples: +/// ``` +/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32 +/// ``` +/// can be tensorized to +/// ``` +/// %tensor = "std.addf"(%a, %b) : (tensor, tensor) +/// -> tensor) +/// ``` +/// +/// ``` +/// %scalar_pred = "std.select"(%pred, %true_val, %false_val) +/// : (i1, tensor, tensor) -> tensor +/// ``` +/// can be tensorized to +/// ``` +/// %tensor_pred = "std.select"(%pred, %true_val, %false_val) +/// : (tensor, tensor, tensor) +/// -> tensor +/// ``` +template +struct Vectorizable : public TraitBase { + static LogicalResult verifyTrait(Operation *op) { + assert(op->hasTrait() && + "`Vectorizable` trait is only applicable to `Elementwise` ops."); + return success(); + } +}; template -struct ElementwiseMappable - : public TraitBase { +struct Tensorizable : public TraitBase { static LogicalResult verifyTrait(Operation *op) { - return ::mlir::OpTrait::impl::verifyElementwiseMappable(op); + assert(op->hasTrait() && + "`Tensorizable` trait is only applicable to `Elementwise` ops."); + return success(); } }; +/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable` +/// provide an easy way for scalar operations to conveniently generalize their +/// behavior to vectors/tensors, and systematize conversion between these forms. +bool hasElementwiseMappableTraits(Operation *op); + } // end namespace OpTrait //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -18,7 +18,7 @@ using namespace mlir; static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { - if (!op->hasTrait()) + if (!OpTrait::hasElementwiseMappableTraits(op)) return false; // TODO: The conversion pattern can be made to work for `any_of` here, but diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -205,7 +205,7 @@ return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)}; // 3. Only ElementwiseMappable are allowed in the generic vectorization. - if (!op->hasTrait()) + if (!OpTrait::hasElementwiseMappableTraits(op)) return VectorizationResult{VectorizationStatus::Failure, nullptr}; // 4. Generic vectorization path for ElementwiseMappable ops. @@ -323,7 +323,7 @@ return false; for (Operation &op : r.front()) { if (!(isa(op) || - op.hasTrait()) || + OpTrait::hasElementwiseMappableTraits(&op)) || llvm::any_of(op.getResultTypes(), [](Type type) { return !type.isIntOrIndexOrFloat(); })) return false; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1090,7 +1090,7 @@ return a.getShape() == b.getShape(); } -LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) { +LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { auto isMappableType = [](Type type) { return type.isa(); }; @@ -1132,6 +1132,11 @@ return success(); } +bool OpTrait::hasElementwiseMappableTraits(Operation *op) { + return op->hasTrait() && op->hasTrait() && + op->hasTrait() && op->hasTrait(); +} + //===----------------------------------------------------------------------===// // BinaryOp implementation //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -379,7 +379,7 @@ } def ElementwiseMappableOp : TEST_Op<"elementwise_mappable", - [ElementwiseMappable]> { + [Elementwise, Scalarizable, Vectorizable, Tensorizable]> { let arguments = (ins Variadic); let results = (outs Variadic); }