diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -177,14 +177,17 @@ Reduction dimensions are inferred to be any dimensions on the RHS that are not on the LHS. -A number of arithmetic primitive functions are supported: - -* `PrimFn.add(a, b)` (also via overloading the binary `+` operator) -* `PrimFn.exp(a)` -* `PrimFn.log(a)` -* `PrimFn.mul(a, b)` (also via overloading the binary `*` operator) -* `PrimFn.max(a, b)` -* `PrimFn.sub(a, b)` (also via overloading the binary `-` operator) +A number of arithmetic functions are supported: + +* `ArithFn.add(a, b)` (also via overloading the binary `+` operator) +* `ArithFn.exp(a)` +* `ArithFn.log(a)` +* `ArithFn.mul(a, b)` (also via overloading the binary `*` operator) +* `ArithFn.max(a, b)` +* `ArithFn.min(a, b)` +* `ArithFn.sub(a, b)` (also via overloading the binary `-` operator) +* `ArithFn.max_unsigned(a, b)` +* `ArithFn.min_unsigned(a, b)` Reduction functions can appear as the outer-most function on the RHS: diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -41,13 +41,13 @@ - !ScalarAssign arg: C value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -105,13 +105,13 @@ - !ScalarAssign arg: C value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -179,17 +179,17 @@ - !ScalarAssign arg: C value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -207,7 +207,7 @@ - !ScalarExpression scalar_arg: AZp - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -276,13 +276,13 @@ - !ScalarAssign arg: accum value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: accum - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -341,13 +341,13 @@ - !ScalarAssign arg: C value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -416,17 +416,17 @@ - !ScalarAssign arg: C value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -444,7 +444,7 @@ - !ScalarExpression scalar_arg: AZp - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -501,13 +501,13 @@ - !ScalarAssign arg: x value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: x - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -564,13 +564,13 @@ - !ScalarAssign arg: x value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: x - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -628,13 +628,13 @@ - !ScalarAssign arg: C value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -690,13 +690,13 @@ - !ScalarAssign arg: C value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: C - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -753,13 +753,13 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -818,13 +818,13 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -886,13 +886,13 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -964,13 +964,13 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -1054,13 +1054,13 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -1157,17 +1157,17 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -1185,7 +1185,7 @@ - !ScalarExpression scalar_arg: IZp - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -1269,13 +1269,13 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -1359,13 +1359,13 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -1436,13 +1436,13 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -1519,13 +1519,13 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -1613,17 +1613,17 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -1641,7 +1641,7 @@ - !ScalarExpression scalar_arg: IZp - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -1721,13 +1721,13 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression @@ -1819,17 +1819,17 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression scalar_arg: O - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -1847,7 +1847,7 @@ - !ScalarExpression scalar_arg: IZp - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -1923,7 +1923,7 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression @@ -1994,7 +1994,7 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: max operands: - !ScalarExpression @@ -2065,7 +2065,7 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: max_unsigned operands: - !ScalarExpression @@ -2136,7 +2136,7 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: max operands: - !ScalarExpression @@ -2207,7 +2207,7 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: min operands: - !ScalarExpression @@ -2278,7 +2278,7 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: min_unsigned operands: - !ScalarExpression @@ -2355,7 +2355,7 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression @@ -2432,7 +2432,7 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: max operands: - !ScalarExpression @@ -2509,7 +2509,7 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: min operands: - !ScalarExpression @@ -2572,15 +2572,15 @@ type_var: T operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression @@ -2596,15 +2596,15 @@ type_var: F64 operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression @@ -2615,15 +2615,15 @@ - !ScalarExpression scalar_index: 1 - !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression @@ -2664,11 +2664,11 @@ - !ScalarExpression scalar_const: '12345 : i64' - !ScalarExpression - scalar_apply: + arith_fn: fn_name: mul operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: sub operands: - !ScalarExpression @@ -2716,11 +2716,11 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: log operands: - !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression @@ -2731,7 +2731,7 @@ - !ScalarExpression scalar_const: '1.000000e+00 : f64' - !ScalarExpression - scalar_apply: + arith_fn: fn_name: exp operands: - !ScalarExpression diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -159,11 +159,11 @@ // TODO: Move this to a utility library. // The public methods on this class are referenced directly from generated code // and bind by name to math and type conversion functions in the DSL as: -// `applyfn__{fnName}` +// `arithfn__{fnName}` // `typefn__{fnName}` // Examples: -// `applyfn__add` -// `applyfn__mul` +// `arithfn__add` +// `arithfn__mul` // `typefn__cast` // The naming convention is intentional in order to match snake-cased DSL names. // See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class. @@ -249,7 +249,7 @@ return cast(toType, operand, true); } - Value applyfn__add(Value lhs, Value rhs) { + Value arithfn__add(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -258,21 +258,21 @@ llvm_unreachable("unsupported non numeric type"); } - Value applyfn__exp(Value x) { + Value arithfn__exp(Value x) { OpBuilder builder = getBuilder(); if (isFloatingPoint(x)) return builder.create(x.getLoc(), x); llvm_unreachable("unsupported non numeric type"); } - Value applyfn__log(Value x) { + Value arithfn__log(Value x) { OpBuilder builder = getBuilder(); if (isFloatingPoint(x)) return builder.create(x.getLoc(), x); llvm_unreachable("unsupported non numeric type"); } - Value applyfn__sub(Value lhs, Value rhs) { + Value arithfn__sub(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -281,7 +281,7 @@ llvm_unreachable("unsupported non numeric type"); } - Value applyfn__mul(Value lhs, Value rhs) { + Value arithfn__mul(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -290,7 +290,7 @@ llvm_unreachable("unsupported non numeric type"); } - Value applyfn__max(Value lhs, Value rhs) { + Value arithfn__max(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -299,7 +299,7 @@ llvm_unreachable("unsupported non numeric type"); } - Value applyfn__max_unsigned(Value lhs, Value rhs) { + Value arithfn__max_unsigned(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -308,7 +308,7 @@ llvm_unreachable("unsupported non numeric type"); } - Value applyfn__min(Value lhs, Value rhs) { + Value arithfn__min(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); @@ -317,7 +317,7 @@ llvm_unreachable("unsupported non numeric type"); } - Value applyfn__min_unsigned(Value lhs, Value rhs) { + Value arithfn__min_unsigned(Value lhs, Value rhs) { OpBuilder builder = getBuilder(); if (isFloatingPoint(lhs)) return builder.create(lhs.getLoc(), lhs, rhs); diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -77,13 +77,13 @@ self.visit_tensor_exprs(visit_scalar_def) def __add__(self, rhs: "TensorExpression") -> "TensorExpression": - return PrimFn.add(self, rhs) + return ArithFn.add(self, rhs) def __mul__(self, rhs) -> "TensorExpression": - return PrimFn.mul(self, rhs) + return ArithFn.mul(self, rhs) def __sub__(self, rhs) -> "TensorExpression": - return PrimFn.sub(self, rhs) + return ArithFn.sub(self, rhs) def __hash__(self): return hash(id(self)) @@ -333,42 +333,42 @@ cast_unsigned = TypeFnType("cast_unsigned") -class PrimFnType: - """Primitive operations.""" +class ArithFnType: + """Arithmetic operations.""" - def __init__(self, prim_name: str): - self.prim_name = prim_name + def __init__(self, fn_name: str): + self.fn_name = fn_name - def __call__(self, *args): - return PrimApply(self, args) + def __call__(self, *args) -> "TensorArithFn": + return TensorArithFn(self, args) def reduce(self, *reduce_dims: DimDef): - """Shortcut to create a Reduce operation from this primitive.""" + """Shortcut to create a Reduce operation from this function.""" return ReduceFnType(self, *reduce_dims) def __repr__(self): - return f"{self.prim_name}" + return f"{self.fn_name}" -class PrimFn: - add = PrimFnType("add") - exp = PrimFnType("exp") - log = PrimFnType("log") - mul = PrimFnType("mul") - max = PrimFnType("max") - min = PrimFnType("min") - sub = PrimFnType("sub") - max_unsigned = PrimFnType("max_unsigned") - min_unsigned = PrimFnType("min_unsigned") +class ArithFn: + add = ArithFnType("add") + exp = ArithFnType("exp") + log = ArithFnType("log") + mul = ArithFnType("mul") + max = ArithFnType("max") + min = ArithFnType("min") + sub = ArithFnType("sub") + max_unsigned = ArithFnType("max_unsigned") + min_unsigned = ArithFnType("min_unsigned") class ReduceFnType: """A reduction operator that reduces into its LHS from its RHS.""" - def __init__(self, operator: PrimFnType, *reduce_dims: DimDef): - """Initializes the ReduceFn with a primitive function and dims.""" - if not isinstance(operator, PrimFnType): - raise ValueError(f"Reduce expected a Prim operator but got {operator}") + def __init__(self, operator: ArithFnType, *reduce_dims: DimDef): + """Initializes the ReduceFn with an airthmetic function and dims.""" + if not isinstance(operator, ArithFnType): + raise ValueError(f"Reduce expected a ArithFnType but got {operator}") self.operator = operator self.reduce_dims = tuple(reduce_dims) @@ -376,28 +376,28 @@ return ReduceApply(self, args) def __repr__(self): - return (f"reduce_{self.operator.prim_name}" + return (f"reduce_{self.operator.fn_name}" f"({', '.join(repr(d) for d in self.reduce_dims)})") class ReduceFn: - add = PrimFn.add.reduce - mul = PrimFn.mul.reduce - max = PrimFn.max.reduce - min = PrimFn.min.reduce - max_unsigned = PrimFn.max_unsigned.reduce - min_unsigned = PrimFn.min_unsigned.reduce + add = ArithFn.add.reduce + mul = ArithFn.mul.reduce + max = ArithFn.max.reduce + min = ArithFn.min.reduce + max_unsigned = ArithFn.max_unsigned.reduce + min_unsigned = ArithFn.min_unsigned.reduce -class PrimApply(TensorExpression): - """Application of a primitive.""" +class TensorArithFn(TensorExpression): + """Application of an arithmetic function.""" - def __init__(self, prim: PrimFnType, args: Sequence[TensorExpression]): - self.prim = prim + def __init__(self, arith_fn: ArithFnType, args: Sequence[TensorExpression]): + self.arith_fn = arith_fn self.args = tuple(args) def to_scalar_expression(self) -> ScalarExpression: - return ScalarApplyFn(self.prim.prim_name, + return ScalarArithFn(self.arith_fn.fn_name, *[arg.to_scalar_expression() for arg in self.args ]).expr() @@ -407,7 +407,7 @@ arg.visit_tensor_exprs(callback) def __repr__(self): - return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})" + return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})" class TensorTypeFn(TensorExpression): @@ -489,7 +489,7 @@ f"bound to its lhs: {self}") full_args = [self.lhs.to_scalar_expression() ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr() + return ScalarArithFn(self.reduce.operator.fn_name, *full_args).expr() def visit_tensor_exprs(self, callback): for arg in self.args: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -223,10 +223,10 @@ dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result - elif expr.scalar_apply: - fn = self._get_function(f"_eval_{expr.scalar_apply.fn_name}") + elif expr.arith_fn: + fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}") operand_values = [ - self.expression(operand) for operand in expr.scalar_apply.operands + self.expression(operand) for operand in expr.arith_fn.operands ] return fn(*operand_values) elif expr.type_fn: @@ -312,59 +312,59 @@ def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: return self._cast(type_var_name, operand, True) - def _eval_add(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.AddFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.AddIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'add' operand: {lhs}") - def _eval_exp(self, x: Value) -> Value: + def _arithfn_exp(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") - def _eval_log(self, x: Value) -> Value: + def _arithfn_log(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") - def _eval_sub(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_sub(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.SubFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.SubIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operand: {lhs}") - def _eval_mul(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_mul(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MulFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MulIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operand: {lhs}") - def _eval_max(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") - def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") - def _eval_min(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") - def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -20,7 +20,7 @@ __all__ = [ "ScalarAssign", - "ScalarApplyFn", + "ScalarArithFn", "ScalarTypeFn", "ScalarArg", "ScalarConst", @@ -29,18 +29,18 @@ ] -class ScalarApplyFn: - """A type of ScalarExpression that applies a named function to operands.""" +class ScalarArithFn: + """A type of ScalarExpression that applies an arithmetic function.""" def __init__(self, fn_name: str, *operands: "ScalarExpression"): self.fn_name = fn_name self.operands = operands def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_apply=self) + return ScalarExpression(arith_fn=self) def __repr__(self): - return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})" + return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})" class ScalarTypeFn: @@ -102,7 +102,7 @@ """An expression on scalar values. Can be one of: - - ScalarApplyFn + - ScalarArithFn - ScalarTypeFn - ScalarArg - ScalarConst @@ -112,27 +112,27 @@ yaml_tag = "!ScalarExpression" def __init__(self, - scalar_apply: Optional[ScalarApplyFn] = None, + arith_fn: Optional[ScalarArithFn] = None, type_fn: Optional[ScalarTypeFn] = None, scalar_arg: Optional[ScalarArg] = None, scalar_const: Optional[ScalarConst] = None, scalar_index: Optional[ScalarIndex] = None): - if (bool(scalar_apply) + bool(type_fn) + bool(scalar_arg) + - bool(scalar_const) + bool(scalar_index)) != 1: - raise ValueError("One of 'scalar_apply', 'type_fn', 'scalar_arg', " + if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) + + bool(scalar_index)) != 1: + raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', " "'scalar_const', 'scalar_index', must be specified") - self.scalar_apply = scalar_apply + self.arith_fn = arith_fn self.type_fn = type_fn self.scalar_arg = scalar_arg self.scalar_const = scalar_const self.scalar_index = scalar_index def to_yaml_custom_dict(self): - if self.scalar_apply: + if self.arith_fn: return dict( - scalar_apply=dict( - fn_name=self.scalar_apply.fn_name, - operands=list(self.scalar_apply.operands), + arith_fn=dict( + fn_name=self.arith_fn.fn_name, + operands=list(self.arith_fn.operands), )) if self.type_fn: # Note that even though operands must be arity 1, we write it the diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -665,4 +665,4 @@ """ domain(D.m, D.n) O[D.m, D.n] = \ - PrimFn.log(TypeFn.cast(U, const(1.0)) + PrimFn.exp(TypeFn.cast(U, I[D.m, D.n]))) + ArithFn.log(TypeFn.cast(U, const(1.0)) + ArithFn.exp(TypeFn.cast(U, I[D.m, D.n]))) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -34,7 +34,7 @@ - !ScalarAssign arg: O value: !ScalarExpression - scalar_apply: + arith_fn: fn_name: add operands: - !ScalarExpression @@ -89,7 +89,7 @@ # IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.typefn__cast(block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); # IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.typefn__cast_unsigned(block.getArgument(0).getType(), [[VAL2]]); -# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.applyfn__add([[VAL1]], [[VAL3]]); +# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.arithfn__add([[VAL1]], [[VAL3]]); # @linalg_structured_op diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -9,10 +9,10 @@ # CHECK: - # CHECK: arg: C # CHECK: value: -# CHECK: scalar_apply: +# CHECK: arith_fn: # CHECK: fn_name: add # CHECK: operands: -# CHECK: scalar_apply: +# CHECK: arith_fn: # CHECK: fn_name: mul # CHECK: operands: # CHECK: type_fn: @@ -36,10 +36,10 @@ # CHECK: assignments: # CHECK: - # CHECK: arg: O -# CHECK: scalar_apply: +# CHECK: arith_fn: # CHECK: fn_name: sub # CHECK: operands: -# CHECK: scalar_apply: +# CHECK: arith_fn: # CHECK: fn_name: add # CHECK: operands: # CHECK: type_fn: @@ -67,7 +67,7 @@ # CHECK: assignments: # CHECK: - # CHECK: arg: O -# CHECK: scalar_apply: +# CHECK: arith_fn: # CHECK: fn_name: add # CHECK: operands: # CHECK: scalar_index: 1 diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_misc.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_misc.py @@ -35,8 +35,8 @@ @linalg_structured_op def soft_plus_poly( I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)): - O[D.m, D.n] = PrimFn.log( - TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, PrimFn.exp(I[D.m, D.n]))) + O[D.m, D.n] = ArithFn.log( + TypeFn.cast(U, const(1.0)) + TypeFn.cast(U, ArithFn.exp(I[D.m, D.n]))) @linalg_structured_op(op_name="custom_op_name") diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -83,7 +83,7 @@ struct ScalarExpression; -struct ScalarApply { +struct ScalarArithFn { std::string fnName; // NOTE: Must be pure heap allocated container (not SmallVector) // due to recursive data type. @@ -102,7 +102,7 @@ Optional arg; Optional constant; Optional index; - Optional apply; + Optional arithFn; Optional typeFn; }; @@ -238,16 +238,17 @@ }; /// A scalar expression (RHS of an assignment). Must be one of: -/// - `scalar_arg`: Name of an argument to the op. -/// - `scalar_apply`: Result of evaluating a named function (see -/// `ScalarApply`). +/// - `scalar_arg`: An operation argument. +/// - `scalar_const`: A constant definition. +/// - `scalar_index`: An iteration index. +/// - `arith_fn`: A named arithmetic function (see `ScalarArithFn`). /// - `type_fn`: A named type conversion function (see `ScalarTypeFn`). template <> struct MappingTraits { static void mapping(IO &io, ScalarExpression &info) { io.mapOptional("scalar_arg", info.arg); io.mapOptional("scalar_const", info.constant); io.mapOptional("scalar_index", info.index); - io.mapOptional("scalar_apply", info.apply); + io.mapOptional("arith_fn", info.arithFn); io.mapOptional("type_fn", info.typeFn); } }; @@ -257,8 +258,9 @@ /// functions include: /// - `add(lhs, rhs)` /// - `mul(lhs, rhs)` -template <> struct MappingTraits { - static void mapping(IO &io, ScalarApply &info) { +template <> +struct MappingTraits { + static void mapping(IO &io, ScalarArithFn &info) { io.mapRequired("fn_name", info.fnName); io.mapRequired("operands", info.operands); } @@ -934,11 +936,11 @@ cppIdent, *expression.index)); return cppIdent; } - if (expression.apply) { + if (expression.arithFn) { // Apply function. // Recursively generate operands. SmallVector operandCppValues; - for (ScalarExpression &operand : expression.apply->operands) { + for (ScalarExpression &operand : expression.arithFn->operands) { auto operandCppValue = generateExpression(operand); if (!operandCppValue) return None; @@ -946,8 +948,8 @@ } std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back( - llvm::formatv("Value {0} = helper.applyfn__{1}({2});", cppIdent, - expression.apply->fnName, + llvm::formatv("Value {0} = helper.arithfn__{1}({2});", cppIdent, + expression.arithFn->fnName, interleaveToString(operandCppValues, ", "))); return cppIdent; }