diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp @@ -91,11 +91,19 @@ Location loc = op.getLoc(); Value stride = rewriter.create(loc, 1); for (int i = rank - 1; i >= 0; --i) { - Value index = rewriter.create(loc, i); - Value size = rewriter.create(loc, op.shape(), index); - if (!size.getType().isa()) - size = rewriter.create(loc, size, rewriter.getIndexType()); - sizes[i] = size; + Value size; + // Load dynamic sizes from the shape input, use constants for static dims. + if (op.getType().isDynamicDim(i)) { + Value index = rewriter.create(loc, i); + size = rewriter.create(loc, op.shape(), index); + if (!size.getType().isa()) + size = + rewriter.create(loc, size, rewriter.getIndexType()); + sizes[i] = size; + } else { + sizes[i] = rewriter.getIndexAttr(op.getType().getDimSize(i)); + size = rewriter.create(loc, sizes[i].get()); + } strides[i] = stride; if (i > 0) stride = rewriter.create(loc, stride, size); diff --git a/mlir/test/Dialect/Standard/expand-ops.mlir b/mlir/test/Dialect/Standard/expand-ops.mlir --- a/mlir/test/Dialect/Standard/expand-ops.mlir +++ b/mlir/test/Dialect/Standard/expand-ops.mlir @@ -84,19 +84,17 @@ // CHECK-LABEL: func @memref_reshape( func @memref_reshape(%input: memref<*xf32>, - %shape: memref<3xi32>) -> memref { + %shape: memref<3xi32>) -> memref { %result = memref.reshape %input(%shape) - : (memref<*xf32>, memref<3xi32>) -> memref - return %result : memref + : (memref<*xf32>, memref<3xi32>) -> memref + return %result : memref } // CHECK-SAME: [[SRC:%.*]]: memref<*xf32>, -// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref { +// CHECK-SAME: [[SHAPE:%.*]]: memref<3xi32>) -> memref { // CHECK: [[C1:%.*]] = constant 1 : index -// CHECK: [[C2:%.*]] = constant 2 : index -// CHECK: [[DIM_2:%.*]] = memref.load [[SHAPE]]{{\[}}[[C2]]] : memref<3xi32> -// CHECK: [[SIZE_2:%.*]] = index_cast [[DIM_2]] : i32 to index -// CHECK: [[STRIDE_1:%.*]] = muli [[C1]], [[SIZE_2]] : index +// CHECK: [[C8:%.*]] = constant 8 : index +// CHECK: [[STRIDE_1:%.*]] = muli [[C1]], [[C8]] : index // CHECK: [[C1_:%.*]] = constant 1 : index // CHECK: [[DIM_1:%.*]] = memref.load [[SHAPE]]{{\[}}[[C1_]]] : memref<3xi32> @@ -108,6 +106,6 @@ // CHECK: [[SIZE_0:%.*]] = index_cast [[DIM_0]] : i32 to index // CHECK: [[RESULT:%.*]] = memref.reinterpret_cast [[SRC]] -// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], [[SIZE_2]]], +// CHECK-SAME: to offset: [0], sizes: {{\[}}[[SIZE_0]], [[SIZE_1]], 8], // CHECK-SAME: strides: {{\[}}[[STRIDE_0]], [[STRIDE_1]], [[C1]]] -// CHECK-SAME: : memref<*xf32> to memref +// CHECK-SAME: : memref<*xf32> to memref