diff --git a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td --- a/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td +++ b/mlir/include/mlir/Dialect/Complex/IR/ComplexOps.td @@ -20,11 +20,9 @@ // floating-point element type. These operations take two operands and return // one result, all of which must be complex numbers of the same type. class ComplexArithmeticOp traits = []> : - Complex_Op, - ElementwiseMappable])> { + Complex_Op] # + ElementwiseMappable.traits> { let arguments = (ins Complex:$lhs, Complex:$rhs); let results = (outs Complex:$result); let assemblyFormat = "$lhs `,` $rhs attr-dict `:` type($result)"; diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -17,10 +17,9 @@ : Op; class FloatUnaryOp traits = []> : - MathOp, - ElementwiseMappable, - SameOperandsAndResultType]> { + MathOp, + SameOperandsAndResultType] # ElementwiseMappable.traits> { let arguments = (ins FloatLike:$operand); let results = (outs FloatLike:$result); @@ -29,10 +28,9 @@ } class FloatBinaryOp traits = []> : - MathOp, - ElementwiseMappable, - SameOperandsAndResultType]> { + MathOp, + SameOperandsAndResultType] # ElementwiseMappable.traits> { let arguments = (ins FloatLike:$lhs, FloatLike:$rhs); let results = (outs FloatLike:$result); diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -71,9 +71,9 @@ // Base class for arithmetic cast operations. class ArithmeticCastOp traits = []> : - CastOp])> { + CastOp] # + ElementwiseMappable.traits> { } // Base class for unary ops. Requires single operand and result. Individual @@ -95,21 +95,18 @@ } class FloatUnaryOp traits = []> : - UnaryOpSameOperandAndResultType, - ElementwiseMappable])>, - Arguments<(ins FloatLike:$operand)>; + UnaryOpSameOperandAndResultType] # + ElementwiseMappable.traits>, Arguments<(ins FloatLike:$operand)>; // Base class for standard arithmetic operations. Requires operands and // results to be of the same type, but does not constrain them to specific // types. class ArithmeticOp traits = []> : - Op, - ElementwiseMappable])> { + Op] # + ElementwiseMappable.traits> { let results = (outs AnyType:$result); @@ -930,12 +927,10 @@ let cppNamespace = "::mlir"; } -def CmpFOp : Std_Op<"cmpf", - [NoSideEffect, SameTypeOperands, ElementwiseMappable, - DeclareOpInterfaceMethods, - TypesMatchWith< - "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { +def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, + DeclareOpInterfaceMethods, TypesMatchWith< + "result type has i1 element type and same shape as operands", + "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> { let summary = "floating-point comparison operation"; let description = [{ The `cmpf` operation compares its two operands according to the float @@ -1015,12 +1010,10 @@ let cppNamespace = "::mlir"; } -def CmpIOp : Std_Op<"cmpi", - [NoSideEffect, SameTypeOperands, ElementwiseMappable, - DeclareOpInterfaceMethods, - TypesMatchWith< - "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { +def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, + DeclareOpInterfaceMethods, TypesMatchWith< + "result type has i1 element type and same shape as operands", + "lhs", "result", "getI1SameShape($_self)">] # ElementwiseMappable.traits> { let summary = "integer comparison operation"; let description = [{ The `cmpi` operation is a generic comparison for integer-like types. Its two @@ -2160,8 +2153,9 @@ //===----------------------------------------------------------------------===// def SelectOp : Std_Op<"select", [NoSideEffect, - AllTypesMatch<["true_value", "false_value", "result"]>, - ElementwiseMappable, DeclareOpInterfaceMethods]> { + AllTypesMatch<["true_value", "false_value", "result"]>, + DeclareOpInterfaceMethods] # + ElementwiseMappable.traits> { let summary = "select operation"; let description = [{ The `select` operation chooses one value based on a binary condition @@ -2391,9 +2385,9 @@ // SignExtendIOp //===----------------------------------------------------------------------===// -def SignExtendIOp : Std_Op<"sexti", - [NoSideEffect, ElementwiseMappable, - DeclareOpInterfaceMethods]> { +def SignExtendIOp : Std_Op<"sexti", [NoSideEffect, + DeclareOpInterfaceMethods] # + ElementwiseMappable.traits> { let summary = "integer sign extension operation"; let description = [{ The integer sign extension operation takes an integer input of @@ -3219,9 +3213,9 @@ // TruncateIOp //===----------------------------------------------------------------------===// -def TruncateIOp : Std_Op<"trunci", - [NoSideEffect, ElementwiseMappable, - DeclareOpInterfaceMethods,]> { +def TruncateIOp : Std_Op<"trunci", [NoSideEffect, + DeclareOpInterfaceMethods] # + ElementwiseMappable.traits> { let summary = "integer truncation operation"; let description = [{ The integer truncation operation takes an integer input of @@ -3462,9 +3456,9 @@ // ZeroExtendIOp //===----------------------------------------------------------------------===// -def ZeroExtendIOp : Std_Op<"zexti", - [NoSideEffect, ElementwiseMappable, - DeclareOpInterfaceMethods,]> { +def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, + DeclareOpInterfaceMethods] # + ElementwiseMappable.traits> { let summary = "integer zero extension operation"; let description = [{ The integer zero extension operation takes an integer input of diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1785,9 +1785,24 @@ // Op can be safely normalized in the presence of MemRefs with // non-identity maps. def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; -// Op can be systematically interconverted between scalar and vector/tensor -// form by mapping elementwise based on the type. -def ElementwiseMappable : NativeOpTrait<"ElementwiseMappable">; +// Op is elementwise on tensor/vector operands and results. +def Elementwise : NativeOpTrait<"Elementwise">; +// Elementwise op can be applied to scalars instead tensor/vector operands. +def Scalarizable : NativeOpTrait<"Scalarizable">; +// Elementwise op can be applied all-tensor/vector operands. +def Vectorizable : NativeOpTrait<"Vectorizable">; +def Tensorizable : NativeOpTrait<"Tensorizable">; + +// Group together `Elementwise`, `Scalarizable`, `Vectorizable`, and +// `Vectorizable` for convenience. +def ElementwiseMappable { + list traits = [ + Elementwise, + Scalarizable, + Vectorizable, + Tensorizable, + ]; +} // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -282,7 +282,7 @@ LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyNoRegionArguments(Operation *op); -LogicalResult verifyElementwiseMappable(Operation *op); +LogicalResult verifyElementwise(Operation *op); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -1213,93 +1213,133 @@ struct MemRefsNormalizable : public TraitBase {}; -/// This trait tags scalar ops that also can be applied to vectors/tensors, with -/// their semantics on vectors/tensors being elementwise application. +/// This trait tags element-wise ops that operate on scalars, vectors, or +/// tensors. /// /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this -/// trait. In particular, broadcasting behavior is not allowed. This trait -/// describes a set of invariants that allow systematic -/// vectorization/tensorization, and the reverse, scalarization. The properties -/// needed for this also can be used to implement a number of -/// transformations/analyses/interfaces. +/// trait. In particular, broadcasting behavior is not allowed. /// -/// An `ElementwiseMappable` op must satisfy the following properties: +/// An `Elementwise` op must satisfy the following properties: /// -/// 1. If any result is a vector (resp. tensor), then at least one operand must -/// be a vector (resp. tensor). -/// 2. If any operand is a vector (resp. tensor), then there must be at least -/// one result, and all results must be vectors (resp. tensors). -/// 3. The static types of all vector (resp. tensor) operands and results must -/// have the same shape. -/// 4. In the case of tensor operands, the dynamic shapes of all tensor operands -/// must be the same, otherwise the op has undefined behavior. -/// 5. ("systematic scalarization" property) If an op has vector/tensor -/// operands/results, then the same op, with the operand/result types changed to -/// their corresponding element type, shall be a verifier-valid op. -/// 6. The semantics of the op on vectors (resp. tensors) shall be the same as -/// applying the scalarized version of the op for each corresponding element of -/// the vector (resp. tensor) operands in parallel. -/// 7. ("systematic vectorization/tensorization" property) If an op has -/// scalar operands/results, the op shall remain verifier-valid if all scalar -/// operands are replaced with vectors/tensors of the same shape and -/// corresponding element types. +/// 1. If any result is a vector/tensor then at least one operand must also be a +/// vector/tensor. +/// 2. If any operand is a vector/tensor then there must be at least one result +/// and all results must be vectors/tensors. +/// 3. All operand and result vector/tensor types must be of the same shape. The +/// shape may be dynamic in which case the op's behaviour is undefined for +/// non-matching shapes. +/// 4. The operation must be elementwise on its vector/tensor operands and +/// results. When applied to single-element vectors/tensors, the result must +/// be the same per elememnt. /// -/// Together, these properties provide an easy way for scalar operations to -/// conveniently generalize their behavior to vectors/tensors, and systematize -/// conversion between these forms. +/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new +/// interface `ElementwiseTypeInterface` that describes the container types for +/// which the operation is elementwise. /// -/// Examples: -/// ``` -/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32 -/// // Applying the systematic vectorization/tensorization property, this op -/// // must also be valid: -/// %tensor = "std.addf"(%a_tensor, %b_tensor) -/// : (tensor, tensor) -> tensor) +/// Rationale: +/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases +/// of 0 non-scalar operands or 0 non-scalar results, which complicate a +/// generic definition of the iteration space. +/// - 3. guarantees that folding can be done across scalars/vectors/tensors with +/// the same pattern, as otherwise lots of special handling for type +/// mismatches would be needed. +/// - 4. guarantees that no error handling is needed. Higher-level dialects +/// should reify any needed guards or error handling code before lowering to +/// an `Elementwise` op. +template +struct Elementwise : public TraitBase { + static LogicalResult verifyTrait(Operation *op) { + return ::mlir::OpTrait::impl::verifyElementwise(op); + } +}; + +/// This trait tags `Elementwise` operatons that can be systematically +/// scalarized. All vector/tensor operands and results are then replaced by +/// scalars of the respective element type. Semantically, this is the operation +/// on a single element per vector/tensor. /// -/// // These properties generalize well to the cases of non-scalar operands. -/// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val) -/// : (i1, tensor, tensor) -> tensor -/// // Applying the systematic vectorization / tensorization property, this -/// // op must also be valid: -/// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val) -/// : (tensor, tensor, tensor) -/// -> tensor -/// // Applying the systematic scalarization property, this op must also -/// // be valid. -/// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar) -/// : (i1, f32, f32) -> f32 +/// Rationale: +/// Allow to define the vector/tensor semantics of elementwise operations based +/// on scalars. This provides a constructive procedure for IR transformations +/// to, e.g., create scalar loop bodies from tensor ops. +/// +/// Example: +/// ``` +/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val) +/// : (tensor, tensor, tensor) +/// -> tensor /// ``` +/// can be scalarized to /// -/// TODO: Avoid hardcoding vector/tensor, and generalize this to any type -/// implementing a new "ElementwiseMappableTypeInterface" that describes types -/// for which it makes sense to apply a scalar function to each element. +/// ``` +/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar) +/// : (i1, f32, f32) -> f32 +/// ``` +template +struct Scalarizable : public TraitBase { + static LogicalResult verifyTrait(Operation *op) { + static_assert( + ConcreteType::template hasTrait(), + "`Scalarizable` trait is only applicable to `Elementwise` ops."); + return success(); + } +}; + +/// These traits tag `Elementwise` operatons that can be systematically +/// vectorized/tensorized. All scalar operands and results are then replaced by +/// tensors/vectors with the respective element type. Semantically, this is the +/// operation on multiple arguments simultaneously. /// /// Rationale: -/// - 1. and 2. guarantee a well-defined iteration space for 6. -/// - These also exclude the cases of 0 non-scalar operands or 0 non-scalar -/// results, which complicate a generic definition of the iteration space. -/// - 3. guarantees that folding can be done across scalars/vectors/tensors -/// with the same pattern, as otherwise lots of special handling of type -/// mismatches would be needed. -/// - 4. guarantees that no error handling cases need to be considered. -/// - Higher-level dialects should reify any needed guards / error handling -/// code before lowering to an ElementwiseMappable op. -/// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar -/// semantics and provide a constructive procedure for IR transformations -/// to e.g. create scalar loop bodies from tensor ops. -/// - 7. provides the reverse of 5., which when chained together allows -/// reasoning about the relationship between the tensor and vector case. -/// Additionally, it permits reasoning about promoting scalars to -/// vectors/tensors via broadcasting in cases like `%select_scalar_pred` -/// above. +/// Provide the reverse to `Scalarizable` which, when chained together, allows +/// reasoning about the relationship between the tensor and vector case. +/// Additionally, it permits reasoning about promoting scalars to +/// vectors/tensors via broadcasting in cases like `%select_scalar_pred` above. +/// +/// Examples: +/// ``` +/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32 +/// ``` +/// can be tensorized to +/// ``` +/// %tensor = "std.addf"(%a, %b) : (tensor, tensor) +/// -> tensor) +/// ``` +/// +/// ``` +/// %scalar_pred = "std.select"(%pred, %true_val, %false_val) +/// : (i1, tensor, tensor) -> tensor +/// ``` +/// can be tensorized to +/// ``` +/// %tensor_pred = "std.select"(%pred, %true_val, %false_val) +/// : (tensor, tensor, tensor) +/// -> tensor +/// ``` +template +struct Vectorizable : public TraitBase { + static LogicalResult verifyTrait(Operation *op) { + static_assert( + ConcreteType::template hasTrait(), + "`Vectorizable` trait is only applicable to `Elementwise` ops."); + return success(); + } +}; template -struct ElementwiseMappable - : public TraitBase { +struct Tensorizable : public TraitBase { static LogicalResult verifyTrait(Operation *op) { - return ::mlir::OpTrait::impl::verifyElementwiseMappable(op); + static_assert( + ConcreteType::template hasTrait(), + "`Tensorizable` trait is only applicable to `Elementwise` ops."); + return success(); } }; +/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable` +/// provide an easy way for scalar operations to conveniently generalize their +/// behavior to vectors/tensors, and systematize conversion between these forms. +bool hasElementwiseMappableTraits(Operation *op); + } // end namespace OpTrait //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -18,7 +18,7 @@ using namespace mlir; static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { - if (!op->hasTrait()) + if (!OpTrait::hasElementwiseMappableTraits(op)) return false; // TODO: The conversion pattern can be made to work for `any_of` here, but diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -205,7 +205,7 @@ return VectorizationResult{VectorizationStatus::NewOp, builder.clone(*op)}; // 3. Only ElementwiseMappable are allowed in the generic vectorization. - if (!op->hasTrait()) + if (!OpTrait::hasElementwiseMappableTraits(op)) return VectorizationResult{VectorizationStatus::Failure, nullptr}; // 4. Generic vectorization path for ElementwiseMappable ops. @@ -323,7 +323,7 @@ return false; for (Operation &op : r.front()) { if (!(isa(op) || - op.hasTrait()) || + OpTrait::hasElementwiseMappableTraits(&op)) || llvm::any_of(op.getResultTypes(), [](Type type) { return !type.isIntOrIndexOrFloat(); })) return false; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1085,7 +1085,7 @@ return a.getShape() == b.getShape(); } -LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) { +LogicalResult OpTrait::impl::verifyElementwise(Operation *op) { auto isMappableType = [](Type type) { return type.isa(); }; @@ -1127,6 +1127,11 @@ return success(); } +bool OpTrait::hasElementwiseMappableTraits(Operation *op) { + return op->hasTrait() && op->hasTrait() && + op->hasTrait() && op->hasTrait(); +} + //===----------------------------------------------------------------------===// // BinaryOp implementation //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -356,7 +356,8 @@ let results = (outs AnyType); } -def SameOperandAndResultElementTypeOp : TEST_Op<"same_operand_and_result_element_type", +def SameOperandAndResultElementTypeOp : + TEST_Op<"same_operand_and_result_element_type", [SameOperandsAndResultElementType]> { let arguments = (ins Variadic); let results = (outs Variadic); @@ -379,7 +380,7 @@ } def ElementwiseMappableOp : TEST_Op<"elementwise_mappable", - [ElementwiseMappable]> { + ElementwiseMappable.traits> { let arguments = (ins Variadic); let results = (outs Variadic); }