diff --git a/mlir/include/mlir/IR/BuiltinOps.td b/mlir/include/mlir/IR/BuiltinOps.td --- a/mlir/include/mlir/IR/BuiltinOps.td +++ b/mlir/include/mlir/IR/BuiltinOps.td @@ -229,7 +229,8 @@ //===----------------------------------------------------------------------===// def UnrealizedConversionCastOp : Builtin_Op<"unrealized_conversion_cast", [ - DeclareOpInterfaceMethods, NoSideEffect + DeclareOpInterfaceMethods, FoldOnDialectConversion, + NoSideEffect ]> { let summary = "An unrealized conversion from one set of types to another"; let description = [{ diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2044,6 +2044,8 @@ def Involution : NativeOpTrait<"IsInvolution">; // Op behaves like a constant. def ConstantLike : NativeOpTrait<"ConstantLike">; +// A fold is always attempted on `op` during dialect conversion. +def FoldOnDialectConversion : NativeOpTrait<"FoldOnDialectConversion">; // Op behaves like a function. def FunctionLike : NativeOpTrait<"FunctionLike">; // Op is isolated from above. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -1326,7 +1326,18 @@ } }; -/// This trait tags `Elementwise` operatons that can be systematically +/// A trait for operations that are meant to be folded during dialect +/// conversion even if they have been marked legal. This trait suits certain +/// operations that are meant to be folded by design and not outlive the +/// conversion pass whenever possible. +template +class FoldOnDialectConversion + : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { return success(); } +}; + +/// This trait tags `Elementwise` operations that can be systematically /// scalarized. All vector/tensor operands and results are then replaced by /// scalars of the respective element type. Semantically, this is the operation /// on a single element of the vector/tensor. diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1849,6 +1849,26 @@ } }); + // Check to see if the operation is ignored and doesn't need to be converted. + if (rewriter.getImpl().isOpIgnored(op)) { + LLVM_DEBUG({ + logSuccess(logger, "operation marked 'ignored' during conversion"); + logger.startLine() << logLineComment; + }); + return success(); + } + + // Ops with this trait are always folded even if marked legal. + if (op->hasTrait() && + succeeded(legalizeWithFold(op, rewriter))) { + LLVM_DEBUG({ + logSuccess(logger, + "operation with trait 'FoldDuringDialectConversion' folded"); + logger.startLine() << logLineComment; + }); + return success(); + } + // Check if this operation is legal on the target. if (auto legalityInfo = target.isLegal(op)) { LLVM_DEBUG({ @@ -1867,15 +1887,6 @@ return success(); } - // Check to see if the operation is ignored and doesn't need to be converted. - if (rewriter.getImpl().isOpIgnored(op)) { - LLVM_DEBUG({ - logSuccess(logger, "operation marked 'ignored' during conversion"); - logger.startLine() << logLineComment; - }); - return success(); - } - // If the operation isn't legal, try to fold it in-place. // TODO: Should we always try to do this, even if the op is // already legal?