diff --git a/mlir/include/mlir/Support/DebugStringHelper.h b/mlir/include/mlir/Support/DebugStringHelper.h --- a/mlir/include/mlir/Support/DebugStringHelper.h +++ b/mlir/include/mlir/Support/DebugStringHelper.h @@ -28,7 +28,7 @@ static std::string debugString(T &&op) { std::string instrStr; llvm::raw_string_ostream os(instrStr); - op.print(os); + os << op; return os.str(); } diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -111,6 +111,38 @@ let hasCustomAssemblyFormat = 1; } +// Test simple extern 1D vector using ElementsAttrInterface. +def TestExtern1DI64ElementsAttr : Test_Attr<"TestExtern1DI64Elements", [ + ElementsAttrInterface + ]> { + let mnemonic = "e1di64_elements"; + let parameters = (ins + AttributeSelfTypeParameter<"", "::mlir::ShapedType">:$type, + + // Extern ArrayRef whose pointer rather than contents is uniqued and which + // require the user ensures the lifetime of the array to exceed context. + AttrOrTypeParameter<"PtrI64Array", "pointer to I64 array">:$elements + ); + let builders = [ + AttrBuilder<(ins "llvm::ArrayRef":$value), [{ + return get($_ctxt, + RankedTensorType::get({static_cast(value.size())}, + Builder($_ctxt).getI64Type()), + {&value.front(), value.size()}); + }]>]; + let extraClassDeclaration = [{ + /// The set of data types that can be iterated by this attribute. + using ContiguousIterableTypesT = std::tuple; + + /// Provide begin iterators for the various iterable types. + // * int64_t + auto value_begin_impl(OverloadToken) const { + return getElements().begin(); + } + }]; + let hasCustomAssemblyFormat = 1; +} + def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [ DeclareAttrInterfaceMethods diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.h b/mlir/test/lib/Dialect/Test/TestAttributes.h --- a/mlir/test/lib/Dialect/Test/TestAttributes.h +++ b/mlir/test/lib/Dialect/Test/TestAttributes.h @@ -25,6 +25,13 @@ #include "TestAttrInterfaces.h.inc" #include "TestOpEnums.h.inc" +// Class that is effectively ArrayRef but just with new hash value only +// considering base pointer and size rather than Value. +class PtrI64Array : public llvm::ArrayRef { +public: + using llvm::ArrayRef::ArrayRef; +}; + #define GET_ATTRDEF_CLASSES #include "TestAttrDefs.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -14,6 +14,7 @@ #include "TestAttributes.h" #include "TestDialect.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Types.h" #include "mlir/Support/LogicalResult.h" @@ -25,6 +26,13 @@ using namespace mlir; using namespace test; +namespace llvm { +hash_code hash_value(PtrI64Array s) { + // Hash only based on base pointer and size. + return hash_combine(&s.front(), s.size()); +} +} // namespace llvm + //===----------------------------------------------------------------------===// // AttrWithSelfTypeParamAttr //===----------------------------------------------------------------------===// @@ -127,6 +135,32 @@ return success(); } +Attribute TestExtern1DI64ElementsAttr::parse(AsmParser &parser, Type type) { + SmallVector elements; + if (parser.parseLess() || parser.parseLSquare()) + return Attribute(); + int64_t intVal; + while (succeeded(*parser.parseOptionalInteger(intVal))) { + elements.push_back(intVal); + if (parser.parseOptionalComma()) + break; + } + + if (parser.parseRSquare() || parser.parseGreater()) + return Attribute(); + // Return a ArrayAttr instead. This means that roundtripping through textual + // form loses the extern'ness of the attribute. + type = RankedTensorType::get({static_cast(elements.size())}, + parser.getBuilder().getI64Type()); + return DenseIntElementsAttr::get(type, elements); +} + +void TestExtern1DI64ElementsAttr::print(mlir::AsmPrinter &printer) const { + printer << "<["; + llvm::interleaveComma(getElements(), printer); + printer << "]>"; +} + LogicalResult TestAttrWithFormatAttr::verify(function_ref emitError, int64_t one, std::string two, IntegerAttr three, diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -6,8 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "../../test/lib/Dialect/Test/TestDialect.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Support/DebugStringHelper.h" #include "gtest/gtest.h" using namespace mlir; @@ -250,4 +254,50 @@ EXPECT_TRUE(zeroStringValue.getType() == stringTy); } +TEST(ExternConstantTest, Simple) { + MLIRContext context; + context.loadDialect(); + + SmallVector val = {10, 20}; + + // Create extern constant attribute. + auto attr = test::TestExtern1DI64ElementsAttr::get(&context, val); + for (auto it : llvm::zip(val, attr.getElements())) + EXPECT_EQ(std::get<0>(it), std::get<1>(it)); + // Verify that it is view into SmallVector above/it is same backing data. + val = {30, 40}; + for (auto it : llvm::zip(val, attr.getElements())) + EXPECT_EQ(std::get<0>(it), std::get<1>(it)); + + // Print and parse. + auto parsedAttr = mlir::parseAttribute(debugString(attr), &context) + .cast(); + // Verify that we get the same uniqued Attribute from parsing as if + // constructed as DenseIntElementsAttr. + auto denseAttr = DenseIntElementsAttr::get( + RankedTensorType::get({2}, IntegerType::get(&context, 64)), val); + EXPECT_EQ(denseAttr, parsedAttr); + + // Verify that the parsed attribute is a DenseElementsAttr and doesn't share + // the same memory as val. + val = {50, 60}; + for (auto it : llvm::zip(parsedAttr.getValues(), attr.getElements())) + EXPECT_NE(std::get<0>(it), std::get<1>(it)); + + // The extern constants use the pointer to the data as key and hence two + // different arrays with the same value would not be considered equal. + SmallVector val2 = val; + EXPECT_EQ(attr, test::TestExtern1DI64ElementsAttr::get(&context, val)); + EXPECT_EQ(test::TestExtern1DI64ElementsAttr::get(&context, val), + test::TestExtern1DI64ElementsAttr::get(&context, val)); + EXPECT_NE(attr, test::TestExtern1DI64ElementsAttr::get(&context, val2)); + EXPECT_NE(attr, test::TestExtern1DI64ElementsAttr::get(&context, val2)); + + // Verify building a tensor constant using extern constant. + OpBuilder b(&context); + auto constantOp = b.create(b.getUnknownLoc(), attr); + ASSERT_TRUE(succeeded(constantOp.verify())); + ASSERT_TRUE(constantOp.getType().isa()); +} + } // namespace