diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -1,10 +1,10 @@ -add_subdirectory(Affine) add_subdirectory(AMDGPU) +add_subdirectory(AMX) +add_subdirectory(Affine) add_subdirectory(Arithmetic) -add_subdirectory(Async) add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) -add_subdirectory(AMX) +add_subdirectory(Async) add_subdirectory(Bufferization) add_subdirectory(Complex) add_subdirectory(ControlFlow) @@ -12,11 +12,11 @@ add_subdirectory(EmitC) add_subdirectory(Func) add_subdirectory(GPU) -add_subdirectory(Math) -add_subdirectory(Linalg) add_subdirectory(LLVMIR) -add_subdirectory(MemRef) +add_subdirectory(Linalg) add_subdirectory(MLProgram) +add_subdirectory(Math) +add_subdirectory(MemRef) add_subdirectory(NVGPU) add_subdirectory(OpenACC) add_subdirectory(OpenMP) @@ -24,11 +24,12 @@ add_subdirectory(PDLInterp) add_subdirectory(Quant) add_subdirectory(SCF) +add_subdirectory(SPIRV) add_subdirectory(Shape) add_subdirectory(SparseTensor) -add_subdirectory(SPIRV) add_subdirectory(Tensor) add_subdirectory(Tosa) add_subdirectory(Transform) +add_subdirectory(Utils) add_subdirectory(Vector) add_subdirectory(X86Vector) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Linalg/IR/CMakeLists.txt @@ -44,7 +44,7 @@ add_mlir_dialect(LinalgOps linalg) -set(LLVM_TARGET_DEFINITIONS LinalgOps.td) +set(LLVM_TARGET_DEFINITIONS LinalgEnums.td) mlir_tablegen(LinalgOpsEnums.h.inc -gen-enum-decls) mlir_tablegen(LinalgOpsEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRLinalgOpsEnumsIncGen) diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.h @@ -0,0 +1,31 @@ +//===- LinalgBase.h - Linalg base includes ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Includes shared between Linalg.h and LinalgInterface.h. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LINALG_IR_LINALGBASE_H_ +#define MLIR_DIALECT_LINALG_IR_LINALGBASE_H_ + +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/IR/BuiltinTypes.h" + +//===----------------------------------------------------------------------===// +// Linalg Enums +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc" + +//===----------------------------------------------------------------------===// +// Linalg Attributes +//===----------------------------------------------------------------------===// + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc" + +#endif // MLIR_DIALECT_LINALG_IR_LINALGBASE_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -13,6 +13,7 @@ #ifndef LINALG_BASE #define LINALG_BASE +include "mlir/Dialect/Linalg/IR/LinalgEnums.td" include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" @@ -62,37 +63,6 @@ } // Define the function attribute enums matching the OpDSL functions. -def UnaryFn : I32EnumAttr<"UnaryFn", "", [ - I32EnumAttrCase<"exp", 0>, - I32EnumAttrCase<"log", 1>, - I32EnumAttrCase<"abs", 2>, - I32EnumAttrCase<"ceil", 3>, - I32EnumAttrCase<"floor", 4>, - I32EnumAttrCase<"negf", 5> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::linalg"; -} -def BinaryFn : I32EnumAttr<"BinaryFn", "", [ - I32EnumAttrCase<"add", 0>, - I32EnumAttrCase<"sub", 1>, - I32EnumAttrCase<"mul", 2>, - I32EnumAttrCase<"max_signed", 3>, - I32EnumAttrCase<"min_signed", 4>, - I32EnumAttrCase<"max_unsigned", 5>, - I32EnumAttrCase<"min_unsigned", 6> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::linalg"; -} -def TypeFn : I32EnumAttr<"TypeFn", "", [ - I32EnumAttrCase<"cast_signed", 0>, - I32EnumAttrCase<"cast_unsigned", 1> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::linalg"; -} - def UnaryFnAttr : EnumAttr { let assemblyFormat = "`<` $value `>`"; } diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td @@ -0,0 +1,50 @@ +//===- LinalgBase.td - Linalg dialect base support ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for enums used in linear algebra operations. +// +//===----------------------------------------------------------------------===// + +#ifndef LINALG_ENUMS +#define LINALG_ENUMS + +include "mlir/IR/EnumAttr.td" + +// Define the function attribute enums matching the OpDSL functions. +def UnaryFn : I32EnumAttr<"UnaryFn", "", [ + I32EnumAttrCase<"exp", 0>, + I32EnumAttrCase<"log", 1>, + I32EnumAttrCase<"abs", 2>, + I32EnumAttrCase<"ceil", 3>, + I32EnumAttrCase<"floor", 4>, + I32EnumAttrCase<"negf", 5> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} +def BinaryFn : I32EnumAttr<"BinaryFn", "", [ + I32EnumAttrCase<"add", 0>, + I32EnumAttrCase<"sub", 1>, + I32EnumAttrCase<"mul", 2>, + I32EnumAttrCase<"max_signed", 3>, + I32EnumAttrCase<"min_signed", 4>, + I32EnumAttrCase<"max_unsigned", 5>, + I32EnumAttrCase<"min_unsigned", 6> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} +def TypeFn : I32EnumAttr<"TypeFn", "", [ + I32EnumAttrCase<"cast_signed", 0>, + I32EnumAttrCase<"cast_unsigned", 1> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::linalg"; +} + +#endif // LINALG_ENUMS diff --git a/mlir/include/mlir/Dialect/Utils/CMakeLists.txt b/mlir/include/mlir/Dialect/Utils/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Utils/CMakeLists.txt @@ -0,0 +1,5 @@ +set(LLVM_TARGET_DEFINITIONS StructuredOpsUtils.td) +mlir_tablegen(DialectUtilsEnums.h.inc -gen-enum-decls) +mlir_tablegen(DialectUtilsEnums.cpp.inc -gen-enum-defs) +add_public_tablegen_target(MLIRDialectUtilsIncGen) +add_dependencies(mlir-headers MLIRDialectUtilsIncGen) diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -23,6 +23,9 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" +// Pull in all enum type definitions and utility function declarations. +#include "mlir/Dialect/Utils/DialectUtilsEnums.h.inc" + namespace mlir { class OpBuilder; diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.td @@ -0,0 +1,23 @@ +//===- StructuredOpsUtils.td - structured ops enums --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef STRUCTURED_OPS_UTILS +#define STRUCTURED_OPS_UTILS + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" + +def IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [ + I32EnumAttrCase<"parallel", 0>, + I32EnumAttrCase<"reduction", 1> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::utils"; +} + +#endif // STRUCTURED_OPS_UTILS diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h --- a/mlir/include/mlir/Interfaces/TilingInterface.h +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -14,6 +14,7 @@ #ifndef MLIR_INTERFACES_TILINGINTERFACE_H_ #define MLIR_INTERFACES_TILINGINTERFACE_H_ +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -41,13 +41,9 @@ >, InterfaceMethod< /*desc=*/[{ - Returns a list of `StringRef`s that describe the number of - loops and the iterator types of the operation. The list is - expected to use - `getParallelIteratorTypeName()`/`getReductionIteratorTypeName()` - from MLIR Structured Op Utils. + Returns a list of iterator types that describe the number of loops. }], - /*retType=*/"SmallVector", + /*retType=*/"SmallVector", /*methodName=*/"getLoopIteratorTypes", /*args=*/(ins), /*methodBody=*/"", diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -638,7 +638,8 @@ auto iteratorTypesRange = linalgOp.iterator_types().getAsValueRange(); for (StringRef iteratorType : iteratorTypesRange) { - if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType)) + if (!llvm::is_contained(getAllIteratorTypeNames(), iteratorType) || + !utils::symbolizeIteratorType(iteratorType).has_value()) return op->emitOpError("unexpected iterator_type (") << iteratorType << ")"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -90,11 +90,13 @@ } /// Return the loop iterator type. - SmallVector getLoopIteratorTypes(Operation *op) const { + SmallVector getLoopIteratorTypes(Operation *op) const { LinalgOpTy concreteOp = cast(op); return llvm::to_vector( llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) { - return strAttr.cast().getValue(); + return utils::symbolizeIteratorType( + strAttr.cast().getValue()) + .getValue(); })); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -36,10 +36,10 @@ return {initTensor}; } - SmallVector getLoopIteratorTypes(Operation *op) const { + SmallVector getLoopIteratorTypes(Operation *op) const { auto padOp = cast(op); - SmallVector iteratorTypes(padOp.getResultType().getRank(), - getParallelIteratorTypeName()); + SmallVector iteratorTypes( + padOp.getResultType().getRank(), utils::IteratorType::parallel); return iteratorTypes; } diff --git a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/StructuredOpsUtils.cpp @@ -10,6 +10,8 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc" + using namespace mlir; bool mlir::isRowMajorMatmul(ArrayAttr indexingMaps) { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -2298,6 +2298,33 @@ ], ) +td_library( + name = "DialectUtilsTdFiles", + srcs = [ + "include/mlir/Dialect/Utils/StructuredOpsUtils.td", + ], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + +gentbl_cc_library( + name = "DialectUtilsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/Utils/DialectUtilsEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/Utils/DialectUtilsEnums.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Utils/StructuredOpsUtils.td", + deps = [":DialectUtilsTdFiles"], +) + cc_library( name = "DialectUtils", srcs = glob([ @@ -2309,6 +2336,7 @@ ]), includes = ["include"], deps = [ + ":DialectUtilsIncGen", ":IR", ":Support", "//llvm:Support", @@ -7190,11 +7218,13 @@ name = "LinalgOpsTdFiles", srcs = [ "include/mlir/Dialect/Linalg/IR/LinalgBase.td", + "include/mlir/Dialect/Linalg/IR/LinalgEnums.td", "include/mlir/Dialect/Linalg/IR/LinalgOps.td", ], includes = ["include"], deps = [ ":ControlFlowInterfacesTdFiles", + ":DialectUtilsTdFiles", ":InferTypeOpInterfaceTdFiles", ":LoopLikeInterfaceTdFiles", ":OpBaseTdFiles", @@ -7242,14 +7272,6 @@ ], "include/mlir/Dialect/Linalg/IR/LinalgOpsDialect.cpp.inc", ), - ( - ["-gen-enum-decls"], - "include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc", - ), - ( - ["-gen-enum-defs"], - "include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc", - ), ( ["-gen-attrdef-decls"], "include/mlir/Dialect/Linalg/IR/LinalgOpsAttrDefs.h.inc", @@ -7264,6 +7286,24 @@ deps = [":LinalgOpsTdFiles"], ) +gentbl_cc_library( + name = "LinalgEnumsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-enum-decls"], + "include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.h.inc", + ), + ( + ["-gen-enum-defs"], + "include/mlir/Dialect/Linalg/IR/LinalgOpsEnums.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Linalg/IR/LinalgEnums.td", + deps = [":LinalgOpsTdFiles"], +) + gentbl_cc_library( name = "LinalgTransformOpsIncGen", strip_include_prefix = "include", @@ -7519,6 +7559,7 @@ ":FuncDialect", ":IR", ":InferTypeOpInterface", + ":LinalgEnumsIncGen", ":LinalgInterfacesIncGen", ":LinalgNamedStructuredOpsYamlIncGen", ":LinalgOpsIncGen", @@ -7709,6 +7750,7 @@ hdrs = ["include/mlir/Interfaces/TilingInterface.h"], includes = ["include"], deps = [ + ":DialectUtils", ":IR", ":Support", ":TilingInterfaceIncGen",