diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2934,24 +2934,26 @@ ``` }]; - let arguments = (ins Arg:$memref); + let arguments = (ins Arg:$memref); let results = (outs AnyTensor:$result); // TensorLoadOp is fully verified by traits. let verifier = ?; let builders = [OpBuilder< "OpBuilder &builder, OperationState &result, Value memref", [{ - auto memrefType = memref.getType().cast(); - auto resultType = RankedTensorType::get(memrefType.getShape(), - memrefType.getElementType()); result.addOperands(memref); - result.addTypes(resultType); + result.addTypes(getTensorTypeFromMemRefType(memref.getType())); }]>]; let extraClassDeclaration = [{ /// The result of a tensor_load is always a tensor. - TensorType getType() { return getResult().getType().cast(); } + TensorType getType() { + Type resultType = getResult().getType(); + if (resultType.isa()) + return resultType.cast(); + return {}; + } }]; let assemblyFormat = "$memref attr-dict `:` type($memref)"; @@ -2981,9 +2983,8 @@ ``` }]; - let arguments = (ins AnyTensor:$tensor, - Arg:$memref); + let arguments = (ins AnyTensor:$tensor, Arg:$memref); // TensorStoreOp is fully verified by traits. let verifier = ?; diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -2985,6 +2985,17 @@ static Type getTensorTypeFromMemRefType(Type type) { if (auto memref = type.dyn_cast()) return RankedTensorType::get(memref.getShape(), memref.getElementType()); + if (auto memref = type.dyn_cast()) + return UnrankedTensorType::get(memref.getElementType()); + return NoneType::get(type.getContext()); +} + +static Type getMemRefTypeFromTensorType(Type type) { + if (auto tensor = type.dyn_cast()) + return MemRefType::get(tensor.getShape(), tensor.getElementType()); + if (auto tensor = type.dyn_cast()) + return UnrankedMemRefType::get(tensor.getElementType(), + tensor.getMemorySpace()); return NoneType::get(type.getContext()); } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -813,6 +813,15 @@ return } +// CHECK-LABEL: func @unranked_tensor_load_store +func @unranked_tensor_load_store(%0 : memref<*xi32>) { + // CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF:.*]] : memref<*xi32> + %1 = tensor_load %0 : memref<*xi32> + // CHECK: tensor_store %[[TENSOR]], %[[MEMREF]] : memref<*xi32> + tensor_store %1, %0 : memref<*xi32> + return +} + // CHECK-LABEL: func @atomic_rmw // CHECK-SAME: ([[BUF:%.*]]: memref<10xf32>, [[VAL:%.*]]: f32, [[I:%.*]]: index) func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {