diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -108,18 +108,24 @@ /// operations equal. Returns the updated input1 and input2 for the original /// input. The caller is expected to use these to rewrite the original operator /// with the RESHAPE now in the graph. -static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc, - RankedTensorType outputType, Value input1, - Value input2, Value &outInput1, - Value &outInput2) { +static LogicalResult reshapeLowerToHigher(PatternRewriter &rewriter, + Location loc, + RankedTensorType outputType, + Value input1, Value input2, + Value &outInput1, Value &outInput2) { + auto input1Ty = input1.getType().dyn_cast(); + auto input2Ty = input2.getType().dyn_cast(); - int64_t input1Rank = input1.getType().cast().getRank(); - int64_t input2Rank = input2.getType().cast().getRank(); + if (!input1Ty || !input2Ty) + return failure(); + + int64_t input1Rank = input1Ty.getRank(); + int64_t input2Rank = input2Ty.getRank(); Value higherTensorValue, lowerTensorValue; - // return if rank already match + // Cannot rewrite as its already correct. if (input1Rank == input2Rank) - return 1; + return failure(); if (input1Rank > input2Rank) { higherTensorValue = input1; @@ -129,24 +135,27 @@ lowerTensorValue = input1; } - ArrayRef outputRankShape = outputType.getShape(); ArrayRef higherRankShape = higherTensorValue.getType().cast().getShape(); (void)higherRankShape; ArrayRef lowerRankShape = lowerTensorValue.getType().cast().getShape(); - // outputRank == higherRank == max(input1Rank, input2Rank) - assert(higherRankShape.size() == outputRankShape.size()); - SmallVector reshapeOutputShape; - computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape); + computeReshapeOutput(higherRankShape, lowerRankShape, reshapeOutputShape); auto reshapeInputType = lowerTensorValue.getType().cast(); auto reshapeOutputType = RankedTensorType::get( ArrayRef(reshapeOutputShape), reshapeInputType.getElementType()); + // Verify the rank agrees with the output type if the output type is ranked. + if (outputType) { + if (outputType.getShape().size() != reshapeOutputShape.size() || + outputType.getShape().size() != higherRankShape.size()) + return failure(); + } + auto reshapeLower = rewriter.create( loc, reshapeOutputType, lowerTensorValue, rewriter.getI64ArrayAttr(reshapeOutputShape)); @@ -159,7 +168,7 @@ outInput2 = higherTensorValue; } - return 0; + return success(); } namespace { @@ -173,11 +182,13 @@ Value input1 = tosaBinaryOp.input1(); Value input2 = tosaBinaryOp.input2(); Value output = tosaBinaryOp.getResult(); - auto outputType = output.getType().cast(); + + auto outputType = output.getType().dyn_cast(); Value outInput1, outInput2; if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, - input1, input2, outInput1, outInput2)) + input1, input2, outInput1, outInput2) + .failed()) return failure(); rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, outInput1, @@ -200,11 +211,12 @@ Value input2 = tosaBinaryOp.input2(); int32_t shift = tosaBinaryOp.shift(); Value output = tosaBinaryOp.getResult(); - auto outputType = output.getType().cast(); + auto outputType = output.getType().dyn_cast(); Value outInput1, outInput2; if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, - input1, input2, outInput1, outInput2)) + input1, input2, outInput1, outInput2) + .failed()) return failure(); rewriter.replaceOpWithNewOp(tosaBinaryOp, outputType, @@ -233,7 +245,8 @@ Value outInput1, outInput2; if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType, - input1, input2, outInput1, outInput2)) + input1, input2, outInput1, outInput2) + .failed()) return failure(); rewriter.replaceOpWithNewOp(