diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -67,6 +67,11 @@ let hasFolder = 1; } +// Base class for arithmetic cast operations. +class ArithmeticCastOp traits = []> : + CastOp { +} + // Base class for unary ops. Requires single operand and result. Individual // classes will have `operand` accessor. class UnaryOp traits = []> : @@ -88,7 +93,8 @@ class FloatUnaryOp traits = []> : UnaryOpSameOperandAndResultType])>, + [DeclareOpInterfaceMethods, + ElementwiseMappable])>, Arguments<(ins FloatLike:$operand)>; // Base class for standard arithmetic operations. Requires operands and @@ -96,7 +102,9 @@ // types. Individual classes will have `lhs` and `rhs` accessor to operands. class ArithmeticOp traits = []> : Op { + !listconcat(traits, [NoSideEffect, + SameOperandsAndResultType, + ElementwiseMappable])> { let results = (outs AnyType); @@ -1152,10 +1160,10 @@ } def CmpFOp : Std_Op<"cmpf", - [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, - TypesMatchWith< + [NoSideEffect, SameTypeOperands, + SameOperandsAndResultShape, TypesMatchWith< "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { + "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> { let summary = "floating-point comparison operation"; let description = [{ The `cmpf` operation compares its two operands according to the float @@ -1236,10 +1244,10 @@ } def CmpIOp : Std_Op<"cmpi", - [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, - TypesMatchWith< + [NoSideEffect, SameTypeOperands, + SameOperandsAndResultShape, TypesMatchWith< "result type has i1 element type and same shape as operands", - "lhs", "result", "getI1SameShape($_self)">]> { + "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> { let summary = "integer comparison operation"; let description = [{ The `cmpi` operation is a generic comparison for integer-like types. Its two @@ -1926,7 +1934,7 @@ // FPExtOp //===----------------------------------------------------------------------===// -def FPExtOp : CastOp<"fpext">, Arguments<(ins AnyType:$in)> { +def FPExtOp : ArithmeticCastOp<"fpext">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point to wider floating-point"; let description = [{ Cast a floating-point value to a larger floating-point-typed value. @@ -1947,7 +1955,7 @@ // FPToSIOp //===----------------------------------------------------------------------===// -def FPToSIOp : CastOp<"fptosi">, Arguments<(ins AnyType:$in)> { +def FPToSIOp : ArithmeticCastOp<"fptosi">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point type to integer type"; let description = [{ Cast from a value interpreted as floating-point to the nearest (rounding @@ -1967,7 +1975,7 @@ // FPToUIOp //===----------------------------------------------------------------------===// -def FPToUIOp : CastOp<"fptoui">, Arguments<(ins AnyType:$in)> { +def FPToUIOp : ArithmeticCastOp<"fptoui">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point type to integer type"; let description = [{ Cast from a value interpreted as floating-point to the nearest (rounding @@ -1987,7 +1995,7 @@ // FPTruncOp //===----------------------------------------------------------------------===// -def FPTruncOp : CastOp<"fptrunc">, Arguments<(ins AnyType:$in)> { +def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point to narrower floating-point"; let description = [{ Truncate a floating-point value to a smaller floating-point-typed value. @@ -2039,7 +2047,7 @@ // IndexCastOp //===----------------------------------------------------------------------===// -def IndexCastOp : CastOp<"index_cast">, Arguments<(ins AnyType:$in)> { +def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> { let summary = "cast between index and integer types"; let description = [{ Casts between integer scalars and 'index' scalars. Index is an integer of @@ -2622,7 +2630,8 @@ //===----------------------------------------------------------------------===// def SelectOp : Std_Op<"select", [NoSideEffect, - AllTypesMatch<["true_value", "false_value", "result"]>]> { + AllTypesMatch<["true_value", "false_value", "result"]>, + ElementwiseMappable]> { let summary = "select operation"; let description = [{ The `select` operation chooses one value based on a binary condition @@ -2796,7 +2805,7 @@ //===----------------------------------------------------------------------===// def SignExtendIOp : Std_Op<"sexti", - [NoSideEffect, SameOperandsAndResultShape]> { + [NoSideEffect, SameOperandsAndResultShape, ElementwiseMappable]> { let summary = "integer sign extension operation"; let description = [{ The integer sign extension operation takes an integer input of @@ -2838,7 +2847,7 @@ // SIToFPOp //===----------------------------------------------------------------------===// -def SIToFPOp : CastOp<"sitofp">, Arguments<(ins AnyType:$in)> { +def SIToFPOp : ArithmeticCastOp<"sitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from integer type to floating-point"; let description = [{ Cast from a value interpreted as signed or vector of signed integers to the @@ -3637,7 +3646,9 @@ // TruncateIOp //===----------------------------------------------------------------------===// -def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape]> { +def TruncateIOp : Std_Op<"trunci", [NoSideEffect, + SameOperandsAndResultShape, + ElementwiseMappable]> { let summary = "integer truncation operation"; let description = [{ The integer truncation operation takes an integer input of @@ -3677,7 +3688,7 @@ // UIToFPOp //===----------------------------------------------------------------------===// -def UIToFPOp : CastOp<"uitofp">, Arguments<(ins AnyType:$in)> { +def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from unsigned integer type to floating-point"; let description = [{ Cast from a value interpreted as unsigned integer or vector of unsigned @@ -3904,7 +3915,9 @@ // ZeroExtendIOp //===----------------------------------------------------------------------===// -def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape]> { +def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, + SameOperandsAndResultShape, + ElementwiseMappable]> { let summary = "integer zero extension operation"; let description = [{ The integer zero extension operation takes an integer input of diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1750,6 +1750,9 @@ // Op can be safely normalized in the presence of MemRefs with // non-identity maps. def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">; +// Op can be systematically interconverted between scalar and vector/tensor +// form. +def ElementwiseMappable : NativeOpTrait<"ElementwiseMappable">; // Op's regions have a single block with the specified terminator. class SingleBlockImplicitTerminator diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -423,6 +423,7 @@ LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); LogicalResult verifyNoRegionArguments(Operation *op); +LogicalResult verifyElementwiseMappable(Operation *op); } // namespace impl /// Helper class for implementing traits. Clients are not expected to interact @@ -1304,6 +1305,93 @@ struct MemRefsNormalizable : public TraitBase {}; +/// This trait tags scalar ops that also can be applied to vectors/tensors, with +/// their semantics on vectors/tensors being elementwise application. +/// +/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this +/// trait. In particular, broadcasting behavior is not allowed. This trait +/// describes a set of invariants that allow systematic +/// vectorization/tensorization, and the reverse, scalarization. The properties +/// needed for this also can be used to implement a number of +/// transformations/analyses/interfaces. +/// +/// An `ElementwiseMappable` op must satisfy the following properties: +/// +/// 1. If any result is a vector (resp. tensor), then at least one operand must +/// be a vector (resp. tensor). +/// 2. If any operand is a vector (resp. tensor), then there must be at least +/// one result, and all results must be vectors (resp. tensors). +/// 3. The static types of all vector (resp. tensor) operands and results must +/// have the same shape. +/// 4. In the case of tensor operands, the dynamic shapes of all tensor operands +/// must be the same, otherwise the op has undefined behavior. +/// 5. ("systematic scalarization" property) If an op has vector/tensor +/// operands/results, then the same op, with the operand/result types changed to +/// their corresponding element type, shall be a verifier-valid op. +/// 6. The semantics of the op on vectors (resp. tensors) shall be same as +/// applying the scalarized version of the op for each corresponding element of +/// the vector (resp. tensor) operands in parallel. +/// 7. ("systematic vectorization/tensorization" property) If an op has +/// scalar operands/results, the op shall remain verifier-valid if all scalar +/// operands are replaced with vectors/tensors of the same shape and +/// corresponding element types. +/// +/// Together, these properties provide an easy way for scalar operations to +/// conveniently generalize their behavior to vectors/tensors, and systematize +/// conversion between these forms. +/// +/// Examples: +/// ``` +/// %scalar = "std.addf"(%a, %b) : (f32, f32) -> f32 +/// // Applying the systematic vectorization/tensorization property, this op +/// // must also be valid: +/// %tensor = "std.addf"(%a_tensor, %b_tensor) +/// : (tensor, tensor) -> tensor) +/// +/// // These properties generalize well to the cases of non-scalar operands. +/// %select_scalar_pred = "std.select"(%pred, %true_val, %false_val) +/// : (i1, tensor, tensor) -> tensor +/// // Applying the systematic vectorization / tensorization property, this +/// // op must also be valid: +/// %select_tensor_pred = "std.select"(%pred_tensor, %true_val, %false_val) +/// : (tensor, tensor, tensor) +/// -> tensor +/// // Applying the systematic scalarization property, this op must also +/// // be valid. +/// %select_scalar = "std.select"(%pred, %true_val_scalar, %false_val_scalar) +/// : (i1, f32, f32) -> f32 +/// ``` +/// +/// TODO: Avoid hardcoding vector/tensor, and generalize this to any type +/// implementing a new "ElementwiseMappableTypeInterface" that describes types +/// for which it makes sense to apply a scalar function to each element. +/// +/// Rationale: +/// - 1. and 2. guarantee a well-defined iteration space for 6. +/// - These also exclude the cases of 0 non-scalar operands or 0 non-scalar +/// results, which complicate a generic definition of the iteration space. +/// - 3. guarantees that folding can be done across scalars/vectors/tensors +/// with the same pattern, as otherwise lots of special handling of type +/// mismatches would be needed. +/// - 4. guarantees that no error handling cases need to be considered. +/// - Higher-level dialects should reify any needed guards / error handling +/// code before lowering to an ElementwiseMappable op. +/// - 5. and 6. allow defining the semantics on vectors/tensors via the scalar +/// semantics and provide a constructive procedure for IR transformations +/// to e.g. create scalar loop bodies from tensor ops. +/// - 7. provides the reverse of 5., which when chained together allows +/// reasoning about the relationship between the tensor and vector case. +/// Additionally, it permits reasoning about promoting scalars to +/// vectors/tensors via broadcasting in cases like `%select_scalar_pred` +/// above. +template +struct ElementwiseMappable + : public TraitBase { + static LogicalResult verifyTrait(Operation *op) { + return ::mlir::OpTrait::impl::verifyElementwiseMappable(op); + } +}; + } // end namespace OpTrait //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -1068,6 +1068,57 @@ return success(); } +/// Checks if two ShapedTypes are the same, ignoring the element type. +static bool areSameShapedTypeIgnoringElementType(ShapedType a, ShapedType b) { + if (a.getTypeID() != b.getTypeID()) + return false; + if (!a.hasRank()) + return !b.hasRank(); + return a.getShape() == b.getShape(); +} + +LogicalResult OpTrait::impl::verifyElementwiseMappable(Operation *op) { + auto isMappableType = [](Type type) { + return type.isa() || type.isa(); + }; + auto resultMappableTypes = llvm::to_vector<1>( + llvm::make_filter_range(op->getResultTypes(), isMappableType)); + auto operandMappableTypes = llvm::to_vector<2>( + llvm::make_filter_range(op->getOperandTypes(), isMappableType)); + + // If the op only has scalar operand/result types, then we have nothing to + // check. + if (resultMappableTypes.empty() && operandMappableTypes.empty()) + return success(); + + if (!resultMappableTypes.empty() && operandMappableTypes.empty()) + return op->emitOpError("if a result is non-scalar, then at least one " + "operand must be non-scalar"); + + assert(!operandMappableTypes.empty()); + + if (resultMappableTypes.empty()) + return op->emitOpError("if an operand is non-scalar, then there must be at " + "least one non-scalar result"); + + if (resultMappableTypes.size() != op->getNumResults()) + return op->emitOpError( + "if an operand is non-scalar, then all results must be non-scalar"); + + auto mustMatchType = operandMappableTypes[0].cast(); + for (auto type : + llvm::concat(resultMappableTypes, operandMappableTypes)) { + if (!areSameShapedTypeIgnoringElementType(type.cast(), + mustMatchType)) { + return op->emitOpError() << "all non-scalar operands/results must have " + "the same shape and base type: found " + << type << " and " << mustMatchType; + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // BinaryOp implementation //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: test_index_cast_shape_error func @test_index_cast_shape_error(%arg0 : tensor) -> tensor<2xi64> { - // expected-error @+1 {{requires the same shape for all operands and results}} + // expected-error @+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xi64>' and 'tensor'}} %0 = index_cast %arg0 : tensor to tensor<2xi64> return %0 : tensor<2xi64> } @@ -11,7 +11,7 @@ // CHECK-LABEL: test_index_cast_tensor_error func @test_index_cast_tensor_error(%arg0 : tensor) -> i64 { - // expected-error @+1 {{requires the same shape for all operands and results}} + // expected-error @+1 {{if an operand is non-scalar, then there must be at least one non-scalar result}} %0 = index_cast %arg0 : tensor to i64 return %0 : i64 } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -267,7 +267,7 @@ func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) { ^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>): - // expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'vector<42xi1>', but got 'vector<12xi1>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<42xi32>' and 'vector<12xi1>'}} %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32> } @@ -275,7 +275,7 @@ func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) { ^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>): - // expected-error@+1 {{expected condition type to have the same shape as the result type, expected 'tensor<42xi1>', but got 'tensor<12xi1>'}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<42xi32>' and 'tensor<12xi1>'}} %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> } @@ -685,7 +685,7 @@ // ----- func @fpext_vec(%arg0 : vector<2xf16>) { - // expected-error@+1 {{requires the same shape for all operands and results}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}} %0 = fpext %arg0 : vector<2xf16> to vector<3xf32> return } @@ -757,7 +757,7 @@ // ----- func @fptrunc_vec(%arg0 : vector<2xf16>) { - // expected-error@+1 {{requires the same shape for all operands and results}} + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}} %0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32> return } diff --git a/mlir/test/IR/traits.mlir b/mlir/test/IR/traits.mlir --- a/mlir/test/IR/traits.mlir +++ b/mlir/test/IR/traits.mlir @@ -166,6 +166,84 @@ // ----- +func @failedElementwiseMappable_different_rankedness(%arg0: tensor, %arg1: tensor<*xf32>) { + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<*xf32>' and 'tensor'}} + %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<*xf32>) -> tensor<*xf32> +} + +// ----- + +func @failedElementwiseMappable_different_rank(%arg0: tensor, %arg1: tensor) { + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor' and 'tensor'}} + %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor) -> tensor +} + +// ----- + +func @failedElementwiseMappable_different_shape(%arg0: tensor, %arg1: tensor<5xf32>) { + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<5xf32>' and 'tensor'}} + %0 = "test.elementwise_mappable"(%arg0, %arg1) : (tensor, tensor<5xf32>) -> tensor +} + +// ----- + +func @failedElementwiseMappable_different_base_type(%arg0: vector<2xf32>, %arg1: tensor<2xf32>) { + // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<2xf32>' and 'vector<2xf32>'}} + %0 = "test.elementwise_mappable"(%arg0, %arg1) : (vector<2xf32>, tensor<2xf32>) -> tensor<2xf32> + return +} + +// ----- + +func @failedElementwiseMappable_non_scalar_output(%arg0: vector<2xf32>) { + // expected-error@+1 {{if an operand is non-scalar, then there must be at least one non-scalar result}} + %0 = "test.elementwise_mappable"(%arg0) : (vector<2xf32>) -> f32 + return +} + +// ----- + +func @failedElementwiseMappable_non_scalar_result_all_scalar_input(%arg0: f32) { + // expected-error@+1 {{if a result is non-scalar, then at least one operand must be non-scalar}} + %0 = "test.elementwise_mappable"(%arg0) : (f32) -> tensor + return +} + +// ----- + +func @failedElementwiseMappable_mixed_scalar_non_scalar_results(%arg0: tensor<10xf32>) { + // expected-error@+1 {{if an operand is non-scalar, then all results must be non-scalar}} + %0, %1 = "test.elementwise_mappable"(%arg0) : (tensor<10xf32>) -> (f32, tensor<10xf32>) + return +} + +// ----- + +func @failedElementwiseMappable_zero_results(%arg0: tensor<10xf32>) { + // expected-error@+1 {{if an operand is non-scalar, then there must be at least one non-scalar result}} + "test.elementwise_mappable"(%arg0) : (tensor<10xf32>) -> () + return +} + +// ----- + +func @failedElementwiseMappable_zero_operands() { + // expected-error@+1 {{if a result is non-scalar, then at least one operand must be non-scalar}} + "test.elementwise_mappable"() : () -> (tensor<6xf32>) + return +} + +// ----- + +func @succeededElementwiseMappable(%arg0: vector<2xf32>) { + // Check that varying element types are allowed. + // CHECK: test.elementwise_mappable + %0 = "test.elementwise_mappable"(%arg0) : (vector<2xf32>) -> vector<2xf16> + return +} + +// ----- + func @failedHasParent_wrong_parent() { "some.op"() ({ // expected-error@+1 {{'test.child' op expects parent op 'test.parent'}} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -370,6 +370,12 @@ let results = (outs Variadic); } +def ElementwiseMappableOp : TEST_Op<"elementwise_mappable", + [ElementwiseMappable]> { + let arguments = (ins Variadic); + let results = (outs Variadic); +} + def ArgAndResHaveFixedElementTypesOp : TEST_Op<"arg_and_res_have_fixed_element_types", [PredOpTrait<"fixed type combination",