diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h --- a/mlir/include/mlir-c/Dialect/Transform.h +++ b/mlir/include/mlir-c/Dialect/Transform.h @@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); +//===---------------------------------------------------------------------===// +// AnyValueType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx); + //===---------------------------------------------------------------------===// // OperationType //===---------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -31,6 +31,20 @@ "Get an instance of AnyOpType in the given context.", py::arg("cls"), py::arg("context") = py::none()); + //===-------------------------------------------------------------------===// + // AnyValueType + //===-------------------------------------------------------------------===// + + auto anyValueType = + mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType); + anyValueType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirTransformAnyValueTypeGet(ctx)); + }, + "Get an instance of AnyValueType in the given context.", py::arg("cls"), + py::arg("context") = py::none()); + //===-------------------------------------------------------------------===// // OperationType //===-------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp --- a/mlir/lib/CAPI/Dialect/Transform.cpp +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -29,6 +29,18 @@ return wrap(transform::AnyOpType::get(unwrap(ctx))); } +//===---------------------------------------------------------------------===// +// AnyValueType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformAnyValueType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) { + return wrap(transform::AnyValueType::get(unwrap(ctx))); +} + //===---------------------------------------------------------------------===// // OperationType //===---------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -1070,8 +1070,7 @@ llvm::dbgs() << "\n"); // clang-format on for (auto [activeId, activeMappingSize, availableMappingSize] : - llvm::zip_equal(activeIdOps, activeMappingSizes, - availableMappingSizes)) { + llvm::zip(activeIdOps, activeMappingSizes, availableMappingSizes)) { if (activeMappingSize > availableMappingSize) { return definiteFailureHelper( transformOp, forallOp,