diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -720,16 +720,30 @@ ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands(); SmallVector lhsOperandStorage, rhsOperandStorage; if (lhs->hasTrait()) { - lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end()); - llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - }); + auto sortValues = [](ValueRange values) { + SmallVector blockArgs; + SmallVector sortedValues; + for (auto value : values) { + if (auto argValue = value.dyn_cast()) + blockArgs.push_back(argValue); + else + sortedValues.push_back(value); + } + llvm::sort(blockArgs, [](BlockArgument a, BlockArgument b) { + if (a.getParentBlock() == b.getParentBlock()) { + return a.getArgNumber() < b.getArgNumber(); + } + return a.getParentBlock() < b.getParentBlock(); + }); + llvm::sort(sortedValues, [](Value a, Value b) { + return a.getAsOpaquePointer() < b.getAsOpaquePointer(); + }); + sortedValues.append(blockArgs.begin(), blockArgs.end()); + return sortedValues; + }; + lhsOperandStorage = sortValues(lhsOperands); lhsOperands = lhsOperandStorage; - - rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end()); - llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - }); + rhsOperandStorage = sortValues(rhsOperands); rhsOperands = rhsOperandStorage; } auto checkValueRangeMapping = diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -14,6 +14,7 @@ #include "mlir/Transforms/Passes.h" #include "mlir/IR/Dominance.h" +#include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/DenseMapInfo.h" @@ -47,11 +48,71 @@ if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; + + // Check for equivalence of ops if they have a single region with a single + // block. If op has no regions, nothing else to do. + if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) { + return OperationEquivalence::isEquivalentTo( + const_cast(lhsC), const_cast(rhsC), + OperationEquivalence::exactValueMatch, + OperationEquivalence::ignoreValueEquivalence, + OperationEquivalence::IgnoreLocations); + } + + // If lhs or rhs does not have a single region with a single block, they + // arent CSEed for now. + if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 || + !llvm::hasSingleElement(lhs->getRegion(0)) || + !llvm::hasSingleElement(rhs->getRegion(0))) + return false; + + // Compare the two blocks. + Block &lhsBlock = lhs->getRegion(0).front(); + Block &rhsBlock = rhs->getRegion(0).front(); + + // If number of arguments differ, not CSEed + if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments()) + return false; + + // Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s in + // `rhsBlock`. `Value`s from `lhsBlock` are the key. + llvm::DenseMap areEquivalentValues; + for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(), + rhs->getRegion(0).getArguments())) { + areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs); + } + + // Helper function to get the parent operation. + auto getParent = [](Value v) -> Operation * { + if (auto blockArg = v.dyn_cast()) + return blockArg.getParentBlock()->getParentOp(); + return v.getDefiningOp()->getParentOp(); + }; + + // Callback to compare if operands of ops in the region of `lhs` and `rhs` + // are equivalent. + auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult { + if (lhsValue == rhsValue) + return success(); + if (getParent(lhsValue) == lhs && getParent(rhsValue) == rhs && + areEquivalentValues.lookup(lhsValue) == rhsValue) + return success(); + return failure(); + }; + + // Callback to compare if results of ops in the region of `lhs` and `rhs` + // are equivalent. + auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult { + if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) { + auto insertion = areEquivalentValues.insert({lhsResult, rhsResult}); + return success(insertion.first->second == rhsResult); + } + return success(); + }; + return OperationEquivalence::isEquivalentTo( const_cast(lhsC), const_cast(rhsC), - /*mapOperands=*/OperationEquivalence::exactValueMatch, - /*mapResults=*/OperationEquivalence::ignoreValueEquivalence, - OperationEquivalence::IgnoreLocations); + mapOperands, mapResults, OperationEquivalence::IgnoreLocations); } }; } // namespace @@ -204,7 +265,8 @@ // Don't simplify operations with nested blocks. We don't currently model // equality comparisons correctly among other things. It is also unclear // whether we would want to CSE such operations. - if (op->getNumRegions() != 0) + if (!(op->getNumRegions() == 0 || + (op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0))))) return failure(); // Some simple use case of operation with memory side-effect are dealt with diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -322,3 +322,127 @@ %3 = arith.muli %1, %2 : i32 return %3 : i32 } + +// Check that an operation with a single region can CSE. +// CHECK-LABEL: func @cse_single_block_ops +func.func @cse_single_block_ops(%a : tensor, %b : tensor) + -> (tensor, tensor) { + // CHECK: %[[OP:.+]] = test.cse_of_single_block_op + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32): + test.region_yield %arg0 : f32 + } : tensor, tensor -> tensor + // CHECK-NOT: test.cse_of_single_block_op + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32): + test.region_yield %arg0 : f32 + } : tensor, tensor -> tensor + // CHECK: return %[[OP]], %[[OP]] + return %0, %1 : tensor, tensor +} + +// Operations with different number of bbArgs dont CSE. +// CHECK-LABEL: func @no_cse_varied_bbargs +func.func @no_cse_varied_bbargs(%a : tensor, %b : tensor) + -> (tensor, tensor) { + // CHECK: %[[OP0:.+]] = test.cse_of_single_block_op + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + test.region_yield %arg0 : f32 + } : tensor, tensor -> tensor + // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32): + test.region_yield %arg0 : f32 + } : tensor, tensor -> tensor + // CHECK: return %[[OP0]], %[[OP1]] + return %0, %1 : tensor, tensor +} + +// Operations with different regions dont CSE +// CHECK-LABEL: func @no_cse_region_difference_simple +func.func @no_cse_region_difference_simple(%a : tensor, %b : tensor) + -> (tensor, tensor) { + // CHECK: %[[OP0:.+]] = test.cse_of_single_block_op + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + test.region_yield %arg0 : f32 + } : tensor, tensor -> tensor + // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + test.region_yield %arg1 : f32 + } : tensor, tensor -> tensor + // CHECK: return %[[OP0]], %[[OP1]] + return %0, %1 : tensor, tensor +} + +// Operation with identical region with multiple statements CSE. +// CHECK-LABEL: func @cse_single_block_ops_identical_bodies +func.func @cse_single_block_ops_identical_bodies(%a : tensor, %b : tensor, %c : f32, %d : i1) + -> (tensor, tensor) { + // CHECK: %[[OP:.+]] = test.cse_of_single_block_op + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.divf %arg0, %arg1 : f32 + %2 = arith.remf %arg0, %c : f32 + %3 = arith.select %d, %1, %2 : f32 + test.region_yield %3 : f32 + } : tensor, tensor -> tensor + // CHECK-NOT: test.cse_of_single_block_op + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.divf %arg0, %arg1 : f32 + %2 = arith.remf %arg0, %c : f32 + %3 = arith.select %d, %1, %2 : f32 + test.region_yield %3 : f32 + } : tensor, tensor -> tensor + // CHECK: return %[[OP]], %[[OP]] + return %0, %1 : tensor, tensor +} + +// Operation with non-identical regions dont CSE. +// CHECK-LABEL: func @no_cse_single_block_ops_different_bodies +func.func @no_cse_single_block_ops_different_bodies(%a : tensor, %b : tensor, %c : f32, %d : i1) + -> (tensor, tensor) { + // CHECK: %[[OP0:.+]] = test.cse_of_single_block_op + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.divf %arg0, %arg1 : f32 + %2 = arith.remf %arg0, %c : f32 + %3 = arith.select %d, %1, %2 : f32 + test.region_yield %3 : f32 + } : tensor, tensor -> tensor + // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.divf %arg0, %arg1 : f32 + %2 = arith.remf %arg0, %c : f32 + %3 = arith.select %d, %2, %1 : f32 + test.region_yield %3 : f32 + } : tensor, tensor -> tensor + // CHECK: return %[[OP0]], %[[OP1]] + return %0, %1 : tensor, tensor +} + +// Account for commutative ops within regions during CSE. +// CHECK-LABEL: func @cse_single_block_with_commutative_ops +func.func @cse_single_block_with_commutative_ops(%a : tensor, %b : tensor, %c : f32) + -> (tensor, tensor) { + // CHECK: %[[OP:.+]] = test.cse_of_single_block_op + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.addf %arg0, %arg1 : f32 + %2 = arith.mulf %1, %c : f32 + test.region_yield %2 : f32 + } : tensor, tensor -> tensor + // CHECK-NOT: test.cse_of_single_block_op + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.addf %arg1, %arg0 : f32 + %2 = arith.mulf %c, %1 : f32 + test.region_yield %2 : f32 + } : tensor, tensor -> tensor + // CHECK: return %[[OP]], %[[OP]] + return %0, %1 : tensor, tensor +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -707,8 +707,8 @@ // Produces an error value on the error path def TestInternalBranchOp : TEST_Op<"internal_br", - [DeclareOpInterfaceMethods, Terminator, - AttrSizedOperandSegments]> { + [DeclareOpInterfaceMethods, Terminator, + AttrSizedOperandSegments]> { let arguments = (ins Variadic:$successOperands, Variadic:$errorOperands); @@ -3014,4 +3014,20 @@ let assemblyFormat = "attr-dict $value"; } + +//===---------------------------------------------------------------------===// +// Test CSE +//===---------------------------------------------------------------------===// + +def TestCseOfSingleBlockOp : TEST_Op<"cse_of_single_block_op", + [SingleBlockImplicitTerminator<"RegionYieldOp">, NoSideEffect]> { + let arguments = (ins Variadic:$inputs); + let results = (outs Variadic:$outputs); + let regions = (region SizedRegion<1>:$region); + let assemblyFormat = [{ + attr-dict `inputs` `(` $inputs `)` + $region `:` type($inputs) `->` type($outputs) + }]; +} + #endif // TEST_OPS