diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1484,6 +1484,25 @@ return getResult(); } + // Fold dim to the operand of dynamic_tensor_from_elements. + if (auto fromElements = + dyn_cast_or_null(definingOp)) { + auto resultType = + fromElements.getResult().getType().cast(); + // The case where the type encodes the size of the dimension is handled + // above. + assert(resultType.getShape()[index.getInt()] == + RankedTensorType::kDynamicSize); + + // Find the operand of the fromElements that corresponds to this index. + auto dynExtents = fromElements.dynamicExtents().begin(); + for (auto dim : resultType.getShape().take_front(index.getInt())) + if (dim == RankedTensorType::kDynamicSize) + dynExtents++; + + return Value{*dynExtents}; + } + // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. auto memrefType = argTy.dyn_cast(); if (!memrefType) diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -5,9 +5,9 @@ // CHECK-SAME: %[[TENSOR:.*]]: tensor) -> tensor { // CHECK: return %[[TENSOR]] func @tensor_load_of_tensor_to_memref(%arg0: tensor) -> tensor { - %0 = tensor_to_memref %arg0 : memref - %1 = tensor_load %0 : memref - return %1 : tensor + %0 = tensor_to_memref %arg0 : memref + %1 = tensor_load %0 : memref + return %1 : tensor } // Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m @@ -15,9 +15,9 @@ // CHECK-SAME: %[[MEMREF:.*]]: memref) -> memref { // CHECK: return %[[MEMREF]] func @tensor_to_memref_of_tensor_load(%arg0: memref) -> memref { - %0 = tensor_load %arg0 : memref - %1 = tensor_to_memref %0 : memref - return %1 : memref + %0 = tensor_load %arg0 : memref + %1 = tensor_to_memref %0 : memref + return %1 : memref } // Test case: If the memrefs are not the same type, don't fold them. @@ -27,9 +27,9 @@ // CHECK: %[[MEMREF_ADDRSPACE7:.*]] = tensor_to_memref %[[TENSOR]] : memref // CHECK: return %[[MEMREF_ADDRSPACE7]] func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref) -> memref { - %0 = tensor_load %arg0 : memref - %1 = tensor_to_memref %0 : memref - return %1 : memref + %0 = tensor_load %arg0 : memref + %1 = tensor_to_memref %0 : memref + return %1 : memref } // Test case: Basic folding of dim(tensor_load(m)) -> dim(m). @@ -39,8 +39,23 @@ // CHECK: %[[D:.*]] = dim %[[MEMREF]], %[[C0]] // CHECK: return %[[D]] : index func @dim_of_tensor_load(%arg0: memref) -> index { - %c0 = constant 0 : index - %0 = tensor_load %arg0 : memref - %1 = dim %0, %c0 : tensor - return %1 : index + %c0 = constant 0 : index + %0 = tensor_load %arg0 : memref + %1 = dim %0, %c0 : tensor + return %1 : index +} + +// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx +// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements( +// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index +// CHECK-NOT: dim +// CHECK: return %[[IDX1]] : index +func @dim_of_dynamic_tensor_from_elements(%arg0: index, %arg1: index) -> index { + %c3 = constant 3 : index + %0 = dynamic_tensor_from_elements %arg0, %arg1 { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): + yield %c3 : index + } : tensor<2x?x4x?x5xindex> + %1 = dim %0, %c3 : tensor<2x?x4x?x5xindex> + return %1 : index }