diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -773,15 +773,14 @@ "DenseElementsAttr":$values); let builders = [ AttrBuilderWithInferredContext<(ins "ShapedType":$type, - "DenseElementsAttr":$indices, + "DenseIntElementsAttr":$indices, "DenseElementsAttr":$values), [{ assert(indices.getType().getElementType().isInteger(64) && "expected sparse indices to be 64-bit integer values"); assert((type.isa()) && "type must be ranked tensor or vector"); assert(type.hasStaticShape() && "type must have static shape"); - return $_get(type.getContext(), type, - indices.cast(), values); + return $_get(type.getContext(), type, indices, values); }]>, ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -612,7 +612,7 @@ MlirAttribute denseValues) { return wrap( SparseElementsAttr::get(unwrap(shapedType).cast(), - unwrap(denseIndices).cast(), + unwrap(denseIndices).cast(), unwrap(denseValues).cast())); } diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -916,7 +916,8 @@ RankedTensorType::get({0, type.getRank()}, indiceEltType); ShapedType valuesType = RankedTensorType::get({0}, type.getElementType()); return getChecked( - loc, type, DenseElementsAttr::get(indicesType, ArrayRef()), + loc, type, + DenseIntElementsAttr::get(indicesType, ArrayRef()), DenseElementsAttr::get(valuesType, ArrayRef())); } @@ -955,7 +956,8 @@ // Otherwise, set the shape to the one parsed by the literal parser. indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType); } - auto indices = indiceParser.getAttr(indicesLoc, indicesType); + auto indices = indiceParser.getAttr(indicesLoc, indicesType) + .cast(); // If the values are a splat, set the shape explicitly based on the number of // indices. The number of indices is encoded in the first dimension of the