diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -55,6 +55,7 @@ them to the same data type as the accumulator/output. """ domain(D.m, D.n, D.k) + defines(Canonicalizer) implements(ContractionOpInterface) C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) ``` @@ -77,6 +78,9 @@ Special identifying op interfaces can be declared for the op via `implements(interface1[, interface2...])`. +Extra method definitions can be declared for the op via +`defines(definition1[, definition2...])`. + ## Parameters Structured operations take two types of runtime parameters namely scalars and diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2761,6 +2761,10 @@ Works for arbitrary ranked output tensors since the operation performs scalar accesses only and is thus rank polymorphic. Numeric casting is performed on the value operand, promoting it to the same data type as the output. + implements: + - LinalgFillOpInterface + defines: + - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -552,6 +552,10 @@ FoldInsertPadIntoFill>(context); } +// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp. +void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) {} + //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -631,6 +631,16 @@ FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") +class OpDefinitionDef: + """A method that an op implements.""" + + def __init__(self, def_name: str): + self.def_name = def_name + + +Canonicalizer = OpDefinitionDef("hasCanonicalizer") + + class OpMetadataDef(YAMLObject): """Metadata about the op (generally not behavior impacting).""" yaml_tag = "!LinalgOpMetadata" @@ -641,6 +651,7 @@ self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name self.doc = doc self.implements = [] # type: List[OpInterfaceDef] + self.defines = [] # type: List[OpDefinitionsDef] def to_yaml_custom_dict(self): d = dict( @@ -650,6 +661,8 @@ ) if self.implements: d["implements"] = [intr.cpp_name for intr in self.implements] + if self.defines: + d["defines"] = [defi.def_name for defi in self.defines] return d diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -148,13 +148,21 @@ return DefinedOpCallable(op_name, op_def) +def domain(*dimensions: DimDef): + if any(not isinstance(d, DimDef) for d in dimensions): + raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") + current_op_def().domain.extend(dimensions) + + def implements(*interfaces: OpInterfaceDef): + if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): + raise ValueError( + f"Expected interfaces of type OpInterfaceDef but got {interfaces}") current_op_def().metadata.implements.extend(interfaces) -def domain(*dimensions: DimDef): - if current_op_def().domain: - raise ValueError(f"Expected only one set of domain dimensions per operator") - if any(not isinstance(dim, DimDef) for dim in dimensions): - raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") - current_op_def().domain.extend(dimensions) +def defines(*definitions: OpDefinitionDef): + if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): + raise ValueError( + f"Expected definitions of type OpDefinitionDef but got {definitions}") + current_op_def().metadata.defines.extend(definitions) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -637,6 +637,7 @@ the value operand, promoting it to the same data type as the output. """ implements(FillOpInterface) + defines(Canonicalizer) O[None] = TypeFn.cast(U, value) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -309,3 +309,58 @@ # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.unary__exp(block.getArgument(0)) # IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.binary__add([[VAL0]], block.getArgument(0)) # IMPL-NEXT: yields.push_back([[VAL1]]) + +# @linalg_structured_op +# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)): +# """Title. + +# Detailed description. +# """ +# implements(FillOpInterface) +# defines(Canonicalizer) +# O[None] = TypeFn.cast(U, value) + +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: test5 + cpp_class_name: Test5Op + doc: |- + Title. + + Detailed description. + implements: + - LinalgFillOpInterface + defines: + - hasCanonicalizer +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: value + kind: scalar + type_var: T1 + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: type + fn_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: value + +# ODS-LABEL: def Test5Op : LinalgStructuredBase_Op<"test5" +# ODS-NEXT: /*extraInterfaces=*/[LinalgFillOpInterface])> + +# ODS: let hasCanonicalizer = 1; diff --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/metadata.py rename from mlir/test/python/dialects/linalg/opdsl/interfaces.py rename to mlir/test/python/dialects/linalg/opdsl/metadata.py --- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py +++ b/mlir/test/python/dialects/linalg/opdsl/metadata.py @@ -7,10 +7,13 @@ # CHECK-LABEL: matmul # CHECK: implements: # CHECK-NEXT: - LinalgContractionOpInterface +# CHECK: defines: +# CHECK-NEXT: - hasCanonicalizer @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): implements(ContractionOpInterface) + defines(Canonicalizer) C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -53,6 +53,7 @@ std::string cppClassName; Optional doc; SmallVector implements; + SmallVector defines; }; struct SerializedAffineMap { @@ -229,6 +230,7 @@ io.mapRequired("cpp_class_name", info.cppClassName); io.mapOptional("doc", info.doc); io.mapOptional("implements", info.implements); + io.mapOptional("defines", info.defines); } }; @@ -455,7 +457,8 @@ // {3}: documentation (summary + description) // {4}: op attribute list // {5}: builder methods taking standalone attribute parameters -// {6}: additional methods for attributes used by indexing maps +// {6}: additional method defintions +// {7}: additional methods for attributes used by indexing maps static const char structuredOpOdsHeaderFormat[] = R"FMT( //===----------------------------------------------------------------------===// // Op definition for {0} @@ -529,6 +532,7 @@ ]; let hasCustomAssemblyFormat = 1; let hasFolder = 1; + {6} let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. @@ -545,7 +549,7 @@ // Generic methods. static unsigned getNumRegionArgs(); std::string getLibraryCallName(); - {6} + {7} }]; } )FMT"; @@ -692,6 +696,12 @@ interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); + std::string definitionList; + for (const std::string &definition : opConfig.metadata->defines) { + static const char definitionFmt[] = "let {0} = 1;\n"; + definitionList.append(llvm::formatv(definitionFmt, definition)); + } + if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { return arg.kind == LinalgOperandDefKind::IndexAttr || arg.kind == LinalgOperandDefKind::TypeFnAttr; @@ -750,7 +760,7 @@ os << llvm::formatv(structuredOpOdsHeaderFormat, opConfig.metadata->cppClassName, opConfig.metadata->name, interfaceNameList, doc, attrList, attrBuilder, - attrMethods); + definitionList, attrMethods); return success(); }