diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -22,6 +22,9 @@ extern "C" { #endif +/// Returns an empty attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(); + //===----------------------------------------------------------------------===// // Affine map attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -188,17 +188,20 @@ /// Checks whether the given type is an unranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type); -/// Creates a tensor type of a fixed rank with the given shape and element type -/// in the same context as the element type. The type is owned by the context. +/// Creates a tensor type of a fixed rank with the given shape, element type, +/// and optional encoding in the same context as the element type. The type is +/// owned by the context. Tensor types without any specific encoding field +/// should assign mlirAttributeGetNull() to this parameter. MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, - MlirType elementType); + MlirType elementType, + MlirAttribute encoding); /// Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on /// illegal arguments, emitting appropriate diagnostics. -MLIR_CAPI_EXPORTED MlirType -mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, - const int64_t *shape, MlirType elementType); +MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked( + MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, + MlirAttribute encoding); /// Creates an unranked tensor type with the given element type in the same /// context as the element type. The type is owned by the context. diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -636,9 +636,10 @@ Syntax: ``` - tensor-type ::= `tensor` `<` dimension-list type `>` + tensor-type ::= `tensor` `<` dimension-list type (`,` encoding)? `>` dimension-list ::= (dimension `x`)* dimension ::= `?` | decimal-literal + encoding ::= attribute-value ``` Values with tensor type represents aggregate N-dimensional data values, and @@ -654,6 +655,14 @@ [`dim` operation](Dialects/Standard.md#dim-operation) returns the size of a dimension from a value of tensor type. + The `encoding` attribute provides additional information on the tensor. + An empty attribute denotes a straightforward tensor without any specific + structure. But particular properties, like sparsity or other specific + characteristics of the data of the tensor can be encoded through this + attribute. The semantics are defined by a type and attribute interface + and must be respected by all passes that operate on tensor types. + TODO: provide these interface, and document them further + Note: hexadecimal integer literals are not allowed in tensor type declarations to avoid confusion between `0xf32` and `0 x f32`. Zero sizes are allowed in tensors and treated as other sizes, e.g., @@ -681,18 +690,24 @@ // Zero-element tensor of f32 type (hexadecimal literals not allowed here). tensor<0xf32> + + // Tensor with an encoding attribute (where #ENCODING is a named alias). + tensor ``` }]; let parameters = (ins ArrayRefParameter<"int64_t">:$shape, - "Type":$elementType + "Type":$elementType, + "Attribute":$encoding ); let builders = [ TypeBuilderWithInferredContext<(ins - "ArrayRef":$shape, "Type":$elementType + "ArrayRef":$shape, + "Type":$elementType, + CArg<"Attribute", "{}">:$encoding ), [{ - return $_get(elementType.getContext(), shape, elementType); + return $_get(elementType.getContext(), shape, elementType, encoding); }]> ]; let skipDefaultBuilders = 1; diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -502,8 +502,9 @@ MlirType mlirElementType, py::buffer_info &arrayInfo) { SmallVector shape(arrayInfo.shape.begin(), arrayInfo.shape.begin() + arrayInfo.ndim); - auto shapedType = - mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); + MlirAttribute encodingAttr = mlirAttributeGetNull(); + auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), + mlirElementType, encodingAttr); intptr_t numElements = arrayInfo.size; const ElementTy *contents = static_cast(arrayInfo.ptr); return ctor(shapedType, numElements, contents); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -10,6 +10,7 @@ #include "PybindUtils.h" +#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" namespace py = pybind11; @@ -381,8 +382,9 @@ "get", [](std::vector shape, PyType &elementType, DefaultingPyLocation loc) { + MlirAttribute encodingAttr = mlirAttributeGetNull(); MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType); + loc, shape.size(), shape.data(), elementType, encodingAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -15,6 +15,8 @@ using namespace mlir; +MlirAttribute mlirAttributeGetNull() { return {nullptr}; } + //===----------------------------------------------------------------------===// // Affine map attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -191,18 +191,19 @@ } MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, - MlirType elementType) { + MlirType elementType, MlirAttribute encoding) { return wrap(RankedTensorType::get( - llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType))); + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + unwrap(encoding))); } MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, - MlirType elementType) { + MlirType elementType, + MlirAttribute encoding) { return wrap(RankedTensorType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType))); + unwrap(elementType), unwrap(encoding))); } MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1866,7 +1866,13 @@ os << dim; os << 'x'; } - os << tensorTy.getElementType() << '>'; + os << tensorTy.getElementType(); + // Only print the encoding attribute value if set. + if (tensorTy.getEncoding()) { + os << ", "; + printAttribute(tensorTy.getEncoding()); + } + os << '>'; }) .Case([&](UnrankedTensorType tensorTy) { os << "tensor<*x"; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -441,10 +441,12 @@ LogicalResult RankedTensorType::verify(function_ref emitError, - ArrayRef shape, Type elementType) { + ArrayRef shape, Type elementType, + Attribute encoding) { for (int64_t s : shape) if (s < -1) return emitError() << "invalid tensor dimension size"; + // TODO: verify contents of encoding attribute. return checkTensorElementType(emitError, elementType); } diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -409,14 +409,23 @@ // Parse the element type. auto elementTypeLoc = getToken().getLoc(); auto elementType = parseType(); + + // Parse an optional encoding attribute. + Attribute encoding; + if (consumeIf(Token::comma)) + encoding = parseAttribute(); + if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) return nullptr; if (!TensorType::isValidElementType(elementType)) return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; - if (isUnranked) + if (isUnranked) { + if (encoding) + return emitError("cannot apply encoding to unranked tensor"), nullptr; return UnrankedTensorType::get(elementType); - return RankedTensorType::get(dimensions, elementType); + } + return RankedTensorType::get(dimensions, elementType, encoding); } /// Parse a tuple type. diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -438,8 +438,8 @@ mlirOperationSetAttributeByName( operation, mlirStringRefCreateFromCString("elts"), mlirDenseElementsAttrInt32Get( - mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32)), 4, - eltsData)); + mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32), + mlirAttributeGetNull()), 4, eltsData)); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2); mlirOpPrintingFlagsPrintGenericOpForm(flags); @@ -687,8 +687,8 @@ // CHECK: vector<2x3xf32> // Ranked tensor type. - MlirType rankedTensor = - mlirRankedTensorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32); + MlirType rankedTensor = mlirRankedTensorTypeGet( + sizeof(shape) / sizeof(int64_t), shape, f32, mlirAttributeGetNull()); if (!mlirTypeIsATensor(rankedTensor) || !mlirTypeIsARankedTensor(rankedTensor)) return 16; @@ -889,24 +889,30 @@ int64_t ints64[] = {0, 1}; float floats[] = {0.0f, 1.0f}; double doubles[] = {0.0, 1.0}; + MlirAttribute encoding = mlirAttributeGetNull(); MlirAttribute boolElements = mlirDenseElementsAttrBoolGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 2, bools); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), encoding), + 2, bools); MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32)), 2, - uints32); + mlirRankedTensorTypeGet(2, shape, + mlirIntegerTypeUnsignedGet(ctx, 32), encoding), + 2, uints32); MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 2, - ints32); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), encoding), + 2, ints32); MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64)), 2, - uints64); + mlirRankedTensorTypeGet(2, shape, + mlirIntegerTypeUnsignedGet(ctx, 64), encoding), + 2, uints64); MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 2, - ints64); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), encoding), + 2, ints64); MlirAttribute floatElements = mlirDenseElementsAttrFloatGet( - mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 2, floats); + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), + 2, floats); MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet( - mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 2, doubles); + mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), + 2, doubles); if (!mlirAttributeIsADenseElements(boolElements) || !mlirAttributeIsADenseElements(uint32Elements) || @@ -943,19 +949,24 @@ // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64> MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 1); + mlirRankedTensorTypeGet(2, shape, + mlirIntegerTypeGet(ctx, 1), encoding), 1); MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1); + mlirRankedTensorTypeGet(2, shape, + mlirIntegerTypeGet(ctx, 32), encoding), 1); MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1); + mlirRankedTensorTypeGet(2, shape, + mlirIntegerTypeGet(ctx, 32), encoding), 1); MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1); + mlirRankedTensorTypeGet(2, shape, + mlirIntegerTypeGet(ctx, 64), encoding), 1); MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1); + mlirRankedTensorTypeGet(2, shape, + mlirIntegerTypeGet(ctx, 64), encoding), 1); MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet( - mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 1.0f); + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), 1.0f); MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet( - mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 1.0); + mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), encoding), 1.0); if (!mlirAttributeIsADenseElements(splatBool) || !mlirDenseElementsAttrIsSplat(splatBool) || @@ -1024,13 +1035,14 @@ int64_t indices[] = {4, 7}; int64_t two = 2; MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get( - mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64)), 2, - indices); + mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), encoding), + 2, indices); MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet( - mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx)), 2, floats); + mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), encoding), + 2, floats); MlirAttribute sparseAttr = mlirSparseElementsAttribute( - mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), indicesAttr, - valuesAttr); + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), encoding), + indicesAttr, valuesAttr); mlirAttributeDump(sparseAttr); // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32> diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -100,6 +100,12 @@ // ----- +func @tensor_encoding_mismatch(%arg0: tensor<8xi32, "enc">) -> (tensor<8xi32>) { // expected-note {{prior use here}} + return %arg0: tensor<8xi32> // expected-error {{use of value '%arg0' expects different type than prior uses: 'tensor<8xi32>' vs 'tensor<8xi32, "enc">'}} +} + +// ----- + func @bad_branch() { ^bb12: br ^missing // expected-error {{reference to an undefined block}} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -77,6 +77,9 @@ func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor) +// CHECK: func private @tensor_encoding(tensor<16x32xf64, "sparse">) +func private @tensor_encoding(tensor<16x32xf64, "sparse">) + // CHECK: func private @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>) func private @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>)