diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -56,6 +56,47 @@ `verifyTrait` hook out-of-line as a free function when possible to avoid instantiating the implementation for every concrete operation type. +Operation traits may also provide a `foldTrait` hook that is called when +folding the concrete operation. The trait folders will only be invoked if +the concrete operation fold is either not implemented, fails, or performs +an in-place fold. + +The following signature of fold will be called if it is implemented +and the op has a single result. + +```c++ +template +class MyTrait : public OpTrait::TraitBase { +public: + /// Override the 'foldTrait' hook to support trait based folding on the + /// concrete operation. + static OpFoldResult foldTrait(Operation *op, ArrayRef operands) { { + // ... + } +}; +``` + +Otherwise, if the operation has a single result and the above signature is +not implemented, or the operation has multiple results, then the following signature +will be used (if implemented): + +```c++ +template +class MyTrait : public OpTrait::TraitBase { +public: + /// Override the 'foldTrait' hook to support trait based folding on the + /// concrete operation. + static LogicalResult foldTrait(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { { + // ... + } +}; +``` + +Note: It is generally good practice to define the implementation of the +`foldTrait` hook out-of-line as a free function when possible to avoid +instantiating the implementation for every concrete operation type. + ### Parametric Traits The above demonstrates the definition of a simple self-contained trait. It is diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1709,6 +1709,8 @@ NativeOpTrait<"ResultsBroadcastableShape">; // X op Y == Y op X def Commutative : NativeOpTrait<"IsCommutative">; +// op op X == X +def Involution : NativeOpTrait<"IsInvolution">; // Op behaves like a constant. def ConstantLike : NativeOpTrait<"ConstantLike">; // Op behaves like a function. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -21,6 +21,7 @@ #include "mlir/IR/Operation.h" #include "llvm/Support/PointerLikeTypeTraits.h" + #include namespace mlir { @@ -277,7 +278,16 @@ /// AbstractOperation. static LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - return cast(op).fold(operands, results); + auto operation_fold_result = cast(op).fold(operands, results); + // Failure to fold or in place fold both mean we can continue folding. + if (failed(operation_fold_result) || results.empty()) { + auto trait_fold_result = ConcreteType::foldTraits(op, operands, results); + // Only return the trait fold result if it is a success since + // operation_fold_result might have been a success originally. + if (succeeded(trait_fold_result)) + return trait_fold_result; + } + return operation_fold_result; } /// This hook implements a generalized folder for this operation. Operations @@ -326,6 +336,15 @@ static LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { auto result = cast(op).fold(operands); + // Failure to fold or in place fold both mean we can continue folding. + if (!result || result.template dyn_cast() == op->getResult(0)) { + // Only consider the trait fold result if it is a success since + // the operation fold might have been a success originally. + auto trait_fold_result = ConcreteType::foldTraits(op, operands); + if (trait_fold_result) + result = trait_fold_result; + } + if (!result) return failure(); @@ -370,9 +389,11 @@ // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { +OpFoldResult foldInvolution(Operation *op); LogicalResult verifyZeroOperands(Operation *op); LogicalResult verifyOneOperand(Operation *op); LogicalResult verifyNOperands(Operation *op, unsigned numOperands); +LogicalResult verifyIsInvolution(Operation *op); LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); LogicalResult verifyOperandsAreFloatLike(Operation *op); LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op); @@ -426,6 +447,23 @@ static AbstractOperation::OperationProperties getTraitProperties() { return 0; } + + static OpFoldResult foldTrait(Operation *op, ArrayRef operands) { + SmallVector results; + + if (failed(foldTrait(op, operands, results))) + return {}; + else if (results.empty()) + return op->getResult(0); + assert(results.size() == 1); + + return results[0]; + } + + static LogicalResult foldTrait(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return failure(); + } }; //===----------------------------------------------------------------------===// @@ -974,6 +1012,26 @@ } }; +/// This class adds property that the operation is an involution. +/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x) +template +class IsInvolution : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to take one operand"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to preserve type"); + return impl::verifyIsInvolution(op); + } + + static OpFoldResult foldTrait(Operation *op, ArrayRef operands) { + return impl::foldInvolution(op); + } +}; + /// This class verifies that all operands of the specified op have a float type, /// a vector thereof, or a tensor thereof. template @@ -1311,6 +1369,19 @@ failed(cast(op).verify())); } + /// This is the hook that tries to fold the given operation according to its + /// traits. It delegates to the Traits for their policy implementations, and + /// allows the user to specify their own fold() method. + static OpFoldResult foldTraits(Operation *op, ArrayRef operands) { + return BaseFolder...>::foldTraits(op, operands); + } + + static LogicalResult foldTraits(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return BaseFolder...>::foldTraits(op, operands, + results); + } + // Returns the properties of an operation by combining the properties of the // traits of the op. static AbstractOperation::OperationProperties getOperationProperties() { @@ -1363,6 +1434,53 @@ } }; + template + struct BaseFolder; + + template + struct BaseFolder { + static OpFoldResult foldTraits(Operation *op, + ArrayRef operands) { + auto result = First::foldTrait(op, operands); + // Failure to fold or in place fold both mean we can continue folding. + if (!result || result.template dyn_cast() == op->getResult(0)) { + // Only consider the trait fold result if it is a success since + // the operation fold might have been a success originally. + auto result_remaining = BaseFolder::foldTraits(op, operands); + if (result_remaining) + result = result_remaining; + } + + return result; + } + + static LogicalResult foldTraits(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + auto result = First::foldTrait(op, operands, results); + // Failure to fold or in place fold both mean we can continue folding. + if (failed(result) || results.empty()) { + auto result_remaining = + BaseFolder::foldTraits(op, operands, results); + if (succeeded(result_remaining)) + result = result_remaining; + } + + return result; + } + }; + + template + struct BaseFolder { + static OpFoldResult foldTraits(Operation *op, + ArrayRef operands) { + return {}; + } + static LogicalResult foldTraits(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return failure(); + } + }; + template struct BaseProperties { static AbstractOperation::OperationProperties getTraitProperties() { return 0; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FoldInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include using namespace mlir; @@ -679,6 +680,17 @@ // Op Trait implementations //===----------------------------------------------------------------------===// +OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { + auto *argumentOp = op->getOperand(0).getDefiningOp(); + + if (argumentOp && op->getName() == argumentOp->getName()) { + // Replace the outer involutions output with inner's input. + return argumentOp->getOperand(0); + } + + return {}; +} + LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { if (op->getNumOperands() != 0) return op->emitOpError() << "requires zero operands"; @@ -720,6 +732,12 @@ return type; } +LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { + if (!MemoryEffectOpInterface::hasNoEffect(op)) + return op->emitOpError() << "requires operation to have no side effects"; + return success(); +} + LogicalResult OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { for (auto opType : op->getOperandTypes()) { diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -1,6 +1,7 @@ set(LLVM_OPTIONAL_SOURCES TestDialect.cpp TestPatterns.cpp + TestTraits.cpp ) set(LLVM_TARGET_DEFINITIONS TestInterfaces.td) @@ -23,6 +24,7 @@ add_mlir_library(MLIRTestDialect TestDialect.cpp TestPatterns.cpp + TestTraits.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -789,6 +789,29 @@ let results = (outs I32); } +def TestInvolutionTraitNoOperationFolderOp + : TEST_Op<"op_involution_trait_no_operation_fold", + [SameOperandsAndResultType, NoSideEffect, Involution]> { + let arguments = (ins I32:$op1); + let results = (outs I32); +} + +def TestInvolutionTraitFailingOperationFolderOp + : TEST_Op<"op_involution_trait_failing_operation_fold", + [SameOperandsAndResultType, NoSideEffect, Involution]> { + let arguments = (ins I32:$op1); + let results = (outs I32); + let hasFolder = 1; +} + +def TestInvolutionTraitSuccesfulOperationFolderOp + : TEST_Op<"op_involution_trait_succesful_operation_fold", + [SameOperandsAndResultType, NoSideEffect, Involution]> { + let arguments = (ins I32:$op1); + let results = (outs I32); + let hasFolder = 1; +} + def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> { let arguments = (ins I32); let results = (outs I32); diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -0,0 +1,45 @@ +//===- TestTraits.cpp - Test trait folding --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/FoldUtils.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Trait Folder. +//===----------------------------------------------------------------------===// + +OpFoldResult TestInvolutionTraitFailingOperationFolderOp::fold( + ArrayRef operands) { + // This failure should cause the trait fold to run instead. + return {}; +} + +OpFoldResult TestInvolutionTraitSuccesfulOperationFolderOp::fold( + ArrayRef operands) { + auto argument_op = getOperand(); + // The success case should cause the trait fold to be supressed. + return argument_op.getDefiningOp() ? argument_op : OpFoldResult{}; +} + +namespace { +struct TestTraitFolder : public PassWrapper { + void runOnFunction() override { + applyPatternsAndFoldGreedily(getFunction(), {}); + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTraitsTestPass() { + PassRegistration("test-trait-folder", "Run trait folding"); +} +} // namespace mlir diff --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/trait.mlir @@ -0,0 +1,61 @@ +// RUN: mlir-opt -test-trait-folder %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test that involutions fold correctly +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @testSingleInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testSingleInvolution(%arg0 : i32) -> i32 { + // CHECK: [[INVOLUTION:%.+]] = "test.op_involution_trait_no_operation_fold"([[ARG0]]) + %0 = "test.op_involution_trait_no_operation_fold"(%arg0) : (i32) -> i32 + // CHECK: return [[INVOLUTION]] + return %0: i32 +} + +// CHECK-LABEL: func @testDoubleInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testDoubleInvolution(%arg0: i32) -> i32 { + %0 = "test.op_involution_trait_no_operation_fold"(%arg0) : (i32) -> i32 + %1 = "test.op_involution_trait_no_operation_fold"(%0) : (i32) -> i32 + // CHECK: return [[ARG0]] + return %1: i32 +} + +// CHECK-LABEL: func @testTripleInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testTripleInvolution(%arg0: i32) -> i32 { + // CHECK: [[INVOLUTION:%.+]] = "test.op_involution_trait_no_operation_fold"([[ARG0]]) + %0 = "test.op_involution_trait_no_operation_fold"(%arg0) : (i32) -> i32 + %1 = "test.op_involution_trait_no_operation_fold"(%0) : (i32) -> i32 + %2 = "test.op_involution_trait_no_operation_fold"(%1) : (i32) -> i32 + // CHECK: return [[INVOLUTION]] + return %2: i32 +} + +//===----------------------------------------------------------------------===// +// Test that involutions fold occurs if operation fold fails +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @testFailingOperationFolder +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testFailingOperationFolder(%arg0: i32) -> i32 { + %0 = "test.op_involution_trait_failing_operation_fold"(%arg0) : (i32) -> i32 + %1 = "test.op_involution_trait_failing_operation_fold"(%0) : (i32) -> i32 + // CHECK: return [[ARG0]] + return %1: i32 +} + +//===----------------------------------------------------------------------===// +// Test that involution fold does not occur if operation fold succeeds +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @testInhibitInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testInhibitInvolution(%arg0: i32) -> i32 { + // CHECK: [[OP:%.+]] = "test.op_involution_trait_succesful_operation_fold"([[ARG0]]) + %0 = "test.op_involution_trait_succesful_operation_fold"(%arg0) : (i32) -> i32 + %1 = "test.op_involution_trait_succesful_operation_fold"(%0) : (i32) -> i32 + // CHECK: return [[OP]] + return %1: i32 +} diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -77,6 +77,7 @@ void registerTestSpirvEntryPointABIPass(); void registerTestSCFUtilsPass(); void registerTestVectorConversions(); +void registerTraitsTestPass(); void registerVectorizerTestPass(); } // namespace mlir @@ -133,6 +134,7 @@ registerTestSpirvEntryPointABIPass(); registerTestSCFUtilsPass(); registerTestVectorConversions(); + registerTraitsTestPass(); registerVectorizerTestPass(); } #endif