diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -188,17 +188,19 @@ /// Checks whether the given type is an unranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type); -/// Creates a tensor type of a fixed rank with the given shape and element type -/// in the same context as the element type. The type is owned by the context. +/// Creates a tensor type of a fixed rank with the given shape, element type +/// and format in the same context as the element type. The type is owned by +/// the context. MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, - MlirType elementType); + MlirType elementType, + MlirAttribute format); /// Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on /// illegal arguments, emitting appropriate diagnostics. -MLIR_CAPI_EXPORTED MlirType -mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, - const int64_t *shape, MlirType elementType); +MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked( + MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, + MlirAttribute format); /// Creates an unranked tensor type with the given element type in the same /// context as the element type. The type is owned by the context. diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -636,9 +636,10 @@ Syntax: ``` - tensor-type ::= `tensor` `<` dimension-list type `>` + tensor-type ::= `tensor` `<` dimension-list type (`,` format)? `>` dimension-list ::= (dimension `x`)* dimension ::= `?` | decimal-literal + format ::= attribute-value ``` Values with tensor type represents aggregate N-dimensional data values, and @@ -654,6 +655,13 @@ [`dim` operation](Dialects/Standard.md#dim-operation) returns the size of a dimension from a value of tensor type. + The format attribute provides additional information on data residing + in the tensor. An empty attribute denotes straightforward, dense tensor + data without any special structure. More specialized formats can provide + information on the structure and sparsity of the data that resides in + the tensor, used by the compiler to optimize access to this tensor. + TODO: provide and document attribute interface for this + Note: hexadecimal integer literals are not allowed in tensor type declarations to avoid confusion between `0xf32` and `0 x f32`. Zero sizes are allowed in tensors and treated as other sizes, e.g., @@ -685,14 +693,17 @@ }]; let parameters = (ins ArrayRefParameter<"int64_t">:$shape, - "Type":$elementType + "Type":$elementType, + "Attribute":$format ); let builders = [ TypeBuilderWithInferredContext<(ins - "ArrayRef":$shape, "Type":$elementType + "ArrayRef":$shape, + "Type":$elementType, + CArg<"Attribute", "{}">:$format ), [{ - return $_get(elementType.getContext(), shape, elementType); + return $_get(elementType.getContext(), shape, elementType, format); }]> ]; let skipDefaultBuilders = 1; diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -502,8 +502,10 @@ MlirType mlirElementType, py::buffer_info &arrayInfo) { SmallVector shape(arrayInfo.shape.begin(), arrayInfo.shape.begin() + arrayInfo.ndim); - auto shapedType = - mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); + // TODO: pass in format + MlirAttribute formatAttr = {}; + auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), + mlirElementType, formatAttr); intptr_t numElements = arrayInfo.size; const ElementTy *contents = static_cast(arrayInfo.ptr); return ctor(shapedType, numElements, contents); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -381,8 +381,10 @@ "get", [](std::vector shape, PyType &elementType, DefaultingPyLocation loc) { + // TODO: pass in format + MlirAttribute formatAttr = {}; MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType); + loc, shape.size(), shape.data(), elementType, formatAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -191,18 +191,19 @@ } MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, - MlirType elementType) { + MlirType elementType, MlirAttribute format) { return wrap(RankedTensorType::get( - llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType))); + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + unwrap(format))); } MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, - MlirType elementType) { + MlirType elementType, + MlirAttribute format) { return wrap(RankedTensorType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType))); + unwrap(elementType), unwrap(format))); } MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1866,7 +1866,13 @@ os << dim; os << 'x'; } - os << tensorTy.getElementType() << '>'; + os << tensorTy.getElementType(); + // Only print the format attribute if it is the non-default one. + if (tensorTy.getFormat()) { + os << ", "; + printAttribute(tensorTy.getFormat()); + } + os << '>'; }) .Case([&](UnrankedTensorType tensorTy) { os << "tensor<*x"; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -428,10 +428,12 @@ LogicalResult RankedTensorType::verify(function_ref emitError, - ArrayRef shape, Type elementType) { + ArrayRef shape, Type elementType, + Attribute format) { for (int64_t s : shape) if (s < -1) return emitError() << "invalid tensor dimension size"; + // TODO: verify format attribute. return checkTensorElementType(emitError, elementType); } diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -409,14 +409,23 @@ // Parse the element type. auto elementTypeLoc = getToken().getLoc(); auto elementType = parseType(); + + // Parse an optional format attribute. + Attribute format; + if (consumeIf(Token::comma)) + format = parseAttribute(); + if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) return nullptr; if (!TensorType::isValidElementType(elementType)) return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; - if (isUnranked) + if (isUnranked) { + if (format) + return emitError("cannot define format for unranked tensor"), nullptr; return UnrankedTensorType::get(elementType); - return RankedTensorType::get(dimensions, elementType); + } + return RankedTensorType::get(dimensions, elementType, format); } /// Parse a tuple type. diff --git a/mlir/test/Bindings/Python/ir_types.py b/mlir/test/Bindings/Python/ir_types.py --- a/mlir/test/Bindings/Python/ir_types.py +++ b/mlir/test/Bindings/Python/ir_types.py @@ -281,6 +281,8 @@ print("ranked tensor type:", RankedTensorType.get(shape, f32)) + # TODO: check format + none = NoneType.get() try: tensor_invalid = RankedTensorType.get(shape, none) diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c --- a/mlir/test/CAPI/ir.c +++ b/mlir/test/CAPI/ir.c @@ -435,11 +435,10 @@ // Add a large attribute to verify printing flags. int64_t eltsShape[] = {4}; int32_t eltsData[] = {1, 2, 3, 4}; + MlirAttribute format = {0}; mlirOperationSetAttributeByName( operation, mlirStringRefCreateFromCString("elts"), - mlirDenseElementsAttrInt32Get( - mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32)), 4, - eltsData)); + mlirDenseElementsAttrInt32Get(mlirRankedTensorTypeGet(1, eltsShape, mlirIntegerTypeGet(ctx, 32), format), 4, eltsData)); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); mlirOpPrintingFlagsElideLargeElementsAttrs(flags, 2); mlirOpPrintingFlagsPrintGenericOpForm(flags); @@ -687,8 +686,8 @@ // CHECK: vector<2x3xf32> // Ranked tensor type. - MlirType rankedTensor = - mlirRankedTensorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32); + MlirAttribute format = {0}; + MlirType rankedTensor = mlirRankedTensorTypeGet(sizeof(shape) / sizeof(int64_t), shape, f32, format); if (!mlirTypeIsATensor(rankedTensor) || !mlirTypeIsARankedTensor(rankedTensor)) return 16; @@ -889,24 +888,25 @@ int64_t ints64[] = {0, 1}; float floats[] = {0.0f, 1.0f}; double doubles[] = {0.0, 1.0}; + MlirAttribute format = {0}; MlirAttribute boolElements = mlirDenseElementsAttrBoolGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 2, bools); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), format), 2, bools); MlirAttribute uint32Elements = mlirDenseElementsAttrUInt32Get( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32)), 2, + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 32), format), 2, uints32); MlirAttribute int32Elements = mlirDenseElementsAttrInt32Get( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 2, + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), format), 2, ints32); MlirAttribute uint64Elements = mlirDenseElementsAttrUInt64Get( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64)), 2, + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeUnsignedGet(ctx, 64), format), 2, uints64); MlirAttribute int64Elements = mlirDenseElementsAttrInt64Get( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 2, + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), format), 2, ints64); MlirAttribute floatElements = mlirDenseElementsAttrFloatGet( - mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 2, floats); + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), format), 2, floats); MlirAttribute doubleElements = mlirDenseElementsAttrDoubleGet( - mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 2, doubles); + mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), format), 2, doubles); if (!mlirAttributeIsADenseElements(boolElements) || !mlirAttributeIsADenseElements(uint32Elements) || @@ -943,19 +943,19 @@ // CHECK: dense<{{\[}}[0.000000e+00, 1.000000e+00]]> : tensor<1x2xf64> MlirAttribute splatBool = mlirDenseElementsAttrBoolSplatGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1)), 1); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 1), format), 1); MlirAttribute splatUInt32 = mlirDenseElementsAttrUInt32SplatGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), format), 1); MlirAttribute splatInt32 = mlirDenseElementsAttrInt32SplatGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32)), 1); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 32), format), 1); MlirAttribute splatUInt64 = mlirDenseElementsAttrUInt64SplatGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), format), 1); MlirAttribute splatInt64 = mlirDenseElementsAttrInt64SplatGet( - mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64)), 1); + mlirRankedTensorTypeGet(2, shape, mlirIntegerTypeGet(ctx, 64), format), 1); MlirAttribute splatFloat = mlirDenseElementsAttrFloatSplatGet( - mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), 1.0f); + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), format), 1.0f); MlirAttribute splatDouble = mlirDenseElementsAttrDoubleSplatGet( - mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx)), 1.0); + mlirRankedTensorTypeGet(2, shape, mlirF64TypeGet(ctx), format), 1.0); if (!mlirAttributeIsADenseElements(splatBool) || !mlirDenseElementsAttrIsSplat(splatBool) || @@ -1024,12 +1024,12 @@ int64_t indices[] = {4, 7}; int64_t two = 2; MlirAttribute indicesAttr = mlirDenseElementsAttrInt64Get( - mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64)), 2, + mlirRankedTensorTypeGet(1, &two, mlirIntegerTypeGet(ctx, 64), format), 2, indices); MlirAttribute valuesAttr = mlirDenseElementsAttrFloatGet( - mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx)), 2, floats); + mlirRankedTensorTypeGet(1, &two, mlirF32TypeGet(ctx), format), 2, floats); MlirAttribute sparseAttr = mlirSparseElementsAttribute( - mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx)), indicesAttr, + mlirRankedTensorTypeGet(2, shape, mlirF32TypeGet(ctx), format), indicesAttr, valuesAttr); mlirAttributeDump(sparseAttr); // CHECK: sparse<[4, 7], [0.000000e+00, 1.000000e+00]> : tensor<1x2xf32> diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -104,6 +104,12 @@ // ----- +func @tensorformatmismatch(%arg0: tensor<8xi32, "format">) -> (tensor<8xi32>) { // expected-note {{prior use here}} + return %arg0: tensor<8xi32> // expected-error {{use of value '%arg0' expects different type than prior uses: 'tensor<8xi32>' vs 'tensor<8xi32, "format">'}} +} + +// ----- + func @bad_branch() { ^bb12: br ^missing // expected-error {{reference to an undefined block}} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -77,6 +77,9 @@ func private @tensors(tensor<* x f32>, tensor<* x vector<2x4xf32>>, tensor<1x?x4x?x?xi32>, tensor) +// CHECK: func private @tensorformat(tensor<16x32xf64, "sparse">) +func private @tensorformat(tensor<16x32xf64, "sparse">) + // CHECK: func private @memrefs(memref<1x?x4x?x?xi32, #map{{[0-9]+}}>, memref<8xi8>) func private @memrefs(memref<1x?x4x?x?xi32, #map0>, memref<8xi8, #map1, #map1>)