diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -350,8 +350,8 @@ if (quantAttr) { result.addAttribute("quantization_info", quantAttr); - auto inputType = a.getType().dyn_cast(); - assert(inputType && "Input must be a ranked tensor type!"); + auto inputType = a.getType().dyn_cast(); + assert(inputType && "Input must be a shaped tensor type!"); auto inputQType = inputType.getElementType() .dyn_cast(); @@ -359,17 +359,15 @@ unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); - auto outputShapedType = outputType.dyn_cast(); - assert(outputShapedType && "Output must be a ranked tensor type"); - - auto outputShape = outputShapedType.getShape(); + auto outputShapedType = outputType.dyn_cast(); + assert(outputShapedType && "Output must be a shaped type"); IntegerType accElementType; if (inputBits == 16) accElementType = builder.getIntegerType(48); else accElementType = builder.getI32Type(); - auto accType = RankedTensorType::get(outputShape, accElementType); + auto accType = outputShapedType.clone(accElementType); result.addTypes(accType); } else { result.addTypes(outputType); diff --git a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp --- a/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp @@ -102,8 +102,8 @@ mlir::tosa::buildConvOpQuantizationAttr(OpBuilder &builder, Value input, Value weight) { - auto inputType = input.getType().dyn_cast(); - auto weightType = weight.getType().dyn_cast(); + auto inputType = input.getType().dyn_cast(); + auto weightType = weight.getType().dyn_cast(); if (!inputType || !weightType) return nullptr; @@ -151,8 +151,8 @@ mlir::tosa::buildMatMulOpQuantizationAttr(OpBuilder &builder, Value a, Value b) { - auto aType = a.getType().dyn_cast(); - auto bType = b.getType().dyn_cast(); + auto aType = a.getType().dyn_cast(); + auto bType = b.getType().dyn_cast(); if (!aType || !bType) return nullptr; @@ -187,8 +187,8 @@ mlir::tosa::buildUnaryOpQuantizationAttr(OpBuilder &builder, Value input, Type outputRawType) { - auto inputType = input.getType().dyn_cast(); - auto outputType = outputRawType.dyn_cast(); + auto inputType = input.getType().dyn_cast(); + auto outputType = outputRawType.dyn_cast(); if (!inputType || !outputType) return nullptr; @@ -220,7 +220,7 @@ PadOpQuantizationAttr mlir::tosa::buildPadOpQuantizationAttr(OpBuilder &builder, Value input) { - auto inputType = input.getType().dyn_cast(); + auto inputType = input.getType().dyn_cast(); if (!inputType) return nullptr; @@ -245,8 +245,8 @@ Type mlir::tosa::buildConvOpResultTypeInfo(OpBuilder &builder, Type outputType, Value input, Value weight) { - auto inputType = input.getType().dyn_cast(); - auto weightType = weight.getType().dyn_cast(); + auto inputType = input.getType().dyn_cast(); + auto weightType = weight.getType().dyn_cast(); assert(inputType && weightType && "Could not extract input or weight tensors from Conv op"); @@ -260,18 +260,16 @@ unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); unsigned weightBits = weightQType.getStorageTypeIntegralWidth(); - auto outputShapedType = outputType.dyn_cast(); + auto outputShapedType = outputType.dyn_cast(); assert(outputShapedType && "Could not extract output shape type from Conv op"); - auto outputShape = outputShapedType.getShape(); - IntegerType accElementType; if (inputBits == 16 && weightBits == 8) accElementType = builder.getIntegerType(48); else accElementType = builder.getI32Type(); - auto accType = RankedTensorType::get(outputShape, accElementType); + auto accType = outputShapedType.clone(accElementType); return accType; }