diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -721,16 +721,34 @@ ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands(); SmallVector lhsOperandStorage, rhsOperandStorage; if (lhs->hasTrait()) { - lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end()); - llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - }); - lhsOperands = lhsOperandStorage; + auto sortValues = [](ValueRange values) { + SmallVector sortedValues = llvm::to_vector(values); + llvm::sort(sortedValues, [](Value a, Value b) { + auto aArg = a.dyn_cast(); + auto bArg = b.dyn_cast(); + + // Case 1. Both `a` and `b` are `BlockArgument`s. + if (aArg && bArg) { + if (aArg.getParentBlock() == bArg.getParentBlock()) + return aArg.getArgNumber() < bArg.getArgNumber(); + return aArg.getParentBlock() < bArg.getParentBlock(); + } - rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end()); - llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - }); + // Case 2. One of then is a `BlockArgument` and other is not. Treat + // `BlockArgument` as lesser. + if (aArg && !bArg) + return true; + if (bArg && !aArg) + return false; + + // Case 3. Both are values. + return a.getAsOpaquePointer() < b.getAsOpaquePointer(); + }); + return sortedValues; + }; + lhsOperandStorage = sortValues(lhsOperands); + lhsOperands = lhsOperandStorage; + rhsOperandStorage = sortValues(rhsOperands); rhsOperands = rhsOperandStorage; } auto checkValueRangeMapping = diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -47,11 +47,70 @@ if (lhs == getTombstoneKey() || lhs == getEmptyKey() || rhs == getTombstoneKey() || rhs == getEmptyKey()) return false; + + // If op has no regions, operation equivalence w.r.t operands alone is + // enough. + if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) { + return OperationEquivalence::isEquivalentTo( + const_cast(lhsC), const_cast(rhsC), + OperationEquivalence::exactValueMatch, + OperationEquivalence::ignoreValueEquivalence, + OperationEquivalence::IgnoreLocations); + } + + // If lhs or rhs does not have a single region with a single block, they + // aren't CSEed for now. + if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 || + !llvm::hasSingleElement(lhs->getRegion(0)) || + !llvm::hasSingleElement(rhs->getRegion(0))) + return false; + + // Compare the two blocks. + Block &lhsBlock = lhs->getRegion(0).front(); + Block &rhsBlock = rhs->getRegion(0).front(); + + // Don't CSE if number of arguments differ. + if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments()) + return false; + + // Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s in + // `rhsBlock`. `Value`s from `lhsBlock` are the key. + DenseMap areEquivalentValues; + for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(), + rhs->getRegion(0).getArguments())) { + areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs); + } + + // Helper function to get the parent operation. + auto getParent = [](Value v) -> Operation * { + if (auto blockArg = v.dyn_cast()) + return blockArg.getParentBlock()->getParentOp(); + return v.getDefiningOp()->getParentOp(); + }; + + // Callback to compare if operands of ops in the region of `lhs` and `rhs` + // are equivalent. + auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult { + if (lhsValue == rhsValue) + return success(); + if (areEquivalentValues.lookup(lhsValue) == rhsValue) + return success(); + return failure(); + }; + + // Callback to compare if results of ops in the region of `lhs` and `rhs` + // are equivalent. + auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult { + if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) { + auto insertion = areEquivalentValues.insert({lhsResult, rhsResult}); + return success(insertion.first->second == rhsResult); + } + return success(); + }; + return OperationEquivalence::isEquivalentTo( const_cast(lhsC), const_cast(rhsC), - /*mapOperands=*/OperationEquivalence::exactValueMatch, - /*mapResults=*/OperationEquivalence::ignoreValueEquivalence, - OperationEquivalence::IgnoreLocations); + mapOperands, mapResults, OperationEquivalence::IgnoreLocations); } }; } // namespace @@ -204,7 +263,8 @@ // Don't simplify operations with nested blocks. We don't currently model // equality comparisons correctly among other things. It is also unclear // whether we would want to CSE such operations. - if (op->getNumRegions() != 0) + if (!(op->getNumRegions() == 0 || + (op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0))))) return failure(); // Some simple use case of operation with memory side-effect are dealt with diff --git a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir --- a/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir +++ b/mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir @@ -17,7 +17,6 @@ // CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T4]] : memref<16xindex>) // CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64> // CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref -// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[T6]] : memref<16xf64>) // CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>) // CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<1xindex> // CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]] diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -322,3 +322,127 @@ %3 = arith.muli %1, %2 : i32 return %3 : i32 } + +// Check that an operation with a single region can CSE. +func.func @cse_single_block_ops(%a : tensor, %b : tensor) + -> (tensor, tensor) { + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32): + test.region_yield %arg0 : f32 + } : tensor, tensor -> tensor + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32): + test.region_yield %arg0 : f32 + } : tensor, tensor -> tensor + return %0, %1 : tensor, tensor +} +// CHECK-LABEL: func @cse_single_block_ops +// CHECK: %[[OP:.+]] = test.cse_of_single_block_op +// CHECK-NOT: test.cse_of_single_block_op +// CHECK: return %[[OP]], %[[OP]] + +// Operations with different number of bbArgs dont CSE. +func.func @no_cse_varied_bbargs(%a : tensor, %b : tensor) + -> (tensor, tensor) { + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + test.region_yield %arg0 : f32 + } : tensor, tensor -> tensor + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32): + test.region_yield %arg0 : f32 + } : tensor, tensor -> tensor + return %0, %1 : tensor, tensor +} +// CHECK-LABEL: func @no_cse_varied_bbargs +// CHECK: %[[OP0:.+]] = test.cse_of_single_block_op +// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op +// CHECK: return %[[OP0]], %[[OP1]] + +// Operations with different regions dont CSE +func.func @no_cse_region_difference_simple(%a : tensor, %b : tensor) + -> (tensor, tensor) { + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + test.region_yield %arg0 : f32 + } : tensor, tensor -> tensor + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + test.region_yield %arg1 : f32 + } : tensor, tensor -> tensor + return %0, %1 : tensor, tensor +} +// CHECK-LABEL: func @no_cse_region_difference_simple +// CHECK: %[[OP0:.+]] = test.cse_of_single_block_op +// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op +// CHECK: return %[[OP0]], %[[OP1]] + +// Operation with identical region with multiple statements CSE. +func.func @cse_single_block_ops_identical_bodies(%a : tensor, %b : tensor, %c : f32, %d : i1) + -> (tensor, tensor) { + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.divf %arg0, %arg1 : f32 + %2 = arith.remf %arg0, %c : f32 + %3 = arith.select %d, %1, %2 : f32 + test.region_yield %3 : f32 + } : tensor, tensor -> tensor + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.divf %arg0, %arg1 : f32 + %2 = arith.remf %arg0, %c : f32 + %3 = arith.select %d, %1, %2 : f32 + test.region_yield %3 : f32 + } : tensor, tensor -> tensor + return %0, %1 : tensor, tensor +} +// CHECK-LABEL: func @cse_single_block_ops_identical_bodies +// CHECK: %[[OP:.+]] = test.cse_of_single_block_op +// CHECK-NOT: test.cse_of_single_block_op +// CHECK: return %[[OP]], %[[OP]] + +// Operation with non-identical regions dont CSE. +func.func @no_cse_single_block_ops_different_bodies(%a : tensor, %b : tensor, %c : f32, %d : i1) + -> (tensor, tensor) { + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.divf %arg0, %arg1 : f32 + %2 = arith.remf %arg0, %c : f32 + %3 = arith.select %d, %1, %2 : f32 + test.region_yield %3 : f32 + } : tensor, tensor -> tensor + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.divf %arg0, %arg1 : f32 + %2 = arith.remf %arg0, %c : f32 + %3 = arith.select %d, %2, %1 : f32 + test.region_yield %3 : f32 + } : tensor, tensor -> tensor + return %0, %1 : tensor, tensor +} +// CHECK-LABEL: func @no_cse_single_block_ops_different_bodies +// CHECK: %[[OP0:.+]] = test.cse_of_single_block_op +// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op +// CHECK: return %[[OP0]], %[[OP1]] + +// Account for commutative ops within regions during CSE. +func.func @cse_single_block_with_commutative_ops(%a : tensor, %b : tensor, %c : f32) + -> (tensor, tensor) { + %0 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.addf %arg0, %arg1 : f32 + %2 = arith.mulf %1, %c : f32 + test.region_yield %2 : f32 + } : tensor, tensor -> tensor + %1 = test.cse_of_single_block_op inputs(%a, %b) { + ^bb0(%arg0 : f32, %arg1 : f32): + %1 = arith.addf %arg1, %arg0 : f32 + %2 = arith.mulf %c, %1 : f32 + test.region_yield %2 : f32 + } : tensor, tensor -> tensor + return %0, %1 : tensor, tensor +} +// CHECK-LABEL: func @cse_single_block_with_commutative_ops +// CHECK: %[[OP:.+]] = test.cse_of_single_block_op +// CHECK-NOT: test.cse_of_single_block_op +// CHECK: return %[[OP]], %[[OP]] diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -670,8 +670,8 @@ // Produces an error value on the error path def TestInternalBranchOp : TEST_Op<"internal_br", - [DeclareOpInterfaceMethods, Terminator, - AttrSizedOperandSegments]> { + [DeclareOpInterfaceMethods, Terminator, + AttrSizedOperandSegments]> { let arguments = (ins Variadic:$successOperands, Variadic:$errorOperands); @@ -3045,4 +3045,19 @@ let regions = (region SizedRegion<1>:$body); } +//===---------------------------------------------------------------------===// +// Test CSE +//===---------------------------------------------------------------------===// + +def TestCSEOfSingleBlockOp : TEST_Op<"cse_of_single_block_op", + [SingleBlockImplicitTerminator<"RegionYieldOp">, Pure]> { + let arguments = (ins Variadic:$inputs); + let results = (outs Variadic:$outputs); + let regions = (region SizedRegion<1>:$region); + let assemblyFormat = [{ + attr-dict `inputs` `(` $inputs `)` + $region `:` type($inputs) `->` type($outputs) + }]; +} + #endif // TEST_OPS