diff --git a/mlir/docs/Dialects/Linalg/OpDSL.md b/mlir/docs/Dialects/Linalg/OpDSL.md --- a/mlir/docs/Dialects/Linalg/OpDSL.md +++ b/mlir/docs/Dialects/Linalg/OpDSL.md @@ -55,6 +55,7 @@ them to the same data type as the accumulator/output. """ domain(D.m, D.n, D.k) + defines(Canonicalizer) implements(ContractionOpInterface) C[D.m, D.n] += TypeFn.cast_signed( U, A[D.m, D.k]) * TypeFn.cast_signed(U, B[D.k, D.n]) @@ -78,6 +79,9 @@ Special identifying op interfaces can be declared for the op via `implements(interface1[, interface2...])`. +Extra method definitions can be declared for the op via +`defines(definition1[, definition2...])`. + ## Parameters Structured operations take two types of runtime parameters namely scalars and diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -2877,6 +2877,8 @@ the value operand, promoting it to the same data type as the output. implements: - LinalgFillOpInterface + defines: + - hasCanonicalizer structured_op: !LinalgStructuredOpConfig args: - !LinalgOperandDefConfig diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -509,6 +509,10 @@ FoldInsertPadIntoFill>(context); } +// TODO: Add the FillOp patterns when transitioning to the OpDSL FillOp. +void FillTensorOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) {} + //===----------------------------------------------------------------------===// // GenericOps //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -689,6 +689,16 @@ FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") +class OpDefinitionDef: + """A method that an op implements.""" + + def __init__(self, def_name: str): + self.def_name = def_name + + +Canonicalizer = OpDefinitionDef("hasCanonicalizer") + + class OpMetadataDef(YAMLObject): """Metadata about the op (generally not behavior impacting).""" yaml_tag = "!LinalgOpMetadata" @@ -699,6 +709,7 @@ self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name self.doc = doc self.implements = [] # type: List[OpInterfaceDef] + self.defines = [] # type: List[OpDefinitionsDef] def to_yaml_custom_dict(self): d = dict( @@ -708,6 +719,8 @@ ) if self.implements: d["implements"] = [intr.cpp_name for intr in self.implements] + if self.defines: + d["defines"] = [defi.def_name for defi in self.defines] return d diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -149,13 +149,21 @@ return DefinedOpCallable(op_name, op_def) +def domain(*dimensions: DimDef): + if any(not isinstance(d, DimDef) for d in dimensions): + raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") + current_op_def().domain.extend(dimensions) + + def implements(*interfaces: OpInterfaceDef): + if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): + raise ValueError( + f"Expected interfaces of type OpInterfaceDef but got {interfaces}") current_op_def().metadata.implements.extend(interfaces) -def domain(*dimensions: DimDef): - if current_op_def().domain: - raise ValueError(f"Expected only one set of domain dimensions per operator") - if any(not isinstance(dim, DimDef) for dim in dimensions): - raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") - current_op_def().domain.extend(dimensions) +def defines(*definitions: OpDefinitionDef): + if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): + raise ValueError( + f"Expected definitions of type OpDefinitionDef but got {definitions}") + current_op_def().metadata.defines.extend(definitions) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -672,6 +672,7 @@ the value operand, promoting it to the same data type as the output. """ implements(FillOpInterface) + defines(Canonicalizer) O[None] = TypeFn.cast_signed(U, value) diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-yaml-gen.yaml @@ -333,3 +333,58 @@ # IMPL: Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0)) # IMPL-NEXT: Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0)) # IMPL-NEXT: yields.push_back([[VAL1]]) + +# @linalg_structured_op +# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)): +# """Title. + +# Detailed description. +# """ +# implements(FillOpInterface) +# defines(Canonicalizer) +# O[None] = TypeFn.cast(U, value) + +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: test5 + cpp_class_name: Test5Op + doc: |- + Title. + + Detailed description. + implements: + - LinalgFillOpInterface + defines: + - hasCanonicalizer +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: value + kind: scalar + type_var: T1 + - !LinalgOperandDefConfig + name: O + kind: output_tensor + type_var: U + shape_map: affine_map<() -> ()> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<() -> ()> + - affine_map<() -> ()> + iterator_types: [] + assignments: + - !ScalarAssign + arg: O + value: !ScalarExpression + scalar_fn: + kind: type + fn_name: cast + type_var: U + operands: + - !ScalarExpression + scalar_arg: value + +# ODS-LABEL: def Test5Op : LinalgStructuredBase_Op<"test5" +# ODS-NEXT: /*extraInterfaces=*/[LinalgFillOpInterface])> + +# ODS: let hasCanonicalizer = 1; diff --git a/mlir/test/python/dialects/linalg/opdsl/interfaces.py b/mlir/test/python/dialects/linalg/opdsl/metadata.py rename from mlir/test/python/dialects/linalg/opdsl/interfaces.py rename to mlir/test/python/dialects/linalg/opdsl/metadata.py --- a/mlir/test/python/dialects/linalg/opdsl/interfaces.py +++ b/mlir/test/python/dialects/linalg/opdsl/metadata.py @@ -7,11 +7,14 @@ # CHECK-LABEL: matmul # CHECK: implements: # CHECK-NEXT: - LinalgContractionOpInterface +# CHECK: defines: +# CHECK-NEXT: - hasCanonicalizer @linalg_structured_op def matmul( A=TensorDef(T, S.M, S.K), B=TensorDef(T, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): implements(ContractionOpInterface) + defines(Canonicalizer) C[D.m, D.n] += TypeFn.cast_signed(U, A[D.m, D.k]) * TypeFn.cast_signed( U, B[D.k, D.n]) diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp @@ -53,6 +53,7 @@ std::string cppClassName; Optional doc; SmallVector implements; + SmallVector defines; }; struct SerializedAffineMap { @@ -233,6 +234,7 @@ io.mapRequired("cpp_class_name", info.cppClassName); io.mapOptional("doc", info.doc); io.mapOptional("implements", info.implements); + io.mapOptional("defines", info.defines); } }; @@ -499,7 +501,8 @@ // {3}: documentation (summary + description) // {4}: op attribute list // {5}: builder methods taking standalone attribute parameters -// {6}: additional methods for attributes used by indexing maps +// {6}: additional method defintions +// {7}: additional methods for attributes used by indexing maps static const char structuredOpOdsHeaderFormat[] = R"FMT( //===----------------------------------------------------------------------===// // Op definition for {0} @@ -573,6 +576,7 @@ ]; let hasCustomAssemblyFormat = 1; let hasFolder = 1; + {6} let extraClassDeclaration = structuredOpsBaseDecls # [{{ // Auto-generated. @@ -589,7 +593,7 @@ // Generic methods. static unsigned getNumRegionArgs(); std::string getLibraryCallName(); - {6} + {7} }]; } )FMT"; @@ -736,6 +740,12 @@ interfaceNameList = interleaveToString(opConfig.metadata->implements, ", "); + std::string definitionList; + for (const std::string &definition : opConfig.metadata->defines) { + static const char definitionFmt[] = "let {0} = 1;\n"; + definitionList.append(llvm::formatv(definitionFmt, definition)); + } + if (llvm::any_of(opConfig.structuredOp->args, [](LinalgOperandDef &arg) { return isAttribute(arg.kind); })) { @@ -794,7 +804,7 @@ os << llvm::formatv(structuredOpOdsHeaderFormat, opConfig.metadata->cppClassName, opConfig.metadata->name, interfaceNameList, doc, attrList, attrBuilder, - attrMethods); + definitionList, attrMethods); return success(); }