diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -1174,17 +1174,23 @@ SMLoc loc = getToken().getLoc(); auto emitWrongTokenError = [&] { - emitError(loc, "expected a non-negative 64-bit signed integer or '?'"); + emitError(loc, "expected a 64-bit signed integer or '?'"); return llvm::None; }; + bool negative = consumeIf(Token::minus); + if (getToken().is(Token::integer)) { Optional value = getToken().getUInt64IntegerValue(); if (!value || *value > static_cast(std::numeric_limits::max())) return emitWrongTokenError(); consumeToken(); - return static_cast(*value); + auto result = static_cast(*value); + if (negative) + result = -result; + + return result; } return emitWrongTokenError(); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -269,14 +269,9 @@ LogicalResult StridedLayoutAttr::verify(function_ref emitError, int64_t offset, ArrayRef strides) { - if (offset < 0 && offset != ShapedType::kDynamicStrideOrOffset) - return emitError() << "offset must be non-negative or dynamic"; + if (llvm::any_of(strides, [&](int64_t stride) { return stride == 0; })) + return emitError() << "strides must not be zero"; - if (llvm::any_of(strides, [&](int64_t stride) { - return stride <= 0 && stride != ShapedType::kDynamicStrideOrOffset; - })) { - return emitError() << "strides must be positive or dynamic"; - } return success(); } diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -540,6 +540,31 @@ // ----- +// CHECK-LABEL: func @subview_negative_stride +// CHECK-SAME: (%[[ARG:.*]]: memref<7xf32>) +func.func @subview_negative_stride(%arg0 : memref<7xf32>) -> memref<7xf32, strided<[-1], offset: 6>> { + // CHECK: %[[ORIG:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<7xf32> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[NEW1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[PTR1:.*]] = llvm.extractvalue %[[ORIG]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[PTR2:.*]] = llvm.bitcast %[[PTR1]] : !llvm.ptr to !llvm.ptr + // CHECK: %[[NEW2:.*]] = llvm.insertvalue %[[PTR2]], %[[NEW1]][0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[PTR3:.*]] = llvm.extractvalue %[[ORIG]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[PTR4:.*]] = llvm.bitcast %[[PTR3]] : !llvm.ptr to !llvm.ptr + // CHECK: %[[NEW3:.*]] = llvm.insertvalue %[[PTR4]], %[[NEW2]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(6 : index) : i64 + // CHECK: %[[NEW4:.*]] = llvm.insertvalue %[[OFFSET]], %[[NEW3]][2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[SIZE:.*]] = llvm.mlir.constant(7 : i64) : i64 + // CHECK: %[[STRIDE:.*]] = llvm.mlir.constant(-1 : i64) : i64 + // CHECK: %[[NEW5:.*]] = llvm.insertvalue %[[SIZE]], %[[NEW4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[NEW6:.*]] = llvm.insertvalue %[[STRIDE]], %[[NEW5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[NEW6]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<7xf32, strided<[-1], offset: 6>> + // CHECK: return %[[RES]] : memref<7xf32, strided<[-1], offset: 6>> + %0 = memref.subview %arg0[6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>> + return %0 : memref<7xf32, strided<[-1], offset: 6>> +} + +// ----- + // CHECK-LABEL: func @assume_alignment func.func @assume_alignment(%0 : memref<4x4xf16>) { // CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> diff --git a/mlir/test/Dialect/Builtin/types.mlir b/mlir/test/Dialect/Builtin/types.mlir --- a/mlir/test/Dialect/Builtin/types.mlir +++ b/mlir/test/Dialect/Builtin/types.mlir @@ -16,3 +16,7 @@ func.func private @f7() -> memref> // CHECK: memref> func.func private @f8() -> memref> +// CHECK: memref> +func.func private @f9() -> memref> +// CHECK: memref> +func.func private @f10() -> memref> diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -127,6 +127,43 @@ // CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1] // CHECK-SAME: : memref<1x?xf32, strided<[?, ?], offset: ?>> to memref> +// ----- + +func.func @subview_negative_stride1(%arg0 : memref) -> memref> +{ + %c0 = arith.constant 0 : index + %c1 = arith.constant -1 : index + %1 = memref.dim %arg0, %c0 : memref + %2 = arith.addi %1, %c1 : index + %3 = memref.subview %arg0[%2] [%1] [%c1] : memref to memref> + return %3 : memref> +} +// CHECK: func @subview_negative_stride1 +// CHECK-SAME: (%[[ARG0:.*]]: memref) +// CHECK: %[[C1:.*]] = arith.constant 0 +// CHECK: %[[C2:.*]] = arith.constant -1 +// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref +// CHECK: %[[DIM2:.*]] = arith.addi %[[DIM1]], %[[C2]] : index +// CHECK: %[[RES1:.*]] = memref.subview %[[ARG0]][%[[DIM2]]] [%[[DIM1]]] [-1] : memref to memref> +// CHECK: %[[RES2:.*]] = memref.cast %[[RES1]] : memref> to memref> +// CHECK: return %[[RES2]] : memref> + +// ----- + +func.func @subview_negative_stride2(%arg0 : memref<7xf32>) -> memref> +{ + %c0 = arith.constant 0 : index + %c1 = arith.constant -1 : index + %1 = memref.dim %arg0, %c0 : memref<7xf32> + %2 = arith.addi %1, %c1 : index + %3 = memref.subview %arg0[%2] [%1] [%c1] : memref<7xf32> to memref> + return %3 : memref> +} +// CHECK: func @subview_negative_stride2 +// CHECK-SAME: (%[[ARG0:.*]]: memref<7xf32>) +// CHECK: %[[RES1:.*]] = memref.subview %[[ARG0]][6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>> +// CHECK: %[[RES2:.*]] = memref.cast %[[RES1]] : memref<7xf32, strided<[-1], offset: 6>> to memref> +// CHECK: return %[[RES2]] : memref> // ----- diff --git a/mlir/test/IR/invalid-builtin-types.mlir b/mlir/test/IR/invalid-builtin-types.mlir --- a/mlir/test/IR/invalid-builtin-types.mlir +++ b/mlir/test/IR/invalid-builtin-types.mlir @@ -74,7 +74,7 @@ // ----- -// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}} +// expected-error @below {{expected a 64-bit signed integer or '?'}} func.func private @memref_unfinished_stride_list() -> memref> // ----- @@ -89,7 +89,7 @@ // ----- -// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}} +// expected-error @below {{expected a 64-bit signed integer or '?'}} func.func private @memref_missing_offset_value() -> memref> // ----- @@ -99,21 +99,11 @@ // ----- -// expected-error @below {{strides must be positive or dynamic}} +// expected-error @below {{strides must not be zero}} func.func private @memref_zero_stride() -> memref> // ----- -// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}} -func.func private @memref_negative_stride() -> memref> - -// ----- - -// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}} -func.func private @memref_negative_offset() -> memref> - -// ----- - // expected-error @below {{expected the number of strides to match the rank}} func.func private @memref_strided_rank_mismatch() -> memref>