diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -621,8 +621,7 @@ /// dynamically legal on the target. using DynamicLegalityCallbackFn = std::function; - ConversionTarget(MLIRContext &ctx) - : unknownOpsDynamicallyLegal(false), ctx(ctx) {} + ConversionTarget(MLIRContext &ctx) : ctx(ctx) {} virtual ~ConversionTarget() = default; //===--------------------------------------------------------------------===// @@ -739,18 +738,11 @@ setDialectAction(dialectNames, LegalizationAction::Dynamic); } template - void addDynamicallyLegalDialect( - Optional callback = llvm::None) { + void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback = {}) { SmallVector dialectNames({Args::getDialectNamespace()...}); setDialectAction(dialectNames, LegalizationAction::Dynamic); if (callback) - setLegalityCallback(dialectNames, *callback); - } - template - void addDynamicallyLegalDialect(DynamicLegalityCallbackFn callback) { - SmallVector dialectNames({Args::getDialectNamespace()...}); - setDialectAction(dialectNames, LegalizationAction::Dynamic); - setLegalityCallback(dialectNames, callback); + setLegalityCallback(dialectNames, callback); } /// Register unknown operations as dynamically legal. For operations(and @@ -758,10 +750,11 @@ /// dynamically legal and invoke the given callback if valid or /// 'isDynamicallyLegal'. void markUnknownOpDynamicallyLegal(const DynamicLegalityCallbackFn &fn) { - unknownOpsDynamicallyLegal = true; - unknownLegalityFn = fn; + setLegalityCallback(fn); + } + void markUnknownOpDynamicallyLegal() { + setLegalityCallback([](Operation *) { return true; }); } - void markUnknownOpDynamicallyLegal() { unknownOpsDynamicallyLegal = true; } /// Register the operations of the given dialects as illegal, i.e. /// operations of this dialect are not supported by the target. @@ -805,6 +798,9 @@ void setLegalityCallback(ArrayRef dialects, const DynamicLegalityCallbackFn &callback); + /// Set the dynamic legality callback for the unknown ops. + void setLegalityCallback(const DynamicLegalityCallbackFn &callback); + /// Set the recursive legality callback for the given operation and mark the /// operation as recursively legal. void markOpRecursivelyLegal(OperationName name, @@ -819,7 +815,7 @@ bool isRecursivelyLegal; /// The legality callback if this operation is dynamically legal. - Optional legalityFn; + DynamicLegalityCallbackFn legalityFn; }; /// Get the legalization information for the given operation. @@ -841,11 +837,7 @@ llvm::StringMap dialectLegalityFns; /// An optional legality callback for unknown operations. - Optional unknownLegalityFn; - - /// Flag indicating if unknown operations should be treated as dynamically - /// legal. - bool unknownOpsDynamicallyLegal; + DynamicLegalityCallbackFn unknownLegalityFn; /// The current context this target applies to. MLIRContext &ctx; diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -2672,7 +2672,7 @@ /// Register a legality action for the given operation. void ConversionTarget::setOpAction(OperationName op, LegalizationAction action) { - legalOperations[op] = {action, /*isRecursivelyLegal=*/false, llvm::None}; + legalOperations[op] = {action, /*isRecursivelyLegal=*/false, nullptr}; } /// Register a legality action for the given dialects. @@ -2703,8 +2703,7 @@ // Handle dynamic legality either with the provided legality function, or // the default hook on the derived instance. if (info->action == LegalizationAction::Dynamic) - return info->legalityFn ? (*info->legalityFn)(op) - : isDynamicallyLegal(op); + return info->legalityFn ? info->legalityFn(op) : isDynamicallyLegal(op); // Otherwise, the operation is only legal if it was marked 'Legal'. return info->action == LegalizationAction::Legal; @@ -2758,6 +2757,13 @@ dialectLegalityFns[dialect] = callback; } +/// Set the dynamic legality callback for the unknown ops. +void ConversionTarget::setLegalityCallback( + const DynamicLegalityCallbackFn &callback) { + assert(callback && "expected valid legality callback"); + unknownLegalityFn = callback; +} + /// Get the legalization information for the given operation. auto ConversionTarget::getOpInfo(OperationName op) const -> Optional { @@ -2768,7 +2774,7 @@ // Check for info for the parent dialect. auto dialectIt = legalDialects.find(op.getDialectNamespace()); if (dialectIt != legalDialects.end()) { - Optional callback; + DynamicLegalityCallbackFn callback; auto dialectFn = dialectLegalityFns.find(op.getDialectNamespace()); if (dialectFn != dialectLegalityFns.end()) callback = dialectFn->second; @@ -2776,7 +2782,7 @@ callback}; } // Otherwise, check if we mark unknown operations as dynamic. - if (unknownOpsDynamicallyLegal) + if (unknownLegalityFn) return LegalizationInfo{LegalizationAction::Dynamic, /*isRecursivelyLegal=*/false, unknownLegalityFn}; return llvm::None;