diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -390,23 +390,16 @@ operation ::= ssa-id `=` `std.constant` attribute-value `:` type ``` - The `constant` operation produces an SSA value equal to some constant - specified by an attribute. This is the way that MLIR uses to form simple - integer and floating point constants, as well as more exotic things like - references to functions and tensor/vector constants. + The `constant` operation produces an SSA value from a symbol reference to a + `builtin.func` operation Example: ```mlir - // Complex constant - %1 = constant [1.0 : f32, 1.0 : f32] : complex - // Reference to function @myfn. %2 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32> // Equivalent generic forms - %1 = "std.constant"() {value = [1.0 : f32, 1.0 : f32] : complex} - : () -> complex %2 = "std.constant"() {value = @myfn} : () -> ((tensor<16xf32>, f32) -> tensor<16xf32>) ``` @@ -417,15 +410,9 @@ ([rationale](../Rationale/Rationale.md#multithreading-the-compiler)). }]; - let arguments = (ins AnyAttr:$value); + let arguments = (ins FlatSymbolRefAttr:$value); let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "Attribute":$value), - [{ build($_builder, $_state, value.getType(), value); }]>, - OpBuilder<(ins "Attribute":$value, "Type":$type), - [{ build($_builder, $_state, type, value); }]>, - ]; + let assemblyFormat = "attr-dict $value `:` type(results)"; let extraClassDeclaration = [{ /// Returns true if a constant operation can be built with the given value diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -435,31 +435,19 @@ LogicalResult matchAndRewrite(ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // If constant refers to a function, convert it to "addressof". - if (auto symbolRef = op.getValue().dyn_cast()) { - auto type = typeConverter->convertType(op.getResult().getType()); - if (!type || !LLVM::isCompatibleType(type)) - return rewriter.notifyMatchFailure(op, "failed to convert result type"); - - auto newOp = rewriter.create(op.getLoc(), type, - symbolRef.getValue()); - for (const NamedAttribute &attr : op->getAttrs()) { - if (attr.getName().strref() == "value") - continue; - newOp->setAttr(attr.getName(), attr.getValue()); - } - rewriter.replaceOp(op, newOp->getResults()); - return success(); + auto type = typeConverter->convertType(op.getResult().getType()); + if (!type || !LLVM::isCompatibleType(type)) + return rewriter.notifyMatchFailure(op, "failed to convert result type"); + + auto newOp = + rewriter.create(op.getLoc(), type, op.getValue()); + for (const NamedAttribute &attr : op->getAttrs()) { + if (attr.getName().strref() == "value") + continue; + newOp->setAttr(attr.getName(), attr.getValue()); } - - // Calling into other scopes (non-flat reference) is not supported in LLVM. - if (op.getValue().isa()) - return rewriter.notifyMatchFailure( - op, "referring to a symbol outside of the current module"); - - return LLVM::detail::oneToOneRewrite( - op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), - *getTypeConverter(), rewriter); + rewriter.replaceOp(op, newOp->getResults()); + return success(); } }; diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -291,7 +291,7 @@ return llvm::to_vector( llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value { if (IntegerAttr attr = std::get<1>(tuple)) - return b.create(attr); + return b.create(attr); return std::get<0>(tuple); })); }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1576,7 +1576,7 @@ isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) : DenseElementsAttr::get(outputType, intOutputValues); - rewriter.replaceOpWithNewOp(genericOp, outputAttr); + rewriter.replaceOpWithNewOp(genericOp, outputAttr); return success(); } diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -145,7 +145,7 @@ // Stitch results together into one large vector. Type resultEltType = results[0].getType().cast().getElementType(); Type resultExpandedType = VectorType::get(expandedShape, resultEltType); - Value result = builder.create( + Value result = builder.create( resultExpandedType, builder.getZeroAttr(resultExpandedType)); for (int64_t i = 0; i < maxLinearIndex; ++i) diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -115,7 +115,10 @@ Location loc) { if (arith::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, value); - return builder.create(loc, type, value); + if (ConstantOp::isBuildableWith(value, type)) + return builder.create(loc, type, + value.cast()); + return nullptr; } //===----------------------------------------------------------------------===// @@ -562,97 +565,35 @@ // ConstantOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, ConstantOp &op) { - p << " "; - p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); - - if (op->getAttrs().size() > 1) - p << ' '; - p << op.getValue(); - - // If the value is a symbol reference, print a trailing type. - if (op.getValue().isa()) - p << " : " << op.getType(); -} - -static ParseResult parseConstantOp(OpAsmParser &parser, - OperationState &result) { - Attribute valueAttr; - if (parser.parseOptionalAttrDict(result.attributes) || - parser.parseAttribute(valueAttr, "value", result.attributes)) - return failure(); - - // If the attribute is a symbol reference, then we expect a trailing type. - Type type; - if (!valueAttr.isa()) - type = valueAttr.getType(); - else if (parser.parseColonType(type)) - return failure(); - - // Add the attribute type to the list. - return parser.addTypeToList(type, result.types); -} - -/// The constant op requires an attribute, and furthermore requires that it -/// matches the return type. -static LogicalResult verify(ConstantOp &op) { - auto value = op.getValue(); - if (!value) - return op.emitOpError("requires a 'value' attribute"); - +static LogicalResult verify(ConstantOp op) { + StringRef fnName = op.getValue(); Type type = op.getType(); - if (!value.getType().isa() && type != value.getType()) - return op.emitOpError() << "requires attribute's type (" << value.getType() - << ") to match op's return type (" << type << ")"; - - if (type.isa()) { - auto fnAttr = value.dyn_cast(); - if (!fnAttr) - return op.emitOpError("requires 'value' to be a function reference"); - - // Try to find the referenced function. - auto fn = - op->getParentOfType().lookupSymbol(fnAttr.getValue()); - if (!fn) - return op.emitOpError() - << "reference to undefined function '" << fnAttr.getValue() << "'"; - - // Check that the referenced function has the correct type. - if (fn.getType() != type) - return op.emitOpError("reference to function with mismatched type"); - return success(); - } + // Try to find the referenced function. + auto fn = op->getParentOfType().lookupSymbol(fnName); + if (!fn) + return op.emitOpError() + << "reference to undefined function '" << fnName << "'"; - if (type.isa() && value.isa()) - return success(); + // Check that the referenced function has the correct type. + if (fn.getType() != type) + return op.emitOpError("reference to function with mismatched type"); - return op.emitOpError("unsupported 'value' attribute: ") << value; + return success(); } OpFoldResult ConstantOp::fold(ArrayRef operands) { assert(operands.empty() && "constant has no operands"); - return getValue(); + return getValueAttr(); } void ConstantOp::getAsmResultNames( function_ref setNameFn) { - Type type = getType(); - if (type.isa()) { - setNameFn(getResult(), "f"); - } else { - setNameFn(getResult(), "cst"); - } + setNameFn(getResult(), "f"); } -/// Returns true if a constant operation can be built with the given value and -/// result type. bool ConstantOp::isBuildableWith(Attribute value, Type type) { - // SymbolRefAttr can only be used with a function type. - if (value.isa()) - return type.isa(); - // Otherwise, this must be a UnitAttr. - return value.isa() && type.isa(); + return value.isa() && type.isa(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -307,7 +307,7 @@ return failure(); auto loc = multiReductionOp.getLoc(); - Value result = rewriter.create( + Value result = rewriter.create( loc, multiReductionOp.getDestType(), rewriter.getZeroAttr(multiReductionOp.getDestType())); int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -232,7 +232,7 @@ static LogicalResult printOperation(CppEmitter &emitter, mlir::ConstantOp constantOp) { Operation *operation = constantOp.getOperation(); - Attribute value = constantOp.getValue(); + Attribute value = constantOp.getValueAttr(); return printConstantOp(emitter, operation, value); } diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt -split-input-file %s -verify-diagnostics func @unsupported_attribute() { - // expected-error @+1 {{unsupported 'value' attribute: "" : index}} + // expected-error @+1 {{invalid kind of attribute specified}} %0 = constant "" : index return } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -99,9 +99,6 @@ // CHECK: %{{.*}} = arith.cmpf oeq, %{{.*}}, %{{.*}}: vector<4xf32> %70 = arith.cmpf oeq, %vcf32, %vcf32 : vector<4 x f32> - // CHECK: = constant unit - %73 = constant unit - // CHECK: arith.constant true %74 = arith.constant true diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -578,7 +578,7 @@ LogicalResult matchAndRewrite(ILLegalOpG op, PatternRewriter &rewriter) const final { IntegerAttr attr = rewriter.getI32IntegerAttr(0); - Value val = rewriter.create(op->getLoc(), attr); + Value val = rewriter.create(op->getLoc(), attr); rewriter.replaceOpWithNewOp(op, val); return success(); };