diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1765,11 +1765,24 @@ // Index cast is applicable from index to integer and backwards. bool IndexCastOp::areCastCompatible(Type a, Type b) { + if (a.isa() && b.isa()) { + auto aShaped = a.cast(); + auto bShaped = b.cast(); + + return (aShaped.getShape() == bShaped.getShape()) && + areCastCompatible(aShaped.getElementType(), + bShaped.getElementType()); + } + return (a.isIndex() && b.isSignlessInteger()) || (a.isSignlessInteger() && b.isIndex()); } OpFoldResult IndexCastOp::fold(ArrayRef cstOperands) { + // Cast op does not change the type. + if (getOperand().getType() == getType()) + return getOperand(); + // Fold IndexCast(IndexCast(x)) -> x auto cast = getOperand().getDefiningOp(); if (cast && cast.getOperand().getType() == getType())