diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -314,6 +314,8 @@ // corresponding trait classes. This avoids them being template // instantiated/duplicated. namespace impl { +LogicalResult foldCommutative(Operation *op, ArrayRef operands, + SmallVectorImpl &results); OpFoldResult foldIdempotent(Operation *op); OpFoldResult foldInvolution(Operation *op); LogicalResult verifyZeroOperands(Operation *op); @@ -1148,7 +1150,13 @@ /// This class adds property that the operation is commutative. template -class IsCommutative : public TraitBase {}; +class IsCommutative : public TraitBase { +public: + static LogicalResult foldTrait(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + return impl::foldCommutative(op, operands, results); + } +}; /// This class adds property that the operation is an involution. /// This means a unary to unary operation "f" that satisfies f(f(x)) = x diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" @@ -790,6 +791,24 @@ // Op Trait implementations //===----------------------------------------------------------------------===// +LogicalResult +OpTrait::impl::foldCommutative(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + // Nothing to fold if there are not at least 2 operands. + if (op->getNumOperands() < 2) + return failure(); + // Move all constant operands to the end. + OpOperand *operandsBegin = op->getOpOperands().begin(); + auto isNonConstant = [&](OpOperand &o) { + return !static_cast(operands[std::distance(operandsBegin, &o)]); + }; + auto *firstConstantIt = llvm::find_if_not(op->getOpOperands(), isNonConstant); + auto *newConstantIt = std::stable_partition( + firstConstantIt, op->getOpOperands().end(), isNonConstant); + // Return success if the op was modified. + return success(firstConstantIt != newConstantIt); +} + OpFoldResult OpTrait::impl::foldIdempotent(Operation *op) { if (op->getNumOperands() == 1) { auto *argumentOp = op->getOperand(0).getDefiningOp(); diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -217,21 +217,6 @@ SmallVectorImpl &results) { SmallVector operandConstants; - // If this is a commutative operation, move constants to be trailing operands. - bool updatedOpOperands = false; - if (op->getNumOperands() >= 2 && op->hasTrait()) { - auto isNonConstant = [&](OpOperand &o) { - return !matchPattern(o.get(), m_Constant()); - }; - auto *firstConstantIt = - llvm::find_if_not(op->getOpOperands(), isNonConstant); - auto *newConstantIt = std::stable_partition( - firstConstantIt, op->getOpOperands().end(), isNonConstant); - - // Remember if we actually moved anything. - updatedOpOperands = firstConstantIt != newConstantIt; - } - // Check to see if any operands to the operation is constant and whether // the operation knows how to constant fold itself. operandConstants.assign(op->getNumOperands(), Attribute()); @@ -244,7 +229,7 @@ SmallVector foldResults; if (failed(op->fold(operandConstants, foldResults)) || failed(processFoldResults(op, results, foldResults))) - return success(updatedOpOperands); + return failure(); return success(); } diff --git a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir --- a/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/FuncToLLVM/calling-convention.mlir @@ -127,7 +127,7 @@ // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[CALL_RES]][0] : !llvm.struct<(i64, ptr)> - // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] + // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]] // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]] // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]] @@ -159,7 +159,7 @@ // CHECK: %[[PTR_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] - // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] + // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[RANK]], %[[TWO]] // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]] // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]] diff --git a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir --- a/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir +++ b/mlir/test/Conversion/TensorToLinalg/tensor-ops-to-linalg.mlir @@ -27,7 +27,7 @@ // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32> // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[C2]], %[[OFFSET]] : index +// CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[OFFSET]], %[[C2]] : index // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> // CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index