diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def --- a/mlir/include/mlir/IR/DialectSymbolRegistry.def +++ b/mlir/include/mlir/IR/DialectSymbolRegistry.def @@ -25,7 +25,6 @@ DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect DEFINE_SYM_KIND_RANGE(SHAPE) // Shape dialect -DEFINE_SYM_KIND_RANGE(TEST) // Test dialect // The following ranges are reserved for experimenting with MLIR dialects in a // private context without having to register them here. diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -619,34 +619,37 @@ // If the data buffer is non-empty, we copy it into the allocator with a // 64-bit alignment. ArrayRef copy, data = key.data; - if (!data.empty()) { - int numEntries = key.isSplat ? 1 : data.size(); + if (data.empty()) { + return new (allocator.allocate()) + DenseStringElementsAttributeStorage(key.type, copy, key.isSplat); + } - // Compute the amount data needed to store the ArrayRef and StringRef - // contents. - size_t dataSize = sizeof(ArrayRef) * numEntries; - for (int i = 0; i < numEntries; i++) { - dataSize += data[i].size(); - } + int numEntries = key.isSplat ? 1 : data.size(); - char *rawData = reinterpret_cast( - allocator.allocate(dataSize, alignof(uint64_t))); + // Compute the amount data needed to store the ArrayRef and StringRef + // contents. + size_t dataSize = sizeof(ArrayRef) * numEntries; + for (int i = 0; i < numEntries; i++) { + dataSize += data[i].size(); + } - // Setup the ArrayRef - auto mutable_copy = MutableArrayRef( - reinterpret_cast(rawData), numEntries); - auto stringData = rawData + numEntries * sizeof(StringRef); + char *rawData = reinterpret_cast( + allocator.allocate(dataSize, alignof(uint64_t))); - for (int i = 0; i < numEntries; i++) { - memcpy(stringData, data[i].data(), data[i].size()); - mutable_copy[i] = StringRef(stringData, data[i].size()); - stringData += data[i].size(); - } + // Setup the ArrayRef + auto mutableCopy = MutableArrayRef( + reinterpret_cast(rawData), numEntries); + auto stringData = rawData + numEntries * sizeof(StringRef); - copy = ArrayRef(reinterpret_cast(rawData), - numEntries); + for (int i = 0; i < numEntries; i++) { + memcpy(stringData, data[i].data(), data[i].size()); + mutableCopy[i] = StringRef(stringData, data[i].size()); + stringData += data[i].size(); } + copy = + ArrayRef(reinterpret_cast(rawData), numEntries); + return new (allocator.allocate()) DenseStringElementsAttributeStorage(key.type, copy, key.isSplat); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -2299,9 +2299,8 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) { consumeToken(Token::kw_dense); - if (parseToken(Token::less, "expected '<' after 'dense'")) { + if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; - } // Parse the literal data. TensorLiteralParser literalParser(*this); diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -397,7 +397,7 @@ func @simple_scalar_example() { "test.string_elements_attr"() { // CHECK: dense<"example"> - scalar_string_attr = dense<"example"> : tensor<2x!test.custom_type> + scalar_string_attr = dense<"example"> : tensor<2x!unknown<"">> } : () -> () return } @@ -407,7 +407,7 @@ func @escape_string_example() { "test.string_elements_attr"() { // CHECK: dense<"new\0Aline"> - scalar_string_attr = dense<"new\nline"> : tensor<2x!test.custom_type> + scalar_string_attr = dense<"new\nline"> : tensor<2x!unknown<"">> } : () -> () return } @@ -417,7 +417,7 @@ func @simple_scalar_example() { "test.string_elements_attr"() { // CHECK: dense<["example1", "example2"]> - scalar_string_attr = dense<["example1", "example2"]> : tensor<2x!test.custom_type> + scalar_string_attr = dense<["example1", "example2"]> : tensor<2x!unknown<"">> } : () -> () return } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -30,38 +30,8 @@ namespace mlir { -namespace TestTypes { -enum Kind { - FIRST_USED_TEST_TYPE = Type::FIRST_TEST_TYPE, - CustomTestType, - LAST_USED_TEST_TYPE -}; -} // namespace TestTypes - #include "TestOpsDialect.h.inc" -class TestType : public Type { -public: - using Type::Type; - - static bool classof(Type type) { - return type.getKind() >= TestTypes::FIRST_USED_TEST_TYPE && - type.getKind() <= TestTypes::LAST_USED_TEST_TYPE; - } -}; - -class CustomTestType : public Type::TypeBase { -public: - using Base::Base; - static CustomTestType get(MLIRContext *context) { - return Base::get(context, TestTypes::CustomTestType); - } - - static bool kindof(unsigned kind) { - return kind == TestTypes::CustomTestType; - } -}; - #define GET_OP_CLASSES #include "TestOps.h.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -8,7 +8,6 @@ #include "TestDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/IR/PatternMatch.h" @@ -130,8 +129,6 @@ TestDialect::TestDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addTypes(); - addOperations< #define GET_OP_LIST #include "TestOps.cpp.inc" @@ -166,27 +163,6 @@ return success(); } -Type TestDialect::parseType(DialectAsmParser &parser) const { - Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - llvm::StringRef spec = parser.getFullSymbolSpec(); - if (spec == "custom_type") { - return CustomTestType::get(getContext()); - } - emitError(loc, "unknown TestDialect type:") << spec; - - return Type(); -} - -void TestDialect::printType(Type type, DialectAsmPrinter &os) const { - switch (type.getKind()) { - case TestTypes::CustomTestType: - os << "custom_type"; - break; - default: - llvm_unreachable("unhandle test dialect type"); - } -} - //===----------------------------------------------------------------------===// // TestBranchOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -32,14 +32,6 @@ // Test Types //===----------------------------------------------------------------------===// -def Test_CustomTestType : DialectType< - Test_Dialect, - CPred<"$_self.isa()">, "custom_type"> { - let typeDescription = [{ - Custom type example. - }]; -} - def IntTypesOp : TEST_Op<"int_types"> { let results = (outs AnyI16:$any_i16,