diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -178,15 +178,15 @@ // splitting the Standard dialect. let results = (outs /*SignlessIntegerOrFloatLike*/AnyType:$result); - let builders = [ - OpBuilder<(ins "Attribute":$value, "Type":$type), - [{ build($_builder, $_state, type, value); }]>, - ]; - let extraClassDeclaration = [{ /// Whether the constant op can be constructed with a particular value and /// type. static bool isBuildableWith(Attribute value, Type type); + + /// Build the constant op with `value` and `type` if possible, otherwise + /// returns null. + static ConstantOp materialize(OpBuilder &builder, Attribute value, + Type type, Location loc); }]; let hasFolder = 1; diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -225,7 +225,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return arith::ConstantOp::materialize(builder, value, type, loc); } /// A utility function to check if a value is defined at the top level of an diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -49,5 +49,5 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, value, type); + return ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -185,6 +185,13 @@ return value.isa(); } +ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, + Type type, Location loc) { + if (isBuildableWith(value, type)) + return builder.create(loc, cast(value)); + return nullptr; +} + OpFoldResult arith::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, diff --git a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp --- a/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp @@ -36,9 +36,7 @@ return builder.create(loc, type, value.cast()); } - if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, value); - return nullptr; + return arith::ConstantOp::materialize(builder, value, type, loc); } #define GET_ATTRDEF_CLASSES diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -2109,5 +2109,5 @@ Operation *LinalgDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1688,8 +1688,8 @@ } // Create a constant scalar value from the splat constant. - Value scalarConstant = rewriter.create( - def->getLoc(), constantAttr, constantAttr.getType()); + Value scalarConstant = + rewriter.create(def->getLoc(), constantAttr); SmallVector outputOperands = genericOp.getOutputs(); auto fusedOp = rewriter.create( diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -522,5 +522,5 @@ Operation *math::MathDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, value, type); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -74,9 +74,7 @@ Operation *MemRefDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, value, type); - return nullptr; + return arith::ConstantOp::materialize(builder, value, type, loc); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -154,9 +154,7 @@ return builder.create(loc, type, value.cast()); if (type.isa()) return builder.create(loc, type, value.cast()); - if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, value); - return nullptr; + return arith::ConstantOp::materialize(builder, value, type, loc); } LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -38,8 +38,8 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - if (arith::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, value, type); + if (auto op = arith::ConstantOp::materialize(builder, value, type, loc)) + return op; if (complex::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, value.cast()); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -280,7 +280,7 @@ Operation *VectorDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return arith::ConstantOp::materialize(builder, value, type, loc); } IntegerType vector::getVectorSubscriptType(Builder &builder) {