diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -99,15 +99,6 @@ /// Return the number of elements present in the given shape. static int64_t getNumElements(ArrayRef shape); - - /// Returns the total amount of bits occupied by a value of this type. This - /// does not take into account any memory layout or widening constraints, - /// e.g. a vector<3xi57> may report to occupy 3x57=171 bit, even though in - /// practice it will likely be stored as in a 4xi64 vector register. Fails - /// with an assertion if the size cannot be computed statically, e.g. if the - /// type has a dynamic shape or if its elemental type does not have a known - /// bit width. - int64_t getSizeInBits() const; }]; let extraSharedClassDeclaration = [{ diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -40,8 +40,11 @@ /// Returns the number of bits for the given scalar/vector type. static int getNumBits(Type type) { + // TODO: This does not take into account any memory layout or widening + // constraints. E.g., a vector<3xi57> may report to occupy 3x57=171 bit, even + // though in practice it will likely be stored as in a 4xi64 vector register. if (auto vectorType = type.dyn_cast()) - return vectorType.cast().getSizeInBits(); + return vectorType.getNumElements() * vectorType.getElementTypeBitWidth(); return type.getIntOrFloatBitWidth(); } diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp --- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp @@ -33,18 +33,3 @@ } return num; } - -int64_t ShapedType::getSizeInBits() const { - assert(hasStaticShape() && - "cannot get the bit size of an aggregate with a dynamic shape"); - - auto elementType = getElementType(); - if (elementType.isIntOrFloat()) - return elementType.getIntOrFloatBitWidth() * getNumElements(); - - if (auto complexType = elementType.dyn_cast()) { - elementType = complexType.getElementType(); - return elementType.getIntOrFloatBitWidth() * getNumElements() * 2; - } - return getNumElements() * elementType.cast().getSizeInBits(); -} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -259,8 +259,8 @@ let results = (outs AnyTensor:$output); DerivedTypeAttr element_dtype = DerivedTypeAttr<"return getElementTypeOrSelf(getOutput().getType());">; - DerivedAttr size = DerivedAttr<"int", - "return getOutput().getType().cast().getSizeInBits();", + DerivedAttr num_elements = DerivedAttr<"int", + "return getOutput().getType().cast().getNumElements();", "$_builder.getI32IntegerAttr($_self)">; } diff --git a/mlir/test/mlir-tblgen/op-derived-attribute.mlir b/mlir/test/mlir-tblgen/op-derived-attribute.mlir --- a/mlir/test/mlir-tblgen/op-derived-attribute.mlir +++ b/mlir/test/mlir-tblgen/op-derived-attribute.mlir @@ -3,15 +3,15 @@ // CHECK-LABEL: verifyDerivedAttributes func.func @verifyDerivedAttributes() { // expected-remark @+2 {{element_dtype = f32}} - // expected-remark @+1 {{size = 320}} + // expected-remark @+1 {{num_elements = 10}} %0 = "test.derived_type_attr"() : () -> tensor<10xf32> // expected-remark @+2 {{element_dtype = i79}} - // expected-remark @+1 {{size = 948}} + // expected-remark @+1 {{num_elements = 12}} %1 = "test.derived_type_attr"() : () -> tensor<12xi79> // expected-remark @+2 {{element_dtype = complex}} - // expected-remark @+1 {{size = 768}} + // expected-remark @+1 {{num_elements = 12}} %2 = "test.derived_type_attr"() : () -> tensor<12xcomplex> return