diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir @@ -1,61 +1,37 @@ -// RUN: mlir-opt -legalize-std-for-spirv -convert-std-to-spirv %s -o - | FileCheck %s +// RUN: mlir-opt -legalize-std-for-spirv %s -o - | FileCheck %s -// TODO: For these examples running these passes separately produces -// the desired output. Adding all of patterns within a single pass does -// not seem to work. - -module attributes { - spv.target_env = #spv.target_env< - #spv.vce, - {max_compute_workgroup_invocations = 128 : i32, - max_compute_workgroup_size = dense<[128, 128, 64]> : vector<3xi32>}> -} { +module { //===----------------------------------------------------------------------===// // std.subview //===----------------------------------------------------------------------===// -// CHECK-LABEL: @fold_static_stride_subview_with_load -// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32 -func @fold_static_stride_subview_with_load(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index) { - // CHECK: [[C2:%.*]] = spv.constant 2 - // CHECK: [[C3:%.*]] = spv.constant 3 - // CHECK: [[T2:%.*]] = spv.IMul [[ARG3]], [[C2]] - // CHECK: [[T3:%.*]] = spv.IAdd [[ARG1]], [[T2]] - // CHECK: [[T4:%.*]] = spv.IMul [[ARG4]], [[C3]] - // CHECK: [[T5:%.*]] = spv.IAdd [[ARG2]], [[T4]] - // CHECK: [[C32:%.*]] = spv.constant 32 - // CHECK: [[T7:%.*]] = spv.IMul [[C32]], [[T3]] - // CHECK: [[C1:%.*]] = spv.constant 1 - // CHECK: [[T9:%.*]] = spv.IMul [[C1]], [[T5]] - // CHECK: [[T10:%.*]] = spv.IAdd [[T7]], [[T9]] - // CHECK: [[C0:%.*]] = spv.constant 0 - // CHECK: [[T12:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[C0]], [[T10]] - // CHECK: spv.Load "StorageBuffer" [[T12]] : f32 +// CHECK-LABEL: @fold_static_stride_subview +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref<12x32xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]*]]: index +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]*]]: index +func @fold_static_stride_subview + (%arg0 : memref<12x32xf32>, %arg1 : index, + %arg2 : index, %arg3 : index, %arg4 : index) { + // CHECK: %[[C2:.*]] = constant 2 + // CHECK: %[[C3:.*]] = constant 3 + // CHECK: %[[T0:.*]] = muli %[[ARG3]], %[[C2]] + // CHECK: %[[T1:.*]] = addi %[[ARG1]], %[[T0]] + // CHECK: %[[T2:.*]] = muli %[[ARG4]], %[[C3]] + // CHECK: %[[T3:.*]] = addi %[[ARG2]], %[[T2]] + // CHECK: %[[LOADVAL:.*]] = load %[[ARG0]][%[[T1]], %[[T3]]] + // CHECK: %[[STOREVAL:.*]] = sqrt %[[LOADVAL]] + // CHECK: %[[T6:.*]] = muli %[[ARG3]], %[[C2]] + // CHECK: %[[T7:.*]] = addi %[[ARG1]], %[[T6]] + // CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[C3]] + // CHECK: %[[T9:.*]] = addi %[[ARG2]], %[[T8]] + // CHECK store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]] %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> - return -} - -// CHECK-LABEL: @fold_static_stride_subview_with_store -// CHECK-SAME: [[ARG0:%.*]]: !spv.ptr [0]>, StorageBuffer>, [[ARG1:%.*]]: i32, [[ARG2:%.*]]: i32, [[ARG3:%.*]]: i32, [[ARG4:%.*]]: i32, [[ARG5:%.*]]: f32 -func @fold_static_stride_subview_with_store(%arg0 : memref<12x32xf32>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : f32) { - // CHECK: [[C2:%.*]] = spv.constant 2 - // CHECK: [[C3:%.*]] = spv.constant 3 - // CHECK: [[T2:%.*]] = spv.IMul [[ARG3]], [[C2]] - // CHECK: [[T3:%.*]] = spv.IAdd [[ARG1]], [[T2]] - // CHECK: [[T4:%.*]] = spv.IMul [[ARG4]], [[C3]] - // CHECK: [[T5:%.*]] = spv.IAdd [[ARG2]], [[T4]] - // CHECK: [[C32:%.*]] = spv.constant 32 - // CHECK: [[T7:%.*]] = spv.IMul [[C32]], [[T3]] - // CHECK: [[C1:%.*]] = spv.constant 1 - // CHECK: [[T9:%.*]] = spv.IMul [[C1]], [[T5]] - // CHECK: [[T10:%.*]] = spv.IAdd [[T7]], [[T9]] - // CHECK: [[C0:%.*]] = spv.constant 0 - // CHECK: [[T12:%.*]] = spv.AccessChain [[ARG0]]{{\[}}[[C0]], [[T10]] - // CHECK: spv.Store "StorageBuffer" [[T12]], [[ARG5]] : f32 - %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> - store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> + %2 = sqrt %1 : f32 + store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> return }