diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -503,41 +503,10 @@ // Generic Linalg ops. //===----------------------------------------------------------------------===// -class GenericOpBase : LinalgStructuredBase_Op, SingleBlockImplicitTerminator<"YieldOp">]> { - let arguments = (ins Variadic:$inputs, - Variadic:$outputs, - AffineMapArrayAttr:$indexing_maps, - ArrayAttr:$iterator_types, - OptionalAttr:$doc, - OptionalAttr:$library_call); - let results = (outs Variadic:$result_tensors); - let regions = (region AnyRegion:$region); - let extraClassDeclaration = structuredOpsBaseDecls # [{ - SmallVector linalgTraitAttrNames() { - return SmallVector{ - getDocAttrName(), - getIndexingMapsAttrName(), getLibraryCallAttrName(), - getIteratorTypesAttrName(), - }; - } - std::string getLibraryCallName() { - return library_call().hasValue() ? - library_call()->str() : "op_has_no_registered_library_name"; - } - - static std::function - getRegionBuilder() { - return nullptr; - } - }]; - let printer = [{ return ::print(p, *this); }]; - let parser = [{ return ::parseGenericOp(parser, result); }]; -} - -def GenericOp : GenericOpBase<"generic"> { let description = [{ Generic Linalg op form where the key properties of the computation are specified as attributes. In pretty form, a `linalg.generic` op is written @@ -636,6 +605,15 @@ ``` }]; + let arguments = (ins Variadic:$inputs, + Variadic:$outputs, + AffineMapArrayAttr:$indexing_maps, + ArrayAttr:$iterator_types, + OptionalAttr:$doc, + OptionalAttr:$library_call); + let results = (outs Variadic:$result_tensors); + let regions = (region AnyRegion:$region); + let builders = [ OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayRef":$indexingMaps, @@ -654,6 +632,29 @@ "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, CArg<"function_ref", "nullptr">)> ]; + + let extraClassDeclaration = structuredOpsBaseDecls # [{ + SmallVector linalgTraitAttrNames() { + return SmallVector{ + getDocAttrName(), + getIndexingMapsAttrName(), getLibraryCallAttrName(), + getIteratorTypesAttrName(), + }; + } + std::string getLibraryCallName() { + return library_call().hasValue() ? + library_call()->str() : "op_has_no_registered_library_name"; + } + + static std::function + getRegionBuilder() { + return nullptr; + } + }]; + + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parseGenericOp(parser, result); }]; + let verifier = [{ return ::verify(*this); }]; let hasFolder = 1;