diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -194,6 +194,9 @@ * `ReduceFn.add` (also overloading the inplace `+=` on a LHS) * `ReduceFn.mul` * `ReduceFn.max` +* `ReduceFn.min` +* `ReduceFn.max_unsigned` +* `ReduceFn.min_unsigned` Type functions cast the `operand` to the target type `TypeVar` applying signed or unsigned semantics for integer operands: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -43,8 +43,8 @@ if isinstance(expr, TensorUse): for ind in expr.indices: ind.visit_affine_exprs(visit_dim_def) - if isinstance(expr, ReduceApply): - for ind in expr.reduce.reduce_dims: + if isinstance(expr, TensorReduceFn): + for ind in expr.reduce_fn.reduce_dims: ind.visit_affine_exprs(visit_dim_def) self.visit_tensor_exprs(visit_affine_exprs) @@ -114,8 +114,8 @@ assert name is not None, "TensorDef not attached" return name - def __iadd__(self, rhs: TensorExpression) -> TensorExpression: - return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs) + def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": + return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: """For implicit reductions, computes default reduction dims. @@ -285,7 +285,7 @@ # Find the lhs to reduction rhs. for assign, value in bindings: - if isinstance(value, ReduceApply): + if isinstance(value, TensorReduceFn): if value.lhs: raise ValueError(f"Reduction expression already assigns: {value}") value.lhs = assign @@ -297,8 +297,9 @@ """Gets the reduction dims for the comprehension or None.""" result = set() for use in self.values: - if isinstance(use, ReduceApply): - result.add(use.reduce.reduce_dims) + if isinstance(use, TensorReduceFn): + print(use.reduce_use.reduce_dims) + result.add(use.reduce_use.reduce_dims) else: result.add(tuple()) return result @@ -342,10 +343,6 @@ def __call__(self, *args) -> "TensorArithFn": return TensorArithFn(self, args) - def reduce(self, *reduce_dims: DimDef): - """Shortcut to create a Reduce operation from this function.""" - return ReduceFnType(self, *reduce_dims) - def __repr__(self): return f"{self.fn_name}" @@ -362,31 +359,46 @@ min_unsigned = ArithFnType("min_unsigned") -class ReduceFnType: - """A reduction operator that reduces into its LHS from its RHS.""" +class ReduceFnUse: + """A reduction function assigned to reduction dimensions.""" - def __init__(self, operator: ArithFnType, *reduce_dims: DimDef): - """Initializes the ReduceFn with an airthmetic function and dims.""" - if not isinstance(operator, ArithFnType): - raise ValueError(f"Reduce expected a ArithFnType but got {operator}") - self.operator = operator - self.reduce_dims = tuple(reduce_dims) + def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef): + self.arith_fn = arith_fn + self.reduce_dims = reduce_dims def __call__(self, *args: TensorExpression): - return ReduceApply(self, args) + return TensorReduceFn(self, args) def __repr__(self): - return (f"reduce_{self.operator.fn_name}" + return (f"reduce_{self.arith_fn.fn_name}" f"({', '.join(repr(d) for d in self.reduce_dims)})") +class ReduceFnType: + """Reduction operations. + + All reduction operations reduces their RHS into their LHS. + """ + + def __init__(self, arith_fn: ArithFnType): + if not isinstance(arith_fn, ArithFnType): + raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}") + self.arith_fn = arith_fn + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(self.arith_fn, *reduce_dims) + + def __repr__(self): + return (f"reduce_{self.arith_fn.fn_name}") + + class ReduceFn: - add = ArithFn.add.reduce - mul = ArithFn.mul.reduce - max = ArithFn.max.reduce - min = ArithFn.min.reduce - max_unsigned = ArithFn.max_unsigned.reduce - min_unsigned = ArithFn.min_unsigned.reduce + add = ReduceFnType(ArithFn.add) + mul = ReduceFnType(ArithFn.mul) + max = ReduceFnType(ArithFn.max) + min = ReduceFnType(ArithFn.min) + max_unsigned = ReduceFnType(ArithFn.max_unsigned) + min_unsigned = ReduceFnType(ArithFn.min_unsigned) class TensorArithFn(TensorExpression): @@ -472,31 +484,31 @@ return f"index({repr(self.dim)})" -class ReduceApply(TensorExpression): - """Application of a reduction. +class TensorReduceFn(TensorExpression): + """Application of a reduction operation. - This captures the lhs separately (initial value) separately from the rhs. + This captures the lhs (initial value) separately from the rhs. """ - def __init__(self, reduce: ReduceFnType, args: Sequence[TensorExpression]): - self.reduce = reduce + def __init__(self, reduce_use: ReduceFnUse, args: Sequence[TensorExpression]): + self.reduce_use = reduce_use self.lhs = None # type: Optional[TensorUse] self.args = tuple(args) def to_scalar_expression(self) -> ScalarExpression: if self.lhs is None: - raise ValueError(f"Cannot scalarize a ReduceApply that has not been " + raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been " f"bound to its lhs: {self}") full_args = [self.lhs.to_scalar_expression() ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarArithFn(self.reduce.operator.fn_name, *full_args).expr() + return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr() def visit_tensor_exprs(self, callback): for arg in self.args: arg.visit_tensor_exprs(callback) def __repr__(self): - return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})" + return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" class OpInterfaceDef: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -479,7 +479,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -499,7 +499,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -519,7 +519,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) - O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)( + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max[D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,])) @@ -540,7 +540,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -560,7 +560,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -600,7 +600,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)( + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -621,7 +621,7 @@ """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)( + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py @@ -19,7 +19,7 @@ strides=IndexAttrDef(S.SH, S.SW), dilations=IndexAttrDef(S.DH, S.DW)): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -32,7 +32,7 @@ strides=IndexAttrDef(S.SH, S.SW), dilations=IndexAttrDef(S.DH, S.DW)): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -45,7 +45,7 @@ strides=IndexAttrDef(S.SH, S.SW), dilations=IndexAttrDef(S.DH, S.DW)): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -58,7 +58,7 @@ strides=IndexAttrDef(S.SH, S.SW), dilations=IndexAttrDef(S.DH, S.DW)): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))