diff --git a/mlir/include/mlir-c/StandardAttributes.h b/mlir/include/mlir-c/StandardAttributes.h --- a/mlir/include/mlir-c/StandardAttributes.h +++ b/mlir/include/mlir-c/StandardAttributes.h @@ -79,7 +79,7 @@ /** Returns the dictionary attribute element with the given name or NULL if the * given name does not exist in the dictionary. */ MLIR_CAPI_EXPORTED MlirAttribute -mlirDictionaryAttrGetElementByName(MlirAttribute attr, const char *name); +mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name); //===----------------------------------------------------------------------===// // Floating point attribute. @@ -155,15 +155,13 @@ /** Creates an opaque attribute in the given context associated with the dialect * identified by its namespace. The attribute contains opaque byte data of the * specified length (data need not be null-terminated). */ -MLIR_CAPI_EXPORTED MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, - const char *dialectNamespace, - intptr_t dataLength, - const char *data, - MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute +mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, + intptr_t dataLength, const char *data, MlirType type); /** Returns the namespace of the dialect with which the given opaque attribute * is associated. The namespace string is owned by the context. */ -MLIR_CAPI_EXPORTED const char * +MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr); /** Returns the raw data as a string reference. The data remains live as long as @@ -178,17 +176,14 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAString(MlirAttribute attr); /** Creates a string attribute in the given context containing the given string. - * The string need not be null-terminated and its length must be specified. */ + */ MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrGet(MlirContext ctx, - intptr_t length, - const char *data); + MlirStringRef str); /** Creates a string attribute in the given context containing the given string. - * The string need not be null-terminated and its length must be specified. * Additionally, the attribute has the given type. */ MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type, - intptr_t length, - const char *data); + MlirStringRef str); /** Returns the attribute values as a string reference. The data remains live as * long as the context in which the attribute lives. */ @@ -203,10 +198,9 @@ /** Creates a symbol reference attribute in the given context referencing a * symbol identified by the given string inside a list of nested references. - * Each of the references in the list must not be nested. The string need not be - * null-terminated and its length must be specified. */ + * Each of the references in the list must not be nested. */ MLIR_CAPI_EXPORTED MlirAttribute -mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length, const char *symbol, +mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, intptr_t numReferences, MlirAttribute const *references); /** Returns the string reference to the root referenced symbol. The data remains @@ -236,11 +230,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr); /** Creates a flat symbol reference attribute in the given context referencing a - * symbol identified by the given string. The string need not be null-terminated - * and its length must be specified. */ + * symbol identified by the given string. */ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, - intptr_t length, - const char *symbol); + MlirStringRef symbol); /** Returns the referenced symbol as a string reference. The data remains live * as long as the context in which the attribute lives. */ @@ -349,11 +341,10 @@ MlirType shapedType, intptr_t numElements, const double *elements); /** Creates a dense elements attribute with the given shaped type from string - * elements. The strings need not be null-terminated and their lengths are - * provided as a separate argument co-indexed with the strs argument. */ -MLIR_CAPI_EXPORTED MlirAttribute -mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements, - intptr_t const *strLengths, const char **strs); + * elements. */ +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrStringGet( + MlirType shapedType, intptr_t numElements, MlirStringRef *strs); + /** Creates a dense elements attribute that has the same data as the given dense * elements attribute and a different shaped type. The new type must have the * same total number of elements. */ diff --git a/mlir/lib/CAPI/IR/StandardAttributes.cpp b/mlir/lib/CAPI/IR/StandardAttributes.cpp --- a/mlir/lib/CAPI/IR/StandardAttributes.cpp +++ b/mlir/lib/CAPI/IR/StandardAttributes.cpp @@ -86,8 +86,8 @@ } MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, - const char *name) { - return wrap(unwrap(attr).cast().get(name)); + MlirStringRef name) { + return wrap(unwrap(attr).cast().get(unwrap(name))); } //===----------------------------------------------------------------------===// @@ -160,16 +160,16 @@ return unwrap(attr).isa(); } -MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, const char *dialectNamespace, +MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type) { - return wrap(OpaqueAttr::get(Identifier::get(dialectNamespace, unwrap(ctx)), - StringRef(data, dataLength), unwrap(type), - unwrap(ctx))); + return wrap( + OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), + StringRef(data, dataLength), unwrap(type), unwrap(ctx))); } -const char *mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { - return unwrap(attr).cast().getDialectNamespace().c_str(); +MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { + return wrap(unwrap(attr).cast().getDialectNamespace().strref()); } MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { @@ -184,14 +184,12 @@ return unwrap(attr).isa(); } -MlirAttribute mlirStringAttrGet(MlirContext ctx, intptr_t length, - const char *data) { - return wrap(StringAttr::get(StringRef(data, length), unwrap(ctx))); +MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { + return wrap(StringAttr::get(unwrap(str), unwrap(ctx))); } -MlirAttribute mlirStringAttrTypedGet(MlirType type, intptr_t length, - const char *data) { - return wrap(StringAttr::get(StringRef(data, length), unwrap(type))); +MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { + return wrap(StringAttr::get(unwrap(str), unwrap(type))); } MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { @@ -206,14 +204,14 @@ return unwrap(attr).isa(); } -MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, intptr_t length, - const char *symbol, intptr_t numReferences, +MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, + intptr_t numReferences, MlirAttribute const *references) { SmallVector refs; refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) refs.push_back(unwrap(references[i]).cast()); - return wrap(SymbolRefAttr::get(StringRef(symbol, length), refs, unwrap(ctx))); + return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx))); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { @@ -242,9 +240,8 @@ return unwrap(attr).isa(); } -MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, intptr_t length, - const char *symbol) { - return wrap(FlatSymbolRefAttr::get(StringRef(symbol, length), unwrap(ctx))); +MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { + return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx))); } MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { @@ -424,12 +421,11 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements, - intptr_t const *strLengths, - const char **strs) { + MlirStringRef *strs) { SmallVector values; values.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) - values.push_back(StringRef(strs[i], strLengths[i])); + values.push_back(unwrap(strs[i])); return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), values)); diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -732,6 +732,13 @@ strncpy(userData, data, len); } +bool stringIsEqual(const char *lhs, MlirStringRef rhs) { + if (strlen(lhs) != rhs.length) { + return false; + } + return !strncmp(lhs, rhs.data, rhs.length); +} + int printStandardAttributes(MlirContext ctx) { MlirAttribute floating = mlirFloatAttrDoubleGet(ctx, mlirF64TypeGet(ctx), 2.0); @@ -763,9 +770,10 @@ const char data[] = "abcdefghijklmnopqestuvwxyz"; MlirAttribute opaque = - mlirOpaqueAttrGet(ctx, "std", 3, data, mlirNoneTypeGet(ctx)); + mlirOpaqueAttrGet(ctx, mlirStringRefCreateFromCString("std"), 3, data, + mlirNoneTypeGet(ctx)); if (!mlirAttributeIsAOpaque(opaque) || - strcmp("std", mlirOpaqueAttrGetDialectNamespace(opaque))) + !stringIsEqual("std", mlirOpaqueAttrGetDialectNamespace(opaque))) return 4; MlirStringRef opaqueData = mlirOpaqueAttrGetData(opaque); @@ -775,7 +783,8 @@ mlirAttributeDump(opaque); // CHECK: #std.abc - MlirAttribute string = mlirStringAttrGet(ctx, 2, data + 3); + MlirAttribute string = + mlirStringAttrGet(ctx, mlirStringRefCreate(data + 3, 2)); if (!mlirAttributeIsAString(string)) return 6; @@ -786,7 +795,8 @@ mlirAttributeDump(string); // CHECK: "de" - MlirAttribute flatSymbolRef = mlirFlatSymbolRefAttrGet(ctx, 3, data + 5); + MlirAttribute flatSymbolRef = + mlirFlatSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 5, 3)); if (!mlirAttributeIsAFlatSymbolRef(flatSymbolRef)) return 8; @@ -799,7 +809,8 @@ // CHECK: @fgh MlirAttribute symbols[] = {flatSymbolRef, flatSymbolRef}; - MlirAttribute symbolRef = mlirSymbolRefAttrGet(ctx, 2, data + 8, 2, symbols); + MlirAttribute symbolRef = + mlirSymbolRefAttrGet(ctx, mlirStringRefCreate(data + 8, 2), 2, symbols); if (!mlirAttributeIsASymbolRef(symbolRef) || mlirSymbolRefAttrGetNumNestedReferences(symbolRef) != 2 || !mlirAttributeEqual(mlirSymbolRefAttrGetNestedReference(symbolRef, 0),