diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -159,7 +159,9 @@ Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + if (ConstantOp::isBuildableWith(value, type)) + return builder.create(loc, type, value); + return nullptr; } void mlir::printDimAndSymbolList(Operation::operand_iterator begin, diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -22,6 +22,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/Passes.h" +#include "llvm/ADT/SetVector.h" using namespace mlir; @@ -47,11 +48,12 @@ public: /// Initialize a lattice value with "Unknown". - LatticeValue() - : constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {} + LatticeValue() : constantAndTag(nullptr, Kind::Unknown) {} /// Initialize a lattice value with a constant. LatticeValue(Attribute attr, Dialect *dialect) - : constantAndTag(attr, Kind::Constant), constantDialect(dialect) {} + : constantAndTag(attr, Kind::Constant) { + constantDialects.insert(dialect); + } /// Returns true if this lattice value is unknown. bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; } @@ -59,7 +61,7 @@ /// Mark the lattice value as overdefined. void markOverdefined() { constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined); - constantDialect = nullptr; + constantDialects.clear(); } /// Returns true if the lattice is overdefined. @@ -70,7 +72,7 @@ /// Mark the lattice value as constant. void markConstant(Attribute value, Dialect *dialect) { constantAndTag.setPointerAndInt(value, Kind::Constant); - constantDialect = dialect; + constantDialects.insert(dialect); } /// If this lattice is constant, return the constant. Returns nullptr @@ -79,9 +81,15 @@ /// If this lattice is constant, return the dialect to use when materializing /// the constant. - Dialect *getConstantDialect() const { + ArrayRef getConstantDialects() const { assert(getConstant() && "expected valid constant"); - return constantDialect; + return constantDialects.getArrayRef(); + } + + /// Add the given dialect to the set of constant source dialects. + void addConstantDialect(Dialect *dialect) { + assert(getConstant() && "expected valid constant"); + constantDialects.insert(dialect); } /// Merge in the value of the 'rhs' lattice into this one. Returns true if the @@ -93,7 +101,7 @@ // If we are unknown, just take the value of rhs. if (isUnknown()) { constantAndTag = rhs.constantAndTag; - constantDialect = rhs.constantDialect; + constantDialects.set_union(rhs.constantDialects); return true; } @@ -102,6 +110,7 @@ markOverdefined(); return true; } + constantDialects.set_union(rhs.constantDialects); return false; } @@ -110,10 +119,12 @@ /// kind. llvm::PointerIntPair constantAndTag; - /// The dialect the constant originated from. This is only valid if the + /// The dialects the constant originated from. This is only valid if the /// lattice is a constant. This is not used as part of the key, and is only /// needed to materialize the held constant if necessary. - Dialect *constantDialect; + llvm::SetVector, + SmallPtrSet> + constantDialects; }; /// This class contains various state used when computing the lattice of a @@ -278,9 +289,17 @@ /// Returns true if the given value was marked as overdefined. bool isOverdefined(Value value) const; - /// Merge in the given lattice 'from' into the lattice 'to'. 'owner' - /// corresponds to the parent operation of 'to'. - void meet(Operation *owner, LatticeValue &to, const LatticeValue &from); + /// Merge in the given lattice 'from' into the lattice 'to'. 'op' corresponds + /// to the source operation of either 'to' or 'from' that should be added back + /// to the worklist if the meet results in a change. + void meet(Operation *op, LatticeValue &to, const LatticeValue &from); + + /// Merge in the given lattice 'from' into the lattice 'to'. 'to' corresponds + /// to a value whose uses should be visited if the meet changed. + /// 'ownerDialect' corresponds to the source dialect of 'to' or 'from'. This + /// is added to 'to' if it resolves to a constant. + void meet(Value value, Dialect *dialect, LatticeValue &to, + const LatticeValue &from); /// The lattice for each SSA value. DenseMap latticeValues; @@ -457,15 +476,16 @@ return failure(); // Attempt to materialize a constant for the given value. - Dialect *dialect = it->second.getConstantDialect(); - Value constant = folder.getOrCreateConstant(builder, dialect, attr, - value.getType(), value.getLoc()); - if (!constant) - return failure(); - - value.replaceAllUsesWith(constant); - latticeValues.erase(it); - return success(); + Value constant; + for (Dialect *dialect : it->second.getConstantDialects()) { + if (Value constant = folder.getOrCreateConstant( + builder, dialect, attr, value.getType(), value.getLoc())) { + value.replaceAllUsesWith(constant); + latticeValues.erase(it); + return success(); + } + } + return failure(); } void SCCPSolver::visitOperation(Operation *op) { @@ -537,6 +557,8 @@ Dialect *opDialect = op->getDialect(); for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { LatticeValue &resultLattice = latticeValues[op->getResult(i)]; + if (resultLattice.isOverdefined()) + continue; // Merge in the result of the fold, either a constant or a value. OpFoldResult foldResult = foldResults[i]; @@ -592,8 +614,8 @@ auto callableArgs = callableLatticeIt->second.getCallableArguments(); for (auto it : llvm::zip(callOperands, callableArgs)) { BlockArgument callableArg = std::get<1>(it); - if (latticeValues[callableArg].meet(latticeValues[std::get<0>(it)])) - visitUsers(callableArg); + meet(callableArg, op.getOperation()->getDialect(), + latticeValues[callableArg], latticeValues[std::get<0>(it)]); } // Merge in the lattice state for the callable results as well. @@ -703,8 +725,8 @@ OperandRange succOperands = getInputsForRegion(region->getRegionNumber()); for (auto it : llvm::zip(succArgs, succOperands)) { LatticeValue &argLattice = latticeValues[std::get<0>(it)]; - if (argLattice.meet(latticeValues[std::get<1>(it)])) - visitUsers(std::get<0>(it)); + meet(std::get<0>(it), parentOp->getDialect(), argLattice, + latticeValues[std::get<1>(it)]); } } } @@ -888,10 +910,29 @@ return it != latticeValues.end() && it->second.isOverdefined(); } -void SCCPSolver::meet(Operation *owner, LatticeValue &to, +void SCCPSolver::meet(Operation *op, LatticeValue &to, const LatticeValue &from) { + if (to.isOverdefined()) + return; if (to.meet(from)) - opWorklist.push_back(owner); + opWorklist.push_back(op); + + // If the value resolved to a constant, merge in the dialect from op. + if (to.getConstant()) + to.addConstantDialect(op->getDialect()); +} + +void SCCPSolver::meet(Value value, Dialect *dialect, LatticeValue &to, + const LatticeValue &from) { + bool changed = to.meet(from); + + // If the value resolved to a constant, merge in the dialect. + if (to.getConstant()) + to.addConstantDialect(dialect); + + // If 'to' changed, visit the users of the given value. + if (changed) + visitUsers(value); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/sccp-callgraph.mlir b/mlir/test/Transforms/sccp-callgraph.mlir --- a/mlir/test/Transforms/sccp-callgraph.mlir +++ b/mlir/test/Transforms/sccp-callgraph.mlir @@ -255,3 +255,28 @@ %res = call_indirect %fn() : () -> (i32) return %res : i32 } + +// ----- + +/// Check that calls with implicit conversion have constants properly +/// propagated. + +// CHECK-LABEL: func @conversion_callee +func @conversion_callee(%arg : i32) -> (i32, i32) attributes { sym_visibility = "private" } { + // CHECK-DAG: %[[CONV_CST:.*]] = "test.constant" + // CHECK-DAG: %[[CST:.*]] = constant + // CHECK: return %[[CONV_CST]], %[[CST]] + + %test_input = constant 0 : i32 + return %arg, %test_input : i32, i32 +} + +// CHECK-LABEL: func @call_with_conversion +func @call_with_conversion() -> (!foo.i16, !foo.i16) { + // CHECK: %[[CONV_CST:.*]] = "test.constant" + // CHECK: return %[[CONV_CST]], %[[CONV_CST]] + + %test_input = "test.constant"() {value = 0 : i32} : () -> (!foo.i16) + %res:2 = "test.conversion_call_op"(%test_input) { callee=@conversion_callee } : (!foo.i16) -> (!foo.i16, !foo.i16) + return %res#0, %res#1 : !foo.i16, !foo.i16 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -138,6 +138,11 @@ allowUnknownOperations(); } +Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value, + Type type, Location loc) { + return builder.create(loc, type, value); +} + LogicalResult TestDialect::verifyOperationAttribute(Operation *op, NamedAttribute namedAttr) { if (namedAttr.first == "test.invalid_attr") @@ -174,6 +179,14 @@ bool TestBranchOp::canEraseSuccessorOperand() { return true; } +//===----------------------------------------------------------------------===// +// TestConstantOp +//===----------------------------------------------------------------------===// + +OpFoldResult TestConstantOp::fold(ArrayRef operands) { + return value(); +} + //===----------------------------------------------------------------------===// // Test IsolatedRegionOp - parse passthrough region arguments. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -21,6 +21,7 @@ def Test_Dialect : Dialect { let name = "test"; let cppNamespace = ""; + let hasConstantMaterializer = 1; let hasOperationAttrVerify = 1; let hasRegionArgAttrVerify = 1; let hasRegionResultAttrVerify = 1; @@ -29,6 +30,16 @@ class TEST_Op traits = []> : Op; +//===----------------------------------------------------------------------===// +// Test Constant +//===----------------------------------------------------------------------===// + +def TestConstantOp : TEST_Op<"constant", [ConstantLike]> { + let arguments = (ins AnyAttr:$value); + let results = (outs AnyType:$result); + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // Test Types //===----------------------------------------------------------------------===//