diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -33,7 +33,8 @@ }]; let cppNamespace = "::mlir::linalg"; let dependentDialects = [ - "AffineDialect", "StandardOpsDialect", "tensor::TensorDialect" + "AffineDialect", "memref::MemRefDialect", "StandardOpsDialect", + "tensor::TensorDialect" ]; let hasCanonicalizer = 1; let hasOperationAttrVerify = 1; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_LINALG_LINALGOPS_H_ #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_LINALG_LINALGTYPES_H_ #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -10,6 +10,7 @@ #define DIALECT_LINALG_TRANSFORMS_TRANSFORMS_H_ #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Identifier.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_ #define MLIR_DIALECT_MEMREF_IR_MEMREF_H_ +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/CastInterfaces.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td @@ -19,6 +19,7 @@ manipulation ops, which are not strongly associated with any particular other dialect or domain abstraction. }]; + let dependentDialects = ["tensor::TensorDialect"]; let hasConstantMaterializer = 1; } diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -14,7 +14,6 @@ #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H -#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -27,9 +27,6 @@ let name = "std"; let cppNamespace = "::mlir"; let hasConstantMaterializer = 1; - // TODO: This dependency is needed to handle memref ops in the - // canonicalize pass and should be resolved. - let dependentDialects = ["memref::MemRefDialect"]; } // Base class for Standard dialect ops. diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -17,6 +17,7 @@ #include "../PassDetail.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorUtils.h" diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -17,6 +17,7 @@ #include "../PassDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorUtils.h" diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -23,6 +23,7 @@ #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LoopUtils.h" diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/StandardOps/Utils/Utils.h" diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp @@ -8,6 +8,7 @@ #include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Passes.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SCF/Transforms.h" diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp @@ -8,6 +8,7 @@ #include "mlir/Transforms/Bufferize.h" #include "PassDetail.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Shape/Transforms/Passes.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp @@ -43,6 +43,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -217,3 +217,177 @@ return } +// ----- + +// Test case: Folding of memref.load(memref.buffer_cast(%v, %idxs)) +// -> tensor.extract(%v, %idx) +// CHECK-LABEL: func @load_from_buffer_cast( +// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index +// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor +// CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]] +// CHECK-NOT: memref.load +// CHECK: return %[[RES]] : f32 +func @load_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor) -> f32 { + %0 = memref.buffer_cast %arg2 : memref + %1 = memref.load %0[%arg0, %arg1] : memref + return %1 : f32 +} + +// ----- + + +// Test case: Basic folding of memref.dim(memref.tensor_load(m)) -> memref.dim(m). +// CHECK-LABEL: func @dim_of_tensor_load( +// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref +// CHECK: %[[C0:.*]] = constant 0 +// CHECK: %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]] +// CHECK: return %[[D]] : index +func @dim_of_tensor_load(%arg0: memref) -> index { + %c0 = constant 0 : index + %0 = memref.tensor_load %arg0 : memref + %1 = memref.dim %0, %c0 : tensor + return %1 : index +} + +// ----- + +// Test case: Folding of memref.dim(tensor.generate %idx) -> %idx +// CHECK-LABEL: func @dim_of_tensor.generate( +// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index +// CHECK-NOT: memref.dim +// CHECK: return %[[IDX1]] : index +func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index { + %c3 = constant 3 : index + %0 = tensor.generate %arg0, %arg1 { + ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): + tensor.yield %c3 : index + } : tensor<2x?x4x?x5xindex> + %1 = memref.dim %0, %c3 : tensor<2x?x4x?x5xindex> + return %1 : index +} + +// ----- + +// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size +// CHECK-LABEL: func @dim_of_alloca( +// CHECK-SAME: %[[SIZE:[0-9a-z]+]]: index +// CHECK-NEXT: return %[[SIZE]] : index +func @dim_of_alloca(%size: index) -> index { + %0 = memref.alloca(%size) : memref + %c0 = constant 0 : index + %1 = memref.dim %0, %c0 : memref + return %1 : index +} + +// ----- + +// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v) +// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32> +// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32> +// CHECK-NEXT: return %[[RANK]] : index +func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index { + %0 = rank %arg0 : memref<*xf32> + %1 = memref.alloca(%0) : memref + %c0 = constant 0 : index + %2 = memref.dim %1, %c0 : memref + return %2 : index +} + +// ----- + +// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref +// CHECK-NEXT: %[[IDX:.*]] = constant 3 +// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] +// CHECK-NEXT: memref.store +// CHECK-NOT: memref.dim +// CHECK: return %[[DIM]] : index +func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref) + -> index { + %c3 = constant 3 : index + %0 = memref.reshape %arg0(%arg1) + : (memref<*xf32>, memref) -> memref<*xf32> + // Update the shape to test that he load ends up in the right place. + memref.store %c3, %arg1[%c3] : memref + %1 = memref.dim %0, %c3 : memref<*xf32> + return %1 : index +} + +// ----- + +// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] +// CHECK-LABEL: func @dim_of_memref_reshape_i32( +// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, +// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref +// CHECK-NEXT: %[[IDX:.*]] = constant 3 +// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] +// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]] +// CHECK-NOT: memref.dim +// CHECK: return %[[CAST]] : index +func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref) + -> index { + %c3 = constant 3 : index + %0 = memref.reshape %arg0(%arg1) + : (memref<*xf32>, memref) -> memref<*xf32> + %1 = memref.dim %0, %c3 : memref<*xf32> + return %1 : index +} + +// ----- + +// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx +// CHECK-LABEL: func @fold_dim_of_tensor.cast +// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32> +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C4:.+]] = constant 4 : index +// CHECK: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK-NEXT: return %[[C4]], %[[T0]] +func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor + %1 = memref.dim %0, %c0 : tensor + %2 = memref.dim %0, %c1 : tensor + return %1, %2: index, index +} + +// ----- + +// CHECK-LABEL: func @tensor_cast_to_memref +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8> +// CHECK: %[[M:.+]] = memref.buffer_cast %[[ARG0]] : memref<4x6x16x32xi8> +// CHECK: %[[M1:.+]] = memref.cast %[[M]] : memref<4x6x16x32xi8> to memref +// CHECK: return %[[M1]] : memref +func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) -> + memref { + %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor + %1 = memref.buffer_cast %0 : memref + return %1 : memref +} + +// ----- + +// TODO: Move this test to Tensor/canonicalize.mlir. +func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor, + %arg2 : index, %arg3 : index) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c8 = constant 8 : index + %0 = memref.dim %arg0, %c1 : tensor<2x?xi32> + %1 = tensor.extract %arg1[] : tensor + %2 = tensor.generate %arg2, %c8 { + ^bb0(%arg4: index, %arg5: index): + tensor.yield %1 : i32 + } : tensor + %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor + return %3 : tensor +} +// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast +// CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1] +// CHECK-SAME: tensor<2x?xi32> into tensor +// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]] +// CHECK: return %[[CAST]] diff --git a/mlir/test/Dialect/Standard/canonicalize.mlir b/mlir/test/Dialect/Standard/canonicalize.mlir --- a/mlir/test/Dialect/Standard/canonicalize.mlir +++ b/mlir/test/Dialect/Standard/canonicalize.mlir @@ -1,53 +1,5 @@ // RUN: mlir-opt %s -canonicalize --split-input-file | FileCheck %s -// Test case: Basic folding of memref.dim(memref.tensor_load(m)) -> memref.dim(m). -// CHECK-LABEL: func @dim_of_tensor_load( -// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref -// CHECK: %[[C0:.*]] = constant 0 -// CHECK: %[[D:.*]] = memref.dim %[[MEMREF]], %[[C0]] -// CHECK: return %[[D]] : index -func @dim_of_tensor_load(%arg0: memref) -> index { - %c0 = constant 0 : index - %0 = memref.tensor_load %arg0 : memref - %1 = memref.dim %0, %c0 : tensor - return %1 : index -} - -// ----- - -// Test case: Folding of memref.load(memref.buffer_cast(%v, %idxs)) -// -> tensor.extract(%v, %idx) -// CHECK-LABEL: func @load_from_buffer_cast( -// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index -// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor -// CHECK: %[[RES:.*]] = tensor.extract %[[TENSOR]][%[[IDX0]], %[[IDX1]]] -// CHECK-NOT: memref.load -// CHECK: return %[[RES]] : f32 -func @load_from_buffer_cast(%arg0: index, %arg1: index, %arg2: tensor) -> f32 { - %0 = memref.buffer_cast %arg2 : memref - %1 = memref.load %0[%arg0, %arg1] : memref - return %1 : f32 -} - -// ----- - -// Test case: Folding of memref.dim(tensor.generate %idx) -> %idx -// CHECK-LABEL: func @dim_of_tensor.generate( -// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index -// CHECK-NOT: memref.dim -// CHECK: return %[[IDX1]] : index -func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index { - %c3 = constant 3 : index - %0 = tensor.generate %arg0, %arg1 { - ^bb0(%arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index): - tensor.yield %c3 : index - } : tensor<2x?x4x?x5xindex> - %1 = memref.dim %0, %c3 : tensor<2x?x4x?x5xindex> - return %1 : index -} - -// ----- - // Test case: Folding of comparisons with equal operands. // CHECK-LABEL: @cmpi_equal_operands // CHECK-DAG: %[[T:.*]] = constant true @@ -72,108 +24,6 @@ // ----- -// Test case: Folding of memref.dim(memref.alloca(%size), %idx) -> %size -// CHECK-LABEL: func @dim_of_alloca( -// CHECK-SAME: %[[SIZE:[0-9a-z]+]]: index -// CHECK-NEXT: return %[[SIZE]] : index -func @dim_of_alloca(%size: index) -> index { - %0 = memref.alloca(%size) : memref - %c0 = constant 0 : index - %1 = memref.dim %0, %c0 : memref - return %1 : index -} - -// ----- - -// Test case: Folding of memref.dim(memref.alloca(rank(%v)), %idx) -> rank(%v) -// CHECK-LABEL: func @dim_of_alloca_with_dynamic_size( -// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32> -// CHECK-NEXT: %[[RANK:.*]] = rank %[[MEM]] : memref<*xf32> -// CHECK-NEXT: return %[[RANK]] : index -func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index { - %0 = rank %arg0 : memref<*xf32> - %1 = memref.alloca(%0) : memref - %c0 = constant 0 : index - %2 = memref.dim %1, %c0 : memref - return %2 : index -} - -// ----- - -// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] -// CHECK-LABEL: func @dim_of_memref_reshape( -// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, -// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref -// CHECK-NEXT: %[[IDX:.*]] = constant 3 -// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] -// CHECK-NEXT: memref.store -// CHECK-NOT: memref.dim -// CHECK: return %[[DIM]] : index -func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref) - -> index { - %c3 = constant 3 : index - %0 = memref.reshape %arg0(%arg1) - : (memref<*xf32>, memref) -> memref<*xf32> - // Update the shape to test that he load ends up in the right place. - memref.store %c3, %arg1[%c3] : memref - %1 = memref.dim %0, %c3 : memref<*xf32> - return %1 : index -} - -// ----- - -// Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx] -// CHECK-LABEL: func @dim_of_memref_reshape_i32( -// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>, -// CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref -// CHECK-NEXT: %[[IDX:.*]] = constant 3 -// CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]] -// CHECK-NEXT: %[[CAST:.*]] = index_cast %[[DIM]] -// CHECK-NOT: memref.dim -// CHECK: return %[[CAST]] : index -func @dim_of_memref_reshape_i32(%arg0: memref<*xf32>, %arg1: memref) - -> index { - %c3 = constant 3 : index - %0 = memref.reshape %arg0(%arg1) - : (memref<*xf32>, memref) -> memref<*xf32> - %1 = memref.dim %0, %c3 : memref<*xf32> - return %1 : index -} - -// ----- - -// Test case: Folding memref.dim(tensor.cast %0, %idx) -> memref.dim %0, %idx -// CHECK-LABEL: func @fold_dim_of_tensor.cast -// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32> -// CHECK-DAG: %[[C1:.+]] = constant 1 : index -// CHECK-DAG: %[[C4:.+]] = constant 4 : index -// CHECK: %[[T0:.+]] = memref.dim %[[ARG0]], %[[C1]] -// CHECK-NEXT: return %[[C4]], %[[T0]] -func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %0 = tensor.cast %arg0 : tensor<4x?xf32> to tensor - %1 = memref.dim %0, %c0 : tensor - %2 = memref.dim %0, %c1 : tensor - return %1, %2: index, index -} - -// ----- - -// CHECK-LABEL: func @tensor_cast_to_memref -// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8> -// CHECK: %[[M:.+]] = memref.buffer_cast %[[ARG0]] : memref<4x6x16x32xi8> -// CHECK: %[[M1:.+]] = memref.cast %[[M]] : memref<4x6x16x32xi8> to memref -// CHECK: return %[[M1]] : memref -func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) -> - memref { - %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor - %1 = memref.buffer_cast %0 : memref - return %1 : memref -} - -// ----- - func @subtensor_canonicalize(%arg0 : tensor, %arg1 : index, %arg2 : index) -> tensor { @@ -345,29 +195,6 @@ // ----- -func @subtensor_insert_propagate_dest_cast(%arg0 : tensor<2x?xi32>, %arg1 : tensor, - %arg2 : index, %arg3 : index) -> tensor { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c2 = constant 2 : index - %c8 = constant 8 : index - %0 = memref.dim %arg0, %c1 : tensor<2x?xi32> - %1 = tensor.extract %arg1[] : tensor - %2 = tensor.generate %arg2, %c8 { - ^bb0(%arg4: index, %arg5: index): - tensor.yield %1 : i32 - } : tensor - %3 = subtensor_insert %arg0 into %2[%c0, %arg3] [%c2, %0] [%c1, %c1] : tensor<2x?xi32> into tensor - return %3 : tensor -} -// CHECK-LABEL: func @subtensor_insert_propagate_dest_cast -// CHECK: %[[UPDATED:.+]] = subtensor_insert %{{.+}} into %{{.+}}[0, %{{.+}}] [2, %{{.+}}] [1, 1] -// CHECK-SAME: tensor<2x?xi32> into tensor -// CHECK: %[[CAST:.+]] = tensor.cast %[[UPDATED]] -// CHECK: return %[[CAST]] - -// ----- - func @subtensor_insert_output_dest_canonicalize(%arg0 : tensor<2x3xi32>, %arg1 : tensor) -> tensor<3x9xi32> { %c0 = constant 0 : index %c1 = constant 1 : index