diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -56,6 +56,42 @@ `verifyTrait` hook out-of-line as a free function when possible to avoid instantiating the implementation for every concrete operation type. +Operation traits may also provide a `foldTrait` hook, that is called when +folding the concrete operation. The trait folders will currently only be +invoked if the concrete operation fold does not succeed (or is not implemented). +The following signature of fold must be used when there is a single result. + +```c++ +template +class MyTrait : public OpTrait::TraitBase { +public: + /// Override the 'foldTrait' hook to support trait based folding on the + /// concrete operation. + static OpFoldResult foldTrait(Operation *op, ArrayRef operands) { { + // ... + } +}; +``` + +Otherwise the following signature is used: + +```c++ +template +class MyTrait : public OpTrait::TraitBase { +public: + /// Override the 'foldTrait' hook to support trait based folding on the + /// concrete operation. + static OpFoldResult foldTrait(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { { + // ... + } +}; +``` + +Note: It is generally good practice to define the implementation of the +`foldTrait` hook out-of-line as a free function when possible to avoid +instantiating the implementation for every concrete operation type. + ### Parametric Traits The above demonstrates the definition of a simple self-contained trait. It is diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1709,6 +1709,8 @@ NativeOpTrait<"ResultsBroadcastableShape">; // X op Y == Y op X def Commutative : NativeOpTrait<"IsCommutative">; +// op op X == X +def Involution : NativeOpTrait<"IsInvolution">; // Op behaves like a constant. def ConstantLike : NativeOpTrait<"ConstantLike">; // Op behaves like a function. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -21,6 +21,7 @@ #include "mlir/IR/Operation.h" #include "llvm/Support/PointerLikeTypeTraits.h" + #include namespace mlir { @@ -277,7 +278,10 @@ /// AbstractOperation. static LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { - return cast(op).fold(operands, results); + auto result = cast(op).fold(operands, results); + if (failed(result)) + result = ConcreteType::foldTraits(op, operands, results); + return result; } /// This hook implements a generalized folder for this operation. Operations @@ -326,6 +330,8 @@ static LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { auto result = cast(op).fold(operands); + if (!result) + result = ConcreteType::foldTraits(op, operands); if (!result) return failure(); @@ -370,9 +376,11 @@ // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { +OpFoldResult foldInvolution(Operation *op); LogicalResult verifyZeroOperands(Operation *op); LogicalResult verifyOneOperand(Operation *op); LogicalResult verifyNOperands(Operation *op, unsigned numOperands); +LogicalResult verifyIsInvolution(Operation *op); LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); LogicalResult verifyOperandsAreFloatLike(Operation *op); LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op); @@ -426,6 +434,13 @@ static AbstractOperation::OperationProperties getTraitProperties() { return 0; } + static OpFoldResult foldTrait(Operation *op, ArrayRef operands) { + return {}; + } + static LogicalResult foldTrait(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return failure(); + } }; //===----------------------------------------------------------------------===// @@ -974,6 +989,26 @@ } }; +/// This class adds property that the operation is an involution. +/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x) +template +class IsInvolution : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to take one operand"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to preserve type"); + return impl::verifyIsInvolution(op); + } + + static OpFoldResult foldTrait(Operation *op, ArrayRef operands) { + return impl::foldInvolution(op); + } +}; + /// This class verifies that all operands of the specified op have a float type, /// a vector thereof, or a tensor thereof. template @@ -1311,6 +1346,19 @@ failed(cast(op).verify())); } + /// This is the hook that tries to fold the given operation according to its + /// traits. It delegates to the Traits for their policy implementations, and + /// allows the user to specify their own fold() method. + static OpFoldResult foldTraits(Operation *op, ArrayRef operands) { + return BaseFolder...>::foldTraits(op, operands); + } + + static LogicalResult foldTraits(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return BaseFolder...>::foldTraits(op, operands, + results); + } + // Returns the properties of an operation by combining the properties of the // traits of the op. static AbstractOperation::OperationProperties getOperationProperties() { @@ -1363,6 +1411,42 @@ } }; + template + struct BaseFolder; + + template + struct BaseFolder { + static OpFoldResult foldTraits(Operation *op, + ArrayRef operands) { + + if (auto result = First::foldTrait(op, operands)) + return result; + + return BaseFolder::foldTraits(op, operands); + } + + static LogicalResult foldTraits(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + auto result = First::foldTrait(op, operands, results); + if (failed(result)) + return result; + + return BaseFolder::foldTraits(op, operands, results); + } + }; + + template + struct BaseFolder { + static OpFoldResult foldTraits(Operation *op, + ArrayRef operands) { + return {}; + } + static LogicalResult foldTraits(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return failure(); + } + }; + template struct BaseProperties { static AbstractOperation::OperationProperties getTraitProperties() { return 0; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/FoldInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" #include using namespace mlir; @@ -679,6 +680,17 @@ // Op Trait implementations //===----------------------------------------------------------------------===// +OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { + auto *argumentOp = op->getOperand(0).getDefiningOp(); + + if (argumentOp && op->getName() == argumentOp->getName()) { + // Replace the outer involutions output with inner's input. + return argumentOp->getOperand(0); + } + + return {}; +} + LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { if (op->getNumOperands() != 0) return op->emitOpError() << "requires zero operands"; @@ -720,6 +732,13 @@ return type; } +LogicalResult OpTrait::impl::verifyIsInvolution(Operation *op) { + if (!MemoryEffectOpInterface::hasNoEffect(op)) { + return op->emitOpError() << "requires operation to have no side effects"; + } + return success(); +} + LogicalResult OpTrait::impl::verifyOperandsAreSignlessIntegerLike(Operation *op) { for (auto opType : op->getOperandTypes()) { diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -1,6 +1,7 @@ set(LLVM_OPTIONAL_SOURCES TestDialect.cpp TestPatterns.cpp + TestTraits.cpp ) set(LLVM_TARGET_DEFINITIONS TestInterfaces.td) @@ -23,6 +24,7 @@ add_mlir_library(MLIRTestDialect TestDialect.cpp TestPatterns.cpp + TestTraits.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -789,6 +789,22 @@ let results = (outs I32); } +def TestInvolutionTraitFoldOp + : TEST_Op<"op_involution_trait_fold", + [SameOperandsAndResultType, NoSideEffect, Involution]> { + let arguments = (ins I32:$op1); + let results = (outs I32); + let hasFolder = 1; +} + +def TestInvolutionOperationFoldOp + : TEST_Op<"op_involution_operation_fold", + [SameOperandsAndResultType, NoSideEffect, Involution]> { + let arguments = (ins I32:$op1); + let results = (outs I32); + let hasFolder = 1; +} + def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> { let arguments = (ins I32); let results = (outs I32); diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -0,0 +1,43 @@ +//===- TestTraits.cpp - Test trait folding --------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/FoldUtils.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Trait Folder. +//===----------------------------------------------------------------------===// + +OpFoldResult TestInvolutionTraitFoldOp::fold(ArrayRef operands) { + // this failure should cause the trait fold to run instead + return {}; +} + +OpFoldResult TestInvolutionOperationFoldOp::fold(ArrayRef operands) { + auto argument_op = getOperand(); + // the success case should cause the trait fold to be supressed + return argument_op.getDefiningOp() ? argument_op : OpFoldResult{}; +} + +namespace { +struct TestTraitFolder : public PassWrapper { + void runOnFunction() override { + applyPatternsAndFoldGreedily(getFunction(), {}); + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTraitsTestPass() { + PassRegistration("test-trait-folder", "Run trait folding"); +} +} // namespace mlir diff --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/trait.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt -test-trait-folder %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test that involutions fold correctly +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @testSingleInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testSingleInvolution(%arg0 : i32) -> i32 { + // CHECK: [[INVOLUTION:%.+]] = "test.op_involution_trait_fold"([[ARG0]]) + %0 = "test.op_involution_trait_fold"(%arg0) : (i32) -> i32 + // CHECK: return [[INVOLUTION]] + return %0: i32 +} + +// CHECK-LABEL: func @testDoubleInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testDoubleInvolution(%arg0: i32) -> i32 { + %0 = "test.op_involution_trait_fold"(%arg0) : (i32) -> i32 + %1 = "test.op_involution_trait_fold"(%0) : (i32) -> i32 + // CHECK: return [[ARG0]] + return %1: i32 +} + +// CHECK-LABEL: func @testTripleInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testTripleInvolution(%arg0: i32) -> i32 { + // CHECK: [[INVOLUTION:%.+]] = "test.op_involution_trait_fold"([[ARG0]]) + %0 = "test.op_involution_trait_fold"(%arg0) : (i32) -> i32 + %1 = "test.op_involution_trait_fold"(%0) : (i32) -> i32 + %2 = "test.op_involution_trait_fold"(%1) : (i32) -> i32 + // CHECK: return [[INVOLUTION]] + return %2: i32 +} + +// CHECK-LABEL: func @testInhibitInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testInhibitInvolution(%arg0: i32) -> i32 { + // CHECK: [[OP:%.+]] = "test.op_involution_operation_fold"([[ARG0]]) + %0 = "test.op_involution_operation_fold"(%arg0) : (i32) -> i32 + %1 = "test.op_involution_operation_fold"(%0) : (i32) -> i32 + // CHECK: return [[OP]] + return %1: i32 +} diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -77,6 +77,7 @@ void registerTestSpirvEntryPointABIPass(); void registerTestSCFUtilsPass(); void registerTestVectorConversions(); +void registerTraitsTestPass(); void registerVectorizerTestPass(); } // namespace mlir @@ -133,6 +134,7 @@ registerTestSpirvEntryPointABIPass(); registerTestSCFUtilsPass(); registerTestVectorConversions(); + registerTraitsTestPass(); registerVectorizerTestPass(); } #endif