diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -178,17 +178,17 @@ Reduction dimensions are inferred to be any dimensions on the RHS that are not on the LHS. -A number of arithmetic functions are supported: - -* `ArithFn.add(a, b)` (also via overloading the binary `+` operator) -* `ArithFn.exp(a)` -* `ArithFn.log(a)` -* `ArithFn.mul(a, b)` (also via overloading the binary `*` operator) -* `ArithFn.max(a, b)` -* `ArithFn.min(a, b)` -* `ArithFn.sub(a, b)` (also via overloading the binary `-` operator) -* `ArithFn.max_unsigned(a, b)` -* `ArithFn.min_unsigned(a, b)` +A number of unary and binary arithmetic functions are supported: + +* `BinaryFn.add(a, b)` (also via overloading the binary `+` operator) +* `BinaryFn.mul(a, b)` (also via overloading the binary `*` operator) +* `BinaryFn.max(a, b)` +* `BinaryFn.min(a, b)` +* `BinaryFn.sub(a, b)` (also via overloading the binary `-` operator) +* `BinaryFn.max_unsigned(a, b)` +* `BinaryFn.min_unsigned(a, b)` +* `UnaryFn.exp(a)` +* `UnaryFn.log(a)` As the integer types are signless, signedness is implement by different functions that treat integers as signed or unsigned values. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -46,14 +46,14 @@ arg: C value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -114,14 +114,14 @@ arg: C value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -192,19 +192,19 @@ arg: C value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -225,7 +225,7 @@ scalar_arg: AZp - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -297,14 +297,14 @@ arg: accum value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: accum - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -366,14 +366,14 @@ arg: C value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -445,19 +445,19 @@ arg: C value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -478,7 +478,7 @@ scalar_arg: AZp - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -538,14 +538,14 @@ arg: x value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: x - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -605,14 +605,14 @@ arg: x value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: x - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -673,14 +673,14 @@ arg: C value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -739,14 +739,14 @@ arg: C value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -806,14 +806,14 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -875,14 +875,14 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -947,14 +947,14 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -1031,14 +1031,14 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -1129,14 +1129,14 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -1240,19 +1240,19 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -1273,7 +1273,7 @@ scalar_arg: IZp - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -1364,14 +1364,14 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -1464,14 +1464,14 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -1547,14 +1547,14 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -1640,14 +1640,14 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -1744,19 +1744,19 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -1777,7 +1777,7 @@ scalar_arg: IZp - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -1864,14 +1864,14 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression @@ -1970,19 +1970,19 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -2003,7 +2003,7 @@ scalar_arg: IZp - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -2088,7 +2088,7 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression @@ -2167,7 +2167,7 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: max operands: - !ScalarExpression @@ -2246,7 +2246,7 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: max_unsigned operands: - !ScalarExpression @@ -2325,7 +2325,7 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: max operands: - !ScalarExpression @@ -2404,7 +2404,7 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: min operands: - !ScalarExpression @@ -2483,7 +2483,7 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: min_unsigned operands: - !ScalarExpression @@ -2568,7 +2568,7 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression @@ -2653,7 +2653,7 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: max operands: - !ScalarExpression @@ -2738,7 +2738,7 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: min operands: - !ScalarExpression @@ -2841,17 +2841,17 @@ operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression @@ -2870,17 +2870,17 @@ operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression @@ -2893,17 +2893,17 @@ scalar_index: 1 - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression @@ -2950,12 +2950,12 @@ scalar_const: '12345 : i64' - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: mul operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: sub operands: - !ScalarExpression @@ -3005,12 +3005,12 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: unary fn_name: log operands: - !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression @@ -3023,7 +3023,7 @@ scalar_const: '1.000000e+00 : f64' - !ScalarExpression scalar_fn: - kind: arith + kind: unary fn_name: exp operands: - !ScalarExpression diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -146,13 +146,13 @@ // Region builder helper. // TODO: Move this to a utility library. // The public methods on this class are referenced directly from generated code -// and bind by name to math and type conversion functions in the DSL as: -// `arithfn__{fnName}` -// `typefn__{fnName}` +// and bind by name to math functions in the DSL as: +// `unary__{fnName}` +// `binary__{fnName}` // Examples: -// `arithfn__add` -// `arithfn__mul` -// `typefn__cast` +// `binary__add` +// `binary__mul` +// `unary__exp` // The naming convention is intentional in order to match snake-cased DSL names. // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class. // @@ -240,7 +240,7 @@ } // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__add(Value lhs, Value rhs) { + Value binary__add(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -250,7 +250,7 @@ } // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__exp(Value x) { + Value unary__exp(Value x) { OpBuilder builder = getBuilder(); if (isFloatingPoint(x)) return builder.create(x.getLoc(), x); @@ -258,7 +258,7 @@ } // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__log(Value x) { + Value unary__log(Value x) { OpBuilder builder = getBuilder(); if (isFloatingPoint(x)) return builder.create(x.getLoc(), x); @@ -266,7 +266,7 @@ } // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__sub(Value lhs, Value rhs) { + Value binary__sub(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -276,7 +276,7 @@ } // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__mul(Value lhs, Value rhs) { + Value binary__mul(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -286,7 +286,7 @@ } // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__max(Value lhs, Value rhs) { + Value binary__max(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -296,7 +296,7 @@ } // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__max_unsigned(Value lhs, Value rhs) { + Value binary__max_unsigned(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -306,7 +306,7 @@ } // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__min(Value lhs, Value rhs) { + Value binary__min(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -316,7 +316,7 @@ } // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value arithfn__min_unsigned(Value lhs, Value rhs) { + Value binary__min_unsigned(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -77,13 +77,13 @@ self.visit_tensor_exprs(visit_scalar_def) def __add__(self, rhs: "TensorExpression") -> "TensorExpression": - return ArithFn.add(self, rhs) + return BinaryFn.add(self, rhs) def __mul__(self, rhs) -> "TensorExpression": - return ArithFn.mul(self, rhs) + return BinaryFn.mul(self, rhs) def __sub__(self, rhs) -> "TensorExpression": - return ArithFn.sub(self, rhs) + return BinaryFn.sub(self, rhs) def __hash__(self): return hash(id(self)) @@ -126,7 +126,7 @@ return rhs_dims - lhs_dims def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": - return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) + return ReduceFnUse(BinaryFn.add, *self._compute_reduce_dims(rhs))(rhs) def __repr__(self): return (f"{self.operand_def.name}" @@ -183,8 +183,8 @@ f"bound to its lhs: {self}") full_args = [self.lhs.to_scalar_expression() ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None, - None, full_args).expr() + return ScalarFn(FunctionKind.BINARY, self.reduce_use.binary_fn.fn_name, + None, None, full_args).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): for arg in self.args: @@ -242,61 +242,54 @@ class FunctionKind(Enum): - ARITH = 0 - TYPE = 1 + UNARY = 0 + BINARY = 1 + TYPE = 2 -class TypeFnType: - """Type conversion function. +class UnaryFnType: + """Unary function. - A type conversion function takes a target type and a tensor expression and - returns the casted tensor expression. + A unary function takes one tensor expression and returns the + function evaluation result. """ def __init__(self, fn_name: str): self.fn_name = fn_name - def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": - return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) + def __call__(self, exp: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [exp]) def __repr__(self): return f"{self.fn_name}" -class TypeFn: - """Type conversion function namespace. - - As the integer types are signless, signedness is implement by different cast - functions that treat integers as signed (`cast`) or unsigned - (`cast_unsigned`) values. - - Examples: - - cast(I32 -> I64) -> `arith.ExtSIOp` - - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` - """ - cast = TypeFnType("cast") - cast_unsigned = TypeFnType("cast_unsigned") +class UnaryFn: + """Unary function namespace.""" + exp = UnaryFnType("exp") + log = UnaryFnType("log") -class ArithFnType: - """Arithmetic function. +class BinaryFnType: + """Binary function. - An arithmetic function takes one ore more tensor expressions and returns the + A binary function takes two tensor expressions and returns the function evaluation result. """ def __init__(self, fn_name: str): self.fn_name = fn_name - def __call__(self, *args) -> "TensorFn": - return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args) + def __call__(self, arg0: TensorExpression, + arg1: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1]) def __repr__(self): return f"{self.fn_name}" -class ArithFn: - """Arithmetic function namespace. +class BinaryFn: + """Binary function namespace. As the integer types are signless, signedness is implement by different functions that treat integers as signed or unsigned values. @@ -305,15 +298,45 @@ - max -> `arith.MaxSIOp` - max_unsinged -> `arith.MaxUIOp` """ - add = ArithFnType("add") - exp = ArithFnType("exp") - log = ArithFnType("log") - mul = ArithFnType("mul") - max = ArithFnType("max") - min = ArithFnType("min") - sub = ArithFnType("sub") - max_unsigned = ArithFnType("max_unsigned") - min_unsigned = ArithFnType("min_unsigned") + add = BinaryFnType("add") + mul = BinaryFnType("mul") + max = BinaryFnType("max") + min = BinaryFnType("min") + sub = BinaryFnType("sub") + max_unsigned = BinaryFnType("max_unsigned") + min_unsigned = BinaryFnType("min_unsigned") + + +class TypeFnType: + """Type conversion function. + + A type conversion function takes a target type and a tensor expression and + returns the casted tensor expression. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) + + def __repr__(self): + return f"{self.fn_name}" + + +class TypeFn: + """Type conversion function namespace. + + As the integer types are signless, signedness is implement by different cast + functions that treat integers as signed (`cast`) or unsigned + (`cast_unsigned`) values. + + Examples: + - cast(I32 -> I64) -> `arith.ExtSIOp` + - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` + """ + cast = TypeFnType("cast") + cast_unsigned = TypeFnType("cast_unsigned") class ReduceFnUse: @@ -322,43 +345,43 @@ A reduction use specifies the reduction function and dimensions. """ - def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef): - self.arith_fn = arith_fn + def __init__(self, binary_fn: BinaryFnType, *reduce_dims: DimDef): + self.binary_fn = binary_fn self.reduce_dims = reduce_dims - def __call__(self, *args: TensorExpression): + def __call__(self, *args: TensorExpression) -> "TensorReduceFn": return TensorReduceFn(self, args) def __repr__(self): - return (f"reduce_{self.arith_fn.fn_name}" + return (f"reduce_{self.binary_fn.fn_name}" f"({', '.join(repr(d) for d in self.reduce_dims)})") class ReduceFnType: """Reduction function. - An arithmetic function that reduces its RHS into its LHS. + A binary function that reduces its RHS into its LHS. """ - def __init__(self, arith_fn: ArithFnType): - if not isinstance(arith_fn, ArithFnType): - raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}") - self.arith_fn = arith_fn + def __init__(self, binary_fn: BinaryFnType): + if not isinstance(binary_fn, BinaryFnType): + raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}") + self.binary_fn = binary_fn def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: - return ReduceFnUse(self.arith_fn, *reduce_dims) + return ReduceFnUse(self.binary_fn, *reduce_dims) def __repr__(self): - return (f"reduce_{self.arith_fn.fn_name}") + return (f"reduce_{self.binary_fn.fn_name}") class ReduceFn: - add = ReduceFnType(ArithFn.add) - mul = ReduceFnType(ArithFn.mul) - max = ReduceFnType(ArithFn.max) - min = ReduceFnType(ArithFn.min) - max_unsigned = ReduceFnType(ArithFn.max_unsigned) - min_unsigned = ReduceFnType(ArithFn.min_unsigned) + add = ReduceFnType(BinaryFn.add) + mul = ReduceFnType(BinaryFn.mul) + max = ReduceFnType(BinaryFn.max) + min = ReduceFnType(BinaryFn.min) + max_unsigned = ReduceFnType(BinaryFn.max_unsigned) + min_unsigned = ReduceFnType(BinaryFn.min_unsigned) ############################################################################### diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -270,17 +270,19 @@ dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result - elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH: - fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}") + elif expr.scalar_fn and expr.scalar_fn.kind is not FunctionKind.TYPE: + kind = expr.scalar_fn.kind.name.lower() + fn = self._get_function(f"_{kind}_{expr.scalar_fn.fn_name}") operand_values = [ self.expression(operand) for operand in expr.scalar_fn.operands ] return fn(*operand_values) - elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE: + elif expr.scalar_fn and expr.scalar_fn.kind is FunctionKind.TYPE: + kind = expr.scalar_fn.kind.name.lower() fn_name = expr.scalar_fn.fn_name if expr.scalar_fn.attr_name: fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name] - fn = self._get_function(f"_typefn_{fn_name}") + fn = self._get_function(f"_{kind}_{fn_name}") operand_value = self.expression(expr.scalar_fn.operands[0]) return fn(expr.scalar_fn.type_var.name, operand_value) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") @@ -356,65 +358,65 @@ raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") - def _typefn_cast(self, type_var_name: str, operand: Value) -> Value: + def _type_cast(self, type_var_name: str, operand: Value) -> Value: return self._cast(type_var_name, operand, False) - def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: + def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: return self._cast(type_var_name, operand, True) - def _arithfn_add(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.AddFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.AddIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'add' operand: {lhs}") - - def _arithfn_exp(self, x: Value) -> Value: + def _unary_exp(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") - def _arithfn_log(self, x: Value) -> Value: + def _unary_log(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") - def _arithfn_sub(self, lhs: Value, rhs: Value) -> Value: + def _binary_add(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.AddFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.AddIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'add' operand: {lhs}") + + def _binary_sub(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.SubFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.SubIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operand: {lhs}") - def _arithfn_mul(self, lhs: Value, rhs: Value) -> Value: + def _binary_mul(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MulFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MulIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operand: {lhs}") - def _arithfn_max(self, lhs: Value, rhs: Value) -> Value: + def _binary_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") - def _arithfn_max_unsigned(self, lhs: Value, rhs: Value) -> Value: + def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") - def _arithfn_min(self, lhs: Value, rhs: Value) -> Value: + def _binary_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") - def _arithfn_min_unsigned(self, lhs: Value, rhs: Value) -> Value: + def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -677,4 +677,4 @@ """ domain(D.m, D.n) O[D.m, D.n] = \ - ArithFn.log(TypeFn.cast(U, const(1.0)) + ArithFn.exp(TypeFn.cast(U, I[D.m, D.n]))) + UnaryFn.log(TypeFn.cast(U, const(1.0)) + UnaryFn.exp(TypeFn.cast(U, I[D.m, D.n]))) diff --git a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py --- a/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py +++ b/mlir/test/Integration/Dialect/SparseTensor/taco/tools/mlir_pytaco.py @@ -1505,12 +1505,12 @@ input_accesses.append(expr) -def _op_to_callable(op: _BinaryOp) -> lang.ArithFnType: +def _op_to_callable(op: _BinaryOp) -> lang.BinaryFnType: """Returns the linalg dialect function object for the given operation.""" op_to_callable = { - operator.add: lang.ArithFn.add, - operator.sub: lang.ArithFn.sub, - operator.mul: lang.ArithFn.mul, + operator.add: lang.BinaryFn.add, + operator.sub: lang.BinaryFn.sub, + operator.mul: lang.BinaryFn.mul, } return op_to_callable[op] diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -40,7 +40,7 @@ arg: O value: !ScalarExpression scalar_fn: - kind: arith + kind: binary fn_name: add operands: - !ScalarExpression @@ -111,7 +111,7 @@ # IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); # IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]); -# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]); +# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.binary__add([[VAL1]], [[VAL3]]); # @linalg_structured_op diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -42,11 +42,11 @@ # CHECK: - # CHECK: arg: O # CHECK: scalar_fn: -# CHECK: kind: arith +# CHECK: kind: binary # CHECK: fn_name: sub # CHECK: operands: # CHECK: scalar_fn: -# CHECK: kind: arith +# CHECK: kind: binary # CHECK: fn_name: add # CHECK: operands: # CHECK: scalar_fn: @@ -80,7 +80,7 @@ # CHECK: - # CHECK: arg: O # CHECK: scalar_fn: -# CHECK: kind: arith +# CHECK: kind: binary # CHECK: fn_name: add # CHECK: operands: # CHECK: scalar_index: 1 diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py @@ -35,8 +35,8 @@ @linalg_structured_op def soft_plus_poly( I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)): - O[D.m, D.n] = ArithFn.log( - TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, ArithFn.exp(I[D.m, D.n]))) + O[D.m, D.n] = UnaryFn.log( + TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, UnaryFn.exp(I[D.m, D.n]))) @linalg_structured_op(op_name="custom_op_name") diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -90,7 +90,7 @@ struct ScalarExpression; -enum class ScalarFnKind { Arith, Type }; +enum class ScalarFnKind { Unary, Binary, Type }; struct ScalarFn { ScalarFnKind kind; @@ -275,7 +275,8 @@ template <> struct ScalarEnumerationTraits { static void enumeration(IO &io, ScalarFnKind &value) { - io.enumCase(value, "arith", ScalarFnKind::Arith); + io.enumCase(value, "unary", ScalarFnKind::Unary); + io.enumCase(value, "binary", ScalarFnKind::Binary); io.enumCase(value, "type", ScalarFnKind::Type); } }; @@ -1056,7 +1057,7 @@ return cppIdent; } if (expression.scalarFn && - expression.scalarFn->kind == ScalarFnKind::Arith) { + expression.scalarFn->kind != ScalarFnKind::Type) { // Apply function. // Recursively generate operands. SmallVector operandCppValues; @@ -1066,10 +1067,14 @@ return None; operandCppValues.push_back(*operandCppValue); } + + std::string prefix = expression.scalarFn->kind == ScalarFnKind::Unary + ? "unary" + : "binary"; std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back( - llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent, - expression.scalarFn->fnName, + llvm::formatv("Value {0} = helper.{1}__{2}({3});", cppIdent, + prefix, expression.scalarFn->fnName, interleaveToString(operandCppValues, ", "))); return cppIdent; }