diff --git a/mlir/docs/Traits.md b/mlir/docs/Traits.md --- a/mlir/docs/Traits.md +++ b/mlir/docs/Traits.md @@ -56,6 +56,26 @@ `verifyTrait` hook out-of-line as a free function when possible to avoid instantiating the implementation for every concrete operation type. +Operation traits may also provide a `foldTrait` hook, that is called when +folding the concrete operation. The trait folders will currently only be +invoked if the concrete operation fold does not succeed (or is not implemented). + +```c++ +template +class MyTrait : public OpTrait::TraitBase { +public: + /// Override the 'foldTrait' hook to support trait based folding on the + /// concrete operation. + static OpFoldResult foldTrait(Operation *op) { { + // ... + } +}; +``` + +Note: It is generally good practice to define the implementation of the +`foldTrait` hook out-of-line as a free function when possible to avoid +instantiating the implementation for every concrete operation type. + ### Parametric Traits The above demonstrates the definition of a simple self-contained trait. It is diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1709,6 +1709,8 @@ NativeOpTrait<"ResultsBroadcastableShape">; // X op Y == Y op X def Commutative : NativeOpTrait<"IsCommutative">; +// op op X == op X +def Involution : NativeOpTrait<"IsInvolution">; // Op behaves like a constant. def ConstantLike : NativeOpTrait<"ConstantLike">; // Op behaves like a function. @@ -1717,6 +1719,10 @@ def IsolatedFromAbove : NativeOpTrait<"IsIsolatedFromAbove">; // Op results are float or vectors/tensors thereof. def ResultsAreFloatLike : NativeOpTrait<"ResultsAreFloatLike">; +// Op has a single result +def OneResult : NativeOpTrait<"OneResult">; +// Op has a single operand +def OneOperand : NativeOpTrait<"OneOperand">; // Op has the same operand type. def SameTypeOperands : NativeOpTrait<"SameTypeOperands">; // Op has same shape for all operands. diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -326,6 +326,8 @@ static LogicalResult foldHook(Operation *op, ArrayRef operands, SmallVectorImpl &results) { auto result = cast(op).fold(operands); + if (!result) + result = ConcreteType::foldTraits(op); if (!result) return failure(); @@ -370,6 +372,7 @@ // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { +OpFoldResult foldInvolution(Operation *op); LogicalResult verifyZeroOperands(Operation *op); LogicalResult verifyOneOperand(Operation *op); LogicalResult verifyNOperands(Operation *op, unsigned numOperands); @@ -426,6 +429,7 @@ static AbstractOperation::OperationProperties getTraitProperties() { return 0; } + static OpFoldResult foldTrait(Operation *op) { return {}; } }; //===----------------------------------------------------------------------===// @@ -974,6 +978,26 @@ } }; +/// This class adds property that the operation is an involution. +/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x) +template +class IsInvolution : public TraitBase { +public: + static LogicalResult verifyTrait(Operation *op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to take one operand"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to produce one result"); + return success(); + } + + static OpFoldResult foldTrait(Operation *op) { + return impl::foldInvolution(op); + } +}; + /// This class verifies that all operands of the specified op have a float type, /// a vector thereof, or a tensor thereof. template @@ -1311,6 +1335,13 @@ failed(cast(op).verify())); } + /// This is the hook that tries to fold the given operation according to its + /// traits. It delegates to the Traits for their policy implementations, and + /// allows the user to specify their own fold() method. + static OpFoldResult foldTraits(Operation *op) { + return BaseFolder...>::foldTraits(op); + } + // Returns the properties of an operation by combining the properties of the // traits of the op. static AbstractOperation::OperationProperties getOperationProperties() { @@ -1363,6 +1394,26 @@ } }; + template + struct BaseFolder; + + template + struct BaseFolder { + static OpFoldResult foldTraits(Operation *op) { + + if (auto result = First::foldTrait(op)) { + return result; + } + + return BaseFolder::foldTraits(op); + } + }; + + template + struct BaseFolder { + static OpFoldResult foldTraits(Operation *op) { return {}; } + }; + template struct BaseProperties { static AbstractOperation::OperationProperties getTraitProperties() { return 0; diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -679,6 +679,17 @@ // Op Trait implementations //===----------------------------------------------------------------------===// +OpFoldResult OpTrait::impl::foldInvolution(Operation *op) { + auto argument_op = op->getOperand(0).getDefiningOp(); + + if (argument_op && op->getName() == argument_op->getName()) { + // Replace the outer involutions output with inner's input. + return argument_op->getOperand(0); + } + + return {}; +} + LogicalResult OpTrait::impl::verifyZeroOperands(Operation *op) { if (op->getNumOperands() != 0) return op->emitOpError() << "requires zero operands"; diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -1,6 +1,7 @@ set(LLVM_OPTIONAL_SOURCES TestDialect.cpp TestPatterns.cpp + TestTraits.cpp ) set(LLVM_TARGET_DEFINITIONS TestInterfaces.td) @@ -23,6 +24,7 @@ add_mlir_library(MLIRTestDialect TestDialect.cpp TestPatterns.cpp + TestTraits.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -789,6 +789,18 @@ let results = (outs I32); } +def TestInvolutionTraitFoldOp : TEST_Op<"op_involution_trait_fold", [OneResult, OneOperand, SameOperandsAndResultType, Involution]> { + let arguments = (ins I32:$op1); + let results = (outs I32); + let hasFolder = 1; +} + +def TestInvolutionOperationFoldOp : TEST_Op<"op_involution_operation_fold", [OneResult, OneOperand, SameOperandsAndResultType, Involution]> { + let arguments = (ins I32:$op1); + let results = (outs I32); + let hasFolder = 1; +} + def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> { let arguments = (ins I32); let results = (outs I32); diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -0,0 +1,47 @@ +//===- TestTraits.cpp - Test dialect pattern driver ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +/*#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/StandardOps/Transforms/FuncConversions.h" +#include "mlir/IR/PatternMatch.h"*/ +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/FoldUtils.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Trait Folder. +//===----------------------------------------------------------------------===// + +OpFoldResult TestInvolutionTraitFoldOp::fold(ArrayRef operands) { + // this failure should cause the trait fold to run instead + return {}; +} + +OpFoldResult TestInvolutionOperationFoldOp::fold(ArrayRef operands) { + auto argument_op = getOperand(); + // the success case should cause the trait fold to be supressed + return argument_op.getDefiningOp() ? argument_op : OpFoldResult{}; +} + +namespace { + +struct TestTraitFolder : public PassWrapper { + void runOnFunction() override { + applyPatternsAndFoldGreedily(getFunction(), {}); + } +}; +} // end anonymous namespace + +namespace mlir { +void registerTraitsTestPass() { + PassRegistration("test-trait-folder", "Run trait folding"); +} +} // namespace mlir diff --git a/mlir/test/mlir-tblgen/trait.mlir b/mlir/test/mlir-tblgen/trait.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-tblgen/trait.mlir @@ -0,0 +1,44 @@ +// RUN: mlir-opt -test-trait-folder %s | FileCheck %s + +//===----------------------------------------------------------------------===// +// Test that involutions fold correctly +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @testSingleInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testSingleInvolution(%arg0 : i32) -> i32 { + // CHECK: [[INVOLUTION:%.+]] = "test.op_involution_trait_fold"([[ARG0]]) + %0 = "test.op_involution_trait_fold"(%arg0) : (i32) -> i32 + // CHECK: return [[INVOLUTION]] + return %0: i32 +} + +// CHECK-LABEL: func @testDoubleInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testDoubleInvolution(%arg0: i32) -> i32 { + %0 = "test.op_involution_trait_fold"(%arg0) : (i32) -> i32 + %1 = "test.op_involution_trait_fold"(%0) : (i32) -> i32 + return %1: i32 + // CHECK: return [[ARG0]] +} + +// CHECK-LABEL: func @testTripleInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testTripleInvolution(%arg0: i32) -> i32 { + // CHECK: [[INVOLUTION:%.+]] = "test.op_involution_trait_fold"([[ARG0]]) + %0 = "test.op_involution_trait_fold"(%arg0) : (i32) -> i32 + %1 = "test.op_involution_trait_fold"(%0) : (i32) -> i32 + %2 = "test.op_involution_trait_fold"(%1) : (i32) -> i32 + // CHECK: return [[INVOLUTION]] + return %2: i32 +} + +// CHECK-LABEL: func @testInhibitInvolution +// CHECK-SAME: ([[ARG0:%.+]]: i32) +func @testInhibitInvolution(%arg0: i32) -> i32 { + // CHECK: [[OP:%.+]] = "test.op_involution_operation_fold"([[ARG0]]) + %0 = "test.op_involution_operation_fold"(%arg0) : (i32) -> i32 + %1 = "test.op_involution_operation_fold"(%0) : (i32) -> i32 + return %1: i32 + // CHECK: return [[OP]] +} diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -77,6 +77,7 @@ void registerTestSpirvEntryPointABIPass(); void registerTestSCFUtilsPass(); void registerTestVectorConversions(); +void registerTraitsTestPass(); void registerVectorizerTestPass(); } // namespace mlir @@ -133,6 +134,7 @@ registerTestSpirvEntryPointABIPass(); registerTestSCFUtilsPass(); registerTestVectorConversions(); + registerTraitsTestPass(); registerVectorizerTestPass(); } #endif