diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -107,12 +107,12 @@ ## Index Attributes -Attributes are compile-time constant parameters only accessible in index +Index attributes are compile-time constant parameters only accessible in index expressions. They can be used to parameterize the access pattern of a structured operation, for example, by setting its strides. They cannot take part in the actual computation. -The following example demonstrates the use of attributes: +The following example demonstrates the use of index attributes: ```python @linalg_structured_op @@ -136,9 +136,9 @@ index expressions of the operation instance. If no strides are provided the `default` vector elements are used instead. -Attributes are currently limited to integer vectors and only accessible in index -expressions. An operation may have multiple attributes all of them placed at the -end of the parameter list after the output tensors. +Index attributes are currently limited to integer vectors and only accessible in +index expressions. An operation may have multiple attributes all of them placed +at the end of the parameter list after the output tensors. ## Shape-Only Tensors @@ -220,6 +220,43 @@ * `const(value)` returns a constant value. * `index(dim)` returns the iteration index in the given dimension `dim`. +## Function Attributes + +Function attributes are compile-time constant function parameters. They can be +used to parameterize the computation performed by a structured operation, for +example, to support signed and unsigned computation. + +The following example demonstrates the use of function attributes: + +```python +@linalg_structured_op +def elemwise_binary( + lhs=TensorDef(T1), + rhs=TensorDef(T2), + O=TensorDef(U, output=True), + fun=BinaryFnAttrDef(default=BinaryFn.add), + cast=TypeFnAttrDef(default=TypeFn.cast)): + O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) +``` + +The `fun` and `cast` function attributes by default are aliases for their +default values `BinaryFn.add` and `TypeFn.cast`, respectively. When +instantiating the operation, the function attributes may be set to other +functions using optional named arguments: + +```python +elemwise_binary(lhs, rhs, outs=[out_tensor], + fun=BinaryFn.mul, cast=TypeFn.cast_unsigned) +``` + +In the example, the `fun` and `cast` arguments adapt the body of the operation +to implement multiplication and unsigned casts instead of addition and signed +casts. + +OpDSL supports unary, binary, and type conversion function attributes. An +operation can take multiple attributes of different kinds placed at the end of +the parameter list. + ## Types All types in assignment expressions are late bound based on actual input and diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -58,7 +58,26 @@ }]; } -// Define a TypeFn enum matching the OpDSL TypeFn class. +// Define the function attribute enums matching the OpDSL functions. +def UnaryFn : I32EnumAttr<"UnaryFn", "", [ + I32EnumAttrCase<"exp", 0>, + I32EnumAttrCase<"log", 1> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} +def BinaryFn : I32EnumAttr<"BinaryFn", "", [ + I32EnumAttrCase<"add", 0>, + I32EnumAttrCase<"mul", 1>, + I32EnumAttrCase<"max", 2>, + I32EnumAttrCase<"min", 3>, + I32EnumAttrCase<"sub", 4>, + I32EnumAttrCase<"max_unsigned", 5>, + I32EnumAttrCase<"min_unsigned", 6> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} def TypeFn : I32EnumAttr<"TypeFn", "", [ I32EnumAttrCase<"cast", 0>, I32EnumAttrCase<"cast_unsigned", 1> @@ -67,6 +86,12 @@ let cppNamespace = "::mlir::linalg"; } +def UnaryFnAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} +def BinaryFnAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} def TypeFnAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -1,6 +1,120 @@ ### AUTOGENERATED from core_named_ops.py ### To regenerate, run: bin/update_core_linalg_named_ops.sh --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: elemwise_unary + cpp_class_name: ElemwiseUnaryOp + doc: |- + Applies the unary function fun elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: I + kind: input_tensor + type_var: T1 + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: fun + kind: unary_fn_attr + default_fn: exp + - !LinalgOperandDefConfig + name: cast + kind: type_fn_attr + default_fn: cast + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: unary + attr_name: fun + operands: + - !ScalarExpression + scalar_fn: + kind: type + attr_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: I +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: elemwise_binary + cpp_class_name: ElemwiseBinaryOp + doc: |- + Applies the binary function fun elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: lhs + kind: input_tensor + type_var: T1 + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: rhs + kind: input_tensor + type_var: T2 + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<() -> ()> + - !LinalgOperandDefConfig + name: fun + kind: binary_fn_attr + default_fn: add + - !LinalgOperandDefConfig + name: cast + kind: type_fn_attr + default_fn: cast + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: binary + attr_name: fun + operands: + - !ScalarExpression + scalar_fn: + kind: type + attr_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: lhs + - !ScalarExpression + scalar_fn: + kind: type + attr_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: rhs +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: matmul cpp_class_name: MatmulOp diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -146,17 +146,9 @@ //===----------------------------------------------------------------------===// // Region builder helper. // TODO: Move this to a utility library. -// The public methods on this class are referenced directly from generated code -// and bind by name to math functions in the DSL as: -// `unary__{fnName}` -// `binary__{fnName}` -// Examples: -// `binary__add` -// `binary__mul` -// `unary__exp` -// `unary__log` -// The naming convention is intentional in order to match snake-cased DSL names. -// See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class. +// The public methods on this class are referenced directly from generated code. +// Helper build the unary, binary, and type conversion functions defined by the +// DSL. See mlir-linalg-ods-yaml-gen.cpp for the code that mates to this class. // // Implementations of the math functions must be polymorphic over numeric types, // internally performing necessary casts. If the function application makes no @@ -180,6 +172,104 @@ RegionBuilderHelper(MLIRContext *context, Block &block) : context(context), block(block) {} + // Build the unary functions defined by OpDSL. + Value buildUnaryFn(UnaryFn unaryFn, Value arg) { + if (!isFloatingPoint(arg)) + llvm_unreachable("unsupported non numeric type"); + OpBuilder builder = getBuilder(); + switch (unaryFn) { + case UnaryFn::exp: + return builder.create(arg.getLoc(), arg); + case UnaryFn::log: + return builder.create(arg.getLoc(), arg); + } + llvm_unreachable("unsupported unary function"); + } + + // Build the binary functions defined by OpDSL. + Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) { + bool allFloatingPoint = isFloatingPoint(arg0) && isFloatingPoint(arg1); + bool allInteger = isInteger(arg0) && isInteger(arg1); + if (!allFloatingPoint && !allInteger) + llvm_unreachable("unsupported non numeric type"); + OpBuilder builder = getBuilder(); + switch (binaryFn) { + case BinaryFn::add: + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::mul: + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::max: + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::min: + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::sub: + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::max_unsigned: + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + return builder.create(arg0.getLoc(), arg0, arg1); + case BinaryFn::min_unsigned: + if (allFloatingPoint) + return builder.create(arg0.getLoc(), arg0, arg1); + return builder.create(arg0.getLoc(), arg0, arg1); + } + llvm_unreachable("unsupported binary function"); + } + + // Build the type functions defined by OpDSL. + Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { + switch (typeFn) { + case TypeFn::cast: + return cast(toType, operand, false); + case TypeFn::cast_unsigned: + return cast(toType, operand, true); + } + llvm_unreachable("unsupported type conversion function"); + } + + void yieldOutputs(ValueRange values) { + assert(!values.empty() && "linalg ops must yield outputs"); + if (values.empty()) + return; + Value first = values.front(); + OpBuilder builder = getBuilder(); + builder.create(first.getLoc(), values); + } + + Value constant(const std::string &value) { + OpBuilder builder = getBuilder(); + Location loc = builder.getUnknownLoc(); + Attribute valueAttr = parseAttribute(value, builder.getContext()); + return builder.create(loc, valueAttr.getType(), + valueAttr); + } + + Value index(int64_t dim) { + OpBuilder builder = getBuilder(); + return builder.create(builder.getUnknownLoc(), dim); + } + + Type getIntegerType(unsigned width) { + return IntegerType::get(context, width); + } + + Type getFloat32Type() { return Float32Type::get(context); } + Type getFloat64Type() { return Float64Type::get(context); } + +private: + MLIRContext *context; + Block █ + // Generates operations to cast the given operand to a specified type. // If the cast cannot be performed, a warning will be issued and the // operand returned as-is (which will presumably yield a verification @@ -231,136 +321,6 @@ return operand; } - Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) { - switch (typeFn) { - case TypeFn::cast: - return cast(toType, operand, false); - case TypeFn::cast_unsigned: - return cast(toType, operand, true); - } - llvm_unreachable("unsupported type conversion function"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value binary__add(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value unary__exp(Value x) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(x)) - return builder.create(x.getLoc(), x); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value unary__log(Value x) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(x)) - return builder.create(x.getLoc(), x); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value binary__sub(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value binary__mul(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value binary__max(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value binary__max_unsigned(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value binary__min(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - // NOLINTNEXTLINE(*-identifier-naming): externally called. - Value binary__min_unsigned(Value lhs, Value rhs) { - OpBuilder builder = getBuilder(); - if (isFloatingPoint(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - if (isInteger(lhs)) - return builder.create(lhs.getLoc(), lhs, rhs); - llvm_unreachable("unsupported non numeric type"); - } - - void yieldOutputs(ValueRange values) { - assert(!values.empty() && "linalg ops must yield outputs"); - if (values.empty()) - return; - Value first = values.front(); - OpBuilder builder = getBuilder(); - builder.create(first.getLoc(), values); - } - - Value constant(const std::string &value) { - OpBuilder builder = getBuilder(); - Location loc = builder.getUnknownLoc(); - Attribute valueAttr = parseAttribute(value, builder.getContext()); - return builder.create(loc, valueAttr.getType(), - valueAttr); - } - - Value index(int64_t dim) { - OpBuilder builder = getBuilder(); - return builder.create(builder.getUnknownLoc(), dim); - } - - Type getIntegerType(unsigned width) { - return IntegerType::get(context, width); - } - - Type getFloat32Type() { return Float32Type::get(context); } - - Type getFloat64Type() { return Float64Type::get(context); } - -private: - MLIRContext *context; - Block █ - bool isFloatingPoint(Value value) { return value.getType().isa(); } bool isInteger(Value value) { return value.getType().isa(); } diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -126,7 +126,7 @@ return rhs_dims - lhs_dims def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": - return ReduceFnUse(BinaryFn.add, *self._compute_reduce_dims(rhs))(rhs) + return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs) def __repr__(self): return (f"{self.operand_def.name}" @@ -183,8 +183,14 @@ f"bound to its lhs: {self}") full_args = [self.lhs.to_scalar_expression() ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarFn(FunctionKind.BINARY, self.reduce_use.binary_fn.fn_name, - None, None, full_args).expr() + fn_name = None + attr_name = None + if self.reduce_use.binary_fn: + fn_name = self.reduce_use.binary_fn.fn_name + if self.reduce_use.binary_attr: + attr_name = self.reduce_use.binary_attr.operand_def.name + return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, + full_args).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): for arg in self.args: @@ -257,8 +263,8 @@ def __init__(self, fn_name: str): self.fn_name = fn_name - def __call__(self, exp: TensorExpression) -> "TensorFn": - return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [exp]) + def __call__(self, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg]) def __repr__(self): return f"{self.fn_name}" @@ -345,16 +351,21 @@ A reduction use specifies the reduction function and dimensions. """ - def __init__(self, binary_fn: BinaryFnType, *reduce_dims: DimDef): + def __init__(self, binary_fn: Optional[BinaryFnType], + binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef): + if bool(binary_fn) + bool(binary_attr) != 1: + raise ValueError("One of 'binary_fn', 'binary_attr' must be specified") self.binary_fn = binary_fn + self.binary_attr = binary_attr self.reduce_dims = reduce_dims def __call__(self, *args: TensorExpression) -> "TensorReduceFn": return TensorReduceFn(self, args) def __repr__(self): - return (f"reduce_{self.binary_fn.fn_name}" - f"({', '.join(repr(d) for d in self.reduce_dims)})") + fn = self.binary_fn if self.binary_fn else self.binary_attr + return ( + f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})") class ReduceFnType: @@ -369,10 +380,10 @@ self.binary_fn = binary_fn def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: - return ReduceFnUse(self.binary_fn, *reduce_dims) + return ReduceFnUse(self.binary_fn, None, *reduce_dims) def __repr__(self): - return (f"reduce_{self.binary_fn.fn_name}") + return f"reduce_{repr(self.binary_fn)}" class ReduceFn: @@ -394,7 +405,9 @@ SCALAR = 1 OUTPUT_TENSOR = 2 INDEX_ATTR = 3 - TYPE_FN_ATTR = 4 + UNARY_FN_ATTR = 4 + BINARY_FN_ATTR = 5 + TYPE_FN_ATTR = 6 class OperandDef: @@ -441,6 +454,8 @@ def is_attribute(self) -> bool: return (self.kind == OperandKind.INDEX_ATTR or + self.kind == OperandKind.UNARY_FN_ATTR or + self.kind == OperandKind.BINARY_FN_ATTR or self.kind == OperandKind.TYPE_FN_ATTR) def __hash__(self): @@ -557,6 +572,49 @@ OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default) +class UnaryFnAttrDef: + """Unary function attribute definition. + + Unary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default unary function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "UnaryFnType"): + if not isinstance(default, UnaryFnType): + raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType " + f"but got {default}") + self.operand_def = OperandDef( + OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name) + + def __call__(self, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) + + +class BinaryFnAttrDef: + """Binary function attribute definition. + + Binary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default binary function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "BinaryFnType"): + if not isinstance(default, BinaryFnType): + raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType " + f"but got {default}") + self.operand_def = OperandDef( + OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name) + + def __call__(self, arg0: TensorExpression, + arg1: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, + [arg0, arg1]) + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(None, self, *reduce_dims) + + class TypeFnAttrDef: """Type conversion function attribute definition. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -309,8 +309,8 @@ def add_operand(self, operand_def: OperandDef): if operand_def in self.operands: return - if (operand_def.kind == OperandKind.SCALAR or - operand_def.kind == OperandKind.TYPE_FN_ATTR): + if not (operand_def.is_tensor() or + operand_def.kind == OperandKind.INDEX_ATTR): self.operands[operand_def] = OperandDefConfig(operand_def) return with self.context: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -130,7 +130,8 @@ for param_name, param in sig.parameters.items(): param_default = param.default if isinstance(param_default, - (TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)): + (TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef, + BinaryFnAttrDef, TypeFnAttrDef)): op_def.add_operand(param_name, param_default.operand_def) else: raise ValueError( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -41,7 +41,7 @@ all_arg_defs = op_config.ordered_operands in_arg_defs = [ d for d in all_arg_defs - if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR + if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR] ] out_arg_defs = [ d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR @@ -49,8 +49,11 @@ index_attr_arg_defs = [ d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR ] - type_fn_attr_arg_defs = [ - d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR + fn_attr_arg_defs = [ + d for d in all_arg_defs if d.kind in [ + OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR, + OperandKind.TYPE_FN_ATTR + ] ] # Verify outs is a sequence or a list of results. @@ -135,28 +138,38 @@ array = np.array(index_attr_vals, dtype=np.int64) index_attrs[index_attr.name] = DenseElementsAttr.get(array) - # Compute the type function attribute mapping. - type_fn_attr_mapping = {} - for type_fn_attr in type_fn_attr_arg_defs: - attr_val = type_fn_attr.operand_def.default_fn - if type_fn_attr.name in attrs: - type_fn = attrs.get(type_fn_attr.name) - if not isinstance(type_fn, TypeFnType): - raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type " - f"TypeFnType but got {type(attr_val)}") - attr_val = type_fn.fn_name - assert attr_val, "Type function attribute has no value" - type_fn_attr_mapping[type_fn_attr.name] = attr_val + # Compute the function attribute mapping. + fn_attr_mapping = {} + for fn_attr in fn_attr_arg_defs: + attr_val = fn_attr.operand_def.default_fn + attr_kind = fn_attr.kind + if fn_attr.name in attrs: + fn = attrs.get(fn_attr.name) + if attr_kind == OperandKind.UNARY_FN_ATTR: + if not isinstance(fn, UnaryFnType): + raise ValueError(f"Attribute {fn_attr.name} needs to be of type " + f"UnaryFnType but got {type(attr_val)}") + elif attr_kind == OperandKind.BINARY_FN_ATTR: + if not isinstance(fn, BinaryFnType): + raise ValueError(f"Attribute {fn_attr.name} needs to be of type " + f"BinaryFnType but got {type(attr_val)}") + else: + if not isinstance(fn, TypeFnType): + raise ValueError(f"Attribute {fn_attr.name} needs to be of type " + f"TypeFnType but got {type(attr_val)}") + attr_val = fn.fn_name + assert attr_val, "Function attribute has no value" + fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind) return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs, - type_fn_attr_mapping, block_arg_types) + fn_attr_mapping, block_arg_types) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \ + indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \ block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) @@ -193,7 +206,7 @@ block_arg_mapping = dict(zip(block_arg_names, block.arguments)) with InsertionPoint(block): body_builder = _BodyBuilder(type_mapping, block_arg_mapping, - type_fn_attr_mapping) + fn_attr_mapping) for assignment in op_config.assignments: body_builder.assign(assignment) body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) @@ -208,7 +221,7 @@ op_class_name: str, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \ + indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \ block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) @@ -225,10 +238,12 @@ for name, value in index_attrs.items(): named_op.operation.attributes[name] = value - # Set the type function attributes. - for name, value in type_fn_attr_mapping.items(): + # Compute the function attributes by combining operand kind and function name. + for name, (fn_name, kind) in fn_attr_mapping.items(): + assert kind.name.lower().endswith("_attr") + enum_name = kind.name.lower()[:-5] named_op.operation.attributes[name] = Attribute.parse( - f"#linalg.type_fn<{value}>") + f"#linalg.{enum_name}<{fn_name}>") linalg.fill_builtin_region(named_op.operation) @@ -242,11 +257,11 @@ """Constructs a structured op body by evaluating assignments.""" def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value], - type_fn_attr_mapping: Dict[str, str]): + block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str, + str]): self.type_mapping = type_mapping self.block_arg_mapping = block_arg_mapping - self.type_fn_attr_mapping = type_fn_attr_mapping + self.fn_attr_mapping = fn_attr_mapping self.yield_mapping = dict() # type: Dict[str, Value] def assign(self, assignment: ScalarAssign): @@ -270,21 +285,18 @@ dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result - elif expr.scalar_fn and expr.scalar_fn.kind is not FunctionKind.TYPE: + elif expr.scalar_fn: kind = expr.scalar_fn.kind.name.lower() - fn = self._get_function(f"_{kind}_{expr.scalar_fn.fn_name}") + fn_name = expr.scalar_fn.fn_name + if expr.scalar_fn.attr_name: + fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name] + fn = self._get_function(f"_{kind}_{fn_name}") operand_values = [ self.expression(operand) for operand in expr.scalar_fn.operands ] + if expr.scalar_fn.kind == FunctionKind.TYPE: + operand_values = [expr.scalar_fn.type_var.name] + operand_values return fn(*operand_values) - elif expr.scalar_fn and expr.scalar_fn.kind is FunctionKind.TYPE: - kind = expr.scalar_fn.kind.name.lower() - fn_name = expr.scalar_fn.fn_name - if expr.scalar_fn.attr_name: - fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name] - fn = self._get_function(f"_{kind}_{fn_name}") - operand_value = self.expression(expr.scalar_fn.operands[0]) - return fn(expr.scalar_fn.type_var.name, operand_value) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") def yield_outputs(self, *output_names: str): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -6,6 +6,35 @@ Batch = S.Batch +@linalg_structured_op +def elemwise_unary( + I=TensorDef(T1), + O=TensorDef(U, output=True), + fun=UnaryFnAttrDef(default=UnaryFn.exp), + cast=TypeFnAttrDef(default=TypeFn.cast)): + """Applies the unary function fun elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = fun(cast(U, I[None])) + + +@linalg_structured_op +def elemwise_binary( + lhs=TensorDef(T1), + rhs=TensorDef(T2), + O=TensorDef(U, output=True), + fun=BinaryFnAttrDef(default=BinaryFn.add), + cast=TypeFnAttrDef(default=TypeFn.cast)): + """Applies the binary function fun elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), diff --git a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir --- a/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir +++ b/mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir @@ -292,16 +292,48 @@ // ----- -func @generalize_soft_plus_2d_f32(%input: tensor<16x32xf32>, %output: tensor<16x32xf32>) -> tensor<16x32xf32> { - %0 = linalg.soft_plus_2d ins(%input: tensor<16x32xf32>) outs(%output: tensor<16x32xf32>) -> tensor<16x32xf32> - return %0: tensor<16x32xf32> +// Verifies the default value of the fun attribute is an exp op. +func @generalize_elemwise_exp(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = linalg.elemwise_unary ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> + return %0: tensor<4x8xf32> } -// CHECK-LABEL: @generalize_soft_plus_2d_f32 -// CHECK: %[[C1:.+]] = arith.constant 1.000000e+00 : f32 -// CHECK: ^{{.*}}(%[[IN:.+]]: f32, %[[OUT:.+]]: f32 -// CHECK-NEXT: %[[EXP:.+]] = math.exp %[[IN]] : f32 -// CHECK-NEXT: %[[SUM:.+]] = arith.addf %[[EXP]], %[[C1]] : f32 -// CHECK-NEXT: %[[LOG:.+]] = math.log %[[SUM]] : f32 -// CHECK-NEXT: linalg.yield %[[LOG]] : f32 -// CHECK-NEXT: -> tensor<16x32xf32> +// CHECK-LABEL: @generalize_elemwise_exp +// CHECK: = math.exp + +// ----- + +// Verifies the fun attribute controls the unary function used. +func @generalize_elemwise_log(%lhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = linalg.elemwise_unary {fun = #linalg.unary_fn} + ins(%lhs: tensor<4x8xf32>) outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> + return %0: tensor<4x8xf32> +} + +// CHECK-LABEL: @generalize_elemwise_log +// CHECK: = math.log + +// ----- + +// Verifies the default value of the fun attribute is an add op. +func @generalize_elemwise_add(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = linalg.elemwise_binary ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>) + outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> + return %0: tensor<4x8xf32> +} + +// CHECK-LABEL: @generalize_elemwise_add +// CHECK: = arith.addf + +// ----- + +// Verifies the fun attribute controls the binary function used. +func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> { + %0 = linalg.elemwise_binary {fun = #linalg.binary_fn} + ins(%lhs, %rhs: tensor<4x8xf32>, tensor<4x8xf32>) + outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32> + return %0: tensor<4x8xf32> +} + +// CHECK-LABEL: @generalize_elemwise_mul +// CHECK: = arith.mulf diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -111,7 +111,7 @@ # IMPL-DAG: Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]); # IMPL-DAG: Value [[VAL2:[a-z0-9]+]] = helper.index(1); # IMPL-DAG: Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]); -# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.binary__add([[VAL1]], [[VAL3]]); +# IMPL-DAG: Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]]); # @linalg_structured_op @@ -255,14 +255,15 @@ # IMPL-NEXT: AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context); # IMPL-NEXT: AffineMap tensorMap = AffineMap::getMultiDimIdentityMap( - # @linalg_structured_op -# def test4(O=TensorDef(T, S.M, S.N, output=True)): +# def test4(O=TensorDef(T, S.M, S.N, output=True), +# unary_fun=UnaryFnAttrDef(default=UnaryFn.exp), +# binary_fun=BinaryFnAttrDef(default=BinaryFn.add)): # """Title. # Detailed description. # """ -# O[D.m, D.n] = BinaryFn.add(UnaryFn.exp(O[D.m, D.n]), O[D.m, D.n]) +# O[D.m, D.n] = binary_fun(unary_fun(O[D.m, D.n]), O[D.m, D.n]) --- !LinalgOpConfig metadata: !LinalgOpMetadata @@ -279,6 +280,14 @@ kind: output_tensor type_var: T shape_map: affine_map<()[s0, s1] -> (s0, s1)> + - !LinalgOperandDefConfig + name: unary_fun + kind: unary_fn_attr + default_fn: exp + - !LinalgOperandDefConfig + name: binary_fun + kind: binary_fn_attr + default_fn: add indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1)[s0, s1] -> (d0, d1)> @@ -291,21 +300,36 @@ value: !ScalarExpression scalar_fn: kind: binary - fn_name: add + attr_name: binary_fun operands: - !ScalarExpression scalar_fn: kind: unary - fn_name: exp + attr_name: unary_fun operands: - !ScalarExpression scalar_arg: O - !ScalarExpression scalar_arg: O +# ODS-LABEL: def Test4Op : LinalgStructuredBase_Op<"test4" + +# ODS: let arguments = +# ODS-NEXT: Variadic:$inputs, +# ODS-NEXT: Variadic:$outputs, +# ODS-NEXT: DefaultValuedAttr:$unary_fun, +# ODS-NEXT: DefaultValuedAttr:$binary_fun + +# ODS: "Attribute":$unary_fun, "Attribute":$binary_fun, + +# ODS: $_state.addAttribute("unary_fun", unary_fun) +# ODS-NEXT: $_state.addAttribute("binary_fun", binary_fun) + # IMPL-LABEL: void Test4Op::regionBuilder(ImplicitLocOpBuilder &b, # IMPL-NEXT: Block &block, ArrayRef attrs) +# IMPL: UnaryFn unary_funVal = UnaryFn::exp +# IMPL: BinaryFn binary_funVal = BinaryFn::add -# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.unary__exp(block.getArgument(0)) -# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.binary__add([[VAL0]], block.getArgument(0)) +# IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0)) +# IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0)) # IMPL-NEXT: yields.push_back([[VAL1]]) diff --git a/mlir/test/python/dialects/linalg/opdsl/arguments.py b/mlir/test/python/dialects/linalg/opdsl/arguments.py --- a/mlir/test/python/dialects/linalg/opdsl/arguments.py +++ b/mlir/test/python/dialects/linalg/opdsl/arguments.py @@ -18,6 +18,12 @@ # CHECK: kind: output_tensor # CHECK: type_var: U # CHECK: shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)> +# CHECK: name: bfn +# CHECK: kind: binary_fn_attr +# CHECK: default_fn: mul +# CHECK: name: ufn +# CHECK: kind: unary_fn_attr +# CHECK: default_fn: exp # CHECK: name: cast # CHECK: kind: type_fn_attr # CHECK: default_fn: cast @@ -26,8 +32,10 @@ A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), + bfn=BinaryFnAttrDef(default=BinaryFn.mul), + ufn=UnaryFnAttrDef(default=UnaryFn.exp), cast=TypeFnAttrDef(default=TypeFn.cast)): - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + C[D.m, D.n] += bfn(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) # CHECK: --- diff --git a/mlir/test/python/dialects/linalg/opdsl/assignments.py b/mlir/test/python/dialects/linalg/opdsl/assignments.py --- a/mlir/test/python/dialects/linalg/opdsl/assignments.py +++ b/mlir/test/python/dialects/linalg/opdsl/assignments.py @@ -10,10 +10,12 @@ # CHECK: arg: C # CHECK: value: # CHECK: scalar_fn: +# CHECK: kind: binary # CHECK: fn_name: add # CHECK: operands: # CHECK: scalar_fn: -# CHECK: fn_name: mul +# CHECK: kind: binary +# CHECK: attr_name: mul # CHECK: operands: # CHECK: scalar_fn: # CHECK: kind: type @@ -32,8 +34,9 @@ A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), + mul=BinaryFnAttrDef(default=BinaryFn.mul), cast=TypeFnAttrDef(default=TypeFn.cast)): - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + C[D.m, D.n] += mul(cast(U, A[D.m, D.k]), cast(U, B[D.k, D.n])) # CHECK: --- @@ -69,14 +72,21 @@ # CHECK: fn_name: cast # CHECK: type_var: T # CHECK: operands: -# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64' +# CHECK: scalar_fn: +# CHECK: kind: unary +# CHECK: attr_name: exp +# CHECK: operands: +# CHECK: scalar_const: '1.{{[0]*}}e+03 : f64' @linalg_structured_op -def constants(O=TensorDef(T, S.M, S.K, output=True)): +def constants( + O=TensorDef(T, S.M, S.K, output=True), + exp=UnaryFnAttrDef(default=UnaryFn.exp)): pi = TypeFn.cast(T, const(3.1415926535897931)) cst42 = TypeFn.cast(T, const(42)) - cst1000 = TypeFn.cast(T, const(1e+3)) + cst1000 = TypeFn.cast(T, exp(const(1e+3))) O[D.m, D.n] = UnaryFn.exp(pi) + cst42 - cst1000 + # CHECK: --- # CHECK-LABEL: indices # CHECK: assignments: diff --git a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py --- a/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py +++ b/mlir/test/python/dialects/linalg/opdsl/emit_pooling.py @@ -12,55 +12,18 @@ @linalg_structured_op -def pooling_max_poly( +def pooling_poly( I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + reduce=BinaryFnAttrDef(default=BinaryFn.max), + cast=TypeFnAttrDef(default=TypeFn.cast), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw]( - TypeFn.cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) - - -@linalg_structured_op -def pooling_max_unsigned_poly( - I=TensorDef(T1, S.N, S.H, S.W, S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( - TypeFn.cast_unsigned( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) - - -@linalg_structured_op -def pooling_min_poly( - I=TensorDef(T1, S.N, S.H, S.W, S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw]( - TypeFn.cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) - - -@linalg_structured_op -def pooling_min_unsigned_poly( - I=TensorDef(T1, S.N, S.H, S.W, S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( - TypeFn.cast_unsigned( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + O[D.n, D.oh, D.ow, D.c] = reduce[D.kh, D.kw]( + cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c])) with Context() as ctx, Location.unknown(): @@ -88,7 +51,7 @@ RankedTensorType.get((2, 2), f32), RankedTensorType.get((1, 2, 4, 1), i32)) def test_f32i32_max_pooling(input, shape, init_result): - return pooling_max_poly( + return pooling_poly( input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) # CHECK-LABEL: @test_f32i32_max_unsigned_pooling @@ -99,8 +62,14 @@ RankedTensorType.get((2, 2), f32), RankedTensorType.get((1, 2, 4, 1), i32)) def test_f32i32_max_unsigned_pooling(input, shape, init_result): - return pooling_max_unsigned_poly( - input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + return pooling_poly( + input, + shape, + outs=[init_result], + reduce=BinaryFn.max_unsigned, + cast=TypeFn.cast_unsigned, + strides=[2, 4], + dilations=[1, 2]) # CHECK-LABEL: @test_f32f32_max_pooling # CHECK: linalg.generic @@ -115,7 +84,7 @@ RankedTensorType.get((2, 2), f32), RankedTensorType.get((1, 2, 4, 1), f32)) def test_f32f32_max_pooling(input, shape, init_result): - return pooling_max_poly( + return pooling_poly( input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) # CHECK-LABEL: @test_f32i32_min_pooling @@ -126,8 +95,13 @@ RankedTensorType.get((2, 2), f32), RankedTensorType.get((1, 2, 4, 1), i32)) def test_f32i32_min_pooling(input, shape, init_result): - return pooling_min_poly( - input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + return pooling_poly( + input, + shape, + outs=[init_result], + reduce=BinaryFn.min, + strides=[2, 4], + dilations=[1, 2]) # CHECK-LABEL: @test_f32i32_min_unsigned_pooling # CHECK: = arith.fptoui @@ -137,8 +111,14 @@ RankedTensorType.get((2, 2), f32), RankedTensorType.get((1, 2, 4, 1), i32)) def test_f32i32_min_unsigned_pooling(input, shape, init_result): - return pooling_min_unsigned_poly( - input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + return pooling_poly( + input, + shape, + outs=[init_result], + reduce=BinaryFn.min_unsigned, + cast=TypeFn.cast_unsigned, + strides=[2, 4], + dilations=[1, 2]) # CHECK-LABEL: @test_f32f32_min_pooling # CHECK: = arith.minf @@ -147,8 +127,13 @@ RankedTensorType.get((2, 2), f32), RankedTensorType.get((1, 2, 4, 1), f32)) def test_f32f32_min_pooling(input, shape, init_result): - return pooling_min_poly( - input, shape, outs=[init_result], strides=[2, 4], dilations=[1, 2]) + return pooling_poly( + input, + shape, + outs=[init_result], + reduce=BinaryFn.min, + strides=[2, 4], + dilations=[1, 2]) print(module) diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py --- a/mlir/test/python/dialects/linalg/ops.py +++ b/mlir/test/python/dialects/linalg/ops.py @@ -94,20 +94,27 @@ with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( - RankedTensorType.get((4, 16), f32), RankedTensorType.get((16, 8), - f32)) + RankedTensorType.get((4, 8), f32), RankedTensorType.get((4, 8), f32)) def named_form(lhs, rhs): init_result = linalg.InitTensorOp([4, 8], f32) - # First check the named form with custom format - # CHECK: linalg.matmul - # CHECK: cast = #linalg.type_fn - # CHECK-NOT: linalg.memoized_indexing_maps - # CHECK-SAME: ins(%{{.*}} : tensor<4x16xf32>, tensor<16x8xf32>) - # CHECK-SAME: outs(%{{.*}} : tensor<4x8xf32>) - # CHECK-SAME: -> tensor<4x8xf32> - # CHECK-NEXT: return - return linalg.matmul( - lhs, rhs, outs=[init_result.result], cast=TypeFn.cast_unsigned) + # Check for the named form with custom format + # CHECK: linalg.elemwise_unary + # CHECK-SAME: cast = #linalg.type_fn + # CHECK-SAME: fun = #linalg.unary_fn + # CHECK-SAME: ins(%{{.*}} : tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) + unary_result = linalg.elemwise_unary(lhs, outs=[init_result.result]) + # CHECK: linalg.elemwise_binary + # CHECK-SAME: cast = #linalg.type_fn + # CHECK-SAME: fun = #linalg.binary_fn + # CHECK-SAME: ins(%{{.*}}, %{{.*}} : tensor<4x8xf32>, tensor<4x8xf32>) outs(%{{.*}} : tensor<4x8xf32>) + # CHECK: return + binary_result = linalg.elemwise_binary( + lhs, + rhs, + outs=[init_result.result], + fun=BinaryFn.mul, + cast=TypeFn.cast_unsigned) + return unary_result, binary_result print(module) @@ -130,7 +137,8 @@ # CHECK-NEXT: arith.mulf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: arith.addf{{.*}} (f32, f32) -> f32 # CHECK-NEXT: linalg.yield{{.*}} (f32) -> () - # CHECK-NEXT: operand_segment_sizes = dense<[2, 1]> : vector<2xi32> + # CHECK-NEXT: cast = #linalg.type_fn + # CHECK-SAME: operand_segment_sizes = dense<[2, 1]> : vector<2xi32> # CHECK-SAME: (tensor<4x16xf32>, tensor<16x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return linalg.matmul(lhs, rhs, outs=[init_result.result]) diff --git a/mlir/test/python/integration/dialects/linalg/opsrun.py b/mlir/test/python/integration/dialects/linalg/opsrun.py --- a/mlir/test/python/integration/dialects/linalg/opsrun.py +++ b/mlir/test/python/integration/dialects/linalg/opsrun.py @@ -19,6 +19,37 @@ sys.stderr.flush() +elemwise_boiler = """ +func @main() -> f32 attributes {llvm.emit_c_interface} { + %v0 = arith.constant 0.0 : f32 + %v1 = arith.constant 1.0 : f32 + %v2 = arith.constant 2.0 : f32 + + %lhs = memref.alloc() : memref<4x8xf32> + %rhs = memref.alloc() : memref<4x8xf32> + %O0 = memref.alloc() : memref<4x8xf32> + %O1 = memref.alloc() : memref<4x8xf32> + linalg.fill(%v1, %lhs) : f32, memref<4x8xf32> + linalg.fill(%v2, %rhs) : f32, memref<4x8xf32> + linalg.fill(%v0, %O0) : f32, memref<4x8xf32> + linalg.fill(%v0, %O1) : f32, memref<4x8xf32> + + call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) : + (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> () + call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) : + (memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> () + + %c0 = arith.constant 0 : index + %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32> + %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32> + + %0 = arith.addf %res0, %res1 : f32 + + // TODO: FFI-based solution to allow testing and printing with python code. + return %0 : f32 +} +""" + matmul_boiler = """ func @main() -> f32 attributes {llvm.emit_c_interface} { %v0 = arith.constant 0.0 : f32 @@ -166,13 +197,93 @@ pm = PassManager.parse( "builtin.func(convert-linalg-to-loops, lower-affine, " + - "convert-scf-to-cf, arith-expand, memref-expand), convert-vector-to-llvm," - + "convert-memref-to-llvm, convert-std-to-llvm," + + "convert-math-to-llvm, convert-scf-to-cf, arith-expand, memref-expand), " + + "convert-vector-to-llvm, convert-memref-to-llvm, convert-std-to-llvm," + "reconcile-unrealized-casts") pm.run(mod) return mod +def test_elemwise_builtin(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i8 = IntegerType.get_signless(8) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func( + MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((4, 8), f32)) + def elemwise_exp_add_on_buffers(lhs, rhs, out): + linalg.elemwise_unary(lhs, outs=[out]) + linalg.elemwise_binary(out, rhs, outs=[out]) + + @builtin.FuncOp.from_py_func( + MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((4, 8), f32)) + def elemwise_log_mul_on_buffers(lhs, rhs, out): + linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log) + linalg.elemwise_binary(out, rhs, outs=[out], fun=BinaryFn.mul) + + execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 + # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 + # CHECK: RESULT: 4.71828 + + +test_elemwise_builtin() + + +def test_elemwise_generic(): + with Context() as ctx, Location.unknown(): + module = Module.create() + f32 = F32Type.get() + i8 = IntegerType.get_signless(8) + with InsertionPoint(module.body): + + @builtin.FuncOp.from_py_func( + MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((4, 8), f32)) + def elemwise_exp_add_on_buffers(lhs, rhs, out): + linalg.elemwise_unary(lhs, outs=[out], emit_generic=True) + linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True) + + @builtin.FuncOp.from_py_func( + MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32), + MemRefType.get((4, 8), f32)) + def elemwise_log_mul_on_buffers(lhs, rhs, out): + linalg.elemwise_unary( + lhs, outs=[out], fun=UnaryFn.log, emit_generic=True) + linalg.elemwise_binary( + out, rhs, outs=[out], fun=BinaryFn.mul, emit_generic=True) + + execution_engine = ExecutionEngine(transform(module, elemwise_boiler)) + + # TODO: FFI-based solution to allow testing and printing with python code. + # Prepare arguments: one result f32. + # Arguments must be passed as pointers. + c_float_p = ctypes.c_float * 1 + res = c_float_p(-1.) + execution_engine.invoke("main", res) + + log("RESULT: ", res[0]) + # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846 + # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0 + # CHECK: RESULT: 4.71828 + + +test_elemwise_generic() + + def test_matmul_builtin(): with Context() as ctx, Location.unknown(): module = Module.create() diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -66,6 +66,8 @@ Scalar, OutputTensor, IndexAttr, + UnaryFnAttr, + BinaryFnAttr, TypeFnAttr }; @@ -208,6 +210,8 @@ io.enumCase(value, "scalar", LinalgOperandDefKind::Scalar); io.enumCase(value, "output_tensor", LinalgOperandDefKind::OutputTensor); io.enumCase(value, "index_attr", LinalgOperandDefKind::IndexAttr); + io.enumCase(value, "unary_fn_attr", LinalgOperandDefKind::UnaryFnAttr); + io.enumCase(value, "binary_fn_attr", LinalgOperandDefKind::BinaryFnAttr); io.enumCase(value, "type_fn_attr", LinalgOperandDefKind::TypeFnAttr); } }; @@ -430,6 +434,45 @@ return nullptr; } +// Return true if the operand is a function attribute. +static bool isFunctionAttribute(LinalgOperandDefKind kind) { + return kind == LinalgOperandDefKind::UnaryFnAttr || + kind == LinalgOperandDefKind::BinaryFnAttr || + kind == LinalgOperandDefKind::TypeFnAttr; +} + +// Return true if the operand is an attribute. +static bool isAttribute(LinalgOperandDefKind kind) { + return kind == LinalgOperandDefKind::IndexAttr || isFunctionAttribute(kind); +} + +// Get the enum name for the given operand kind. +std::string convertOperandKindToEnumName(LinalgOperandDefKind kind) { + switch (kind) { + case LinalgOperandDefKind::UnaryFnAttr: + return std::string("UnaryFn"); + case LinalgOperandDefKind::BinaryFnAttr: + return std::string("BinaryFn"); + case LinalgOperandDefKind::TypeFnAttr: + return std::string("TypeFn"); + default: + break; + } + llvm_unreachable("unsupported function attribute kind"); +} + +// Get the enum name for the given function kind. +std::string convertFunctionKindToEnumName(ScalarFnKind kind) { + switch (kind) { + case ScalarFnKind::Unary: + return std::string("UnaryFn"); + case ScalarFnKind::Binary: + return std::string("BinaryFn"); + case ScalarFnKind::Type: + return std::string("TypeFn"); + } +} + //===----------------------------------------------------------------------===// // Templates //===----------------------------------------------------------------------===// @@ -693,8 +736,7 @@ interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { - return arg.kind == LinalgOperandDefKind::IndexAttr || - arg.kind == LinalgOperandDefKind::TypeFnAttr; + return isAttribute(arg.kind); })) { SmallVector attrDefs; SmallVector attrParams; @@ -703,13 +745,14 @@ static const char paramFmt[] = "\"Attribute\":${0}"; static const char stmtFmt[] = "$_state.addAttribute(\"{0}\", {0});"; // Add the type conversion attributes to the op definition and builders. - if (arg.kind == LinalgOperandDefKind::TypeFnAttr) { + if (isFunctionAttribute(arg.kind)) { assert(arg.defaultFn.hasValue()); - static const char typeFmt[] = "TypeFn::{0}"; + std::string enumName = convertOperandKindToEnumName(arg.kind); + static const char typeFmt[] = "{0}::{1}"; static const char defFmt[] = "DefaultValuedAttr<{0}, \"{1}\">:${2}"; - attrDefs.push_back(llvm::formatv(defFmt, "TypeFnAttr", - llvm::formatv(typeFmt, arg.defaultFn), - arg.name)); + attrDefs.push_back(llvm::formatv( + defFmt, llvm::formatv("{0}Attr", enumName), + llvm::formatv(typeFmt, enumName, arg.defaultFn), arg.name)); attrParams.push_back(llvm::formatv(paramFmt, arg.name)); attrStmts.push_back(llvm::formatv(stmtFmt, arg.name)); } @@ -1000,21 +1043,24 @@ SmallVector attrs; SmallVector stmts; for (LinalgOperandDef &arg : args) { - if (arg.kind != LinalgOperandDefKind::TypeFnAttr) + if (!isFunctionAttribute(arg.kind)) continue; // Obtain the type function attribute values. Parameters. - // {0}: attribute name - // {1}: default type function name + // {0}: enum name + // {1}: attribute name + // {2}: default type function name static const char attrDef[] = R"FMT( -TypeFn {0}Val = TypeFn::{1}; -auto {0}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{ - return attr.getName() == "{0}"; }); -if ({0}Iter != attrs.end()) {{ - if (auto attr = {0}Iter->getValue().dyn_cast()) - {0}Val = attr.getValue(); +{0} {1}Val = {0}::{2}; +auto {1}Iter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {{ + return attr.getName() == "{1}"; }); +if ({1}Iter != attrs.end()) {{ + if (auto attr = {1}Iter->getValue().dyn_cast<{0}Attr>()) + {1}Val = attr.getValue(); } )FMT"; - attrs.push_back(llvm::formatv(attrDef, arg.name, arg.defaultFn)); + std::string enumName = convertOperandKindToEnumName(arg.kind); + attrs.push_back( + llvm::formatv(attrDef, enumName, arg.name, arg.defaultFn)); } for (LinalgOperandDef &arg : args) { if (arg.kind != LinalgOperandDefKind::OutputTensor) @@ -1056,71 +1102,59 @@ cppIdent, *expression.index)); return cppIdent; } - if (expression.scalarFn && - expression.scalarFn->kind != ScalarFnKind::Type) { - // Apply function. - // Recursively generate operands. - SmallVector operandCppValues; - for (ScalarExpression &operand : expression.scalarFn->operands) { - auto operandCppValue = generateExpression(operand); - if (!operandCppValue) - return None; - operandCppValues.push_back(*operandCppValue); - } - - std::string prefix = expression.scalarFn->kind == ScalarFnKind::Unary - ? "unary" - : "binary"; - std::string cppIdent = llvm::formatv("value{0}", ++localCounter); - stmts.push_back( - llvm::formatv("Value {0} = helper.{1}__{2}({3});", cppIdent, - prefix, expression.scalarFn->fnName, - interleaveToString(operandCppValues, ", "))); - return cppIdent; - } - if (expression.scalarFn && - expression.scalarFn->kind == ScalarFnKind::Type) { - // Symbolic cast. - // Operands must be arity 1. - if (expression.scalarFn->operands.size() != 1) { - emitError(genContext.getLoc()) - << "type conversion operand arity must be 1"; - return None; + if (expression.scalarFn) { + std::string enumName = + convertFunctionKindToEnumName(expression.scalarFn->kind); + + // Get the function or attribute name. + assert(expression.scalarFn->fnName || expression.scalarFn->attrName); + std::string funcType; + if (expression.scalarFn->fnName) { + funcType = llvm::formatv("{0}::{1}", enumName, + *expression.scalarFn->fnName); } - Optional operandCppValue = - generateExpression(expression.scalarFn->operands[0]); - if (!operandCppValue) - return None; - - assert(expression.scalarFn->typeVar.hasValue()); - Optional typeCppValue = - findTypeValue(expression.scalarFn->typeVar.getValue(), args); - if (!typeCppValue) { - emitError(genContext.getLoc()) - << "type variable " << expression.scalarFn->typeVar.getValue() - << ", used in a type conversion, must map to a predefined or " - << "an argument type but it does not"; - return None; - } - - // Use the function name or the attribute to build the type function. - std::string typeFunc = llvm::formatv( - "TypeFn::{0}", expression.scalarFn->fnName.getValueOr("")); if (expression.scalarFn->attrName) { if (llvm::none_of(args, [&](LinalgOperandDef &arg) { - return arg.kind == LinalgOperandDefKind::TypeFnAttr && + return isFunctionAttribute(arg.kind) && arg.name == expression.scalarFn->attrName.getValue(); })) { emitError(genContext.getLoc()) - << "missing type function attribute " + << "missing function attribute " << expression.scalarFn->attrName.getValue(); } - typeFunc = llvm::formatv("{0}Val", *expression.scalarFn->attrName); + funcType = llvm::formatv("{0}Val", *expression.scalarFn->attrName); + } + assert(!funcType.empty()); + + // Add the optional type parameter to the operands. + SmallVector operandCppValues; + if (expression.scalarFn->kind == ScalarFnKind::Type) { + assert(expression.scalarFn->typeVar.hasValue()); + Optional typeCppValue = + findTypeValue(expression.scalarFn->typeVar.getValue(), args); + if (!typeCppValue) { + emitError(genContext.getLoc()) + << "type variable " << expression.scalarFn->typeVar.getValue() + << ", used in a type conversion, must map to a predefined or " + << "an argument type but it does not"; + return None; + } + operandCppValues.push_back(typeCppValue.getValue()); } + + // Collect the scalar operands. + for (ScalarExpression &operand : expression.scalarFn->operands) { + auto operandCppValue = generateExpression(operand); + if (!operandCppValue) + return None; + operandCppValues.push_back(*operandCppValue); + } + + // Call the function builder. std::string cppIdent = llvm::formatv("value{0}", ++localCounter); stmts.push_back(llvm::formatv( - "Value {0} = helper.buildTypeFn({1}, {2}, {3});", cppIdent, - typeFunc, typeCppValue.getValue(), *operandCppValue)); + "Value {0} = helper.build{1}({2}, {3});", cppIdent, enumName, + funcType, interleaveToString(operandCppValues, ", "))); return cppIdent; } emitError(genContext.getLoc()) << "unknown ScalarExpression type";