diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1765,6 +1765,15 @@ // Index cast is applicable from index to integer and backwards. bool IndexCastOp::areCastCompatible(Type a, Type b) { + if (a.isa() && b.isa()) { + auto aShaped = a.cast(); + auto bShaped = b.cast(); + + return (aShaped.getShape() == bShaped.getShape()) && + areCastCompatible(aShaped.getElementType(), + bShaped.getElementType()); + } + return (a.isIndex() && b.isSignlessInteger()) || (a.isSignlessInteger() && b.isIndex()); } diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt -split-input-file %s -verify-diagnostics + +// CHECK-LABEL: test_index_cast_shape_error +func @test_index_cast_shape_error(%arg0 : tensor) -> tensor<2xi64> { + // expected-error @+1 {{operand type 'tensor' and result type 'tensor<2xi64>' are cast incompatible}} + %0 = index_cast %arg0 : tensor to tensor<2xi64> + return %0 : tensor<2xi64> +} + +// ----- + +// CHECK-LABEL: test_index_cast_tensor_error +func @test_index_cast_tensor_error(%arg0 : tensor) -> i64 { + // expected-error @+1 {{operand type 'tensor' and result type 'i64' are cast incompatible}} + %0 = index_cast %arg0 : tensor to i64 + return %0 : i64 +} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -0,0 +1,20 @@ +// RUN: mlir-opt -split-input-file %s | FileCheck %s + +// CHECK-LABEL: test_index_cast +func @test_index_cast(%arg0 : index) -> i64 { + %0 = index_cast %arg0 : index to i64 + return %0 : i64 +} + +// CHECK-LABEL: test_index_cast_tensor +func @test_index_cast_tensor(%arg0 : tensor) -> tensor { + %0 = index_cast %arg0 : tensor to tensor + return %0 : tensor +} + +// CHECK-LABEL: test_index_cast_tensor_reverse +func @test_index_cast_tensor_reverse(%arg0 : tensor) -> tensor { + %0 = index_cast %arg0 : tensor to tensor + return %0 : tensor +} +