diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -56,7 +56,7 @@ """ domain(D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) ``` Here we have a simple type polymorphic contraction that takes arguments `A` and @@ -159,8 +159,8 @@ O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=IndexAttrDef(S.SH, S.SW), dilations=IndexAttrDef(S.DH, S.DW)): - O[D.n, D.oh, D.ow, D.c] += \ - cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast(U, + I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) ``` The pooling operation does not access the shape-only tensor `K`. Instead, the @@ -192,10 +192,18 @@ * `ReduceFn.mul` * `ReduceFn.max` +Additionally, type conversion functions cast an operand to a target type: + +* `TypeFn.cast(TypeVar, operand)` +* `TypeFn.cast_unsigned(TypeVar, operand)` + +As the integer types are signless, signedness is implement by different +functions that treat integers as signed (`TypeFn.cast`) or unsigned +(`TypeFn.cast_unsigned`) values. + There are also special forms: -* `cast(TypeVar, operand)` casts the `operand` to the target type `TypeVar`. -* `const(TypeVar, value)` returns a constant value of type `TypeVar`. +* `const(value)` returns a constant value. * `index(dim)` returns the iteration index in the given dimension `dim`. ## Types @@ -206,18 +214,25 @@ computations with a type that is independent of the input and output types. For example, parts of floating point computation may require double precision arithmetic despite all inputs and outputs being single precision values. -Assignment expressions with no `cast` calls will generally require uniform types -throughout and will fail to verify if violated. The presence of a `cast` allows -for a limited form of numeric type conversion between element types that can be -derived from inputs and outputs (and in the future, attributes). `cast` calls -with a `TypeVar` first argument are emitted as `symbolic_cast` primitives in the -YAML definition. +Assignment expressions with no `TypeFn.cast` calls will generally require +uniform types throughout and will fail to verify if violated. The presence of a +`TypeFn.cast` or `TypeFn.cast_unsigned` allows for a limited form of numeric +type conversion between element types that can be derived from inputs and +outputs (and in the future, attributes). `TypeFn.cast` calls with a `TypeVar` +first argument are emitted as `type_fn` primitives in the YAML definition. Casting will perform `int<->float` and `index->int` type conversions and will -perform any necessary extension or truncation within type family. Note that -presently, any integer type is assumed to be signed for the purpose of -determining how to extend or truncate. Supporting unsigned integer types is left -for future work. +perform any necessary extension or truncation within the type family. The +integer types themselves are signless and signedness is implemented by +functions/operations. The `TypeFn.cast` function treats all integers as signed, +while `TypeFn.cast_unsigned` treats them as unsigned. + +The following examples illustrate the lowering of signed and unsigned functions: + +* cast(I32 -> I64) -> `arith.ExtSIOp` +* cast(F32 -> I32) -> `arith.FPToSIOp` +* cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` +* cast_unsigned(F32 -> I32) -> `arith.FPToUIOp` Not all functions are applicable for all numeric types, and on mismatch, op verification will fail. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -51,19 +51,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matmul_unsigned @@ -115,19 +115,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: A - is_unsigned_cast: true - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: B - is_unsigned_cast: true --- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_matmul @@ -193,37 +193,37 @@ fn_name: sub operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: AZp - is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: sub operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: BZp - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: mmt4d @@ -286,19 +286,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: AccumType operands: - !ScalarExpression scalar_arg: lhs - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: AccumType operands: - !ScalarExpression scalar_arg: rhs - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matmul @@ -351,19 +351,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: quantized_batch_matmul @@ -430,37 +430,37 @@ fn_name: sub operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: AZp - is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: sub operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: BZp - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: matvec @@ -511,19 +511,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: y - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: vecmat @@ -574,19 +574,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: y - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matvec @@ -638,19 +638,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: dot @@ -700,19 +700,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: A - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: B - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_1d @@ -763,19 +763,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d @@ -828,19 +828,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_3d @@ -896,19 +896,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_1d_nwc_wcf @@ -974,19 +974,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf @@ -1064,19 +1064,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nhwc_hwcf_q @@ -1171,37 +1171,37 @@ fn_name: sub operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp - is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: sub operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_2d_nchw_fchw @@ -1279,19 +1279,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: conv_3d_ndhwc_dhwcf @@ -1369,19 +1369,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_1d_nwc_wc @@ -1446,19 +1446,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwc @@ -1529,19 +1529,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwc_q @@ -1627,37 +1627,37 @@ fn_name: sub operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp - is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: sub operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwcm @@ -1731,19 +1731,19 @@ fn_name: mul operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: depthwise_conv_2d_nhwc_hwcm_q @@ -1833,37 +1833,37 @@ fn_name: sub operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: IZp - is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: sub operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: K - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: KZp - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_sum @@ -1929,12 +1929,12 @@ - !ScalarExpression scalar_arg: O - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_max @@ -2000,12 +2000,12 @@ - !ScalarExpression scalar_arg: O - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_max_unsigned @@ -2071,12 +2071,12 @@ - !ScalarExpression scalar_arg: O - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: true --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nchw_max @@ -2142,12 +2142,12 @@ - !ScalarExpression scalar_arg: O - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_min @@ -2213,12 +2213,12 @@ - !ScalarExpression scalar_arg: O - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_nhwc_min_unsigned @@ -2284,12 +2284,12 @@ - !ScalarExpression scalar_arg: O - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast_unsigned type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: true --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_sum @@ -2361,12 +2361,12 @@ - !ScalarExpression scalar_arg: O - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_max @@ -2438,12 +2438,12 @@ - !ScalarExpression scalar_arg: O - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: pooling_ndhwc_min @@ -2515,12 +2515,12 @@ - !ScalarExpression scalar_arg: O - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: fill_rng_2d @@ -2567,7 +2567,8 @@ - !ScalarAssign arg: O value: !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: T operands: - !ScalarExpression @@ -2583,14 +2584,15 @@ fn_name: add operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: F64 operands: - !ScalarExpression scalar_const: '2147483647 : i64' - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: F64 operands: - !ScalarExpression @@ -2606,12 +2608,12 @@ fn_name: add operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_index: 1 - is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: add @@ -2625,43 +2627,42 @@ fn_name: add operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_index: 0 - is_unsigned_cast: false - !ScalarExpression scalar_arg: seed - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '1103515245 : i64' - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '12345 : i64' - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '1103515245 : i64' - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: I32 operands: - !ScalarExpression scalar_const: '12345 : i64' - is_unsigned_cast: false - is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: mul @@ -2675,15 +2676,14 @@ - !ScalarExpression scalar_arg: min - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: F64 operands: - !ScalarExpression scalar_const: '2.3283063999999999E-10 : f64' - is_unsigned_cast: false - !ScalarExpression scalar_arg: min - is_unsigned_cast: false --- !LinalgOpConfig metadata: !LinalgOpMetadata name: soft_plus_2d @@ -2724,20 +2724,20 @@ fn_name: add operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_const: '1.000000e+00 : f64' - is_unsigned_cast: false - !ScalarExpression scalar_apply: fn_name: exp operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: U operands: - !ScalarExpression scalar_arg: I - is_unsigned_cast: false diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -148,11 +148,13 @@ // Region builder helper. // TODO: Move this to a utility library. // The public methods on this class are referenced directly from generated code -// and bind by name to math functions in the DSL as: +// and bind by name to math and type conversion functions in the DSL as: // `applyfn__{fnName}` +// `typefn__{fnName}` // Examples: // `applyfn__add` // `applyfn__mul` +// `typefn__cast` // The naming convention is intentional in order to match snake-cased DSL names. // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class. // @@ -229,6 +231,14 @@ return operand; } + Value typefn__cast(Type toType, Value operand) { + return cast(toType, operand, false); + } + + Value typefn__cast_unsigned(Type toType, Value operand) { + return cast(toType, operand, true); + } + Value applyfn__add(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -314,6 +314,39 @@ return f"{defs_repr} = {values_repr}" +class TypeFnType: + """Type conversion function. + + A type conversion function takes a target type and a tensor expression and + returns the casted tensor expression. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, type_var: TypeVar, + arg: TensorExpression) -> "TensorTypeFn": + return TensorTypeFn(self, type_var, arg) + + def __repr__(self): + return f"{self.fn_name}" + + +class TypeFn: + """Type conversion function namespace. + + As the integer types are signless, signedness is implement by different cast + functions that treat integers as signed (`cast`) or unsigned + (`cast_unsigned`) values. + + Examples: + - cast(I32 -> I64) -> `arith.ExtSIOp` + - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` + """ + cast = TypeFnType("cast") + cast_unsigned = TypeFnType("cast_unsigned") + + class PrimFnType: """Primitive operations.""" @@ -391,6 +424,26 @@ return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})" +class TensorTypeFn(TensorExpression): + """Application of a type conversion function.""" + + def __init__(self, type_fn: TypeFn, type_var: TypeVar, arg: TensorExpression): + self.type_fn = type_fn + self.type_var = type_var + self.arg = arg + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarTypeFn(self.type_fn.fn_name, self.type_var, + self.arg.to_scalar_expression()).expr() + + def visit_tensor_exprs(self, callback): + super().visit_tensor_exprs(callback) + self.arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.type_fn)}({type_var}, {self.arg})" + + class const(TensorExpression): """Returns the given constant floating point or integer value.""" @@ -433,36 +486,6 @@ return f"index({repr(self.dim)})" -class cast(TensorExpression): - """Casts the element type to a type (typically symbolic TypeVar).""" - - def __init__(self, to_type: TypeVar, operand: TensorExpression): - self.to_type = to_type - self.operand = operand - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(), - False).expr() - - def visit_tensor_exprs(self, callback): - super().visit_tensor_exprs(callback) - self.operand.visit_tensor_exprs(callback) - - def __repr__(self): - return f"cast({self.to_type}, {repr(self.operand)})" - - -class cast_unsigned(cast): - """Casts the element type to an unsigned type (typically symbolic TypeVar).""" - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(), - True).expr() - - def __repr__(self): - return f"cast_unsigned({self.to_type}, {repr(self.operand)})" - - class ReduceApply(TensorExpression): """Application of a reduction. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -2,7 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, List, Sequence, Tuple, Union +from typing import Callable, Dict, List, Sequence, Tuple, Union from .....ir import * from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region @@ -25,6 +25,7 @@ ValueList = Union[Sequence[Value], OpResultList] + def isa(cls: Type, ty: Type): try: cls(ty) @@ -223,24 +224,38 @@ IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result elif expr.scalar_apply: - try: - fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") - except AttributeError: - raise ValueError( - f"Function '{expr.scalar_apply.fn_name}' is not a known " - "scalar body function") + fn = self._get_function(f"_eval_{expr.scalar_apply.fn_name}") operand_values = [ self.expression(operand) for operand in expr.scalar_apply.operands ] return fn(*operand_values) - elif expr.symbolic_cast: - operand_value = self.expression(expr.symbolic_cast.operand) - return self.cast(expr.symbolic_cast.to_type.name, operand_value, - expr.symbolic_cast.is_unsigned_cast) + elif expr.type_fn: + fn = self._get_function(f"_typefn_{expr.type_fn.fn_name}") + operand = self.expression(expr.type_fn.operand) + return fn(expr.type_fn.type_var.name, operand) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") - def cast(self, type_var_name: str, operand: Value, - is_unsigned_cast: bool) -> Value: + def yield_outputs(self, *output_names: str): + output_values = [] + for n in output_names: + try: + output_values.append(self.yield_mapping[n]) + except KeyError: + raise ValueError(f"Body assignments do not assign all outputs: " + f"missing '{n}'") + linalg.YieldOp(output_values) + + def _get_function(self, fn_name: str) -> Callable: + try: + fn = getattr(self, f"{fn_name}") + except AttributeError: + raise ValueError(f"Function '{fn_name}' is not a known function") + return fn + + def _cast(self, + type_var_name: str, + operand: Value, + is_unsigned_cast: bool = False) -> Value: try: to_type = self.type_mapping[type_var_name] except KeyError: @@ -291,15 +306,11 @@ raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") - def yield_outputs(self, *output_names: str): - output_values = [] - for n in output_names: - try: - output_values.append(self.yield_mapping[n]) - except KeyError: - raise ValueError(f"Body assignments do not assign all outputs: " - f"missing '{n}'") - linalg.YieldOp(output_values) + def _typefn_cast(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, False) + + def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, True) def _eval_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -21,11 +21,11 @@ __all__ = [ "ScalarAssign", "ScalarApplyFn", + "ScalarTypeFn", "ScalarArg", "ScalarConst", "ScalarIndex", "ScalarExpression", - "ScalarSymbolicCast", ] @@ -43,6 +43,22 @@ return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})" +class ScalarTypeFn: + """A type of ScalarExpression that applies a type conversion function.""" + + def __init__(self, fn_name: str, type_var: TypeVar, + operand: "ScalarExpression"): + self.fn_name = fn_name + self.type_var = type_var + self.operand = operand + + def expr(self) -> "ScalarExpression": + return ScalarExpression(type_fn=self) + + def __repr__(self): + return f"ScalarTypeFn<{self.fn_name}>({self.type_var}, {self.operand})" + + class ScalarArg: """A type of ScalarExpression that references a named argument.""" @@ -82,27 +98,12 @@ return f"(ScalarIndex({self.dim})" -class ScalarSymbolicCast: - """A type of ScalarExpression that symbolically casts an operand to a TypeVar.""" - - def __init__(self, to_type: TypeVar, operand: "ScalarExpression", - is_unsigned_cast: bool): - self.to_type = to_type - self.operand = operand - self.is_unsigned_cast = is_unsigned_cast - - def expr(self) -> "ScalarExpression": - return ScalarExpression(symbolic_cast=self) - - def __repr__(self): - return f"ScalarSymbolicCast({self.to_type}, {self.operand}, {self.is_unsigned_cast})" - - class ScalarExpression(YAMLObject): """An expression on scalar values. Can be one of: - ScalarApplyFn + - ScalarTypeFn - ScalarArg - ScalarConst - ScalarIndex @@ -112,19 +113,19 @@ def __init__(self, scalar_apply: Optional[ScalarApplyFn] = None, + type_fn: Optional[ScalarTypeFn] = None, scalar_arg: Optional[ScalarArg] = None, scalar_const: Optional[ScalarConst] = None, - scalar_index: Optional[ScalarIndex] = None, - symbolic_cast: Optional[ScalarSymbolicCast] = None): - if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_const) + - bool(scalar_index) + bool(symbolic_cast)) != 1: - raise ValueError("One of 'scalar_apply', 'scalar_arg', 'scalar_const', " - "'scalar_index', 'symbolic_cast' must be specified") + scalar_index: Optional[ScalarIndex] = None): + if (bool(scalar_apply) + bool(type_fn) + bool(scalar_arg) + + bool(scalar_const) + bool(scalar_index)) != 1: + raise ValueError("One of 'scalar_apply', 'type_fn', 'scalar_arg', " + "'scalar_const', 'scalar_index', must be specified") self.scalar_apply = scalar_apply + self.type_fn = type_fn self.scalar_arg = scalar_arg self.scalar_const = scalar_const self.scalar_index = scalar_index - self.symbolic_cast = symbolic_cast def to_yaml_custom_dict(self): if self.scalar_apply: @@ -133,21 +134,22 @@ fn_name=self.scalar_apply.fn_name, operands=list(self.scalar_apply.operands), )) + if self.type_fn: + # Note that even though operands must be arity 1, we write it the + # same way as for apply because it allows handling code to be more + # generic vs having a special form. + return dict( + type_fn=dict( + fn_name=self.type_fn.fn_name, + type_var=self.type_fn.type_var.name, + operands=[self.type_fn.operand], + )) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) elif self.scalar_const: return dict(scalar_const=self.scalar_const.value) elif self.scalar_index: return dict(scalar_index=self.scalar_index.dim) - elif self.symbolic_cast: - # Note that even though operands must be arity 1, we write it the - # same way as for apply because it allows handling code to be more - # generic vs having a special form. - return dict( - symbolic_cast=dict( - type_var=self.symbolic_cast.to_type.name, - operands=[self.symbolic_cast.operand], - is_unsigned_cast=self.symbolic_cast.is_unsigned_cast)) else: raise ValueError(f"Unexpected ScalarExpression type: {self}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -18,7 +18,7 @@ """ domain(D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) @linalg_structured_op @@ -33,7 +33,8 @@ """ domain(D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( + U, B[D.k, D.n]) @linalg_structured_op @@ -51,8 +52,8 @@ matmul. """ domain(D.m, D.n, D.k) - C[D.m, D.n] += (cast(U, A[D.m, D.k]) - cast(U, AZp)) * ( - cast(U, B[D.k, D.n]) - cast(U, BZp)) + C[D.m, D.n] += (TypeFn.cast(U, A[D.m, D.k]) - TypeFn.cast(U, AZp)) * ( + TypeFn.cast(U, B[D.k, D.n]) - TypeFn.cast(U, BZp)) @linalg_structured_op @@ -72,9 +73,9 @@ """ domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) implements(ContractionOpInterface) - accum[D.m, D.n, D.m0, - D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast( - TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast( + TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast( + TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) @linalg_structured_op @@ -89,7 +90,8 @@ """ domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n]) + C[D.b, D.m, + D.n] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k, D.n]) @linalg_structured_op @@ -107,8 +109,9 @@ matmul. """ domain(D.b, D.m, D.n, D.k) - C[D.b, D.m, D.n] += (cast(U, A[D.b, D.m, D.k]) - cast(U, AZp)) * ( - cast(U, B[D.b, D.k, D.n]) - cast(U, BZp)) + C[D.b, D.m, + D.n] += (TypeFn.cast(U, A[D.b, D.m, D.k]) - TypeFn.cast(U, AZp)) * ( + TypeFn.cast(U, B[D.b, D.k, D.n]) - TypeFn.cast(U, BZp)) @linalg_structured_op @@ -123,7 +126,7 @@ """ domain(D.m, D.n) implements(ContractionOpInterface) - x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n]) + x[D.m] += TypeFn.cast(U, A[D.m, D.n]) * TypeFn.cast(U, y[D.n]) @linalg_structured_op @@ -138,7 +141,7 @@ """ domain(D.n, D.m) implements(ContractionOpInterface) - x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n]) + x[D.n] += TypeFn.cast(U, y[D.m]) * TypeFn.cast(U, A[D.m, D.n]) @linalg_structured_op @@ -153,7 +156,7 @@ """ domain(D.b, D.m, D.k) implements(ContractionOpInterface) - C[D.b, D.m] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k]) + C[D.b, D.m] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k]) @linalg_structured_op @@ -165,7 +168,7 @@ them to the same data type as the accumulator/output. """ implements(ContractionOpInterface) - C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) + C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m]) @linalg_structured_op @@ -180,7 +183,7 @@ """ implements(ConvolutionOpInterface) domain(D.ow, D.kw) - O[D.ow] += cast(U, I[D.ow + D.kw]) * cast(U, K[D.kw]) + O[D.ow] += TypeFn.cast(U, I[D.ow + D.kw]) * TypeFn.cast(U, K[D.kw]) @linalg_structured_op @@ -195,7 +198,8 @@ """ implements(ConvolutionOpInterface) domain(D.oh, D.ow, D.kh, D.kw) - O[D.oh, D.ow] += cast(U, I[D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kh, D.kw]) + O[D.oh, D.ow] += TypeFn.cast(U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast( + U, K[D.kh, D.kw]) @linalg_structured_op @@ -211,8 +215,8 @@ implements(ConvolutionOpInterface) domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) O[D.od, D.oh, - D.ow] += cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * cast( - U, K[D.kd, D.kh, D.kw]) + D.ow] += TypeFn.cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow + + D.kw]) * TypeFn.cast(U, K[D.kd, D.kh, D.kw]) @linalg_structured_op @@ -229,8 +233,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.f, D.kw, D.c) - O[D.n, D.ow, D.f] += cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * cast( - U, K[D.kw, D.c, D.f]) + O[D.n, D.ow, + D.f] += TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, + D.c]) * TypeFn.cast(U, K[D.kw, D.c, D.f]) @linalg_structured_op @@ -252,9 +257,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.f] += cast( + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * cast(U, K[D.kh, D.kw, D.c, D.f]) + D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) @linalg_structured_op @@ -280,10 +285,10 @@ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, - D.f] += (cast( + D.f] += (TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) - - cast(U, IZp)) * ( - cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp)) + TypeFn.cast(U, IZp)) * ( + TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast(U, KZp)) @linalg_structured_op @@ -305,9 +310,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.f, D.oh, D.ow] += cast( + O[D.n, D.f, D.oh, D.ow] += TypeFn.cast( U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW]) * cast(U, K[D.f, D.c, D.kh, D.kw]) + D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast(U, K[D.f, D.c, D.kh, D.kw]) @linalg_structured_op @@ -325,9 +330,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.f] += cast( + O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c]) * cast( + D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast( U, K[D.kd, D.kh, D.kw, D.c, D.f]) @@ -347,8 +352,8 @@ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.ic, D.kw) O[D.n, D.ow, D.ic] += \ - cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ - cast(U, K[D.kw, D.ic]) + TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ + TypeFn.cast(U, K[D.kw, D.ic]) @linalg_structured_op @@ -367,9 +372,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic] += cast( + O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * cast(U, K[D.kh, D.kw, D.ic]) + D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic]) @linalg_structured_op @@ -389,10 +394,11 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic] += ( - (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) - cast(U, IZp)) * - (cast(U, K[D.kh, D.kw, D.ic]) - cast(U, KZp))) + O[D.n, D.oh, D.ow, + D.ic] += ((TypeFn.cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - + TypeFn.cast(U, IZp)) * + (TypeFn.cast(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast(U, KZp))) @linalg_structured_op @@ -410,9 +416,9 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic, D.cm] += cast( + O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm]) + D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) @linalg_structured_op @@ -432,10 +438,11 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic, D.cm] += ( - (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) - cast(U, IZp)) * - (cast(U, K[D.kh, D.kw, D.ic, D.cm]) - cast(U, KZp))) + O[D.n, D.oh, D.ow, D.ic, + D.cm] += ((TypeFn.cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - + TypeFn.cast(U, IZp)) * + (TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast(U, KZp))) @linalg_structured_op @@ -453,7 +460,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] += cast( + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -473,8 +480,8 @@ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)( - cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c])) + TypeFn.cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op @@ -493,7 +500,7 @@ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)( - cast_unsigned( + TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -513,8 +520,9 @@ implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)( - cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW,])) + TypeFn.cast( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW,])) @linalg_structured_op @@ -533,8 +541,8 @@ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( - cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c])) + TypeFn.cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op @@ -553,7 +561,7 @@ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)( - cast_unsigned( + TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -572,7 +580,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.c] += cast( + O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -593,7 +601,7 @@ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)( - cast( + TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -614,7 +622,7 @@ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)( - cast( + TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -636,14 +644,15 @@ the range of the generated random numbers. """ domain(D.m, D.n) - multiplier = cast(I32, const(1103515245)) - increment = cast(I32, const(12345)) - rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment - rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment - inv_range = cast(F64, const(2.3283064e-10)) - offset = cast(F64, const(2147483647)) + multiplier = TypeFn.cast(I32, const(1103515245)) + increment = TypeFn.cast(I32, const(12345)) + rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = TypeFn.cast(F64, const(2.3283064e-10)) + offset = TypeFn.cast(F64, const(2147483647)) scaling = (max - min) * inv_range - O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) + O[D.m, D.n] = TypeFn.cast(T, + (offset + TypeFn.cast(F64, rand2)) * scaling + min) @linalg_structured_op @@ -656,4 +665,4 @@ """ domain(D.m, D.n) O[D.m, D.n] = \ - PrimFn.log(cast(U, const(1.0)) + PrimFn.exp(cast(U, I[D.m, D.n]))) + PrimFn.log(TypeFn.cast(U, const(1.0)) + PrimFn.exp(TypeFn.cast(U, I[D.m, D.n]))) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -38,19 +38,19 @@ fn_name: add operands: - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast type_var: T operands: - !ScalarExpression scalar_const: '42 : i64' - is_unsigned_cast: false - !ScalarExpression - symbolic_cast: + type_fn: + fn_name: cast_unsigned type_var: T operands: - !ScalarExpression scalar_index: 1 - is_unsigned_cast: true # ODS-LABEL: def Test1Op : LinalgStructuredBase_Op<"test1" @@ -86,9 +86,9 @@ # IMPL-LABEL: void Test1Op::regionBuilder( # IMPL: ImplicitLocOpBuilder &b, Block &block) # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64"); -# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL0]], false); +# IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); -# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.cast(block.getArgument(0).getType(), [[VAL2]], true); +# IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]); # IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]); diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py --- a/mlir/test/python/dialects/linalg/opdsl/arguments.py +++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py @@ -23,7 +23,7 @@ A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) # CHECK: --- diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -15,11 +15,11 @@ # CHECK: scalar_apply: # CHECK: fn_name: mul # CHECK: operands: -# CHECK: symbolic_cast: +# CHECK: type_fn: # CHECK: type_var: U # CHECK: operands: # CHECK: scalar_arg: A -# CHECK: symbolic_cast: +# CHECK: type_fn: # CHECK: type_var: U # CHECK: operands: # CHECK: scalar_arg: B @@ -28,7 +28,7 @@ A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) # CHECK: --- @@ -42,23 +42,23 @@ # CHECK: scalar_apply: # CHECK: fn_name: add # CHECK: operands: -# CHECK: symbolic_cast: +# CHECK: type_fn: # CHECK: type_var: T # CHECK: operands: # CHECK: scalar_const: '3.1415926535897931 : f64' -# CHECK: symbolic_cast: +# CHECK: type_fn: # CHECK: type_var: T # CHECK: operands: # CHECK: scalar_const: '42 : i64' -# CHECK: symbolic_cast: +# CHECK: type_fn: # CHECK: type_var: T # CHECK: operands: # CHECK: scalar_const: '1.{{[0]*}}e+03 : f64' @linalg_structured_op def constants(O=TensorDef(T, S.M, S.K, output=True)): - pi = cast(T, const(3.1415926535897931)) - cst42 = cast(T, const(42)) - cst1000 = cast(T, const(1e+3)) + pi = TypeFn.cast(T, const(3.1415926535897931)) + cst42 = TypeFn.cast(T, const(42)) + cst1000 = TypeFn.cast(T, const(1e+3)) O[D.m, D.n] = pi + cst42 - cst1000 diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_convolution.py @@ -19,9 +19,9 @@ strides=IndexAttrDef(S.SH, S.SW), dilations=IndexAttrDef(S.DH, S.DW)): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] += cast( + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * cast(U, K[D.kh, D.kw, D.c]) + D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c]) with Context() as ctx, Location.unknown(): diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_matmul.py @@ -26,7 +26,7 @@ B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): domain(D.m, D.n, D.k) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) @linalg_structured_op @@ -35,7 +35,8 @@ B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): domain(D.m, D.n, D.k) - C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( + U, B[D.k, D.n]) with Context() as ctx, Location.unknown(): diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py @@ -14,27 +14,29 @@ # - exponential functions # - custom op names. + @linalg_structured_op def fill_rng_poly( min=ScalarDef(F64), max=ScalarDef(F64), seed=ScalarDef(I32), O=TensorDef(T, S.M, S.N, output=True)): - multiplier = cast(I32, const(1103515245)) - increment = cast(I32, const(12345)) - rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment - rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment - inv_range = cast(F64, const(2.3283064e-10)) - offset = cast(F64, const(2147483647)) + multiplier = TypeFn.cast(I32, const(1103515245)) + increment = TypeFn.cast(I32, const(12345)) + rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = TypeFn.cast(F64, const(2.3283064e-10)) + offset = TypeFn.cast(F64, const(2147483647)) scaling = (max - min) * inv_range - O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) + O[D.m, D.n] = TypeFn.cast(T, + (offset + TypeFn.cast(F64, rand2)) * scaling + min) @linalg_structured_op def soft_plus_poly( I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)): - O[D.m, D.n] = \ - PrimFn.log(cast(U, const(1.0)) + cast(U, PrimFn.exp(I[D.m, D.n]))) + O[D.m, D.n] = PrimFn.log( + TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, PrimFn.exp(I[D.m, D.n]))) @linalg_structured_op(op_name="custom_op_name") diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py @@ -20,8 +20,8 @@ dilations=IndexAttrDef(S.DH, S.DW)): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)( - cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c])) + TypeFn.cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op @@ -33,7 +33,7 @@ dilations=IndexAttrDef(S.DH, S.DW)): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)( - cast_unsigned( + TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -46,8 +46,8 @@ dilations=IndexAttrDef(S.DH, S.DW)): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( - cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c])) + TypeFn.cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op @@ -59,7 +59,7 @@ dilations=IndexAttrDef(S.DH, S.DW)): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)( - cast_unsigned( + TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) diff --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/interfaces.py --- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py +++ b/mlir/test/python/dialects/linalg/opdsl/interfaces.py @@ -13,4 +13,4 @@ B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) diff --git a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py --- a/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py +++ b/mlir/test/python/dialects/linalg/opdsl/shape_maps_iteration.py @@ -24,7 +24,7 @@ B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): domain(D.m, D.n, D.k) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) # Verifies that assignment to a scalar (represented as [None]) is represented @@ -42,7 +42,7 @@ # CHECK-NEXT: - reduction @linalg_structured_op def dot(A=TensorDef(T, S.M), B=TensorDef(T, S.M), C=TensorDef(U, output=True)): - C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) + C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m]) # Verifies that the index_dims of shape-only operands translate to correct @@ -65,4 +65,4 @@ K=TensorDef(T, S.K, index_dims=[D.k]), O=TensorDef(U, S.O, output=True)): domain(D.o, D.k) - O[D.o] += cast(U, I[D.o * 2 + D.k]) + O[D.o] += TypeFn.cast(U, I[D.o * 2 + D.k]) diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -90,12 +90,12 @@ std::vector operands; }; -struct ScalarSymbolicCast { +struct ScalarTypeFn { + std::string fnName; std::string typeVar; // NOTE: This must be of arity 1, but to break the self-referential cycle, // we use a heap allocated vector. std::vector operands; - bool isUnsignedCast; }; struct ScalarExpression { @@ -103,7 +103,7 @@ Optional constant; Optional index; Optional apply; - Optional symbolicCast; + Optional typeFn; }; struct ScalarAssign { @@ -142,7 +142,8 @@ /// Top-level type containing op metadata and one of a concrete op type. /// Currently, the only defined op type is `structured_op` (maps to /// `LinalgStructuredOpConfig`). -template <> struct MappingTraits { +template <> +struct MappingTraits { static void mapping(IO &io, LinalgOpConfig &info) { io.mapOptional("metadata", info.metadata); io.mapOptional("structured_op", info.structuredOp); @@ -155,7 +156,8 @@ /// - List of indexing maps (see `LinalgIndexingMaps`). /// - Iterator types (see `LinalgIteratorTypeDef`). /// - List of scalar level assignment (see `ScalarAssign`). -template <> struct MappingTraits { +template <> +struct MappingTraits { static void mapping(IO &io, LinalgStructuredOpConfig &info) { io.mapRequired("args", info.args); io.mapRequired("indexing_maps", info.indexingMaps); @@ -178,7 +180,8 @@ /// attribute symbols. During op creation these symbols are replaced by the /// corresponding `name` attribute values. Only attribute arguments have /// an `attribute_map`. -template <> struct MappingTraits { +template <> +struct MappingTraits { static void mapping(IO &io, LinalgOperandDef &info) { io.mapRequired("name", info.name); io.mapRequired("usage", info.usage); @@ -189,7 +192,8 @@ }; /// Usage enum for a named argument. -template <> struct ScalarEnumerationTraits { +template <> +struct ScalarEnumerationTraits { static void enumeration(IO &io, LinalgOperandDefUsage &value) { io.enumCase(value, "InputOperand", LinalgOperandDefUsage::input); io.enumCase(value, "OutputOperand", LinalgOperandDefUsage::output); @@ -198,7 +202,8 @@ }; /// Iterator type enum. -template <> struct ScalarEnumerationTraits { +template <> +struct ScalarEnumerationTraits { static void enumeration(IO &io, LinalgIteratorTypeDef &value) { io.enumCase(value, "parallel", LinalgIteratorTypeDef::parallel); io.enumCase(value, "reduction", LinalgIteratorTypeDef::reduction); @@ -206,7 +211,8 @@ }; /// Metadata about the op (name, C++ name, and documentation). -template <> struct MappingTraits { +template <> +struct MappingTraits { static void mapping(IO &io, LinalgOpMetadata &info) { io.mapRequired("name", info.name); io.mapRequired("cpp_class_name", info.cppClassName); @@ -220,7 +226,8 @@ /// some symbols that bind to attributes of the op. Each indexing map must /// be normalized over the same list of dimensions, and its symbols must /// match the symbols for argument shapes. -template <> struct MappingTraits { +template <> +struct MappingTraits { static void mapping(IO &io, LinalgIndexingMapsConfig &info) { io.mapOptional("static_indexing_maps", info.staticIndexingMaps); } @@ -230,7 +237,8 @@ /// - The `arg` name must match a named output. /// - The `value` is a scalar expression for computing the value to /// assign (see `ScalarExpression`). -template <> struct MappingTraits { +template <> +struct MappingTraits { static void mapping(IO &io, ScalarAssign &info) { io.mapRequired("arg", info.arg); io.mapRequired("value", info.value); @@ -241,14 +249,15 @@ /// - `scalar_arg`: Name of an argument to the op. /// - `scalar_apply`: Result of evaluating a named function (see /// `ScalarApply`). -/// - `symbolic_cast`: Cast to a symbolic TypeVar bound elsewhere. -template <> struct MappingTraits { +/// - `type_fn`: A named type conversion function (see `ScalarTypeFn`). +template <> +struct MappingTraits { static void mapping(IO &io, ScalarExpression &info) { io.mapOptional("scalar_arg", info.arg); io.mapOptional("scalar_const", info.constant); io.mapOptional("scalar_index", info.index); io.mapOptional("scalar_apply", info.apply); - io.mapOptional("symbolic_cast", info.symbolicCast); + io.mapOptional("type_fn", info.typeFn); } }; @@ -257,24 +266,27 @@ /// functions include: /// - `add(lhs, rhs)` /// - `mul(lhs, rhs)` -template <> struct MappingTraits { +template <> +struct MappingTraits { static void mapping(IO &io, ScalarApply &info) { io.mapRequired("fn_name", info.fnName); io.mapRequired("operands", info.operands); } }; -template <> struct MappingTraits { - static void mapping(IO &io, ScalarSymbolicCast &info) { +template <> +struct MappingTraits { + static void mapping(IO &io, ScalarTypeFn &info) { + io.mapRequired("fn_name", info.fnName); io.mapRequired("type_var", info.typeVar); io.mapRequired("operands", info.operands); - io.mapRequired("is_unsigned_cast", info.isUnsignedCast); } }; /// Helper mapping which accesses an AffineMapAttr as a serialized string of /// the same. -template <> struct ScalarTraits { +template <> +struct ScalarTraits { static void output(const SerializedAffineMap &value, void *rawYamlContext, raw_ostream &out) { assert(value.affineMapAttr); @@ -950,33 +962,33 @@ interleaveToString(operandCppValues, ", "))); return cppIdent; } - if (expression.symbolicCast) { + if (expression.typeFn) { // Symbolic cast. // Operands must be arity 1. - if (expression.symbolicCast->operands.size() != 1) { + if (expression.typeFn->operands.size() != 1) { emitError(genContext.getLoc()) - << "symbolic_cast operand arity must be 1"; + << "type conversion operand arity must be 1"; return None; } Optional operandCppValue = - generateExpression(expression.symbolicCast->operands[0]); + generateExpression(expression.typeFn->operands[0]); if (!operandCppValue) return None; Optional typeCppValue = - findTypeValue(expression.symbolicCast->typeVar, args); + findTypeValue(expression.typeFn->typeVar, args); if (!typeCppValue) { emitError(genContext.getLoc()) - << "type variable " << expression.symbolicCast->typeVar - << ", used in a symbolic cast must map to a predefined or " + << "type variable " << expression.typeFn->typeVar + << ", used in a type conversion, must map to a predefined or " << "an argument type but it does not"; return None; } std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back( - llvm::formatv("Value {0} = helper.cast({1}, {2}, {3});", cppIdent, - typeCppValue.getValue(), *operandCppValue, - expression.symbolicCast->isUnsignedCast)); + llvm::formatv("Value {0} = helper.typefn__{1}({2}, {3});", + cppIdent, expression.typeFn->fnName, + typeCppValue.getValue(), *operandCppValue)); return cppIdent; } emitError(genContext.getLoc()) << "unknown ScalarExpression type";