diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2342,12 +2342,15 @@ : Availability { let interfaceName = name; - let queryFnRetType = scheme.returnType; + let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">"; let queryFnName = "getMinVersion"; - let mergeAction = "$overall = static_cast<" # scheme.returnType # ">(" - "std::max($overall, $instance))"; - let initializer = "static_cast<" # scheme.returnType # ">(uint32_t(0))"; + let mergeAction = "{ " + "if ($overall.hasValue()) { " + "$overall = static_cast<" # scheme.returnType # ">(" + "std::max(*$overall, $instance)); " + "} else { $overall = $instance; }}"; + let initializer = "::llvm::None"; let instanceType = scheme.cppNamespace # "::" # scheme.className; let instance = scheme.cppNamespace # "::" # scheme.className # "::" # @@ -2358,12 +2361,15 @@ : Availability { let interfaceName = name; - let queryFnRetType = scheme.returnType; + let queryFnRetType = "llvm::Optional<" # scheme.returnType # ">"; let queryFnName = "getMaxVersion"; - let mergeAction = "$overall = static_cast<" # scheme.returnType # ">(" - "std::min($overall, $instance))"; - let initializer = "static_cast<" # scheme.returnType # ">(~uint32_t(0))"; + let mergeAction = "{ " + "if ($overall.hasValue()) { " + "$overall = static_cast<" # scheme.returnType # ">(" + "std::min(*$overall, $instance)); " + "} else { $overall = $instance; }}"; + let initializer = "::llvm::None"; let instanceType = scheme.cppNamespace # "::" # scheme.className; let instance = scheme.cppNamespace # "::" # scheme.className # "::" # diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -841,22 +841,24 @@ // Make sure this op is available at the given version. Ops not implementing // QueryMinVersionInterface/QueryMaxVersionInterface are available to all // SPIR-V versions. - if (auto minVersion = dyn_cast(op)) - if (minVersion.getMinVersion() > this->targetEnv.getVersion()) { + if (auto minVersionIfx = dyn_cast(op)) { + Optional minVersion = minVersionIfx.getMinVersion(); + if (minVersion && *minVersion > this->targetEnv.getVersion()) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring min version " - << spirv::stringifyVersion(minVersion.getMinVersion()) - << "\n"); + << spirv::stringifyVersion(*minVersion) << "\n"); return false; } - if (auto maxVersion = dyn_cast(op)) - if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) { + } + if (auto maxVersionIfx = dyn_cast(op)) { + Optional maxVersion = maxVersionIfx.getMaxVersion(); + if (maxVersion && *maxVersion < this->targetEnv.getVersion()) { LLVM_DEBUG(llvm::dbgs() << op->getName() << " illegal: requiring max version " - << spirv::stringifyVersion(maxVersion.getMaxVersion()) - << "\n"); + << spirv::stringifyVersion(*maxVersion) << "\n"); return false; } + } // Make sure this op's required extensions are allowed to use. Ops not // implementing QueryExtensionInterface do not require extensions to be diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -109,13 +109,17 @@ // requirements. WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult { // Op min version requirements - if (auto minVersion = dyn_cast(op)) { - deducedVersion = std::max(deducedVersion, minVersion.getMinVersion()); - if (deducedVersion > allowedVersion) { - return op->emitError("'") << op->getName() << "' requires min version " - << spirv::stringifyVersion(deducedVersion) - << " but target environment allows up to " - << spirv::stringifyVersion(allowedVersion); + if (auto minVersionIfx = dyn_cast(op)) { + Optional minVersion = minVersionIfx.getMinVersion(); + if (minVersion) { + deducedVersion = std::max(deducedVersion, *minVersion); + if (deducedVersion > allowedVersion) { + return op->emitError("'") + << op->getName() << "' requires min version " + << spirv::stringifyVersion(deducedVersion) + << " but target environment allows up to " + << spirv::stringifyVersion(allowedVersion); + } } } diff --git a/mlir/test/IR/op-availability.mlir b/mlir/test/IR/op-availability.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/op-availability.mlir @@ -0,0 +1,78 @@ +// RUN: mlir-opt -mlir-disable-threading -test-op-availability %s | FileCheck %s + +// CHECK-LABEL: availability in @min_version +func @min_version() { + // CHECK: test.min_version_op min version: v1.1 + // CHECK: test.min_version_op has no max version requirements + "test.min_version_op"() : () -> () + return +} + +// CHECK-LABEL: availability in @max_version +func @max_version() { + // CHECK: test.max_version_op has no min version requirements + // CHECK: test.max_version_op max version: v1.2 + "test.max_version_op"() : () -> () + return +} + +// CHECK-LABEL: availability in @min_max_version +func @min_max_version() { + // CHECK: test.min_max_version_op min version: v1.0 + // CHECK: test.min_max_version_op max version: v1.3 + "test.min_max_version_op"() : () -> () + return +} + +// Enum case A (0): no version requirements +// Enum case B (1): <= v1.2 +// Enum case C (2): >= v1.1 + +// CHECK-LABEL: availability in @one_versioned_attr_op +func @one_versioned_attr_op() { + // CHECK: test.one_versioned_attr_op min version: None + // CHECK: test.one_versioned_attr_op max version: None + "test.one_versioned_attr_op"() {attr = 0: i32} : () -> () + + // CHECK: test.one_versioned_attr_op min version: None + // CHECK: test.one_versioned_attr_op max version: v1.2 + "test.one_versioned_attr_op"() {attr = 1: i32} : () -> () + + // CHECK: test.one_versioned_attr_op min version: v1.1 + // CHECK: test.one_versioned_attr_op max version: None + "test.one_versioned_attr_op"() {attr = 2: i32} : () -> () + return +} + +// CHECK-LABEL: availability in @two_versioned_attr_op +func @two_versioned_attr_op() { + // CHECK: test.two_versioned_attr_op min version: None + // CHECK: test.two_versioned_attr_op max version: v1.2 + "test.two_versioned_attr_op"() {attr1 = 0: i32, attr2 = 1: i32} : () -> () + + // CHECK: test.two_versioned_attr_op min version: v1.1 + // CHECK: test.two_versioned_attr_op max version: None + "test.two_versioned_attr_op"() {attr1 = 0: i32, attr2 = 2: i32} : () -> () + + // CHECK: test.two_versioned_attr_op min version: v1.1 + // CHECK: test.two_versioned_attr_op max version: v1.2 + "test.two_versioned_attr_op"() {attr1 = 1: i32, attr2 = 2: i32} : () -> () + return +} + + +// CHECK-LABEL: availability in @mix_versioned_attr_op +func @mix_versioned_attr_op() { + // CHECK: test.mix_versioned_attr_op min version: v1.0 + // CHECK: test.mix_versioned_attr_op max version: v1.2 + "test.mix_versioned_attr_op"() {attr1 = 0: i32, attr2 = 1: i32} : () -> () + + // CHECK: test.mix_versioned_attr_op min version: v1.1 + // CHECK: test.mix_versioned_attr_op max version: v1.3 + "test.mix_versioned_attr_op"() {attr1 = 0: i32, attr2 = 2: i32} : () -> () + + // CHECK: test.mix_versioned_attr_op min version: v1.1 + // CHECK: test.mix_versioned_attr_op max version: v1.2 + "test.mix_versioned_attr_op"() {attr1 = 1: i32, attr2 = 2: i32} : () -> () + return +} diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp --- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp +++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp @@ -43,13 +43,23 @@ auto opName = op->getName(); auto &os = llvm::outs(); - if (auto minVersion = dyn_cast(op)) - os << opName << " min version: " - << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n"; + if (auto minVersionIfx = dyn_cast(op)) { + Optional minVersion = minVersionIfx.getMinVersion(); + os << opName << " min version: "; + if (minVersion) + os << spirv::stringifyVersion(*minVersion) << "\n"; + else + os << "None\n"; + } - if (auto maxVersion = dyn_cast(op)) - os << opName << " max version: " - << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n"; + if (auto maxVersionIfx = dyn_cast(op)) { + Optional maxVersion = maxVersionIfx.getMaxVersion(); + os << opName << " max version: "; + if (maxVersion) + os << spirv::stringifyVersion(*maxVersion) << "\n"; + else + os << "None\n"; + } if (auto extension = dyn_cast(op)) { os << opName << " extensions: ["; @@ -81,7 +91,7 @@ } namespace mlir { -void registerPrintOpAvailabilityPass() { +void registerPrintSpirvAvailabilityPass() { PassRegistration(); } } // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -34,6 +34,10 @@ mlir_tablegen(TestOpStructs.h.inc -gen-struct-attr-decls) mlir_tablegen(TestOpStructs.cpp.inc -gen-struct-attr-defs) mlir_tablegen(TestPatterns.inc -gen-rewriters) +mlir_tablegen(TestOpAvailability.h.inc -gen-avail-interface-decls) +mlir_tablegen(TestOpAvailability.cpp.inc -gen-avail-interface-defs) +mlir_tablegen(TestEnumAvailability.h.inc -gen-enum-avail-decls) +mlir_tablegen(TestEnumAvailability.cpp.inc -gen-enum-avail-defs) add_public_tablegen_target(MLIRTestOpsIncGen) # Exclude tests from libMLIR.so @@ -41,6 +45,7 @@ TestAttributes.cpp TestDialect.cpp TestInterfaces.cpp + TestOpAvailability.cpp TestPatterns.cpp TestTraits.cpp TestTypes.cpp diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -42,6 +42,15 @@ #include "TestOpStructs.h.inc" #include "TestOpsDialect.h.inc" +// This needs to be included after TestOpEnums.h.inc due to symbol dependency. +#include "TestEnumAvailability.h.inc" + +namespace mlir { +namespace test { +#include "TestOpAvailability.h.inc" +} // namespace test +} // namespace mlir + #define GET_OP_CLASSES #include "TestOps.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOpAvailability.cpp b/mlir/test/lib/Dialect/Test/TestOpAvailability.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOpAvailability.cpp @@ -0,0 +1,85 @@ +//===- TestOpAvailability.cpp - Pass to test op availability --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +#include "TestEnumAvailability.cpp.inc" + +namespace mlir { +namespace test { +#include "TestOpAvailability.cpp.inc" +} // namespace test +} // namespace mlir + +//===----------------------------------------------------------------------===// +// Printing op availability pass +//===----------------------------------------------------------------------===// + +namespace { +/// A pass for testing SPIR-V op availability. +struct PrintOpAvailability + : public PassWrapper { + StringRef getArgument() const final { return "test-op-availability"; } + StringRef getDescription() const final { return "Test op availability"; } + + void runOnFunction() override; +}; +} // end anonymous namespace + +void PrintOpAvailability::runOnFunction() { + auto f = getFunction(); + llvm::outs() << "availability in @" << f.getName() << "\n"; + + Dialect *spvDialect = getContext().getLoadedDialect("test"); + + f->walk([&](Operation *op) { + if (op->getDialect() != spvDialect) + return WalkResult::advance(); + + auto opName = op->getName(); + auto &os = llvm::outs(); + + // Check that we can access TestQueryMinVersionInterface. + if (auto minVersionIfx = dyn_cast(op)) { + auto minVersion = minVersionIfx.getMinVersion(); + os << opName << " min version: "; + if (minVersion.hasValue()) + os << stringifyVersion(minVersion.getValue()) << "\n"; + else + os << "None\n"; + } else { + os << opName << " has no min version requirements\n"; + } + + // Check that we can access TestQueryMaxVersionInterface. + if (auto maxVersionIfx = dyn_cast(op)) { + auto maxVersion = maxVersionIfx.getMaxVersion(); + os << opName << " max version: "; + if (maxVersion.hasValue()) + os << stringifyVersion(maxVersion.getValue()) << "\n"; + else + os << "None\n"; + } else { + os << opName << " has no max version requirements\n"; + } + + os.flush(); + + return WalkResult::advance(); + }); +} + +namespace mlir { +void registerPrintOpAvailabilityPass() { + PassRegistration(); +} +} // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2163,4 +2163,101 @@ def : Pat<(OpCrashLong $_, $_, $_), (OpCrashShort)>; +//===----------------------------------------------------------------------===// +// Test Op Availability +//===----------------------------------------------------------------------===// + +def Test_V_1_0 : I32EnumAttrCase<"V_1_0", 0, "v1.0">; +def Test_V_1_1 : I32EnumAttrCase<"V_1_1", 1, "v1.1">; +def Test_V_1_2 : I32EnumAttrCase<"V_1_2", 2, "v1.2">; +def Test_V_1_3 : I32EnumAttrCase<"V_1_3", 3, "v1.3">; + +def Test_VersionAttr : I32EnumAttr<"Version", "valid version numbers", [ + Test_V_1_0, Test_V_1_1, Test_V_1_2, Test_V_1_3]>; + +class Test_MinVersion : MinVersionBase< + "TestQueryMinVersionInterface", Test_VersionAttr, min> { + let interfaceDescription = [{ + Querying interface for minimal required version. + }]; + let interfaceNamespace = "::mlir::test"; +} + +class Test_MaxVersion : MaxVersionBase< + "TestQueryMaxVersionInterface", Test_VersionAttr, max> { + let interfaceDescription = [{ + Querying interface for maximal supported version. + }]; + let interfaceNamespace = "::mlir::test"; +} + +// Test op with only min version availability. +def MinVersionOp : TEST_Op<"min_version_op"> { + let arguments = (ins); + let results = (outs); + + let availability = [Test_MinVersion]; +} + +// Test op with only max version availability. +def MaxVersionOp : TEST_Op<"max_version_op"> { + let arguments = (ins); + let results = (outs); + + let availability = [Test_MaxVersion]; +} + +// Test op with both min and max version availability. +def MinMaxVersionOp : TEST_Op<"min_max_version_op"> { + let arguments = (ins); + let results = (outs); + + let availability = [ + Test_MinVersion, + Test_MaxVersion + ]; +} + +def VersionedEnumCaseA : I32EnumAttrCase<"A", 0> { + // No version requirements. +} +def VersionedEnumCaseB : I32EnumAttrCase<"B", 1> { + // Removed since v1.3. + let availability = [Test_MaxVersion]; +} +def VersionedEnumCaseC : I32EnumAttrCase<"C", 2> { + // Added since v1.1. + let availability = [Test_MinVersion]; +} + +def VersionedEnumAttr : I32EnumAttr<"VersionedEnum", "valid enum cases", [ + VersionedEnumCaseA, VersionedEnumCaseB, VersionedEnumCaseC]> { + let cppNamespace = "::mlir::test"; +} + +// Test op having one versioned enum attribute. +def OneVersionedAttrOp : TEST_Op<"one_versioned_attr_op"> { + let arguments = (ins VersionedEnumAttr:$attr); + let results = (outs); +} + +// Test op having two versioned enum attribute--the final version requirements +// need to consider both. +def TwoVersionedAttrOp : TEST_Op<"two_versioned_attr_op"> { + let arguments = (ins VersionedEnumAttr:$attr1, VersionedEnumAttr:$attr2); + let results = (outs); +} + +// Test op having availability on itself and two versioned enum attribute-- +// the final version requirements need to consider everything. +def MixVersionedAttrOp : TEST_Op<"mix_versioned_attr_op"> { + let arguments = (ins VersionedEnumAttr:$attr1, VersionedEnumAttr:$attr2); + let results = (outs); + + let availability = [ + Test_MinVersion, + Test_MaxVersion + ]; +} + #endif // TEST_OPS diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -32,6 +32,7 @@ void registerConvertToTargetEnvPass(); void registerPassManagerTestPass(); void registerPrintOpAvailabilityPass(); +void registerPrintSpirvAvailabilityPass(); void registerShapeFunctionTestPasses(); void registerSideEffectTestPasses(); void registerSliceAnalysisTestPass(); @@ -113,6 +114,7 @@ registerConvertToTargetEnvPass(); registerPassManagerTestPass(); registerPrintOpAvailabilityPass(); + registerPrintSpirvAvailabilityPass(); registerShapeFunctionTestPasses(); registerSideEffectTestPasses(); registerSliceAnalysisTestPass(); diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -98,6 +98,22 @@ ["-gen-rewriters"], "lib/Dialect/Test/TestPatterns.inc", ), + ( + ["-gen-avail-interface-decls"], + "lib/Dialect/Test/TestOpAvailability.h.inc", + ), + ( + ["-gen-avail-interface-defs"], + "lib/Dialect/Test/TestOpAvailability.cpp.inc", + ), + ( + ["-gen-enum-avail-decls"], + "lib/Dialect/Test/TestEnumAvailability.h.inc", + ), + ( + ["-gen-enum-avail-defs"], + "lib/Dialect/Test/TestEnumAvailability.cpp.inc", + ), ], tblgen = "//mlir:mlir-tblgen", td_file = "lib/Dialect/Test/TestOps.td", @@ -200,6 +216,7 @@ "lib/Dialect/Test/TestAttributes.cpp", "lib/Dialect/Test/TestDialect.cpp", "lib/Dialect/Test/TestInterfaces.cpp", + "lib/Dialect/Test/TestOpAvailability.cpp", "lib/Dialect/Test/TestPatterns.cpp", "lib/Dialect/Test/TestTraits.cpp", "lib/Dialect/Test/TestTypes.cpp",