diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -56,7 +56,8 @@ """ domain(D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast_signed( + U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n]) ``` Here we have a simple type polymorphic contraction that takes arguments `A` and @@ -160,7 +161,7 @@ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U, + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) ``` @@ -182,8 +183,8 @@ * `BinaryFn.add(a, b)` (also via overloading the binary `+` operator) * `BinaryFn.mul(a, b)` (also via overloading the binary `*` operator) -* `BinaryFn.max(a, b)` -* `BinaryFn.min(a, b)` +* `BinaryFn.max_signed(a, b)` +* `BinaryFn.min_signed(a, b)` * `BinaryFn.sub(a, b)` (also via overloading the binary `-` operator) * `BinaryFn.max_unsigned(a, b)` * `BinaryFn.min_unsigned(a, b)` @@ -198,8 +199,8 @@ * `ReduceFn.add` (also overloading the inplace `+=` on a LHS) * `ReduceFn.mul` -* `ReduceFn.max` -* `ReduceFn.min` +* `ReduceFn.max_signed` +* `ReduceFn.min_signed` * `ReduceFn.max_unsigned` * `ReduceFn.min_unsigned` @@ -208,11 +209,11 @@ Additionally, type conversion functions cast an operand to a target type: -* `TypeFn.cast(TypeVar, operand)` +* `TypeFn.cast_signed(TypeVar, operand)` * `TypeFn.cast_unsigned(TypeVar, operand)` As the integer types are signless, signedness is implement by different -functions that treat integers as signed (`TypeFn.cast`) or unsigned +functions that treat integers as signed (`TypeFn.cast_signed`) or unsigned (`TypeFn.cast_unsigned`) values. There are also special forms: @@ -235,12 +236,12 @@ rhs=TensorDef(T2), O=TensorDef(U, output=True), fun=BinaryFnAttrDef(default=BinaryFn.add), - cast=TypeFnAttrDef(default=TypeFn.cast)): + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) ``` The `fun` and `cast` function attributes by default are aliases for their -default values `BinaryFn.add` and `TypeFn.cast`, respectively. When +default values `BinaryFn.add` and `TypeFn.cast_signed`, respectively. When instantiating the operation, the function attributes may be set to other functions using optional named arguments: @@ -265,26 +266,27 @@ computations with a type that is independent of the input and output types. For example, parts of floating point computation may require double precision arithmetic despite all inputs and outputs being single precision values. -Assignment expressions with no `TypeFn.cast` calls will generally require +Assignment expressions with no `TypeFn.cast_signed` calls will generally require uniform types throughout and will fail to verify if violated. The presence of a -`TypeFn.cast` or `TypeFn.cast_unsigned` allows for a limited form of numeric -type conversion between element types that can be derived from inputs and -outputs (and in the future, attributes). `TypeFn.cast` calls with a `TypeVar` -first argument are emitted as `type_fn` primitives in the YAML definition. +`TypeFn.cast_signed` or `TypeFn.cast_unsigned` allows for a limited form of +numeric type conversion between element types that can be derived from inputs +and outputs (and in the future, attributes). `TypeFn.cast_signed` calls with a +`TypeVar` first argument are emitted as `type_fn` primitives in the YAML +definition. Casting will perform `int<->float` and `index->int` type conversions and will perform any necessary extension or truncation within the type family. The integer types themselves are signless and signedness is implemented by -functions/operations. The `TypeFn.cast` function treats all integers as signed, -while `TypeFn.cast_unsigned` treats them as unsigned. +functions/operations. The `TypeFn.cast_signed` function treats all integers as +signed, while `TypeFn.cast_unsigned` treats them as unsigned. The following examples illustrate the lowering of signed and unsigned functions: -* cast(I32 -> I64) -> `arith.ExtSIOp` -* cast(F32 -> I32) -> `arith.FPToSIOp` +* cast_signed(I32 -> I64) -> `arith.ExtSIOp` +* cast_signed(F32 -> I32) -> `arith.FPToSIOp` * cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` * cast_unsigned(F32 -> I32) -> `arith.FPToUIOp` -* max -> `arith.MaxSIOp` +* max_signed -> `arith.MaxSIOp` * max_unsinged -> `arith.MaxUIOp` Not all functions are applicable for all numeric types, and on mismatch, op @@ -302,7 +304,7 @@ @linalg_structured_op def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): - O[None] = TypeFn.cast(U, value) + O[None] = TypeFn.cast_signed(U, value) ``` The operation sets the elements of the output tensor `O` to `value`. All diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -68,10 +68,10 @@ } def BinaryFn : I32EnumAttr<"BinaryFn", "", [ I32EnumAttrCase<"add", 0>, - I32EnumAttrCase<"mul", 1>, - I32EnumAttrCase<"max", 2>, - I32EnumAttrCase<"min", 3>, - I32EnumAttrCase<"sub", 4>, + I32EnumAttrCase<"sub", 1>, + I32EnumAttrCase<"mul", 2>, + I32EnumAttrCase<"max_signed", 3>, + I32EnumAttrCase<"min_signed", 4>, I32EnumAttrCase<"max_unsigned", 5>, I32EnumAttrCase<"min_unsigned", 6> ]> { @@ -79,7 +79,7 @@ let cppNamespace = "::mlir::linalg"; } def TypeFn : I32EnumAttr<"TypeFn", "", [ - I32EnumAttrCase<"cast", 0>, + I32EnumAttrCase<"cast_signed", 0>, I32EnumAttrCase<"cast_unsigned", 1> ]> { let genSpecializedAttr = 0; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -28,7 +28,7 @@ - !LinalgOperandDefConfig name: cast kind: type_fn_attr - default_fn: cast + default_fn: cast_signed indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<() -> ()> @@ -83,7 +83,7 @@ - !LinalgOperandDefConfig name: cast kind: type_fn_attr - default_fn: cast + default_fn: cast_signed indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<() -> ()> @@ -145,7 +145,7 @@ - !LinalgOperandDefConfig name: cast kind: type_fn_attr - default_fn: cast + default_fn: cast_signed indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)> @@ -324,7 +324,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -332,7 +332,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -345,7 +345,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -353,7 +353,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -424,7 +424,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: AccumType operands: - !ScalarExpression @@ -432,7 +432,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: AccumType operands: - !ScalarExpression @@ -493,7 +493,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -501,7 +501,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -577,7 +577,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -585,7 +585,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -598,7 +598,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -606,7 +606,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -665,7 +665,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -673,7 +673,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -732,7 +732,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -740,7 +740,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -800,7 +800,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -808,7 +808,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -866,7 +866,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -874,7 +874,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -933,7 +933,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -941,7 +941,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1002,7 +1002,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1010,7 +1010,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1074,7 +1074,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1082,7 +1082,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1158,7 +1158,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1166,7 +1166,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1256,7 +1256,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1264,7 +1264,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1372,7 +1372,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1380,7 +1380,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1393,7 +1393,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1401,7 +1401,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1491,7 +1491,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1499,7 +1499,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1591,7 +1591,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1599,7 +1599,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1674,7 +1674,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1682,7 +1682,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1767,7 +1767,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1775,7 +1775,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1876,7 +1876,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1884,7 +1884,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1897,7 +1897,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1905,7 +1905,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1991,7 +1991,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -1999,7 +1999,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2102,7 +2102,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2110,7 +2110,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2123,7 +2123,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2131,7 +2131,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2210,7 +2210,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2282,14 +2282,14 @@ value: !ScalarExpression scalar_fn: kind: binary - fn_name: max + fn_name: max_signed operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2440,14 +2440,14 @@ value: !ScalarExpression scalar_fn: kind: binary - fn_name: max + fn_name: max_signed operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2519,14 +2519,14 @@ value: !ScalarExpression scalar_fn: kind: binary - fn_name: min + fn_name: min_signed operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2690,7 +2690,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2768,14 +2768,14 @@ value: !ScalarExpression scalar_fn: kind: binary - fn_name: max + fn_name: max_signed operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2853,14 +2853,14 @@ value: !ScalarExpression scalar_fn: kind: binary - fn_name: min + fn_name: min_signed operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2897,7 +2897,7 @@ value: !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -2950,7 +2950,7 @@ value: !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: T operands: - !ScalarExpression @@ -2971,7 +2971,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: F64 operands: - !ScalarExpression @@ -2979,7 +2979,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: F64 operands: - !ScalarExpression @@ -3000,7 +3000,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: I32 operands: - !ScalarExpression @@ -3023,7 +3023,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: I32 operands: - !ScalarExpression @@ -3033,7 +3033,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: I32 operands: - !ScalarExpression @@ -3041,7 +3041,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: I32 operands: - !ScalarExpression @@ -3049,7 +3049,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: I32 operands: - !ScalarExpression @@ -3057,7 +3057,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: I32 operands: - !ScalarExpression @@ -3079,7 +3079,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: F64 operands: - !ScalarExpression @@ -3130,7 +3130,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression @@ -3143,7 +3143,7 @@ - !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -160,22 +160,22 @@ if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::sub: + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::mul: if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); - case BinaryFn::max: + case BinaryFn::max_signed: if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); - case BinaryFn::min: + case BinaryFn::min_signed: if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); return builder.create(arg0.getLoc(), arg0, arg1); - case BinaryFn::sub: - if (allFloatingPoint) - return builder.create(arg0.getLoc(), arg0, arg1); - return builder.create(arg0.getLoc(), arg0, arg1); case BinaryFn::max_unsigned: if (allFloatingPoint) return builder.create(arg0.getLoc(), arg0, arg1); @@ -191,7 +191,7 @@ // Build the type functions defined by OpDSL. Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { switch (typeFn) { - case TypeFn::cast: + case TypeFn::cast_signed: return cast(toType, operand, false); case TypeFn::cast_unsigned: return cast(toType, operand, true); diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -305,10 +305,10 @@ - max_unsinged -> `arith.MaxUIOp` """ add = BinaryFnType("add") - mul = BinaryFnType("mul") - max = BinaryFnType("max") - min = BinaryFnType("min") sub = BinaryFnType("sub") + mul = BinaryFnType("mul") + max_signed = BinaryFnType("max_signed") + min_signed = BinaryFnType("min_signed") max_unsigned = BinaryFnType("max_unsigned") min_unsigned = BinaryFnType("min_unsigned") @@ -334,14 +334,14 @@ """Type conversion function namespace. As the integer types are signless, signedness is implement by different cast - functions that treat integers as signed (`cast`) or unsigned + functions that treat integers as signed (`cast_signed`) or unsigned (`cast_unsigned`) values. Examples: - - cast(I32 -> I64) -> `arith.ExtSIOp` + - cast_signed(I32 -> I64) -> `arith.ExtSIOp` - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` """ - cast = TypeFnType("cast") + cast_signed = TypeFnType("cast_signed") cast_unsigned = TypeFnType("cast_unsigned") @@ -389,8 +389,8 @@ class ReduceFn: add = ReduceFnType(BinaryFn.add) mul = ReduceFnType(BinaryFn.mul) - max = ReduceFnType(BinaryFn.max) - min = ReduceFnType(BinaryFn.min) + max_signed = ReduceFnType(BinaryFn.max_signed) + min_signed = ReduceFnType(BinaryFn.min_signed) max_unsigned = ReduceFnType(BinaryFn.max_unsigned) min_unsigned = ReduceFnType(BinaryFn.min_unsigned) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -370,7 +370,7 @@ raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") - def _type_cast(self, type_var_name: str, operand: Value) -> Value: + def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value: return self._cast(type_var_name, operand, False) def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: @@ -407,7 +407,7 @@ return arith.MulIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") - def _binary_max(self, lhs: Value, rhs: Value) -> Value: + def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): @@ -422,7 +422,7 @@ raise NotImplementedError( "Unsupported 'max_unsigned' operands: {lhs}, {rhs}") - def _binary_min(self, lhs: Value, rhs: Value) -> Value: + def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -11,7 +11,7 @@ I=TensorDef(T1), O=TensorDef(U, output=True), fun=UnaryFnAttrDef(default=UnaryFn.exp), - cast=TypeFnAttrDef(default=TypeFn.cast)): + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): """Applies the unary function fun elementwise. Numeric casting is performed on the input operand, promoting it to the same @@ -26,7 +26,7 @@ rhs=TensorDef(T2), O=TensorDef(U, output=True), fun=BinaryFnAttrDef(default=BinaryFn.add), - cast=TypeFnAttrDef(default=TypeFn.cast)): + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): """Applies the binary function fun elementwise. Numeric casting is performed on the input operand, promoting it to the same @@ -40,7 +40,7 @@ A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast)): + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -82,8 +82,9 @@ matmul. """ domain(D.m, D.n, D.k) - C[D.m, D.n] += (TypeFn.cast(U, A[D.m, D.k]) - TypeFn.cast(U, AZp)) * ( - TypeFn.cast(U, B[D.k, D.n]) - TypeFn.cast(U, BZp)) + C[D.m, D.n] += ( + TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( + TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp)) @linalg_structured_op @@ -103,8 +104,8 @@ """ domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) implements(ContractionOpInterface) - accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast( - TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast( + accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( + TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed( TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) @@ -121,7 +122,8 @@ domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.m, - D.n] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k, D.n]) + D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k, D.n]) @linalg_structured_op @@ -139,9 +141,9 @@ matmul. """ domain(D.b, D.m, D.n, D.k) - C[D.b, D.m, - D.n] += (TypeFn.cast(U, A[D.b, D.m, D.k]) - TypeFn.cast(U, AZp)) * ( - TypeFn.cast(U, B[D.b, D.k, D.n]) - TypeFn.cast(U, BZp)) + C[D.b, D.m, D.n] += ( + TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( + TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) @linalg_structured_op @@ -156,7 +158,7 @@ """ domain(D.m, D.n) implements(ContractionOpInterface) - x[D.m] += TypeFn.cast(U, A[D.m, D.n]) * TypeFn.cast(U, y[D.n]) + x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n]) @linalg_structured_op @@ -171,7 +173,7 @@ """ domain(D.n, D.m) implements(ContractionOpInterface) - x[D.n] += TypeFn.cast(U, y[D.m]) * TypeFn.cast(U, A[D.m, D.n]) + x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n]) @linalg_structured_op @@ -186,7 +188,8 @@ """ domain(D.b, D.m, D.k) implements(ContractionOpInterface) - C[D.b, D.m] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k]) + C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k]) @linalg_structured_op @@ -198,7 +201,7 @@ them to the same data type as the accumulator/output. """ implements(ContractionOpInterface) - C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m]) + C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) @linalg_structured_op @@ -213,7 +216,8 @@ """ implements(ConvolutionOpInterface) domain(D.ow, D.kw) - O[D.ow] += TypeFn.cast(U, I[D.ow + D.kw]) * TypeFn.cast(U, K[D.kw]) + O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed( + U, K[D.kw]) @linalg_structured_op @@ -228,8 +232,8 @@ """ implements(ConvolutionOpInterface) domain(D.oh, D.ow, D.kh, D.kw) - O[D.oh, D.ow] += TypeFn.cast(U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast( - U, K[D.kh, D.kw]) + O[D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw]) @linalg_structured_op @@ -244,9 +248,9 @@ """ implements(ConvolutionOpInterface) domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) - O[D.od, D.oh, - D.ow] += TypeFn.cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow + - D.kw]) * TypeFn.cast(U, K[D.kd, D.kh, D.kw]) + O[D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed( + U, K[D.kd, D.kh, D.kw]) @linalg_structured_op @@ -264,8 +268,8 @@ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.f, D.kw, D.c) O[D.n, D.ow, - D.f] += TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast(U, K[D.kw, D.c, D.f]) + D.f] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, + D.c]) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f]) @linalg_structured_op @@ -287,9 +291,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.f] += TypeFn.cast( + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) + D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) @linalg_structured_op @@ -315,10 +319,11 @@ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, - D.f] += (TypeFn.cast( + D.f] += (TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) - - TypeFn.cast(U, IZp)) * ( - TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast(U, KZp)) + TypeFn.cast_signed(U, IZp)) * ( + TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - + TypeFn.cast_signed(U, KZp)) @linalg_structured_op @@ -340,9 +345,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.f, D.oh, D.ow] += TypeFn.cast( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast(U, K[D.f, D.c, D.kh, D.kw]) + O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) @linalg_structured_op @@ -360,9 +365,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast( + O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast( + D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( U, K[D.kd, D.kh, D.kw, D.c, D.f]) @@ -382,8 +387,8 @@ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.ic, D.kw) O[D.n, D.ow, D.ic] += \ - TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ - TypeFn.cast(U, K[D.kw, D.ic]) + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ + TypeFn.cast_signed(U, K[D.kw, D.ic]) @linalg_structured_op @@ -402,9 +407,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast( + O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic]) + D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) @linalg_structured_op @@ -424,11 +429,11 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, - D.ic] += ((TypeFn.cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - - TypeFn.cast(U, IZp)) * - (TypeFn.cast(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast(U, KZp))) + O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - + TypeFn.cast_signed(U, IZp)) * + (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - + TypeFn.cast_signed(U, KZp))) @linalg_structured_op @@ -446,9 +451,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast( + O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) + D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) @linalg_structured_op @@ -469,10 +474,11 @@ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, - D.cm] += ((TypeFn.cast( + D.cm] += ((TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - - TypeFn.cast(U, IZp)) * - (TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast(U, KZp))) + TypeFn.cast_signed(U, IZp)) * + (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - + TypeFn.cast_signed(U, KZp))) @linalg_structured_op @@ -490,7 +496,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] += TypeFn.cast( + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -509,8 +515,8 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw]( - TypeFn.cast( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -549,8 +555,8 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) - O[D.n, D.c, D.oh, D.ow] = ReduceFn.max[D.kh, D.kw]( - TypeFn.cast( + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,])) @@ -570,8 +576,8 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw]( - TypeFn.cast( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw]( + TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -610,7 +616,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast( + O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -630,8 +636,8 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw]( - TypeFn.cast( + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -651,8 +657,8 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw]( - TypeFn.cast( + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -665,7 +671,7 @@ accesses only and is thus rank polymorphic. Numeric casting is performed on the value operand, promoting it to the same data type as the output. """ - O[None] = TypeFn.cast(U, value) + O[None] = TypeFn.cast_signed(U, value) @linalg_structured_op @@ -685,15 +691,15 @@ the range of the generated random numbers. """ domain(D.m, D.n) - multiplier = TypeFn.cast(I32, const(1103515245)) - increment = TypeFn.cast(I32, const(12345)) - rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment - rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment - inv_range = TypeFn.cast(F64, const(2.3283064e-10)) - offset = TypeFn.cast(F64, const(2147483647)) + multiplier = TypeFn.cast_signed(I32, const(1103515245)) + increment = TypeFn.cast_signed(I32, const(12345)) + rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) + offset = TypeFn.cast_signed(F64, const(2147483647)) scaling = (max - min) * inv_range - O[D.m, D.n] = TypeFn.cast(T, - (offset + TypeFn.cast(F64, rand2)) * scaling + min) + O[D.m, D.n] = TypeFn.cast_signed( + T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min) @linalg_structured_op @@ -706,4 +712,4 @@ """ domain(D.m, D.n) O[D.m, D.n] = \ - UnaryFn.log(TypeFn.cast(U, const(1.0)) + UnaryFn.exp(TypeFn.cast(U, I[D.m, D.n]))) + UnaryFn.log(TypeFn.cast_signed(U, const(1.0)) + UnaryFn.exp(TypeFn.cast_signed(U, I[D.m, D.n]))) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -3,7 +3,7 @@ # @linalg_structured_op # def test1(O=TensorDef(T, S.M, S.N, output=True), -# cast=TypeFnAttrDef(default=TypeFn.cast)): +# cast=TypeFnAttrDef(default=TypeFn.cast_signed)): # """Title. # Detailed description. @@ -28,7 +28,7 @@ - !LinalgOperandDefConfig name: cast kind: type_fn_attr - default_fn: cast + default_fn: cast_signed indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> @@ -70,7 +70,7 @@ # ODS: let arguments = # ODS-NEXT: Variadic:$inputs, # ODS-NEXT: Variadic:$outputs, -# ODS-NEXT: DefaultValuedAttr:$cast +# ODS-NEXT: DefaultValuedAttr:$cast # ODS: let builders = # ODS: (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, @@ -99,7 +99,7 @@ # IMPL-LABEL: void Test1Op::regionBuilder(ImplicitLocOpBuilder &b, # IMPL-NEXT: Block &block, ArrayRef attrs) -# IMPL: TypeFn castVal = TypeFn::cast; +# IMPL: TypeFn castVal = TypeFn::cast_signed; # IMPL-NEXT: auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) { # IMPL-NEXT: return attr.getName() == "cast"; }); # IMPL-NEXT: if (castIter != attrs.end()) { @@ -209,7 +209,7 @@ # Detailed description. # """ -# O[None] = TypeFn.cast(U, value) +# O[None] = TypeFn.cast_signed(U, value) --- !LinalgOpConfig metadata: !LinalgOpMetadata @@ -241,7 +241,7 @@ value: !ScalarExpression scalar_fn: kind: type - fn_name: cast + fn_name: cast_signed type_var: U operands: - !ScalarExpression diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py --- a/mlir/test/python/dialects/linalg/opdsl/arguments.py +++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py @@ -26,7 +26,7 @@ # CHECK: default_fn: exp # CHECK: name: cast # CHECK: kind: type_fn_attr -# CHECK: default_fn: cast +# CHECK: default_fn: cast_signed @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), @@ -34,7 +34,7 @@ C=TensorDef(U, S.M, S.N, output=True), bfn=BinaryFnAttrDef(default=BinaryFn.mul), ufn=UnaryFnAttrDef(default=UnaryFn.exp), - cast=TypeFnAttrDef(default=TypeFn.cast)): + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -35,7 +35,7 @@ B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), mul=BinaryFnAttrDef(default=BinaryFn.mul), - cast=TypeFnAttrDef(default=TypeFn.cast)): + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) @@ -63,13 +63,13 @@ # CHECK: scalar_const: '3.1415926535897931 : f64' # CHECK: scalar_fn: # CHECK: kind: type -# CHECK: fn_name: cast +# CHECK: fn_name: cast_signed # CHECK: type_var: T # CHECK: operands: # CHECK: scalar_const: '42 : i64' # CHECK: scalar_fn: # CHECK: kind: type -# CHECK: fn_name: cast +# CHECK: fn_name: cast_signed # CHECK: type_var: T # CHECK: operands: # CHECK: scalar_fn: @@ -81,9 +81,9 @@ def constants( O=TensorDef(T, S.M, S.K, output=True), exp=UnaryFnAttrDef(default=UnaryFn.exp)): - pi = TypeFn.cast(T, const(3.1415926535897931)) - cst42 = TypeFn.cast(T, const(42)) - cst1000 = TypeFn.cast(T, exp(const(1e+3))) + pi = TypeFn.cast_signed(T, const(3.1415926535897931)) + cst42 = TypeFn.cast_signed(T, const(42)) + cst1000 = TypeFn.cast_signed(T, exp(const(1e+3))) O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000 diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py @@ -19,9 +19,9 @@ strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 2])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] += TypeFn.cast( + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c]) + D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c]) with Context() as ctx, Location.unknown(): diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_fill.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_fill.py @@ -13,7 +13,7 @@ @linalg_structured_op def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)): - O[None] = TypeFn.cast(U, value) + O[None] = TypeFn.cast_signed(U, value) with Context() as ctx, Location.unknown(): diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py @@ -25,7 +25,7 @@ A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast)): + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): domain(D.m, D.n, D.k) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py @@ -21,22 +21,23 @@ max=ScalarDef(F64), seed=ScalarDef(I32), O=TensorDef(T, S.M, S.N, output=True)): - multiplier = TypeFn.cast(I32, const(1103515245)) - increment = TypeFn.cast(I32, const(12345)) - rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment - rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment - inv_range = TypeFn.cast(F64, const(2.3283064e-10)) - offset = TypeFn.cast(F64, const(2147483647)) + multiplier = TypeFn.cast_signed(I32, const(1103515245)) + increment = TypeFn.cast_signed(I32, const(12345)) + rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) + offset = TypeFn.cast_signed(F64, const(2147483647)) scaling = (max - min) * inv_range - O[D.m, D.n] = TypeFn.cast(T, - (offset + TypeFn.cast(F64, rand2)) * scaling + min) + O[D.m, D.n] = TypeFn.cast_signed( + T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min) @linalg_structured_op def soft_plus_poly( I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)): O[D.m, D.n] = UnaryFn.log( - TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, UnaryFn.exp(I[D.m, D.n]))) + TypeFn.cast_signed(U, const(1.0)) + + TypeFn.cast_signed(U, UnaryFn.exp(I[D.m, D.n]))) @linalg_structured_op(op_name="custom_op_name") diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py @@ -16,8 +16,8 @@ I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - reduce=BinaryFnAttrDef(default=BinaryFn.max), - cast=TypeFnAttrDef(default=TypeFn.cast), + reduce=BinaryFnAttrDef(default=BinaryFn.max_signed), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) @@ -99,7 +99,7 @@ input, shape, outs=[init_result], - reduce=BinaryFn.min, + reduce=BinaryFn.min_signed, strides=[2, 4], dilations=[1, 2]) @@ -131,7 +131,7 @@ input, shape, outs=[init_result], - reduce=BinaryFn.min, + reduce=BinaryFn.min_signed, strides=[2, 4], dilations=[1, 2]) diff --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/interfaces.py --- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py +++ b/mlir/test/python/dialects/linalg/opdsl/interfaces.py @@ -13,4 +13,5 @@ B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed( + U, B[D.k, D.n]) diff --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py --- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py +++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py @@ -24,7 +24,8 @@ B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): domain(D.m, D.n, D.k) - C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed( + U, B[D.k, D.n]) # Verifies that assignment to a scalar (represented as [None]) is represented @@ -42,7 +43,7 @@ # CHECK-NEXT: - reduction @linalg_structured_op def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)): - C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m]) + C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) # Verifies that the index_dims of shape-only operands translate to correct @@ -65,4 +66,4 @@ K=TensorDef(T, S.K, index_dims=[D.k]), O=TensorDef(U, S.O, output=True)): domain(D.o, D.k) - O[D.o] += TypeFn.cast(U, I[D.o * 2 + D.k]) + O[D.o] += TypeFn.cast_signed(U, I[D.o * 2 + D.k]) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -99,7 +99,7 @@ init_result = linalg.InitTensorOp([4, 8], f32) # Check for the named form with custom format # CHECK: linalg.elemwise_unary - # CHECK-SAME: cast = #linalg.type_fn + # CHECK-SAME: cast = #linalg.type_fn # CHECK-SAME: fun = #linalg.unary_fn # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result]) @@ -137,7 +137,7 @@ # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () - # CHECK-NEXT: cast = #linalg.type_fn + # CHECK-NEXT: cast = #linalg.type_fn # CHECK-SAME: operand_segment_sizes = dense<[2, 1]> : vector<2xi32> # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return linalg.matmul(lhs, rhs, outs=[init_result.result])